Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
9e75a84d
Commit
9e75a84d
authored
Jun 07, 2023
by
lintangsutawika
Browse files
moved METRIC_REGISTRY to registry.py
parent
9b39df9b
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
251 additions
and
2 deletions
+251
-2
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+251
-2
No files found.
lm_eval/api/metrics.py
View file @
9e75a84d
import
math
from
collections.abc
import
Iterable
import
numpy
as
np
import
sacrebleu
import
sklearn.metrics
import
random
from
lm_eval.api.registry
import
register_metric
,
register_aggregation
# Register Aggregations First
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_aggregation
(
"bits_per_byte"
)
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
@
register_metric
(
metric
=
"acc"
,
higher_is_better
=
True
,
output_type
=
[
"loglikelihood"
,
"multiple_choice"
],
aggregation
=
"mean"
,
)
def
acc_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"acc_mutual_info"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
aggregation
=
"mean"
,
)
def
acc_mutual_info_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood"
,
aggregation
=
"perplexity"
,
)
def
perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"word_perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
word_perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"byte_perplexity"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
byte_perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_metric
(
metric
=
"bits_per_byte"
,
higher_is_better
=
False
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"bits_per_byte"
,
)
def
bits_per_byte_fn
(
items
):
# This is a passthrough function
return
items
def
pop_stddev
(
arr
):
mu
=
sum
(
arr
)
/
le
n
(
arr
)
mu
=
mea
n
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
sample_stddev
(
arr
):
mu
=
sum
(
arr
)
/
le
n
(
arr
)
mu
=
mea
n
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
...
...
@@ -16,6 +109,54 @@ def mean_stderr(arr):
return
sample_stddev
(
arr
)
/
math
.
sqrt
(
len
(
arr
))
@
register_metric
(
metric
=
"matthews_corrcoef"
,
higher_is_better
=
True
,
aggregation
=
"mean"
)
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_metric
(
metric
=
"f1_score"
,
higher_is_better
=
True
,
output_type
=
"multiple_choice"
,
aggregation
=
"mean"
,
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
fscore
=
sklearn
.
metrics
.
f1_score
(
golds
,
preds
)
return
np
.
max
(
fscore
)
@
register_metric
(
metric
=
"acc_all"
,
higher_is_better
=
True
,
output_type
=
"loglikelihood"
,
aggregation
=
"mean"
,
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
question_id
=
doc
[
"idx"
][
"question"
]
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
def
acc_all_stderr
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
...
...
@@ -43,6 +184,95 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
@
register_metric
(
metric
=
"bleu"
,
higher_is_better
=
True
,
aggregation
=
"mean"
)
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_metric
(
metric
=
"chrf"
,
higher_is_better
=
True
,
aggregation
=
"mean"
)
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_metric
(
metric
=
"ter"
,
higher_is_better
=
True
,
aggregation
=
"mean"
)
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
preds
=
[
pred
[
0
]
for
pred
in
preds
]
return
refs
,
preds
# stderr stuff
class
_bootstrap_internal
:
def
__init__
(
self
,
f
,
n
):
self
.
f
=
f
...
...
@@ -87,6 +317,25 @@ def bootstrap_stderr(f, xs, iters):
return
sample_stddev
(
res
)
def
stderr_for_metric
(
metric
,
bootstrap_iters
):
bootstrappable
=
[
median
,
matthews_corrcoef
,
f1_score
,
perplexity
,
bleu
,
chrf
,
ter
,
]
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stderr
(
metric
,
x
,
iters
=
bootstrap_iters
)
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
def
yesno
(
x
):
if
x
:
return
"yes"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment