Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
b3591562
Commit
b3591562
authored
Jun 06, 2023
by
lintangsutawika
Browse files
metrics are now in a special folder so that registry can work better
parent
48344fcb
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
320 additions
and
264 deletions
+320
-264
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+2
-264
lm_eval/metrics/__init__.py
lm_eval/metrics/__init__.py
+82
-0
lm_eval/metrics/aggregation.py
lm_eval/metrics/aggregation.py
+32
-0
lm_eval/metrics/metric.py
lm_eval/metrics/metric.py
+204
-0
No files found.
lm_eval/api/metrics.py
View file @
b3591562
import
math
from
collections.abc
import
Iterable
import
numpy
as
np
import
sacrebleu
import
sklearn.metrics
import
random
import
evaluate
AGGREGATION_REGISTRY
=
{}
METRIC_REGISTRY
=
{
"acc"
:
None
,
"acc_norm"
:
None
,
"acc_mutual_info"
:
None
,
"word_perplexity"
:
None
,
"byte_perplexity"
:
None
,
}
HIGHER_IS_BETTER_REGISTRY
=
{
"matthews_corrcoef"
:
True
,
"f1_score"
:
True
,
"perplexity"
:
False
,
"bleu"
:
True
,
"chrf"
:
True
,
"ter"
:
False
,
"acc"
:
True
,
"acc_norm"
:
True
,
"acc_mutual_info"
:
True
,
"word_perplexity"
:
False
,
"byte_perplexity"
:
False
,
"bits_per_byte"
:
False
,
}
def
register_metric
(
name
):
# TODO: do we want to enforce a certain interface to registered metrics?
def
decorate
(
fn
):
assert
(
name
not
in
METRIC_REGISTRY
),
f
"metric named '
{
name
}
' conflicts with existing registered metric!"
METRIC_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_metric
(
name
):
try
:
return
METRIC_REGISTRY
[
name
]
except
KeyError
:
# TODO: change this print to logging?
print
(
f
"Could not find registered metric '
{
name
}
' in lm-eval,
\
searching in HF Evaluate library..."
)
try
:
metric_object
=
evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
def
register_aggregation
(
name
):
# TODO: should we enforce a specific interface to aggregation metrics?
def
decorate
(
fn
):
assert
(
name
not
in
AGGREGATION_REGISTRY
),
f
"aggregation named '
{
name
}
' conflicts with existing registered aggregation!"
AGGREGATION_REGISTRY
[
name
]
=
fn
return
fn
return
decorate
def
get_aggregation
(
name
):
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
raise
Warning
(
"{} not a registered aggregation metric!"
.
format
(
name
),
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
def
pop_stddev
(
arr
):
mu
=
mea
n
(
arr
)
mu
=
sum
(
arr
)
/
le
n
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
sample_stddev
(
arr
):
mu
=
mea
n
(
arr
)
mu
=
sum
(
arr
)
/
le
n
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
...
...
@@ -110,48 +16,6 @@ def mean_stderr(arr):
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
@
register_aggregation
(
"median"
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
@
register_metric
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_metric
(
"f1_score"
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
fscore
=
sklearn
.
metrics
.
f1_score
(
golds
,
preds
)
return
np
.
max
(
fscore
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
question_id
=
doc
[
"idx"
][
"question"
]
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
def
acc_all_stderr
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
...
...
@@ -179,113 +43,6 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
@
register_metric
(
"perplexity"
)
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
@
register_metric
(
"weighted_perplexity"
)
@
register_aggregation
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_metric
(
"bits_per_byte"
)
@
register_aggregation
(
"bits_per_byte"
)
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
@
register_metric
(
"bleu"
)
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_metric
(
"chrf"
)
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_metric
(
"ter"
)
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
preds
=
[
pred
[
0
]
for
pred
in
preds
]
return
refs
,
preds
# stderr stuff
class
_bootstrap_internal
:
def
__init__
(
self
,
f
,
n
):
self
.
f
=
f
...
...
@@ -330,25 +87,6 @@ def bootstrap_stderr(f, xs, iters):
return
sample_stddev
(
res
)
def
stderr_for_metric
(
metric
,
bootstrap_iters
):
bootstrappable
=
[
median
,
matthews_corrcoef
,
f1_score
,
perplexity
,
bleu
,
chrf
,
ter
,
]
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stderr
(
metric
,
x
,
iters
=
bootstrap_iters
)
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
def
yesno
(
x
):
if
x
:
return
"yes"
...
...
lm_eval/metrics/__init__.py
0 → 100644
View file @
b3591562
from
.aggregation
import
*
from
.metric
import
*
from
lm_eval.api.metrics
import
bootstrap_stderr
,
mean_stderr
,
acc_all_stderr
from
lm_eval.api.register
import
(
metric_registry
,
aggregation_registry
,
higher_is_better_registry
,
output_type_registry
,
default_aggregation_registry
,
)
METRIC_REGISTRY
=
metric_registry
OUTPUT_TYPE_REGISTRY
=
output_type_registry
AGGREGATION_REGISTRY
=
aggregation_registry
DEFAULT_AGGREGATION_REGISTRY
=
default_aggregation_registry
HIGHER_IS_BETTER_REGISTRY
=
higher_is_better_registry
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"perplexity"
,
"acc"
,
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
],
"greedy_until"
:
[
"exact_match"
],
}
def
get_metric
(
name
):
try
:
return
METRIC_REGISTRY
[
name
]
except
KeyError
:
# TODO: change this print to logging?
print
(
f
"Could not find registered metric '
{
name
}
' in lm-eval,
\
searching in HF Evaluate library..."
)
try
:
import
evaluate
metric_object
=
evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
def
get_aggregation
(
name
):
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
raise
Warning
(
"{} not a registered aggregation metric!"
.
format
(
name
),
)
def
stderr_for_metric
(
metric
,
bootstrap_iters
):
bootstrappable
=
[
"median"
,
"matthews_corrcoef"
,
"f1_score"
,
"perplexity"
,
"bleu"
,
"chrf"
,
"ter"
,
]
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stderr
(
METRIC_REGISTRY
[
metric
],
x
,
iters
=
bootstrap_iters
)
stderr
=
{
"mean"
:
mean_stderr
,
"acc_all"
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
lm_eval/metrics/aggregation.py
0 → 100644
View file @
b3591562
import
math
from
lm_eval.api.register
import
register_aggregation
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_aggregation
(
"bits_per_byte"
)
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
lm_eval/metrics/metric.py
0 → 100644
View file @
b3591562
import
math
from
collections.abc
import
Iterable
import
numpy
as
np
import
sacrebleu
import
sklearn.metrics
import
random
from
lm_eval.api.register
import
(
register_metric
,
register_higher_is_better
,
register_output_type
,
register_default_aggregation
,
)
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"loglikelihood"
)
@
register_output_type
(
"multiple_choice"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc"
)
def
acc_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"multiple_choice"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc_norm"
)
def
acc_norm_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"multiple_choice"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc_mutual_info"
)
def
acc_mutual_info_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"perplexity"
)
@
register_output_type
(
"loglikelihood"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"perplexity"
)
def
perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"weighted_perplexity"
)
@
register_output_type
(
"loglikelihood_rolling"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"word_perplexity"
)
def
word_perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"weighted_perplexity"
)
@
register_output_type
(
"loglikelihood_rolling"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"byte_perplexity"
)
def
byte_perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"bits_per_byte"
)
@
register_output_type
(
"loglikelihood_rolling"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"bits_per_byte"
)
def
bits_per_byte_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"loglikelihood"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc_all"
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
question_id
=
doc
[
"idx"
][
"question"
]
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"f1"
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
fscore
=
sklearn
.
metrics
.
f1_score
(
golds
,
preds
)
return
np
.
max
(
fscore
)
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
preds
=
[
pred
[
0
]
for
pred
in
preds
]
return
refs
,
preds
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"bleu"
)
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"chrf"
)
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"ter"
)
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment