Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
2d61b3ce
Unverified
Commit
2d61b3ce
authored
Jan 29, 2021
by
Stella Biderman
Committed by
GitHub
Jan 29, 2021
Browse files
Merge pull request #103 from EleutherAI/piqa
Implement PiQA
parents
a2f5b74b
63854c10
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
35 additions
and
57 deletions
+35
-57
lm_eval/base.py
lm_eval/base.py
+0
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-0
lm_eval/tasks/piqa.py
lm_eval/tasks/piqa.py
+33
-56
No files found.
lm_eval/base.py
View file @
2d61b3ce
...
@@ -59,7 +59,6 @@ class LM(abc.ABC):
...
@@ -59,7 +59,6 @@ class LM(abc.ABC):
class
Dataset
(
abc
.
ABC
):
class
Dataset
(
abc
.
ABC
):
@
abc
.
abstractmethod
def
__init__
(
self
):
def
__init__
(
self
):
self
.
download
()
self
.
download
()
self
.
_traindocs
=
None
self
.
_traindocs
=
None
...
...
lm_eval/tasks/__init__.py
View file @
2d61b3ce
...
@@ -14,6 +14,7 @@ from . import naturalqs
...
@@ -14,6 +14,7 @@ from . import naturalqs
from
.
import
sat
from
.
import
sat
from
.
import
arithmetic
from
.
import
arithmetic
from
.
import
lambada
from
.
import
lambada
from
.
import
piqa
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
# GLUE
# GLUE
...
@@ -39,6 +40,7 @@ TASK_REGISTRY = {
...
@@ -39,6 +40,7 @@ TASK_REGISTRY = {
# Order by benchmark/genre?
# Order by benchmark/genre?
"lambada"
:
lambada
.
LAMBADA
,
"lambada"
:
lambada
.
LAMBADA
,
"piqa"
:
piqa
.
PiQA
,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
...
...
lm_eval/tasks/piqa.py
View file @
2d61b3ce
import
json
import
json
import
random
import
random
from
lm_eval.base
import
Dataset
from
lm_eval.base
import
Dataset
,
rf
,
mean
from
..utils
import
sh
from
..utils
import
sh
import
os
class
PiQA
(
Dataset
):
class
PiQA
(
Dataset
):
def
__init__
(
self
):
self
.
download
()
def
download
(
self
):
def
download
(
self
):
#pass
if
not
os
.
path
.
exists
(
'data/piqa'
):
#TODO:
don't download if files already there
#TODO:
use best_download
sh
(
"""
sh
(
"""
mkdir -p data/piqa
mkdir -p data/piqa
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
wget https://yonatanbisk.com/piqa/data/train.jsonl -O data/piqa/piqa-train.jsonl
...
@@ -25,11 +24,11 @@ class PiQA(Dataset):
...
@@ -25,11 +24,11 @@ class PiQA(Dataset):
return
True
return
True
def
has_test_docs
(
self
):
def
has_test_docs
(
self
):
return
Tru
e
return
Fals
e
def
load_docs
(
self
,
textfilename
,
labelfilename
):
def
load_docs
(
self
,
textfilename
,
labelfilename
):
if
labelfilename
!=
None
:
if
labelfilename
!=
None
:
return
zip
([
json
.
loads
(
entry
)
for
entry
in
list
(
open
(
textfilename
,
'r'
))],
list
(
open
(
labelfilename
,
'r'
)))
return
zip
([
json
.
loads
(
entry
)
for
entry
in
list
(
open
(
textfilename
,
'r'
))],
list
(
map
(
lambda
x
:
x
.
strip
(),
open
(
labelfilename
,
'r'
)))
)
else
:
else
:
return
[
json
.
loads
(
entry
)
for
entry
in
list
(
open
(
textfilename
,
'r'
))]
return
[
json
.
loads
(
entry
)
for
entry
in
list
(
open
(
textfilename
,
'r'
))]
...
@@ -39,62 +38,40 @@ class PiQA(Dataset):
...
@@ -39,62 +38,40 @@ class PiQA(Dataset):
def
validation_docs
(
self
):
def
validation_docs
(
self
):
return
self
.
load_docs
(
'data/piqa/piqa-valid.jsonl'
,
'data/piqa/piqa-valid-labels.lst'
)
return
self
.
load_docs
(
'data/piqa/piqa-valid.jsonl'
,
'data/piqa/piqa-valid-labels.lst'
)
def
test_docs
(
self
):
#
def test_docs(self):
return
self
.
load_docs
(
'data/piqa/piqa-test.jsonl'
,
None
)
#
return self.load_docs('data/piqa/piqa-test.jsonl', None)
def
fewshot_description
(
self
):
def
fewshot_description
(
self
):
# TODO: figure out fewshot description
# TODO: figure out fewshot description
return
""
return
""
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
#TODO: check if oa uses newline
return
doc
[
0
][
'goal'
]
return
doc
[
'goal'
]
+
' '
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
rightanswer
=
int
(
doc
[
1
][
0
])
+
1
#TODO: check if oa uses newline
return
''
.
join
([
doc
[
0
][
'goal'
],
' '
,
doc
[
0
][
'sol'
+
str
(
rightanswer
)]])
rightanswer
=
int
(
doc
[
1
])
+
1
return
'
\n
'
+
''
.
join
([
doc
[
0
][
'goal'
],
' '
,
doc
[
0
][
'sol'
+
str
(
rightanswer
)]])
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
ll_1
,
_
=
rf
.
loglikelihood
(
ctx
,
doc
[
0
][
'sol1'
])
Requests which will be sent to the LM.
ll_2
,
_
=
rf
.
loglikelihood
(
ctx
,
doc
[
0
][
'sol2'
])
:param doc:
return
ll_1
,
ll_2
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
ll_1
,
ll_2
=
results
dict where keys are the names of submetrics and values are the values of
the metric for that one document
return
{
'acc'
:
(
ll_1
>
ll_2
)
==
(
int
(
doc
[
1
])
==
0
)
:param doc:
}
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
return
{
:returns: {str: [float] -> float}
'acc'
:
mean
A dictionary where keys are the names of submetrics and values are
}
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
return
{
:returns: {str: bool}
'acc'
:
True
A dictionary where keys are the names of submetrics and values are
}
whether a higher value of the submetric is better
\ No newline at end of file
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment