Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
1b35f6b9
Commit
1b35f6b9
authored
Feb 02, 2021
by
jeffhsu3
Browse files
pubmedqa
parent
5b6182d5
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
87 additions
and
8 deletions
+87
-8
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+3
-0
lm_eval/tasks/pubmedqa.py
lm_eval/tasks/pubmedqa.py
+63
-0
main.py
main.py
+18
-5
write_out.py
write_out.py
+3
-3
No files found.
lm_eval/tasks/__init__.py
View file @
1b35f6b9
...
...
@@ -17,6 +17,7 @@ from . import lambada
from
.
import
race
from
.
import
piqa
from
.
import
triviaqa
from
.
import
pubmedqa
TASK_REGISTRY
=
{
...
...
@@ -45,6 +46,8 @@ TASK_REGISTRY = {
"lambada"
:
lambada
.
LAMBADA
,
"piqa"
:
piqa
.
PiQA
,
"pubmedqa"
:
pubmedqa
.
Pubmed_QA
,
#"triviaqa": triviaqa.TriviaQA,
# "arc_easy": arc.ARCEasy, # not implemented yet
# "arc_challenge": arc.ARCChallenge, # not implemented yet
...
...
lm_eval/tasks/pubmedqa.py
0 → 100644
View file @
1b35f6b9
"""
"""
import
numpy
as
np
from
..utils
import
sh
from
.
common
import
HFTask
,
yesno
from
lm_eval.base
import
Dataset
,
rf
,
mean
class
Pubmed_QA
(
HFTask
):
DATASET_PATH
=
"pubmed_qa"
DATASET_NAME
=
"pqa_labeled"
def
has_training_docs
(
self
):
return
True
def
has_test_docs
(
self
):
return
False
def
has_validation_docs
(
self
):
return
False
def
fewshot_description
(
self
):
# Average ctx length in labelled dataset is 238.9
return
""
def
doc_to_text
(
self
,
doc
):
ctxs
=
"
\n
"
.
join
(
doc
[
'context'
][
'contexts'
])
return
"abstract: {}
\n
question: {}
\n
answer:"
.
format
(
ctxs
,
doc
[
'question'
],
doc
[
'final_decision'
]
)
def
doc_to_target
(
self
,
doc
):
return
" {}"
.
format
(
doc
[
'final_decision'
])
def
construct_requests
(
self
,
doc
,
ctx
):
""" Uses RequestFactory to construct Requests and returns
an iterable of Requests which will be sent to the LM.
"""
ll_yes
,
_
=
rf
.
loglikelihood
(
ctx
,
" yes"
)
ll_no
,
_
=
rf
.
loglikelihood
(
ctx
,
" no"
)
ll_maybe
,
_
=
rf
.
loglikelihood
(
ctx
,
" maybe"
)
return
ll_yes
,
ll_no
,
ll_maybe
def
process_results
(
self
,
doc
,
results
):
gold
=
doc
[
'final_decision'
]
ll_yes
,
ll_no
,
ll_maybe
=
results
pred
=
np
.
argmax
(
results
)
return
{
'acc'
:
[
"yes"
,
"no"
,
"maybe"
][
pred
]
==
gold
,
}
def
aggregation
(
self
):
return
{
'acc'
:
mean
}
def
higher_is_better
(
self
):
return
{
'acc'
:
True
}
main.py
View file @
1b35f6b9
...
...
@@ -32,9 +32,15 @@ def main():
task_names
=
args
.
tasks
.
split
(
","
)
task_dict
=
tasks
.
get_task_dict
(
task_names
)
# TODO: fall back to test docs
task_dict_items
=
[(
name
,
task
)
for
name
,
task
in
task_dict
.
items
()
if
task
.
has_validation_docs
()]
task_dict_items
=
[]
for
name
,
task
in
task_dict
.
items
():
if
task
.
has_validation_docs
():
task_dict_items
.
append
((
name
,
task
,
'validation'
))
elif
task
.
has_test_docs
():
task_dict_items
.
append
((
name
,
task
,
'test'
))
elif
task
.
has_training_docs
():
task_dict_items
.
append
((
name
,
task
,
'training'
))
results
=
collections
.
defaultdict
(
dict
)
requests
=
collections
.
defaultdict
(
list
)
...
...
@@ -49,8 +55,15 @@ def main():
docs
=
{}
# get lists of each type of requeste
for
task_name
,
task
in
task_dict_items
:
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task
.
validation_docs
(),
0
,
args
.
limit
)):
for
task_name
,
task
,
dset
in
task_dict_items
:
if
dset
==
'training'
:
temp
=
task
.
training_docs
()
elif
dset
==
'test'
:
temp
=
task
.
test_docs
()
else
:
temp
=
task
.
validation_docs
()
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
temp
,
0
,
args
.
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
...
...
write_out.py
View file @
1b35f6b9
...
...
@@ -37,14 +37,14 @@ def main():
iters
=
[]
for
set
in
args
.
sets
.
split
(
","
):
if
set
==
'train'
and
task
.
has_train_docs
():
docs
=
task
.
train_docs
()
if
set
==
'train'
and
task
.
has_train
ing
_docs
():
docs
=
task
.
train
ing
_docs
()
if
set
==
'val'
and
task
.
has_validation_docs
():
docs
=
task
.
validation_docs
()
if
set
==
'test'
and
task
.
has_test_docs
():
docs
=
task
.
test_docs
()
iters
.
append
(
docs
)
docs
=
join_iters
(
iters
)
with
open
(
os
.
path
.
join
(
args
.
output_base_path
,
task_name
),
"w"
)
as
f
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment