Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
269d3683
Unverified
Commit
269d3683
authored
Feb 02, 2021
by
Leo Gao
Committed by
GitHub
Feb 02, 2021
Browse files
Merge branch 'master' into webqs
parents
34eb121f
a1a4a32e
Changes
28
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
93 additions
and
13 deletions
+93
-13
lm_eval/tasks/triviaqa.py
lm_eval/tasks/triviaqa.py
+4
-4
lm_eval/tasks/wikitext.py
lm_eval/tasks/wikitext.py
+5
-1
lm_eval/tasks/winogrande.py
lm_eval/tasks/winogrande.py
+1
-1
lm_eval/tasks/wsc273.py
lm_eval/tasks/wsc273.py
+2
-2
lm_eval/utils.py
lm_eval/utils.py
+1
-1
lm_eval/utils_stream.py
lm_eval/utils_stream.py
+7
-4
setup.py
setup.py
+22
-0
tests/test_all_sanitycheck.py
tests/test_all_sanitycheck.py
+51
-0
No files found.
lm_eval/tasks/triviaqa.py
View file @
269d3683
import
os
import
json
import
random
from
lm_eval.base
import
Dataset
,
mean
,
rf
from
lm_eval.base
import
Task
,
mean
,
rf
from
..utils
import
sh
class
TriviaQA
(
Dataset
):
class
TriviaQA
(
Task
):
def
download
(
self
):
if
not
os
.
path
.
exists
(
'data/triviaqa'
):
sh
(
"""
...
...
@@ -21,7 +21,7 @@ class TriviaQA(Dataset):
return
True
def
has_test_docs
(
self
):
return
Tru
e
return
Fals
e
def
training_docs
(
self
):
return
json
.
load
(
open
(
'data/triviaqa/triviaqa-unfiltered/unfiltered-web-train.json'
))[
'Data'
]
...
...
@@ -74,4 +74,4 @@ class TriviaQA(Dataset):
def
higher_is_better
(
self
):
return
{
"acc"
:
True
}
\ No newline at end of file
}
lm_eval/tasks/wikitext.py
View file @
269d3683
...
...
@@ -14,9 +14,11 @@ class WikiText103(NLP_TASK):
def
doc_to_text
(
self
,
doc
):
# TODO: implement
pass
def
doc_to_target
(
self
,
doc
):
# TODO: implement
pass
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -74,9 +76,11 @@ class WikiText2(NLP_TASK):
def
doc_to_text
(
self
,
doc
):
# TODO: implement
pass
def
doc_to_target
(
self
,
doc
):
# TODO: implement
pass
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
...
...
@@ -121,4 +125,4 @@ class WikiText2(NLP_TASK):
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
\ No newline at end of file
raise
NotImplementedError
(
'Evaluation not implemented'
)
lm_eval/tasks/winogrande.py
View file @
269d3683
...
...
@@ -90,4 +90,4 @@ class Winogrande(HFTask):
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
\ No newline at end of file
raise
NotImplementedError
(
'Evaluation not implemented'
)
lm_eval/tasks/wsc273.py
View file @
269d3683
import
json
import
random
import
os
from
lm_eval.base
import
Dataset
from
lm_eval.base
import
Task
from
..utils
import
sh
class
WinogradSchemaChallenge273
(
Dataset
):
class
WinogradSchemaChallenge273
(
Task
):
def
__init__
(
self
):
super
().
__init__
()
...
...
lm_eval/utils.py
View file @
269d3683
...
...
@@ -28,4 +28,4 @@ def simple_parse_args_string(args_string):
def
join_iters
(
iters
):
for
iter
in
iters
:
yield
from
iter
\ No newline at end of file
yield
from
iter
lm_eval/utils_stream.py
View file @
269d3683
...
...
@@ -5,11 +5,13 @@ from tqdm import tqdm
import
json
class
ExitCodeError
(
Exception
):
pass
class
ExitCodeError
(
Exception
):
pass
def
sh
(
x
):
if
os
.
system
(
x
):
raise
ExitCodeError
()
if
os
.
system
(
x
):
raise
ExitCodeError
()
def
ls
(
x
):
return
[
x
+
'/'
+
fn
for
fn
in
os
.
listdir
(
x
)]
...
...
@@ -64,7 +66,8 @@ class join:
self
.
sep
=
sep
def
__rrshift__
(
self
,
other
):
if
other
is
None
:
return
if
other
is
None
:
return
try
:
return
self
.
sep
.
join
(
other
)
except
:
...
...
@@ -156,4 +159,4 @@ def comp(*fs):
return
_f
X
=
Reflective
()
\ No newline at end of file
X
=
Reflective
()
setup.py
0 → 100644
View file @
269d3683
import
setuptools
with
open
(
"README.md"
,
"r"
,
encoding
=
"utf-8"
)
as
fh
:
long_description
=
fh
.
read
()
setuptools
.
setup
(
name
=
"lm_eval_harness"
,
version
=
"0.0.1"
,
author
=
"Leo Gao"
,
author_email
=
"lg@eleuther.ai"
,
description
=
"A framework for evaluating autoregressive language models"
,
long_description
=
long_description
,
long_description_content_type
=
"text/markdown"
,
url
=
"https://github.com/EleutherAI/lm-evaluation-harness"
,
packages
=
setuptools
.
find_packages
(),
classifiers
=
[
"Programming Language :: Python :: 3"
,
"License :: OSI Approved :: MIT License"
,
"Operating System :: OS Independent"
,
],
python_requires
=
'>=3.6'
,
)
tests/test_all_sanitycheck.py
0 → 100644
View file @
269d3683
import
lm_eval.tasks
as
tasks
import
lm_eval.base
as
base
from
unittest.mock
import
MagicMock
from
itertools
import
islice
import
pytest
@
pytest
.
mark
.
parametrize
(
"taskname,Task"
,
tasks
.
TASK_REGISTRY
.
items
())
def
test_basic_interface
(
taskname
,
Task
):
print
(
'Evaluating task'
,
taskname
)
dl
=
Task
.
download
Task
.
download
=
MagicMock
()
task
=
Task
()
Task
.
download
=
dl
assert
task
.
has_training_docs
()
in
[
True
,
False
]
assert
task
.
has_validation_docs
()
in
[
True
,
False
]
assert
task
.
has_test_docs
()
in
[
True
,
False
]
assert
isinstance
(
task
.
aggregation
(),
dict
)
assert
isinstance
(
task
.
higher_is_better
(),
dict
)
assert
task
.
aggregation
().
keys
()
==
task
.
higher_is_better
().
keys
()
for
v
in
task
.
higher_is_better
().
values
():
assert
v
in
[
True
,
False
]
@
pytest
.
mark
.
parametrize
(
"taskname,Task"
,
tasks
.
TASK_REGISTRY
.
items
())
def
test_documents_and_requests
(
taskname
,
Task
):
print
(
'Evaluating task'
,
taskname
)
task
=
Task
()
fns
=
[]
if
task
.
has_training_docs
():
fns
.
append
(
task
.
training_docs
)
if
task
.
has_validation_docs
():
fns
.
append
(
task
.
validation_docs
)
# test doce might not have labels
#if task.has_test_docs(): fns.append(task.test_docs)
for
fn
in
fns
:
#print(list(islice(fn(), 10)))
for
doc
in
islice
(
fn
(),
10
):
txt
=
task
.
doc_to_text
(
doc
)
tgt
=
task
.
doc_to_target
(
doc
)
assert
isinstance
(
txt
,
str
)
assert
isinstance
(
tgt
,
str
)
reqs
=
task
.
construct_requests
(
doc
,
txt
)
# todo: mock lm by pluggin what's currently in main.py in here
for
req
in
reqs
:
assert
isinstance
(
req
,
base
.
Request
)
\ No newline at end of file
Prev
1
2
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment