Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9484eecc
Commit
9484eecc
authored
Apr 25, 2022
by
jon-tow
Browse files
Fix coqa
parent
7d282b5f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
30 additions
and
24 deletions
+30
-24
lm_eval/base.py
lm_eval/base.py
+5
-3
lm_eval/evaluator.py
lm_eval/evaluator.py
+21
-18
lm_eval/tasks/arithmetic.py
lm_eval/tasks/arithmetic.py
+3
-2
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+1
-1
No files found.
lm_eval/base.py
View file @
9484eecc
import
abc
import
abc
from
typing
import
Iterable
from
typing
import
Iterable
import
promptsource
import
numpy
as
np
import
numpy
as
np
import
random
import
random
import
re
import
re
...
@@ -639,11 +640,12 @@ class PromptSourceTask(Task):
...
@@ -639,11 +640,12 @@ class PromptSourceTask(Task):
self
.
prompt
=
prompt
self
.
prompt
=
prompt
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
):
_
,
target
=
prompt
.
apply
(
doc
)
_
,
target
=
self
.
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
return
f
"
{
target
}
"
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
):
text
,
_
=
prompt
.
apply
(
doc
)
print
(
doc
)
text
,
_
=
self
.
prompt
.
apply
(
doc
)
return
text
return
text
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
...
@@ -660,7 +662,7 @@ class PromptSourceTask(Task):
...
@@ -660,7 +662,7 @@ class PromptSourceTask(Task):
_requests
=
[]
_requests
=
[]
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
if
self
.
prompt
.
metadata
.
choices_in_prompt
:
for
answer_choice
in
prompt
.
get_fixed_answer_choices_list
():
for
answer_choice
in
self
.
prompt
.
get_fixed_answer_choices_list
():
ll_answer_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
answer_choice
}
"
)
ll_answer_choice
,
_
=
rf
.
loglikelihood
(
ctx
,
f
"
{
answer_choice
}
"
)
_requests
.
append
(
ll_answer_choice
)
_requests
.
append
(
ll_answer_choice
)
else
:
else
:
...
...
lm_eval/evaluator.py
View file @
9484eecc
...
@@ -169,8 +169,10 @@ def evaluate(
...
@@ -169,8 +169,10 @@ def evaluate(
docs
=
{}
docs
=
{}
# get lists of each type of request
# get lists of each type of request
for
task_name
,
task
in
task_dict_items
:
for
task_prompt_name
,
task
in
task_dict_items
:
versions
[
task_name
]
=
task
.
VERSION
print
(
f
"TASK PROMPT NAME:
{
task_prompt_name
}
"
)
versions
[
task_prompt_name
]
=
task
.
VERSION
# default to test doc, fall back to val doc if validation unavailable
# default to test doc, fall back to val doc if validation unavailable
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
# TODO: the test-fallback-to-val system isn't final, we should revisit it at some point
if
task
.
has_test_docs
():
if
task
.
has_test_docs
():
...
@@ -187,13 +189,13 @@ def evaluate(
...
@@ -187,13 +189,13 @@ def evaluate(
rnd
.
shuffle
(
task_docs
)
rnd
.
shuffle
(
task_docs
)
description
=
(
description
=
(
description_dict
[
task_name
]
description_dict
[
task_
prompt_
name
]
if
description_dict
and
task_name
in
description_dict
if
description_dict
and
task_
prompt_
name
in
description_dict
else
""
else
""
)
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
docs
[(
task_name
,
doc_id
)]
=
doc
docs
[(
task_
prompt_
name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
)
)
...
@@ -204,7 +206,7 @@ def evaluate(
...
@@ -204,7 +206,7 @@ def evaluate(
requests
[
req
.
request_type
].
append
(
req
)
requests
[
req
.
request_type
].
append
(
req
)
# i: index in requests for a single task instance
# i: index in requests for a single task instance
# doc_id: unique id that we can get back to a doc using `docs`
# doc_id: unique id that we can get back to a doc using `docs`
requests_origin
[
req
.
request_type
].
append
((
i
,
task_name
,
doc
,
doc_id
))
requests_origin
[
req
.
request_type
].
append
((
i
,
task_
prompt_
name
,
doc
,
doc_id
))
# all responses for each (task, doc)
# all responses for each (task, doc)
process_res_queue
=
collections
.
defaultdict
(
list
)
process_res_queue
=
collections
.
defaultdict
(
list
)
...
@@ -222,32 +224,33 @@ def evaluate(
...
@@ -222,32 +224,33 @@ def evaluate(
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
x
if
req
.
index
is
None
else
x
[
req
.
index
]
for
x
,
req
in
zip
(
resps
,
reqs
)
]
]
for
resp
,
(
i
,
task_name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
for
resp
,
(
i
,
task_
prompt_
name
,
doc
,
doc_id
)
in
zip
(
resps
,
requests_origin
[
reqtype
]):
process_res_queue
[(
task_name
,
doc_id
)].
append
((
i
,
resp
))
process_res_queue
[(
task_
prompt_
name
,
doc_id
)].
append
((
i
,
resp
))
vals
=
collections
.
defaultdict
(
list
)
vals
=
collections
.
defaultdict
(
list
)
# unpack results and sort back in order and return control to Task
# unpack results and sort back in order and return control to Task
for
(
task_name
,
doc_id
),
requests
in
process_res_queue
.
items
():
for
(
task_
prompt_
name
,
doc_id
),
requests
in
process_res_queue
.
items
():
requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
requests
.
sort
(
key
=
lambda
x
:
x
[
0
])
requests
=
[
x
[
1
]
for
x
in
requests
]
requests
=
[
x
[
1
]
for
x
in
requests
]
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_
prompt_
name
]
doc
=
docs
[(
task_name
,
doc_id
)]
doc
=
docs
[(
task_
prompt_
name
,
doc_id
)]
metrics
=
task
.
process_results
(
doc
,
requests
)
metrics
=
task
.
process_results
(
doc
,
requests
)
for
metric
,
value
in
metrics
.
items
():
for
metric
,
value
in
metrics
.
items
():
vals
[(
task_name
,
metric
)].
append
(
value
)
vals
[(
task_prompt_name
,
metric
)].
append
(
value
)
task_name
,
prompt_name
=
task_name
.
split
(
"+"
)
results
[
task_name
][
"task_name"
]
=
task_name
results
[
task_name
][
"prompt_name"
]
=
prompt_name
# aggregate results
# aggregate results
for
(
task_name
,
metric
),
items
in
vals
.
items
():
for
(
task_prompt_name
,
metric
),
items
in
vals
.
items
():
task_name
,
prompt_name
=
task_prompt_name
.
split
(
"+"
)
results
[
task_prompt_name
][
"task_name"
]
=
task_name
results
[
task_prompt_name
][
"prompt_name"
]
=
prompt_name
task
=
task_dict
[
task_name
]
task
=
task_dict
[
task_name
]
results
[
task_name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
results
[
task_
prompt_
name
][
metric
]
=
task
.
aggregation
()[
metric
](
items
)
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# hotfix: bleu, chrf, ter seem to be really expensive to bootstrap
# so we run them less iterations. still looking for a cleaner way to do this
# so we run them less iterations. still looking for a cleaner way to do this
...
@@ -258,7 +261,7 @@ def evaluate(
...
@@ -258,7 +261,7 @@ def evaluate(
else
bootstrap_iters
,
else
bootstrap_iters
,
)
)
if
stderr
is
not
None
:
if
stderr
is
not
None
:
results
[
task_name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
results
[
task_
prompt_
name
][
metric
+
"_stderr"
]
=
stderr
(
items
)
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
return
{
"results"
:
dict
(
results
),
"versions"
:
dict
(
versions
)}
...
...
lm_eval/tasks/arithmetic.py
View file @
9484eecc
...
@@ -58,10 +58,11 @@ class Arithmetic(Task):
...
@@ -58,10 +58,11 @@ class Arithmetic(Task):
def
construct_requests
(
self
,
doc
,
ctx
):
def
construct_requests
(
self
,
doc
,
ctx
):
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
[
"completion"
])
ll
,
is_prediction
=
rf
.
loglikelihood
(
ctx
,
doc
[
"completion"
])
return
is_prediction
return
ll
,
is_prediction
def
process_results
(
self
,
doc
,
results
):
def
process_results
(
self
,
doc
,
results
):
is_prediction
,
=
results
print
(
results
)
results
=
results
return
{
return
{
"acc"
:
is_prediction
"acc"
:
is_prediction
}
}
...
...
lm_eval/tasks/coqa.py
View file @
9484eecc
...
@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
...
@@ -12,7 +12,7 @@ Homepage: https://stanfordnlp.github.io/coqa/
import
inspect
import
inspect
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
transformers.data.metrics.squad_metrics
as
squad_metrics
import
lm_eval.datasets.coqa.coqa
import
lm_eval.datasets.coqa.coqa
from
lm_eval.base
import
PromptSourceTask
,
rf
,
mean
from
lm_eval.base
import
PromptSourceTask
,
Task
,
rf
,
mean
from
itertools
import
zip_longest
from
itertools
import
zip_longest
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment