Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d83f7eb0
Commit
d83f7eb0
authored
Jul 21, 2025
by
Baber
Browse files
add type hints
parent
a617e184
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
23 deletions
+29
-23
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+29
-23
No files found.
lm_eval/api/metrics.py
View file @
d83f7eb0
from
__future__
import
annotations
import
logging
import
logging
import
math
import
math
import
os
import
os
import
random
import
random
import
re
import
re
import
string
import
string
from
collections.abc
import
Iterable
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
typing
import
Callable
,
Generic
,
TypeVar
import
numpy
as
np
import
numpy
as
np
...
@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
...
@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@
register_aggregation
(
"mean"
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
:
list
[
float
])
->
float
:
def
mean
(
arr
:
Sequence
[
float
])
->
float
:
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
...
@@ -70,7 +72,7 @@ def f1_score(items):
...
@@ -70,7 +72,7 @@ def f1_score(items):
@
register_aggregation
(
"matthews_corrcoef"
)
@
register_aggregation
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
)
:
def
matthews_corrcoef
(
items
:
Iterable
[
tuple
[
int
,
int
]
|
tuple
[
str
,
str
]])
->
float
:
from
sklearn.metrics
import
matthews_corrcoef
from
sklearn.metrics
import
matthews_corrcoef
unzipped_list
=
list
(
zip
(
*
items
))
unzipped_list
=
list
(
zip
(
*
items
))
...
@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
...
@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@
register_aggregation
(
"bleu"
)
@
register_aggregation
(
"bleu"
)
def
bleu
(
items
):
def
bleu
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
n-grams in the candidate translation to n-grams in the reference text, where
...
@@ -117,7 +119,7 @@ def chrf(items):
...
@@ -117,7 +119,7 @@ def chrf(items):
@
register_aggregation
(
"ter"
)
@
register_aggregation
(
"ter"
)
def
ter
(
items
):
def
ter
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""Translation Error Rate is an error metric for machine translation that
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
measures the number of edits required to change a system output into one
of the references
of the references
...
@@ -135,7 +137,9 @@ def ter(items):
...
@@ -135,7 +137,9 @@ def ter(items):
@
register_aggregation
(
"brier_score"
)
@
register_aggregation
(
"brier_score"
)
def
brier_score
(
items
):
# This is a passthrough function
def
brier_score
(
items
:
Iterable
[
tuple
[
str
,
float
]],
):
# This is a passthrough function
gold
,
predictions
=
list
(
zip
(
*
items
))
gold
,
predictions
=
list
(
zip
(
*
items
))
bs
,
num_class
=
np
.
array
(
predictions
).
shape
bs
,
num_class
=
np
.
array
(
predictions
).
shape
...
@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
...
@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
aggregation
=
"weighted_perplexity"
,
)
)
def
word_perplexity_fn
(
items
)
:
# This is a passthrough function
def
word_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
return
items
...
@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
...
@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
aggregation
=
"weighted_perplexity"
,
)
)
def
byte_perplexity_fn
(
items
)
:
# This is a passthrough function
def
byte_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
return
items
...
@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
...
@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"bits_per_byte"
,
aggregation
=
"bits_per_byte"
,
)
)
def
bits_per_byte_fn
(
items
)
:
# This is a passthrough function
def
bits_per_byte_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
return
items
...
@@ -295,7 +299,7 @@ def pop_stddev(arr):
...
@@ -295,7 +299,7 @@ def pop_stddev(arr):
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
sample_stddev
(
arr
:
Sequence
[
T
])
->
float
:
def
sample_stddev
(
arr
:
Sequence
[
float
])
->
float
:
mu
=
mean
(
arr
)
mu
=
mean
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
...
@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
...
@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
:
L
ist
[
tuple
[
float
,
float
]])
->
float
:
def
weighted_mean
(
items
:
l
ist
[
tuple
[
float
,
float
]])
->
float
:
a
,
b
=
zip
(
*
items
)
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
return
sum
(
a
)
/
sum
(
b
)
...
@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
...
@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def
_sacreformat
(
refs
,
preds
):
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (
L
ist[str],
L
ist[
L
ist[str])
# Sacrebleu expects (
l
ist[str],
l
ist[
l
ist[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# This is a different order of dimensions that I would expect
# We expect refs to be
L
ist[str] or
L
ist[
L
ist[str]], the outer list corresponding to preds
# We expect refs to be
l
ist[str] or
l
ist[
l
ist[str]], the outer list corresponding to preds
# Must become
L
ist[
L
ist[str]] with the inner list corresponding to preds
# Must become
l
ist[
l
ist[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
if
not
is_non_str_iterable
(
refs
[
0
]):
...
@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
...
@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs
=
list
(
zip
(
*
refs
))
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be
L
ist[str] or
L
ist[
L
ist[str]]. Must become
L
ist[str]
# We expect preds to be
l
ist[str] or
l
ist[
l
ist[str]]. Must become
l
ist[str]
if
not
is_non_str_iterable
(
preds
):
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
if
is_non_str_iterable
(
preds
[
0
]):
...
@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
...
@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff
# stderr stuff
class
_bootstrap_internal
:
class
_bootstrap_internal
(
Generic
[
T
])
:
"""
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
of `f(xs)`using a RNG seeded with `i`.
...
@@ -539,7 +543,7 @@ def bootstrap_stderr(
...
@@ -539,7 +543,7 @@ def bootstrap_stderr(
def
stderr_for_metric
(
def
stderr_for_metric
(
metric
:
Callable
[[
Sequence
[
T
]],
float
],
bootstrap_iters
:
int
metric
:
Callable
[[
Sequence
[
T
]],
float
],
bootstrap_iters
:
int
)
->
Optional
[
Callable
[[
Sequence
[
T
]],
float
]
]
:
)
->
Callable
[[
Sequence
[
T
]],
float
]
|
None
:
"""
"""
Return a function that estimates the standard error of `metric(xs)`.
Return a function that estimates the standard error of `metric(xs)`.
...
@@ -569,10 +573,10 @@ def stderr_for_metric(
...
@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
return
stderr
.
get
(
metric
)
def
pooled_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
]):
def
pooled_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
# when we are weighting by the size of each subtask.
#
#
...
@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
...
@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return
np
.
sqrt
(
pooled_sample_var
/
sum
(
sizes
))
return
np
.
sqrt
(
pooled_sample_var
/
sum
(
sizes
))
def
combined_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
],
metrics
=
None
):
def
combined_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
],
metrics
=
None
):
assert
metrics
is
not
None
,
(
assert
metrics
is
not
None
,
(
"Need to pass a list of each subtask's metric for this stderr aggregation"
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
)
...
@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
...
@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return
np
.
sqrt
(
variance
)
return
np
.
sqrt
(
variance
)
def
aggregate_subtask_metrics
(
metrics
,
sizes
,
weight_by_size
=
True
):
def
aggregate_subtask_metrics
(
metrics
:
list
[
float
],
sizes
:
list
[
float
],
weight_by_size
:
bool
=
True
):
# A helper function that is used to aggregate
# A helper function that is used to aggregate
# subtask scores cross-task.
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
# TODO: does not hold for non-mean aggregations
...
@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
...
@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert
len
(
metrics
)
==
len
(
sizes
)
assert
len
(
metrics
)
==
len
(
sizes
)
return
sum
(
[
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
)
]
)
/
sum
(
sizes
)
return
sum
(
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
))
/
sum
(
sizes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment