Unverified Commit 269d3683 authored by Leo Gao's avatar Leo Gao Committed by GitHub
Browse files

Merge branch 'master' into webqs

parents 34eb121f a1a4a32e
import os import os
import json import json
import random import random
from lm_eval.base import Dataset, mean, rf from lm_eval.base import Task, mean, rf
from ..utils import sh from ..utils import sh
class TriviaQA(Dataset): class TriviaQA(Task):
def download(self): def download(self):
if not os.path.exists('data/triviaqa'): if not os.path.exists('data/triviaqa'):
sh(""" sh("""
...@@ -21,7 +21,7 @@ class TriviaQA(Dataset): ...@@ -21,7 +21,7 @@ class TriviaQA(Dataset):
return True return True
def has_test_docs(self): def has_test_docs(self):
return True return False
def training_docs(self): def training_docs(self):
return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data'] return json.load(open('data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'))['Data']
...@@ -74,4 +74,4 @@ class TriviaQA(Dataset): ...@@ -74,4 +74,4 @@ class TriviaQA(Dataset):
def higher_is_better(self): def higher_is_better(self):
return { return {
"acc": True "acc": True
} }
\ No newline at end of file
...@@ -14,9 +14,11 @@ class WikiText103(NLP_TASK): ...@@ -14,9 +14,11 @@ class WikiText103(NLP_TASK):
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: implement # TODO: implement
pass
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: implement # TODO: implement
pass
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -74,9 +76,11 @@ class WikiText2(NLP_TASK): ...@@ -74,9 +76,11 @@ class WikiText2(NLP_TASK):
def doc_to_text(self, doc): def doc_to_text(self, doc):
# TODO: implement # TODO: implement
pass
def doc_to_target(self, doc): def doc_to_target(self, doc):
# TODO: implement # TODO: implement
pass
def construct_requests(self, doc, ctx): def construct_requests(self, doc, ctx):
""" Uses RequestFactory to construct Requests and returns an iterable of """ Uses RequestFactory to construct Requests and returns an iterable of
...@@ -121,4 +125,4 @@ class WikiText2(NLP_TASK): ...@@ -121,4 +125,4 @@ class WikiText2(NLP_TASK):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
...@@ -90,4 +90,4 @@ class Winogrande(HFTask): ...@@ -90,4 +90,4 @@ class Winogrande(HFTask):
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
# TODO: implement evaluation. # TODO: implement evaluation.
raise NotImplementedError('Evaluation not implemented') raise NotImplementedError('Evaluation not implemented')
\ No newline at end of file
import json import json
import random import random
import os import os
from lm_eval.base import Dataset from lm_eval.base import Task
from ..utils import sh from ..utils import sh
class WinogradSchemaChallenge273(Dataset): class WinogradSchemaChallenge273(Task):
def __init__(self): def __init__(self):
super().__init__() super().__init__()
......
...@@ -28,4 +28,4 @@ def simple_parse_args_string(args_string): ...@@ -28,4 +28,4 @@ def simple_parse_args_string(args_string):
def join_iters(iters): def join_iters(iters):
for iter in iters: for iter in iters:
yield from iter yield from iter
\ No newline at end of file
...@@ -5,11 +5,13 @@ from tqdm import tqdm ...@@ -5,11 +5,13 @@ from tqdm import tqdm
import json import json
class ExitCodeError(Exception): pass class ExitCodeError(Exception):
pass
def sh(x): def sh(x):
if os.system(x): raise ExitCodeError() if os.system(x):
raise ExitCodeError()
def ls(x): def ls(x):
return [x + '/' + fn for fn in os.listdir(x)] return [x + '/' + fn for fn in os.listdir(x)]
...@@ -64,7 +66,8 @@ class join: ...@@ -64,7 +66,8 @@ class join:
self.sep = sep self.sep = sep
def __rrshift__(self, other): def __rrshift__(self, other):
if other is None: return if other is None:
return
try: try:
return self.sep.join(other) return self.sep.join(other)
except: except:
...@@ -156,4 +159,4 @@ def comp(*fs): ...@@ -156,4 +159,4 @@ def comp(*fs):
return _f return _f
X = Reflective() X = Reflective()
\ No newline at end of file
import setuptools
with open("README.md", "r", encoding="utf-8") as fh:
long_description = fh.read()
setuptools.setup(
name="lm_eval_harness",
version="0.0.1",
author="Leo Gao",
author_email="lg@eleuther.ai",
description="A framework for evaluating autoregressive language models",
long_description=long_description,
long_description_content_type="text/markdown",
url="https://github.com/EleutherAI/lm-evaluation-harness",
packages=setuptools.find_packages(),
classifiers=[
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
],
python_requires='>=3.6',
)
import lm_eval.tasks as tasks
import lm_eval.base as base
from unittest.mock import MagicMock
from itertools import islice
import pytest
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_basic_interface(taskname, Task):
print('Evaluating task', taskname)
dl = Task.download
Task.download = MagicMock()
task = Task()
Task.download = dl
assert task.has_training_docs() in [True, False]
assert task.has_validation_docs() in [True, False]
assert task.has_test_docs() in [True, False]
assert isinstance(task.aggregation(), dict)
assert isinstance(task.higher_is_better(), dict)
assert task.aggregation().keys() == task.higher_is_better().keys()
for v in task.higher_is_better().values(): assert v in [True, False]
@pytest.mark.parametrize("taskname,Task", tasks.TASK_REGISTRY.items())
def test_documents_and_requests(taskname, Task):
print('Evaluating task', taskname)
task = Task()
fns = []
if task.has_training_docs(): fns.append(task.training_docs)
if task.has_validation_docs(): fns.append(task.validation_docs)
# test doce might not have labels
#if task.has_test_docs(): fns.append(task.test_docs)
for fn in fns:
#print(list(islice(fn(), 10)))
for doc in islice(fn(), 10):
txt = task.doc_to_text(doc)
tgt = task.doc_to_target(doc)
assert isinstance(txt, str)
assert isinstance(tgt, str)
reqs = task.construct_requests(doc, txt)
# todo: mock lm by pluggin what's currently in main.py in here
for req in reqs:
assert isinstance(req, base.Request)
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment