Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d4c00093
Commit
d4c00093
authored
Apr 26, 2022
by
cjlovering
Browse files
Added default behavior for bleu to the promtsourcetask class
parent
f39c27c2
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
51 additions
and
9 deletions
+51
-9
lm_eval/base.py
lm_eval/base.py
+51
-9
No files found.
lm_eval/base.py
View file @
d4c00093
...
@@ -14,6 +14,7 @@ from tqdm import tqdm
...
@@ -14,6 +14,7 @@ from tqdm import tqdm
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
lm_eval
import
metrics
from
lm_eval.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
from
lm_eval.metrics
import
mean
,
weighted_perplexity
,
weighted_mean
,
bits_per_byte
from
lm_eval
import
utils
from
lm_eval
import
utils
from
abc
import
abstractmethod
from
abc
import
abstractmethod
...
@@ -637,6 +638,16 @@ class Task(abc.ABC):
...
@@ -637,6 +638,16 @@ class Task(abc.ABC):
class
PromptSourceTask
(
Task
):
class
PromptSourceTask
(
Task
):
"""These are the metrics from promptsource that we have
added default behavior for. If you want to add default behavior for a new metric,
update the functions below. If you want to use one of the following metrics,
*and* add additional custom processing, override `process_results`, `higher_is_better`, and `aggregation`.
WARNING: ROUGE is WIP.
"""
CONFIGURED_PS_METRICS
=
set
([
"Accuracy"
,
"BLEU"
,
"ROUGE"
])
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
def
__init__
(
self
,
data_dir
=
None
,
cache_dir
=
None
,
download_mode
=
None
,
prompt
=
None
):
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
super
().
__init__
(
data_dir
,
cache_dir
,
download_mode
)
self
.
prompt
=
prompt
self
.
prompt
=
prompt
...
@@ -737,29 +748,60 @@ class PromptSourceTask(Task):
...
@@ -737,29 +748,60 @@ class PromptSourceTask(Task):
),
f
"We expect this to be a ranked choice task; double check please."
),
f
"We expect this to be a ranked choice task; double check please."
pred
=
answer_choices_list
[
np
.
argmax
(
results
)]
pred
=
answer_choices_list
[
np
.
argmax
(
results
)]
out
=
{}
out
=
{}
if
"Accuracy"
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
pred
==
target
out
[
"acc"
]
=
pred
==
target
# TODO: Add metrics here.
# TODO: Add metrics here.
return
out
return
out
else
:
else
:
raise
NotImplementedError
(
"Generation is not implemented yet."
)
# NOTE: In the future, target may be a list, not a string.
pred
=
results
[
0
].
strip
()
out
=
{}
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
(
target
,
pred
)
if
metric
==
"ROUGE"
:
print
(
"WARNING: Skipping Rouge."
)
return
out
# Map metric name to HF metric.
# Map metric name to HF metric.
# TODO(Albert): What is Other?
# TODO(Albert): What is Other?
# metric_names = prompt.metadata.metrics
# metric_names = prompt.metadata.metrics
def
higher_is_better
(
self
):
def
higher_is_better
(
self
):
out
=
{}
out
=
{}
if
"Accuracy"
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
True
out
[
"acc"
]
=
True
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
True
if
metric
==
"ROUGE"
:
print
(
"WARNING: Skipping Rouge."
)
return
out
return
out
def
aggregation
(
self
):
def
aggregation
(
self
):
out
=
{}
out
=
{}
if
"Accuracy"
in
self
.
prompt
.
metadata
.
metrics
:
for
metric
in
self
.
prompt
.
metadata
.
metrics
:
assert
(
metric
in
self
.
CONFIGURED_PS_METRICS
),
"Unexpected metric. Add it, or use a task-specific solution."
if
metric
==
"Accuracy"
:
out
[
"acc"
]
=
mean
out
[
"acc"
]
=
mean
if
metric
==
"BLEU"
:
out
[
"bleu"
]
=
metrics
.
bleu
if
metric
==
"ROUGE"
:
print
(
"WARNING: Skipping Rouge."
)
return
out
return
out
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment