Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0b8cb8b9
Commit
0b8cb8b9
authored
Apr 26, 2022
by
Tian Yun
Browse files
Merge with master branch
parents
96ea7ddc
b2838b8d
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
88 additions
and
22 deletions
+88
-22
lm_eval/base.py
lm_eval/base.py
+79
-20
lm_eval/evaluator.py
lm_eval/evaluator.py
+4
-0
lm_eval/tasks/coqa.py
lm_eval/tasks/coqa.py
+1
-1
lm_eval/tasks/drop.py
lm_eval/tasks/drop.py
+1
-1
lm_eval/tasks/glue.py
lm_eval/tasks/glue.py
+3
-0
No files found.
lm_eval/base.py
View file @
0b8cb8b9
...
...
@@ -14,6 +14,7 @@ from tqdm import tqdm
import
torch
import
torch.nn.functional
as
F
from
lm_eval
import
metrics
from
lm_eval.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
from
lm_eval
import
utils
from
abc
import
abstractmethod
...
...
@@ -637,12 +638,28 @@ class Task(abc.ABC):
class
PromptSourceTask
(
Task
):
"""These are the metrics from promptsource that we have
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
"""
CONFIGURED_PS_METRICS
=
set
([
"Accuracy"
,
"BLEU"
,
"ROUGE"
])
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
def
eos_token
(
self
):
raise
NotImplementedError
()
def
stopping_criteria
(
self
):
"""Denote where the generation should end.
For example, for coqa, this is '
\n
Q:' and for drop '.'.
By default, its None, meaning to generate up to max or EOT, whichever comes first.
"""
return
None
def
is_generation_task
(
self
):
return
(
...
...
@@ -650,11 +667,28 @@ class PromptSourceTask(Task):
or
"ROUGE"
in
self
.
prompt
.
metadata
.
metrics
)
def
doc_to_target
(
self
,
doc
):
def
invalid_doc_for_prompt
(
self
,
doc
)
->
bool
:
"""Some prompts may not work for some documents."""
if
(
# generate_paraphrase for mrpc
# This generation prompt assumes a positive example. We filter out the negative examples.
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L7
# https://github.com/bigscience-workshop/promptsource/blob/ba8c9eccbe82f2409208c655896f1dd131171ece/promptsource/templates/glue/mrpc/templates.yaml#L88
(
self
.
prompt
.
id
==
"3b88d2c4-0aeb-4c6d-9ccc-653a388250a5"
or
self
.
prompt
.
id
==
"d830d7a5-abc0-4275-ac62-974e0088876f"
)
and
doc
[
"label"
]
==
0
):
return
True
return
False
def
doc_to_target
(
self
,
doc
)
->
str
:
"""NOTE: In the future, this may return Union[str, List[str]]."""
_
,
target
=
self
.
prompt
.
apply
(
doc
)
return
f
"
{
target
}
"
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
)
->
str
:
text
,
_
=
self
.
prompt
.
apply
(
doc
)
return
text
...
...
@@ -684,7 +718,7 @@ class PromptSourceTask(Task):
_requests
.
append
(
ll_answer_choice
)
else
:
# TODO(Albert): What is the stop symbol? Is it model specific?
cont_request
=
rf
.
greedy_until
(
ctx
,
[
self
.
eos_token
()])
cont_request
=
rf
.
greedy_until
(
ctx
,
[
self
.
stopping_criteria
()])
_requests
.
append
(
cont_request
)
return
_requests
...
...
@@ -699,9 +733,6 @@ class PromptSourceTask(Task):
:param results:
The results of the requests created in construct_requests.
"""
# raise NotImplementedError(
# "Implement process results using the `prompt.metadata.metrics`. See below."
# )
target
=
self
.
doc_to_target
(
doc
).
strip
()
answer_choices_list
=
self
.
prompt
.
get_answer_choices_list
(
doc
)
if
answer_choices_list
:
...
...
@@ -710,29 +741,57 @@ class PromptSourceTask(Task):
),
f
"We expect this to be a ranked choice task; double check please."
pred
=
answer_choices_list
[
np
.
argmax
(
results
)]
out
=
{}
if
"Accuracy"
in
self
.
prompt
.
metadata
.
metrics
:
out
[
"acc"
]
=
pred
==
target
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
pred
==
target
# TODO: Add metrics here.
return
out
else
:
raise
NotImplementedError
(
"Generation is not implemented yet."
)
# NOTE: In the future, target may be a list, not a string.
pred
=
results
[
0
].
strip
()
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
(
target
,
pred
)
if
metric
==
"ROUGE"
:
print
(
"WARNING: Skipping Rouge."
)
# Map metric name to HF metric.
# TODO(Albert): What is Other?
# metric_names = prompt.metadata.metrics
return
out
def
higher_is_better
(
self
):
out
=
{}
if
"Accuracy"
in
self
.
prompt
.
metadata
.
metrics
:
out
[
"acc"
]
=
True
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
True
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
True
if
metric
==
"ROUGE"
:
print
(
"WARNING: Skipping Rouge."
)
return
out
def
aggregation
(
self
):
out
=
{}
if
"Accuracy"
in
self
.
prompt
.
metadata
.
metrics
:
out
[
"acc"
]
=
mean
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
mean
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
metrics
.
bleu
if
metric
==
"ROUGE"
:
print
(
"WARNING: Skipping Rouge."
)
return
out
...
...
lm_eval/evaluator.py
View file @
0b8cb8b9
...
...
@@ -2,6 +2,7 @@ import collections
import
itertools
import
pathlib
import
random
import
lm_eval.metrics
import
lm_eval.models
import
lm_eval.tasks
...
...
@@ -199,6 +200,9 @@ def evaluate(
)
for
doc_id
,
doc
in
enumerate
(
itertools
.
islice
(
task_docs
,
0
,
limit
)):
if
task
.
invalid_doc_for_prompt
(
doc
):
continue
docs
[(
task_prompt_name
,
doc_id
)]
=
doc
ctx
=
task
.
fewshot_context
(
doc
=
doc
,
num_fewshot
=
num_fewshot
,
rnd
=
rnd
,
description
=
description
...
...
lm_eval/tasks/coqa.py
View file @
0b8cb8b9
...
...
@@ -90,7 +90,7 @@ class CoQA(PromptSourceTask):
"f1"
:
f1_sum
/
max
(
1
,
len
(
gold_list
)),
}
def
eos_token
(
self
):
def
stopping_criteria
(
self
):
return
"
\n
Q:"
# def construct_requests(self, doc, ctx):
...
...
lm_eval/tasks/drop.py
View file @
0b8cb8b9
...
...
@@ -92,7 +92,7 @@ class DROP(PromptSourceTask):
# """
# conts = [rf.greedy_until(ctx, ["."])]
# return conts
def
eos_token
(
self
):
def
stopping_criteria
(
self
):
return
"."
def
process_results
(
self
,
doc
,
results
):
...
...
lm_eval/tasks/glue.py
View file @
0b8cb8b9
...
...
@@ -236,6 +236,9 @@ class MRPC(PromptSourceTask):
def
has_test_docs
(
self
):
return
False
def
stopping_criteria
(
self
):
return
"
\n
"
def
training_docs
(
self
):
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
dataset
[
"train"
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment