Commit 4d8dde31 authored by Stephen Hogg's avatar Stephen Hogg
Browse files

Add integrity check

parent 05590e11
import collections import collections
import itertools import itertools
import pathlib
import random import random
import lm_eval.metrics import lm_eval.metrics
import lm_eval.models import lm_eval.models
import lm_eval.tasks import lm_eval.tasks
import lm_eval.base import lm_eval.base
import numpy as np import numpy as np
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated @positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[], def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None, num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000, no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None): description_dict=None, check_integrity=False):
"""Instantiate and evaluate a model on a list of tasks. """Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM] :param model: Union[str, LM]
...@@ -37,6 +38,8 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -37,6 +38,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
Number of iterations for bootstrap statistics Number of iterations for bootstrap statistics
:param description_dict: dict[str, str] :param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description` Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return :return
Dictionary of results Dictionary of results
""" """
...@@ -61,6 +64,9 @@ def simple_evaluate(model, model_args=None, tasks=[], ...@@ -61,6 +64,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict = lm_eval.tasks.get_task_dict(tasks) task_dict = lm_eval.tasks.get_task_dict(tasks)
if check_integrity:
run_task_tests(start_path=pathlib.Path(__file__), task_list=tasks)
results = evaluate( results = evaluate(
lm=lm, lm=lm,
task_dict=task_dict, task_dict=task_dict,
......
import os import os
import pathlib
import re import re
import collections import collections
import functools import functools
import inspect import inspect
import sys
import pytest
from typing import List
class ExitCodeError(Exception): class ExitCodeError(Exception):
...@@ -155,3 +159,32 @@ def positional_deprecated(fn): ...@@ -155,3 +159,32 @@ def positional_deprecated(fn):
"lm-evaluation-harness!") "lm-evaluation-harness!")
return fn(*args, **kwargs) return fn(*args, **kwargs)
return _wrapper return _wrapper
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\
f"of {start_path}")
@positional_deprecated
def run_task_tests(start_path: pathlib.Path, task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
package_root = find_test_root(start_path=start_path)
task_string = ' or '.join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}']
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment