Commit 442f843c authored by Leo Gao's avatar Leo Gao
Browse files

Add tests for {val,test}_docs determinism

parent 77d4b087
...@@ -22,6 +22,33 @@ def test_basic_interface(taskname, Task): ...@@ -22,6 +22,33 @@ def test_basic_interface(taskname, Task):
for v in task.higher_is_better().values(): assert v in [True, False] for v in task.higher_is_better().values(): assert v in [True, False]
# test deterministic docs
# (don't test train because it's slow)
task2 = Task()
if task.has_validation_docs():
arr = list(islice(task.validation_docs(), 100))
arr2 = list(islice(task2.validation_docs(), 100))
assert arr == arr2
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
if task.has_test_docs():
arr = list(islice(task.test_docs(), 100))
arr2 = list(islice(task2.test_docs(), 100))
assert arr == arr2
reqs = [task.construct_requests(doc, task.doc_to_text(doc)) for doc in arr]
reqs2 = [task2.construct_requests(doc, task2.doc_to_text(doc)) for doc in arr2]
assert reqs == reqs2
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items()) @pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, Task): def test_documents_and_requests(taskname, Task):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment