Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
34eb121f
"backend/vscode:/vscode.git/clone" did not exist on "58bead039892136ac16e601d37e0dd87a3a75bf3"
Commit
34eb121f
authored
Jan 31, 2021
by
Anthony DiPofi
Browse files
add webqs evaluation and fallback to test set when validation is unavailable
parent
6598967b
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
37 additions
and
42 deletions
+37
-42
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-1
lm_eval/tasks/webqs.py
lm_eval/tasks/webqs.py
+27
-38
main.py
main.py
+8
-3
No files found.
lm_eval/tasks/__init__.py
View file @
34eb121f
...
@@ -17,6 +17,7 @@ from . import lambada
...
@@ -17,6 +17,7 @@ from . import lambada
from
.
import
race
from
.
import
race
from
.
import
piqa
from
.
import
piqa
from
.
import
triviaqa
from
.
import
triviaqa
from
.
import
webqs
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
...
@@ -55,7 +56,7 @@ TASK_REGISTRY = {
...
@@ -55,7 +56,7 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet
"race"
:
race
.
RACE
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
#
"webqs": webqs.WebQs,
# not implemented yet
"webqs"
:
webqs
.
WebQs
,
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "wsc273": wsc273.WinogradSchemaChallenge273, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet
# "winogrande": winogrande.Winogrande, # not implemented yet
"anli_r1"
:
anli
.
ANLIRound1
,
"anli_r1"
:
anli
.
ANLIRound1
,
...
...
lm_eval/tasks/webqs.py
View file @
34eb121f
from
.
common
import
HFTask
from
.
common
import
HFTask
from
lm_eval.base
import
mean
,
rf
class
WebQs
(
HFTask
):
class
WebQs
(
HFTask
):
DATASET_PATH
=
"web_questions"
DATASET_PATH
=
"web_questions"
...
@@ -18,7 +19,6 @@ class WebQs(HFTask):
...
@@ -18,7 +19,6 @@ class WebQs(HFTask):
return
""
return
""
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
print
(
doc
)
return
"Q: "
+
doc
[
'question'
]
+
'
\n
A:'
return
"Q: "
+
doc
[
'question'
]
+
'
\n
A:'
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
...
@@ -26,48 +26,37 @@ class WebQs(HFTask):
...
@@ -26,48 +26,37 @@ class WebQs(HFTask):
# multiple correct answers being possible.
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
# TODO: make sure we're actually handling multi-answer correctly
return
" "
+
doc
[
'answers'
][
0
]
return
" "
+
doc
[
'answers'
][
0
]
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases
.
sort
()
ret
=
[
aliases
[
0
]]
for
alias
in
aliases
[
1
:]:
if
not
alias
.
startswith
(
ret
[
-
1
]):
ret
.
append
(
alias
)
return
ret
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns an iterable of
ret
=
[]
Requests which will be sent to the LM.
for
alias
in
self
.
_remove_prefixes
(
doc
[
'answers'
]):
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
ret
.
append
(
is_prediction
)
return
ret
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param ctx: str
The context string, generated by fewshot_context. This includes the natural
language description, as well as the few shot examples, and the question
part of the document for `doc`.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
"""Take a single document and the LM results and evaluates, returning a
return
{
dict where keys are the names of submetrics and values are the values of
"acc"
:
float
(
any
(
results
))
the metric for that one document
}
:param doc:
The document as returned from training_docs, validation_docs, or test_docs.
:param results:
The results of the requests created in construct_requests.
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
aggregation
(
self
):
def
aggregation
(
self
):
"""
return
{
:returns: {str: [float] -> float}
"acc"
:
mean
,
A dictionary where keys are the names of submetrics and values are
}
functions that aggregate a list of metrics
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
"""
return
{
:returns: {str: bool}
"acc"
:
True
A dictionary where keys are the names of submetrics and values are
}
whether a higher value of the submetric is better
\ No newline at end of file
"""
# TODO: implement evaluation.
raise
NotImplementedError
(
'Evaluation not implemented'
)
\ No newline at end of file
main.py
View file @
34eb121f
...
@@ -32,8 +32,7 @@ def main():
...
@@ -32,8 +32,7 @@ def main():
task_names
=
args
.
tasks
.
split
(
","
)
task_names
=
args
.
tasks
.
split
(
","
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
# TODO: fall back to test docs
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
(
task
.
has_validation_docs
()
or
task
.
has_test_docs
())]
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
task
.
has_validation_docs
()]
results
=
collections
.
defaultdict
(
dict
)
results
=
collections
.
defaultdict
(
dict
)
...
@@ -50,7 +49,13 @@ def main():
...
@@ -50,7 +49,13 @@ def main():
# get lists of each type of requeste
# get lists of each type of requeste
for
task_name
,
task
in
task_dict_items
:
for
task_name
,
task
in
task_dict_items
:
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task
.
validation_docs
(),
0
,
args
.
limit
)):
#default to validation doc, fall back to test doc if validation unavailable
if
task
.
has_validation_docs
():
task_doc_func
=
task
.
validation_docs
elif
task
.
has_test_docs
():
task_doc_func
=
task
.
test_docs
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_doc_func
(),
0
,
args
.
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment