Commit fae09c2c authored by baberabb's avatar baberabb
Browse files

updated test_tasks

parent 6862fa7d
import lm_eval.tasks as tasks
import pytest
from itertools import islice
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, task_class):
print("Evaluating task", taskname)
task = task_class()
@pytest.fixture()
def task_class(task_name="arc_easy"):
return next(
(name, cls) for name, cls in tasks.TASK_REGISTRY.items() if name == task_name
)[1]
@pytest.fixture()
def limit(limit=10):
return limit
assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False]
assert task.has_test_docs() in [True, False]
def test_download(task_class):
task_class().download()
assert task_class.dataset is not None
assert isinstance(task.aggregation(), dict)
assert isinstance(task.higher_is_better(), dict)
assert task.aggregation().keys() == task.higher_is_better().keys()
for v in task.higher_is_better().values():
assert v in [True, False]
def test_has_training_docs(task_class):
assert task_class().has_training_docs() in [True, False]
assert isinstance(task.VERSION, int)
# test deterministic docs
# (don't test train because it's slow)
def test_check_training_docs(task_class):
assert task_class().has_training_docs()
task2 = task_class()
limit = None
def test_has_validation_docs(task_class):
assert task_class().has_training_docs() in [True, False]
if taskname in ["triviaqa"] or taskname.startswith("pile_"):
limit = 10000
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), limit))
arr2 = list(islice(task2.validation_docs(), limit))
assert arr == arr2
def test_check_validation_docs(task_class):
assert task_class().has_training_docs()
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
def test_has_test_docs(task_class):
assert task_class.has_training_docs() in [True, False]
if task.has_test_docs():
arr = list(islice(task.test_docs(), limit))
arr2 = list(islice(task2.test_docs(), limit))
assert arr == arr2
def test_check_test_docs(task_class):
assert task_class.has_training_docs()
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
def test_should_decontaminate(task_class):
assert task_class.should_decontaminate() in [True, False]
if task.has_training_docs():
arr = list(islice(task.training_docs(), limit))
arr2 = list(islice(task2.training_docs(), limit))
assert arr == arr2
def test_doc_to_text(task_class, limit):
task = task_class()
arr = list(islice(task.test_docs(), limit)) if limit else list(task.test_docs())
_array = [task.doc_to_text(doc) for doc in arr]
# space convention; allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
assert all(
isinstance(x, str) and (x[-1] != " " if len(x) != 0 else True) for x in _array
)
def test_doc_to_target(task_class, limit):
task = task_class()
arr = list(islice(task.test_docs(), limit)) if limit else list(task.test_target())
_array_target = [task.doc_to_target(doc) for doc in arr]
assert all(isinstance(doc, str) for doc in _array_target)
# _array_text = [task.doc_to_text(doc) for doc in arr]
# Not working
# assert all(tgt[0] == " " or txt[-1] == "\n" if len(txt) != 0 else True for txt, tgt in zip(_array_text, _array_target))
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
def test_build_all_requests(task_class, limit):
task_class().build_all_requests(rank=1, limit=limit, world_size=1)
assert task_class.instances is not None
@pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, task_class):
print("Evaluating task", taskname)
def test_construct_requests(task_class, limit):
task = task_class()
fns = []
if task.has_training_docs():
fns.append(task.training_docs)
if task.has_validation_docs():
fns.append(task.validation_docs)
# test doc might not have labels
# if task.has_test_docs(): fns.append(task.test_docs)
for fn in fns:
# print(list(islice(fn(), 10)))
for doc in islice(fn(), 10):
txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc)
assert isinstance(txt, str)
assert isinstance(tgt, str)
# space convention
# allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
if len(txt) != 0:
assert txt[-1] != " "
assert tgt[0] == " " or txt[-1] == "\n"
reqs = task.construct_requests(doc, txt)
# construct_requests can return just one request
if not isinstance(reqs, (list, tuple)):
reqs = [reqs]
# todo: mock lm after refactoring evaluator.py to not be a mess
# for req in reqs:
# assert isinstance(req, base.Request)
arr = list(islice(task.test_docs(), limit)) if limit else list(task.test_docs())
requests = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
assert all(isinstance(doc, list) for doc in requests)
assert len(requests) == limit if limit else True
def test_create_choices(task_class):
arr = list(islice(task_class().test_docs(), 1))
choices = task_class().create_choices(arr[0])
assert choices is not None
# checking if number of choices is correct
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_basic_interface(taskname, task_class):
# print("Evaluating task", taskname)
# task = task_class()
#
# assert task.has_training_docs() in [True, False]
# assert task.has_validation_docs() in [True, False]
# assert task.has_test_docs() in [True, False]
#
# assert isinstance(task.aggregation(), dict)
# assert isinstance(task.higher_is_better(), dict)
# assert task.aggregation().keys() == task.higher_is_better().keys()
#
# for v in task.higher_is_better().values():
# assert v in [True, False]
#
# assert isinstance(task.VERSION, int)
#
# # test deterministic docs
# # (don't test train because it's slow)
#
# task2 = task_class()
#
# limit = None
#
# if taskname in ["triviaqa"] or taskname.startswith("pile_"):
# limit = 10000
# if task.has_validation_docs():
# arr = list(islice(task.validation_docs(), limit))
# arr2 = list(islice(task2.validation_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
# if task.has_test_docs():
# arr = list(islice(task.test_docs(), limit))
# arr2 = list(islice(task2.test_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
# if task.has_training_docs():
# arr = list(islice(task.training_docs(), limit))
# arr2 = list(islice(task2.training_docs(), limit))
#
# assert arr == arr2
#
# reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
# reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
#
# assert reqs == reqs2
#
#
# @pytest.mark.parametrize("taskname,task_class", tasks.TASK_REGISTRY.items())
# def test_documents_and_requests(taskname, task_class):
# print("Evaluating task", taskname)
# task = task_class()
# fns = []
# if task.has_training_docs():
# fns.append(task.training_docs)
# if task.has_validation_docs():
# fns.append(task.validation_docs)
# # test doc might not have labels
# # if task.has_test_docs(): fns.append(task.test_docs)
#
# for fn in fns:
# # print(list(islice(fn(), 10)))
# for doc in islice(fn(), 10):
#
# txt = task.doc_to_text(doc)
# tgt = task.doc_to_target(doc)
#
# assert isinstance(txt, str)
# assert isinstance(tgt, str)
#
# # space convention
# # allow txt to have length 0 for perplexity-like tasks since the model tacks an <|endoftext|> on
# if len(txt) != 0:
# assert txt[-1] != " "
# assert tgt[0] == " " or txt[-1] == "\n"
#
# reqs = task.construct_requests(doc, txt)
#
# # construct_requests can return just one request
# if not isinstance(reqs, (list, tuple)):
# reqs = [reqs]
#
# # todo: mock lm after refactoring evaluator.py to not be a mess
# # for req in reqs:
# # assert isinstance(req, base.Request)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment