Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9484eecc
Commit
9484eecc
authored
Apr 25, 2022
by
jon-tow
Browse files
Fix coqa
parent
7d282b5f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
24 deletions
+30
-24
lm_eval/base.py
lm_eval/base.py
+5
-3
lm_eval/evaluator.py
lm_eval/evaluator.py
+21
-18
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+3
-2
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+1
-1
No files found.
lm_eval/base.py
View file @
9484eecc
import
abc
from
typing
import
Iterable
import
promptsource
import
numpy
as
np
import
random
import
re
...
...
@@ -639,11 +640,12 @@ class PromptSourceTask(Task):
self
.
prompt
=
prompt
def
doc_to_target
(
self
,
doc
):
_
,
target
=
prompt
.
apply
(
doc
)
_
,
target
=
self
.
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
def
doc_to_text
(
self
,
doc
):
text
,
_
=
prompt
.
apply
(
doc
)
print
(
doc
)
text
,
_
=
self
.
prompt
.
apply
(
doc
)
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
...
...
@@ -660,7 +662,7 @@ class PromptSourceTask(Task):
_requests
=
[]
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
answer_choice
in
prompt
.
get_fixed_answer_choices_list
():
for
answer_choice
in
self
.
prompt
.
get_fixed_answer_choices_list
():
ll_answer_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
answer_choice
}
"
)
_requests
.
append
(
ll_answer_choice
)
else
:
...
...
lm_eval/evaluator.py
View file @
9484eecc
...
...
@@ -169,8 +169,10 @@ def evaluate(
docs
=
{}
# get lists of each type of request
for
task_name
,
task
in
task_dict_items
:
versions
[
task_name
]
=
task
.
VERSION
for
task_prompt_name
,
task
in
task_dict_items
:
print
(
f
"TASK PROMPT NAME:
{
task_prompt_name
}
"
)
versions
[
task_prompt_name
]
=
task
.
VERSION
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if
task
.
has_test_docs
():
...
...
@@ -187,13 +189,13 @@ def evaluate(
rnd
.
shuffle
(
task_docs
)
description
=
(
description_dict
[
task_name
]
if
description_dict
and
task_name
in
description_dict
description_dict
[
task_
prompt_
name
]
if
description_dict
and
task_
prompt_
name
in
description_dict
else
""
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_
prompt_
name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
...
...
@@ -204,7 +206,7 @@ def evaluate(
requests
[
req
.
request_type
].
append
(
req
)
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin
[
req
.
request_type
].
append
((
i
,
task_name
,
doc
,
doc_id
))
requests_origin
[
req
.
request_type
].
append
((
i
,
task_
prompt_
name
,
doc
,
doc_id
))
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
...
@@ -222,32 +224,33 @@ def evaluate(
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
for
resp
,
(
i
,
task_
prompt_
name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_
prompt_
name
,
doc_id
)].
append
((
i
,
resp
))
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
for
(
task_name
,
doc_id
),
requests
in
process_res_queue
.
items
():
for
(
task_
prompt_
name
,
doc_id
),
requests
in
process_res_queue
.
items
():
requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
requests
=
[
x
[
1
]
for
x
in
requests
]
task
=
task_dict
[
task_name
]
doc
=
docs
[(
task_name
,
doc_id
)]
task
=
task_dict
[
task_
prompt_
name
]
doc
=
docs
[(
task_
prompt_
name
,
doc_id
)]
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
vals
[(
task_prompt_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
for
(
task_prompt_name
,
metric
),
items
in
vals
.
items
():
task_name
,
prompt_name
=
task_prompt_name
.
split
(
"+"
)
results
[
task_prompt_name
][
"task_name"
]
=
task_name
results
[
task_prompt_name
][
"prompt_name"
]
=
prompt_name
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_
prompt_
name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
...
...
@@ -258,7 +261,7 @@ def evaluate(
else
bootstrap_iters
,
)
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_
prompt_
name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
...
...
lm_eval/tasks/arithmetic.py
View file @
9484eecc
...
...
@@ -58,10 +58,11 @@ class Arithmetic(Task):
def
construct_requests
(
self
,
doc
,
ctx
):
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
[
"completion"
])
return
is_prediction
return
ll
,
is_prediction
def
process_results
(
self
,
doc
,
results
):
is_prediction
,
=
results
print
(
results
)
results
=
results
return
{
"acc"
:
is_prediction
}
...
...
lm_eval/tasks/coqa.py
View file @
9484eecc
...
...
@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
import
inspect
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
lm_eval.datasets.coqa.coqa
from
lm_eval.base
import
PromptSourceTask
,
rf
,
mean
from
lm_eval.base
import
PromptSourceTask
,
Task
,
rf
,
mean
from
itertools
import
zip_longest
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment