Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
97174266
Commit
97174266
authored
Feb 13, 2021
by
Jonathan Tow
Browse files
Refactor `HeadQA` as a `MultipleChoiceTask`
parent
d194d65c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
44 deletions
+28
-44
lm_eval/tasks/headqa.py
lm_eval/tasks/headqa.py
+28
-44
No files found.
lm_eval/tasks/headqa.py
View file @
97174266
from
.
common
import
HFTask
from
lm_eval.base
import
mean
,
rf
from
lm_eval.base
import
MultipleChoiceTask
class
HeadQA
(
HFTask
):
class
HeadQA
(
HFTask
,
MultipleChoiceTask
):
DATASET_PATH
=
"head_qa"
DATASET_NAME
=
None
...
...
@@ -14,51 +15,34 @@ class HeadQA(HFTask):
def
has_test_docs
(
self
):
return
True
def
fewshot_description
(
self
):
# TODO: figure out description
return
""
def
doc_to_text
(
self
,
doc
):
return
"Question: "
+
doc
[
'qtext'
]
+
'
\n
Answer:'
def
doc_to_target
(
self
,
doc
):
# this picks one answer to be the "correct" one, despite sometimes
# multiple correct answers being possible.
# TODO: make sure we're actually handling multi-answer correctly
return
" "
+
doc
[
'answers'
][
0
][
'atext'
]
def
_remove_prefixes
(
self
,
aliases
):
# Optimization: Remove any alias that has a strict prefix elsewhere in the list
# we can do this because if the prefix is acceptable by isgreedy, we can stop looking
aliases
.
sort
()
ret
=
[
aliases
[
0
]]
for
alias
in
aliases
[
1
:]:
if
not
alias
.
startswith
(
ret
[
-
1
]):
ret
.
append
(
alias
)
def
_convert_standard
(
self
,
doc
):
out_doc
=
{
"id"
:
doc
[
"qid"
],
"query"
:
"Question: "
+
doc
[
"qtext"
]
+
"
\n
Answer:"
,
"choices"
:
[
answer
[
"atext"
]
for
answer
in
doc
[
"answers"
]],
"gold"
:
int
(
doc
[
"ra"
])
-
1
,
}
return
out_doc
return
ret
def
_load_docs
(
self
,
docs
):
for
doc
in
docs
:
yield
self
.
_convert_standard
(
doc
)
def
construct_requests
(
self
,
doc
,
ctx
):
def
training_docs
(
self
):
docs
=
super
().
training_docs
()
return
self
.
_load_docs
(
docs
)
ret
=
[]
atexts
=
[
x
[
'atext'
]
for
x
in
doc
[
'answers'
]]
for
alias
in
self
.
_remove_prefixes
(
atexts
):
_
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
" "
+
alias
)
ret
.
append
(
is_prediction
)
return
ret
def
validation_docs
(
self
):
docs
=
super
().
validation_docs
()
return
self
.
_load_docs
(
docs
)
def
process_results
(
self
,
doc
,
results
):
return
{
"acc"
:
float
(
any
(
results
))
}
def
test_docs
(
self
):
docs
=
super
().
test_docs
()
return
self
.
_load_docs
(
docs
)
def
aggregation
(
self
):
return
{
"acc"
:
mean
,
}
def
fewshot_description
(
self
):
# TODO: figure out description
return
""
def
higher_is_better
(
self
):
return
{
"acc"
:
True
}
\ No newline at end of file
def
doc_to_text
(
self
,
doc
):
return
doc
[
"query"
]
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment