Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7b0b42c4
Commit
7b0b42c4
authored
Jun 07, 2023
by
lintangsutawika
Browse files
removed
parent
a22d8ffa
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
318 deletions
+0
-318
lm_eval/metrics/__init__.py
lm_eval/metrics/__init__.py
+0
-82
lm_eval/metrics/aggregation.py
lm_eval/metrics/aggregation.py
+0
-32
lm_eval/metrics/metric.py
lm_eval/metrics/metric.py
+0
-204
No files found.
lm_eval/metrics/__init__.py
deleted
100644 → 0
View file @
a22d8ffa
from
.aggregation
import
*
from
.metric
import
*
from
lm_eval.api.metrics
import
bootstrap_stderr
,
mean_stderr
,
acc_all_stderr
from
lm_eval.api.register
import
(
metric_registry
,
aggregation_registry
,
higher_is_better_registry
,
output_type_registry
,
default_aggregation_registry
,
)
METRIC_REGISTRY
=
metric_registry
OUTPUT_TYPE_REGISTRY
=
output_type_registry
AGGREGATION_REGISTRY
=
aggregation_registry
DEFAULT_AGGREGATION_REGISTRY
=
default_aggregation_registry
HIGHER_IS_BETTER_REGISTRY
=
higher_is_better_registry
DEFAULT_METRIC_REGISTRY
=
{
"loglikelihood"
:
[
"perplexity"
,
"acc"
,
],
"loglikelihood_rolling"
:
[
"word_perplexity"
,
"byte_perplexity"
,
"bits_per_byte"
],
"multiple_choice"
:
[
"acc"
,
],
"greedy_until"
:
[
"exact_match"
],
}
def
get_metric
(
name
):
try
:
return
METRIC_REGISTRY
[
name
]
except
KeyError
:
# TODO: change this print to logging?
print
(
f
"Could not find registered metric '
{
name
}
' in lm-eval,
\
searching in HF Evaluate library..."
)
try
:
import
evaluate
metric_object
=
evaluate
.
load
(
name
)
return
metric_object
.
compute
except
Exception
:
raise
Warning
(
"{} not found in the evaluate library!"
.
format
(
name
),
"Please check https://huggingface.co/evaluate-metric"
,
)
def
get_aggregation
(
name
):
try
:
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
raise
Warning
(
"{} not a registered aggregation metric!"
.
format
(
name
),
)
def
stderr_for_metric
(
metric
,
bootstrap_iters
):
bootstrappable
=
[
"median"
,
"matthews_corrcoef"
,
"f1_score"
,
"perplexity"
,
"bleu"
,
"chrf"
,
"ter"
,
]
if
metric
in
bootstrappable
:
return
lambda
x
:
bootstrap_stderr
(
METRIC_REGISTRY
[
metric
],
x
,
iters
=
bootstrap_iters
)
stderr
=
{
"mean"
:
mean_stderr
,
"acc_all"
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
lm_eval/metrics/aggregation.py
deleted
100644 → 0
View file @
a22d8ffa
import
math
from
lm_eval.api.register
import
register_aggregation
def
weighted_mean
(
items
):
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
):
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
def
median
(
arr
):
return
arr
[
len
(
arr
)
//
2
]
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
):
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
):
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_aggregation
(
"bits_per_byte"
)
def
bits_per_byte
(
items
):
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
lm_eval/metrics/metric.py
deleted
100644 → 0
View file @
a22d8ffa
import
math
from
collections.abc
import
Iterable
import
numpy
as
np
import
sacrebleu
import
sklearn.metrics
import
random
from
lm_eval.api.register
import
(
register_metric
,
register_higher_is_better
,
register_output_type
,
register_default_aggregation
,
)
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"loglikelihood"
)
@
register_output_type
(
"multiple_choice"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc"
)
def
acc_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"multiple_choice"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc_norm"
)
def
acc_norm_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"multiple_choice"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc_mutual_info"
)
def
acc_mutual_info_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"perplexity"
)
@
register_output_type
(
"loglikelihood"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"perplexity"
)
def
perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"weighted_perplexity"
)
@
register_output_type
(
"loglikelihood_rolling"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"word_perplexity"
)
def
word_perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"weighted_perplexity"
)
@
register_output_type
(
"loglikelihood_rolling"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"byte_perplexity"
)
def
byte_perplexity_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"bits_per_byte"
)
@
register_output_type
(
"loglikelihood_rolling"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"bits_per_byte"
)
def
bits_per_byte_fn
(
items
):
# This is a passthrough function
return
items
@
register_default_aggregation
(
"mean"
)
@
register_output_type
(
"loglikelihood"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"acc_all"
)
def
acc_all
(
items
):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict
=
{}
preds
=
list
(
zip
(
*
items
))[
0
]
docs
=
list
(
zip
(
*
items
))[
1
]
for
doc
,
pred
in
zip
(
docs
,
preds
):
paragraph_id
=
doc
[
"idx"
][
"paragraph"
]
question_id
=
doc
[
"idx"
][
"question"
]
if
(
paragraph_id
,
question_id
)
not
in
question_scoring_dict
:
question_scoring_dict
[(
paragraph_id
,
question_id
)]
=
[]
gold_label
=
doc
[
"label"
]
==
1
question_scoring_dict
[(
paragraph_id
,
question_id
)].
append
(
gold_label
==
pred
)
acc
=
np
.
mean
([
int
(
all
(
x
))
for
x
in
question_scoring_dict
.
values
()])
return
acc
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
return
sklearn
.
metrics
.
matthews_corrcoef
(
golds
,
preds
)
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"f1"
)
def
f1_score
(
items
):
unzipped_list
=
list
(
zip
(
*
items
))
golds
=
unzipped_list
[
0
]
preds
=
unzipped_list
[
1
]
fscore
=
sklearn
.
metrics
.
f1_score
(
golds
,
preds
)
return
np
.
max
(
fscore
)
def
is_non_str_iterable
(
obj
):
return
isinstance
(
obj
,
Iterable
)
and
not
isinstance
(
obj
,
str
)
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
refs
=
[[
ref
]
for
ref
in
refs
]
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str]
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
assert
len
(
preds
[
0
])
==
1
,
f
"Pred must be a str, was
{
preds
[
0
]
}
"
preds
=
[
pred
[
0
]
for
pred
in
preds
]
return
refs
,
preds
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"bleu"
)
def
bleu
(
items
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
1-gram or unigram would be each token and a bigram comparison would be each
word pair. The comparison is made regardless of word order
Source: https://machinelearningmastery.com/calculate-bleu-score-for-text-python/
Paper: https://www.aclweb.org/anthology/P02-1040/
Higher is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_bleu
(
preds
,
refs
).
score
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
True
)
@
register_metric
(
"chrf"
)
def
chrf
(
items
):
"""chrF++ is a tool for automatic evaluation of machine translation output
based on character n-gram precision and recall enhanced with word n-grams.
Source: https://github.com/m-popovic/chrF
Paper: https://www.aclweb.org/anthology/W15-3049.pdf
Higher is better # TODO I think
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_chrf
(
preds
,
refs
).
score
@
register_default_aggregation
(
"mean"
)
@
register_higher_is_better
(
False
)
@
register_metric
(
"ter"
)
def
ter
(
items
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
Source: http://www.cs.umd.edu/~snover/tercom/
Paper: http://mt-archive.info/AMTA-2006-Snover.pdf
Lower is better
"""
refs
=
list
(
zip
(
*
items
))[
0
]
preds
=
list
(
zip
(
*
items
))[
1
]
refs
,
preds
=
_sacreformat
(
refs
,
preds
)
return
sacrebleu
.
corpus_ter
(
preds
,
refs
).
score
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment