Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
822fcc6f
Commit
822fcc6f
authored
Jan 28, 2021
by
Leo Gao
Browse files
Implement LAMBADA
parent
60a6fd8c
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
36 additions
and
53 deletions
+36
-53
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+3
-0
lm_eval/tasks/lambada.py
lm_eval/tasks/lambada.py
+33
-53
No files found.
lm_eval/tasks/__init__.py
View file @
822fcc6f
...
...
@@ -13,6 +13,7 @@ from . import squad
from
.
import
naturalqs
from
.
import
sat
from
.
import
arithmetic
from
.
import
lambada
TASK_REGISTRY
=
{
# GLUE
...
...
@@ -37,6 +38,8 @@ TASK_REGISTRY = {
# Order by benchmark/genre?
"lambada"
:
lambada
.
LAMBADA
,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
# "quac": quac.QuAC, # not implemented yet
...
...
lm_eval/tasks/lambada.py
View file @
822fcc6f
from
lm_eval.base
import
Dataset
from
lm_eval.base
import
Dataset
,
rf
,
mean
from
lm_eval.utils
import
sh
import
json
import
requests
import
ftfy
import
math
from
best_download
import
download_file
class
Lambada
(
Dataset
):
def
__init__
(
self
):
self
.
download
()
class
LAMBADA
(
Dataset
):
def
download
(
self
):
sh
(
"mkdir -p data/lambada"
)
with
open
(
"data/lambada/lambada_test.json"
,
'w'
)
as
f
:
req
=
requests
.
get
(
"https://storage.googleapis.com/gpt-2/data/lambada_test.jsonl"
)
req
.
raise_for_status
()
jsons
=
[
json
.
loads
(
l
)
for
l
in
req
.
iter_lines
()]
texts
=
[
ftfy
.
fix_text
(
j
[
'text'
],
normalization
=
'NFKC'
)
for
j
in
jsons
]
json
.
dump
(
texts
,
f
)
download_file
(
"https://storage.googleapis.com/gpt-2/data/lambada_test.jsonl"
,
"data/lambada/lambada_test.jsonl"
,
"4aa8d02cd17c719165fc8a7887fddd641f43fcafa4b1c806ca8abc31fabdb226"
)
def
has_training_docs
(
self
):
return
False
...
...
@@ -32,61 +31,42 @@ class Lambada(Dataset):
def
validation_docs
(
self
):
pass
def
load_doc
(
self
,
myjson
):
return
[
doc
for
doc
in
myjson
]
def
test_docs
(
self
):
myjson
=
json
.
load
(
open
(
"data/lambada/lambada_test.json"
))
return
self
.
load_doc
(
myjson
)
with
open
(
"data/lambada/lambada_test.jsonl"
)
as
fh
:
for
line
in
fh
:
yield
json
.
loads
(
line
)
def
doc_to_text
(
self
,
doc
):
return
doc
[
'text'
].
rsplit
(
' '
,
1
)[
0
]
def
doc_to_te
x
t
(
self
,
doc
,
include_target
=
True
):
# TODO: implement.
def
doc_to_t
arg
et
(
self
,
doc
):
return
" "
+
doc
[
'text'
].
rsplit
(
' '
,
1
)[
1
]
def
fewshot_description
(
self
):
# TODO: figure out description
return
""
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
ll
,
is_greedy
=
rf
.
loglikelihood
(
doc
,
self
.
doc_to_target
(
doc
))
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
return
ll
,
is_greedy
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
ll
,
is_greedy
=
results
return
{
'perplexity'
:
math
.
exp
(
-
ll
),
'accuracy'
:
int
(
is_greedy
)
}
def
aggregation
(
self
):
"""
:returns: {str: [float] -> float}
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
return
{
'perplexity'
:
mean
,
'accuracy'
:
mean
}
def
higher_is_better
(
self
):
"""
:returns: {str: bool}
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
\ No newline at end of file
return
{
'perplexity'
:
False
,
'accuracy'
:
True
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment