Commit 4d8dde31 authored by Stephen Hogg's avatar Stephen Hogg
Browse files

Add integrity check

parent 05590e11
import collections
import itertools
import pathlib
import random
import lm_eval.metrics
import lm_eval.models
import lm_eval.tasks
import lm_eval.base
import numpy as np
from lm_eval.utils import positional_deprecated
from lm_eval.utils import positional_deprecated, run_task_tests
@positional_deprecated
def simple_evaluate(model, model_args=None, tasks=[],
num_fewshot=0, batch_size=None, device=None,
no_cache=False, limit=None, bootstrap_iters=100000,
description_dict=None):
description_dict=None, check_integrity=False):
"""Instantiate and evaluate a model on a list of tasks.
:param model: Union[str, LM]
......@@ -37,6 +38,8 @@ def simple_evaluate(model, model_args=None, tasks=[],
Number of iterations for bootstrap statistics
:param description_dict: dict[str, str]
Dictionary of custom task descriptions of the form: `task_name: description`
:param check_integrity: bool
Whether to run the relevant part of the test suite for the tasks
:return
Dictionary of results
"""
......@@ -61,6 +64,9 @@ def simple_evaluate(model, model_args=None, tasks=[],
task_dict = lm_eval.tasks.get_task_dict(tasks)
if check_integrity:
run_task_tests(start_path=pathlib.Path(__file__), task_list=tasks)
results = evaluate(
lm=lm,
task_dict=task_dict,
......
import os
import pathlib
import re
import collections
import functools
import inspect
import sys
import pytest
from typing import List
class ExitCodeError(Exception):
......@@ -155,3 +159,32 @@ def positional_deprecated(fn):
"lm-evaluation-harness!")
return fn(*args, **kwargs)
return _wrapper
@positional_deprecated
def find_test_root(start_path: pathlib.Path) -> pathlib.Path:
"""
Search upward in the directory tree to a maximum of three layers
to find and return the package root (containing the 'tests' folder)
"""
cur_path = start_path.resolve()
max_layers = 3
for _ in range(max_layers):
if (cur_path / 'tests' / 'test_version_stable.py').exists():
return cur_path
else:
cur_path = cur_path.parent.resolve()
raise FileNotFoundError(f"Unable to find package root within {max_layers} upwards" +\
f"of {start_path}")
@positional_deprecated
def run_task_tests(start_path: pathlib.Path, task_list: List[str]):
"""
Find the package root and run the tests for the given tasks
"""
package_root = find_test_root(start_path=start_path)
task_string = ' or '.join(task_list)
args = [f'{package_root}/tests/test_version_stable.py', f'--rootdir={package_root}', '-k', f'{task_string}']
sys.path.append(str(package_root))
pytest_return_val = pytest.main(args)
if pytest_return_val:
raise ValueError(f"Not all tests for the specified tasks ({task_list}) ran successfully! Error code: {pytest_return_val}")
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment