Commit d9b547b7 authored by baberabb's avatar baberabb
Browse files

fix test_evaluator.py

parent 7d4e92fa
...@@ -6,7 +6,7 @@ from lm_eval.api.registry import register_model ...@@ -6,7 +6,7 @@ from lm_eval.api.registry import register_model
@register_model("dummy") @register_model("dummy")
class DummyLM(LM): class DummyLM(LM):
def __init__(self): def __init__(self):
pass super().__init__()
@classmethod @classmethod
def create_from_arg_string(cls, arg_string, additional_config=None): def create_from_arg_string(cls, arg_string, additional_config=None):
......
...@@ -14,10 +14,11 @@ import pytest ...@@ -14,10 +14,11 @@ import pytest
# TODO: more fine grained unit tests rather than this big honking integration # TODO: more fine grained unit tests rather than this big honking integration
# test once we break evaluator into smaller, more manageable pieces # test once we break evaluator into smaller, more manageable pieces
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) def test_evaluator():
def test_evaluator(taskname, task_class): TASK = ["arc_easy"]
task_dict = tasks.get_task_dict([taskname]) LIMIT = 10
# task_dict = tasks.get_task_dict(task)
# TODO: re-add cachingLM # TODO: re-add cachingLM
# os.system("rm test_cache.db") # os.system("rm test_cache.db")
...@@ -25,7 +26,7 @@ def test_evaluator(taskname, task_class): ...@@ -25,7 +26,7 @@ def test_evaluator(taskname, task_class):
lm = registry.get_model("dummy")() lm = registry.get_model("dummy")()
def ll_fn(reqs): def ll_fn(reqs):
for ctx, cont in reqs: for ctx, cont in [req.args for req in reqs]:
if len(ctx) == 0: if len(ctx) == 0:
continue continue
# space convention # space convention
...@@ -54,19 +55,16 @@ def test_evaluator(taskname, task_class): ...@@ -54,19 +55,16 @@ def test_evaluator(taskname, task_class):
lm.loglikelihood = ll_fn lm.loglikelihood = ll_fn
lm.loglikelihood_rolling = ll_perp_fn lm.loglikelihood_rolling = ll_perp_fn
limit = 10 e1 = evaluator.simple_evaluate(
e1 = evaluator.evaluate( model="dummy",
lm=lm, tasks=TASK,
task_dict=task_dict, limit=LIMIT,
num_fewshot=0,
limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
) )
e2 = evaluator.evaluate( e2 = evaluator.simple_evaluate(
lm=lm, model="dummy",
task_dict=task_dict, tasks=TASK,
num_fewshot=0, limit=LIMIT,
limit=limit,
bootstrap_iters=10, bootstrap_iters=10,
) )
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment