Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
bc5495d2
Unverified
Commit
bc5495d2
authored
Feb 02, 2021
by
Leo Gao
Committed by
GitHub
Feb 02, 2021
Browse files
Merge branch 'master' into wsc273-evaluation
parents
b5e86d3f
e12d0078
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
38 additions
and
42 deletions
+38
-42
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-1
lm_eval/tasks/webqs.py
lm_eval/tasks/webqs.py
+27
-38
main.py
main.py
+9
-3
No files found.
lm_eval/tasks/__init__.py
View file @
bc5495d2
...
@@ -17,6 +17,7 @@ from . import lambada
...
@@ -17,6 +17,7 @@ from . import lambada
from
.
import
race
from
.
import
race
from
.
import
piqa
from
.
import
piqa
from
.
import
triviaqa
from
.
import
triviaqa
from
.
import
webqs
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
...
@@ -55,7 +56,7 @@ TASK_REGISTRY = {
...
@@ -55,7 +56,7 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet
"race"
:
race
.
RACE
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
#
"webqs": webqs.WebQs,
# not implemented yet
"webqs"
:
webqs
.
WebQs
,
"wsc273"
:
wsc273
.
WinogradSchemaChallenge273
,
"wsc273"
:
wsc273
.
WinogradSchemaChallenge273
,
# "winogrande": winogrande.Winogrande, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r1"
:
anli
.
ANLIRound1
,
...
...
lm_eval/tasks/webqs.py
View file @
bc5495d2
from
.
common
import
HFTask
from
.
common
import
HFTask
from
lm_eval.base
import
mean
,
rf
class
WebQs
(
HFTask
):
class
WebQs
(
HFTask
):
DATASET_PATH
=
"web_questions"
DATASET_PATH
=
"web_questions"
...
@@ -18,7 +19,6 @@ class WebQs(HFTask):
...
@@ -18,7 +19,6 @@ class WebQs(HFTask):
return
""
return
""
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
print
(
doc
)
return
"Q: "
+
doc
[
'question'
]
+
'
\n
A:'
return
"Q: "
+
doc
[
'question'
]
+
'
\n
A:'
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
...
@@ -26,48 +26,37 @@ class WebQs(HFTask):
...
@@ -26,48 +26,37 @@ class WebQs(HFTask):
# multiple correct answers being possible.
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
# TODO: make sure we're actually handling multi-answer correctly
return
" "
+
doc
[
'answers'
][
0
]
return
" "
+
doc
[
'answers'
][
0
]
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases
.
sort
()
ret
=
[
aliases
[
0
]]
for
alias
in
aliases
[
1
:]:
if
not
alias
.
startswith
(
ret
[
-
1
]):
ret
.
append
(
alias
)
return
ret
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
ret
=
[]
Requests which will be sent to the LM.
for
alias
in
self
.
_remove_prefixes
(
doc
[
'answers'
]):
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
ret
.
append
(
is_prediction
)
return
ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
return
{
dict where keys are the names of submetrics and values are the values of
"acc"
:
float
(
any
(
results
))
the metric for that one document
}
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
return
{
:returns: {str: [float] -> float}
"acc"
:
mean
,
A dictionary where keys are the names of submetrics and values are
}
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
return
{
:returns: {str: bool}
"acc"
:
True
A dictionary where keys are the names of submetrics and values are
}
whether a higher value of the submetric is better
\ No newline at end of file
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
main.py
View file @
bc5495d2
...
@@ -32,8 +32,7 @@ def main():
...
@@ -32,8 +32,7 @@ def main():
task_names
=
args
.
tasks
.
split
(
","
)
task_names
=
args
.
tasks
.
split
(
","
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
# TODO: fall back to test docs
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())]
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
task
.
has_validation_docs
()]
results
=
collections
.
defaultdict
(
dict
)
results
=
collections
.
defaultdict
(
dict
)
...
@@ -50,7 +49,14 @@ def main():
...
@@ -50,7 +49,14 @@ def main():
# get lists of each type of requeste
# get lists of each type of requeste
for
task_name
,
task
in
task_dict_items
:
for
task_name
,
task
in
task_dict_items
:
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task
.
validation_docs
(),
0
,
args
.
limit
)):
#default to validation doc, fall back to test doc if validation unavailable
# TODO: the val-fallback-to-test system isn't final, we should revisit it at some point
if
task
.
has_validation_docs
():
task_doc_func
=
task
.
validation_docs
elif
task
.
has_test_docs
():
task_doc_func
=
task
.
test_docs
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_doc_func
(),
0
,
args
.
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment