Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
742b5df2
Commit
742b5df2
authored
Feb 11, 2021
by
Anthony DiPofi
Browse files
add headqa and mathqa datasets
parent
f7992789
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
141 additions
and
1 deletion
+141
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+4
-1
lm_eval/tasks/headqa.py
lm_eval/tasks/headqa.py
+64
-0
lm_eval/tasks/mathqa.py
lm_eval/tasks/mathqa.py
+73
-0
No files found.
lm_eval/tasks/__init__.py
View file @
742b5df2
...
@@ -21,7 +21,8 @@ from . import pubmedqa
...
@@ -21,7 +21,8 @@ from . import pubmedqa
from
.
import
sciq
from
.
import
sciq
from
.
import
webqs
from
.
import
webqs
from
.
import
qa4mre
from
.
import
qa4mre
from
.
import
headqa
from
.
import
mathqa
TASK_REGISTRY
=
{
TASK_REGISTRY
=
{
# GLUE
# GLUE
...
@@ -67,6 +68,8 @@ TASK_REGISTRY = {
...
@@ -67,6 +68,8 @@ TASK_REGISTRY = {
# "squad": squad.SQuAD, # not implemented yet
# "squad": squad.SQuAD, # not implemented yet
"race"
:
race
.
RACE
,
"race"
:
race
.
RACE
,
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
# "naturalqs": naturalqs.NaturalQs, # not implemented yet
"headqa"
:
headqa
.
HeadQA
,
"mathqa"
:
mathqa
.
MathQA
,
"webqs"
:
webqs
.
WebQs
,
"webqs"
:
webqs
.
WebQs
,
"wsc273"
:
wsc273
.
WinogradSchemaChallenge273
,
"wsc273"
:
wsc273
.
WinogradSchemaChallenge273
,
"winogrande"
:
winogrande
.
Winogrande
,
"winogrande"
:
winogrande
.
Winogrande
,
...
...
lm_eval/tasks/headqa.py
0 → 100644
View file @
742b5df2
from
.
common
import
HFTask
from
lm_eval.base
import
mean
,
rf
class
HeadQA
(
HFTask
):
DATASET_PATH
=
"head_qa"
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
True
def
fewshot_description
(
self
):
# TODO: figure out description
return
""
def
doc_to_text
(
self
,
doc
):
return
"Q: "
+
doc
[
'qtext'
]
+
'
\n
A:'
def
doc_to_target
(
self
,
doc
):
# this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
return
" "
+
doc
[
'answers'
][
0
][
'atext'
]
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases
.
sort
()
ret
=
[
aliases
[
0
]]
for
alias
in
aliases
[
1
:]:
if
not
alias
.
startswith
(
ret
[
-
1
]):
ret
.
append
(
alias
)
return
ret
def
construct_requests
(
self
,
doc
,
ctx
):
ret
=
[]
atexts
=
[
x
[
'atext'
]
for
x
in
doc
[
'answers'
]]
for
alias
in
self
.
_remove_prefixes
(
atexts
):
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
ret
.
append
(
is_prediction
)
return
ret
def
process_results
(
self
,
doc
,
results
):
return
{
"acc"
:
float
(
any
(
results
))
}
def
aggregation
(
self
):
return
{
"acc"
:
mean
,
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
}
\ No newline at end of file
lm_eval/tasks/mathqa.py
0 → 100644
View file @
742b5df2
from
.
common
import
HFTask
from
lm_eval.base
import
mean
,
rf
class
MathQA
(
HFTask
):
DATASET_PATH
=
"math_qa"
DATASET_NAME
=
None
def
has_training_docs
(
self
):
return
True
def
has_validation_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
True
def
fewshot_description
(
self
):
# TODO: figure out description
return
""
def
doc_to_text
(
self
,
doc
):
return
"Q: "
+
doc
[
'Problem'
]
+
'
\n
A:'
def
doc_to_target
(
self
,
doc
):
# this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
return
" "
+
doc
[
'correct'
]
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases
.
sort
()
ret
=
[
aliases
[
0
]]
for
alias
in
aliases
[
1
:]:
if
not
alias
.
startswith
(
ret
[
-
1
]):
ret
.
append
(
alias
)
return
ret
def
construct_requests
(
self
,
doc
,
ctx
):
self
.
answer_options
=
[
'a'
,
'b'
,
'c'
,
'd'
,
'e'
]
ret
=
[]
for
i
in
range
(
len
(
self
.
answer_options
)):
ll
,
_
=
rf
.
loglikelihood
(
ctx
,
' '
+
self
.
answer_options
[
i
])
ret
.
append
(
ll
)
return
ret
def
process_results
(
self
,
doc
,
results
):
max_result_idx
=
max
(
enumerate
(
results
),
key
=
lambda
x
:
x
[
1
])[
0
]
if
doc
[
'correct'
]
==
self
.
answer_options
[
max_result_idx
]:
result
=
1.0
else
:
result
=
0.0
return
{
"acc"
:
result
}
def
aggregation
(
self
):
return
{
"acc"
:
mean
,
}
def
higher_is_better
(
self
):
return
{
"acc"
:
True
}
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment