Commit fae09c2c authored by baberabb's avatar baberabb
Browse files

updated test_tasks

parent 6862fa7d
import lm_eval.tasks as tasks import lm_eval.tasks as tasks
import pytest import pytest
from itertools import islice from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) @pytest.fixture()
def test_basic_interface(taskname, task_class): def task_class(task_name="arc_easy"):
print("Evaluating task", taskname) return next(
task = task_class() (name, cls) for name, cls in tasks.TASK_REGISTRY.items() if name == task_name
)[1]
@pytest.fixture()
def limit(limit=10):
return limit
assert task.has_training_docs() in [True, False] def test_download(task_class):
assert task.has_validation_docs() in [True, False] task_class().download()
assert task.has_test_docs() in [True, False] assert task_class.dataset is not None
assert isinstance(task.aggregation(), dict)
assert isinstance(task.higher_is_better(), dict)
assert task.aggregation().keys() == task.higher_is_better().keys()
for v in task.higher_is_better().values(): def test_has_training_docs(task_class):
assert v in [True, False] assert task_class().has_training_docs() in [True, False]
assert isinstance(task.VERSION, int)
# test deterministic docs def test_check_training_docs(task_class):
# (don't test train because it's slow) assert task_class().has_training_docs()
task2 = task_class()
limit = None def test_has_validation_docs(task_class):
assert task_class().has_training_docs() in [True, False]
if taskname in ["triviaqa"] or taskname.startswith("pile_"):
limit = 10000
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit))
assert arr == arr2 def test_check_validation_docs(task_class):
assert task_class().has_training_docs()
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 def test_has_test_docs(task_class):
assert task_class.has_training_docs() in [True, False]
if task.has_test_docs():
arr = list(islice(task.test_docs(), limit))
arr2 = list(islice(task2.test_docs(), limit))
assert arr == arr2 def test_check_test_docs(task_class):
assert task_class.has_training_docs()
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 def test_should_decontaminate(task_class):
assert task_class.should_decontaminate() in [True, False]
if task.has_training_docs():
arr = list(islice(task.training_docs(), limit))
arr2 = list(islice(task2.training_docs(), limit))
assert arr == arr2 def test_doc_to_text(task_class, limit):
task = task_class()
arr = list(islice(task.test_docs(), limit)) if limit else list(task.test_docs())
_array = [task.doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
assert all(
isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
)
def test_doc_to_target(task_class, limit):
task = task_class()
arr = list(islice(task.test_docs(), limit)) if limit else list(task.test_target())
_array_target = [task.doc_to_target(doc) for doc in arr]
assert all(isinstance(doc, str) for doc in _array_target)
# _array_text = [task.doc_to_text(doc) for doc in arr]
# Not working
# assert all(tgt[0] == " " or txt[-1] == "\n" if len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2 def test_build_all_requests(task_class, limit):
task_class().build_all_requests(rank=1, limit=limit, world_size=1)
assert task_class.instances is not None
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items()) def test_construct_requests(task_class, limit):
def test_documents_and_requests(taskname, task_class):
print("Evaluating task", taskname)
task = task_class() task = task_class()
fns = [] arr = list(islice(task.test_docs(), limit)) if limit else list(task.test_docs())
if task.has_training_docs(): requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
fns.append(task.training_docs) assert all(isinstance(doc, list) for doc in requests)
if task.has_validation_docs(): assert len(requests) == limit if limit else True
fns.append(task.validation_docs)
# test doc might not have labels
# if task.has_test_docs(): fns.append(task.test_docs) def test_create_choices(task_class):
arr = list(islice(task_class().test_docs(), 1))
for fn in fns: choices = task_class().create_choices(arr[0])
# print(list(islice(fn(), 10))) assert choices is not None
for doc in islice(fn(), 10): # checking if number of choices is correct
txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc) # @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_basic_interface(taskname, task_class):
assert isinstance(txt, str) # print("Evaluating task", taskname)
assert isinstance(tgt, str) # task = task_class()
#
# space convention # assert task.has_training_docs() in [True, False]
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on # assert task.has_validation_docs() in [True, False]
if len(txt) != 0: # assert task.has_test_docs() in [True, False]
assert txt[-1] != " " #
assert tgt[0] == " " or txt[-1] == "\n" # assert isinstance(task.aggregation(), dict)
# assert isinstance(task.higher_is_better(), dict)
reqs = task.construct_requests(doc, txt) # assert task.aggregation().keys() == task.higher_is_better().keys()
#
# construct_requests can return just one request # for v in task.higher_is_better().values():
if not isinstance(reqs, (list, tuple)): # assert v in [True, False]
reqs = [reqs] #
# assert isinstance(task.VERSION, int)
# todo: mock lm after refactoring evaluator.py to not be a mess #
# for req in reqs: # # test deterministic docs
# assert isinstance(req, base.Request) # # (don't test train because it's slow)
#
# task2 = task_class()
#
# limit = None
#
# if taskname in ["triviaqa"] or taskname.startswith("pile_"):
# limit = 10000
# if task.has_validation_docs():
# arr = list(islice(task.validation_docs(), limit))
# arr2 = list(islice(task2.validation_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
# if task.has_test_docs():
# arr = list(islice(task.test_docs(), limit))
# arr2 = list(islice(task2.test_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
# if task.has_training_docs():
# arr = list(islice(task.training_docs(), limit))
# arr2 = list(islice(task2.training_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
#
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_documents_and_requests(taskname, task_class):
# print("Evaluating task", taskname)
# task = task_class()
# fns = []
# if task.has_training_docs():
# fns.append(task.training_docs)
# if task.has_validation_docs():
# fns.append(task.validation_docs)
# # test doc might not have labels
# # if task.has_test_docs(): fns.append(task.test_docs)
#
# for fn in fns:
# # print(list(islice(fn(), 10)))
# for doc in islice(fn(), 10):
#
# txt = task.doc_to_text(doc)
# tgt = task.doc_to_target(doc)
#
# assert isinstance(txt, str)
# assert isinstance(tgt, str)
#
# # space convention
# # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
# if len(txt) != 0:
# assert txt[-1] != " "
# assert tgt[0] == " " or txt[-1] == "\n"
#
# reqs = task.construct_requests(doc, txt)
#
# # construct_requests can return just one request
# if not isinstance(reqs, (list, tuple)):
# reqs = [reqs]
#
# # todo: mock lm after refactoring evaluator.py to not be a mess
# # for req in reqs:
# # assert isinstance(req, base.Request)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment