Unverified Commit 93cf95d0 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge pull request #290 from StephenHogg/check_integrity

Add integrity check
parents 3c37ea9c 40c1e05c
import collections
import itertools
import pathlib
import random
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np
from lm_eval.utils import positional_deprecated
from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None):
description_dict=None, check_integrity=False):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
......@@ -37,6 +38,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
Dictionary of results
"""
......@@ -61,6 +64,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict = lm_eval.tasks.get_task_dict(tasks)
if check_integrity:
run_task_tests(task_list=tasks)
results = evaluate(
lm=lm,
task_dict=task_dict,
......
import os
import pathlib
import re
import collections
import functools
import inspect
import sys
import pytest
from typing import List
class ExitCodeError(Exception):
......@@ -155,3 +159,32 @@ def positional_deprecated(fn):
"lm-evaluation-harness!")
return fn(*args, **kwargs)
return _wrapper
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\
f"of {start_path}")
@positional_deprecated
def run_task_tests(task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
package_root = find_test_root(start_path=pathlib.Path(__file__))
task_string = ' or '.join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}']
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}")
\ No newline at end of file
......@@ -20,6 +20,7 @@ def parse_args():
parser.add_argument('--limit', type=int, default=None)
parser.add_argument('--no_cache', action="store_true")
parser.add_argument('--description_dict_path', default=None)
parser.add_argument('--check_integrity', action="store_true")
return parser.parse_args()
......@@ -49,7 +50,8 @@ def main():
device=args.device,
no_cache=args.no_cache,
limit=args.limit,
description_dict=description_dict
description_dict=description_dict,
check_integrity=args.check_integrity
)
dumped = json.dumps(results, indent=2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment