Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
d83f7eb0
Commit
d83f7eb0
authored
Jul 21, 2025
by
Baber
Browse files
add type hints
parent
a617e184
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
29 additions
and
23 deletions
+29
-23
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+29
-23
No files found.
lm_eval/api/metrics.py
View file @
d83f7eb0
from
__future__
import
annotations
import
logging
import
math
import
os
import
random
import
re
import
string
from
collections.abc
import
Iterable
from
typing
import
Callable
,
List
,
Optional
,
Sequence
,
TypeVar
from
collections.abc
import
Iterable
,
Sequence
from
typing
import
Callable
,
Generic
,
TypeVar
import
numpy
as
np
...
...
@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@
register_aggregation
(
"mean"
)
def
mean
(
arr
:
list
[
float
])
->
float
:
def
mean
(
arr
:
Sequence
[
float
])
->
float
:
return
sum
(
arr
)
/
len
(
arr
)
...
...
@@ -70,7 +72,7 @@ def f1_score(items):
@
register_aggregation
(
"matthews_corrcoef"
)
def
matthews_corrcoef
(
items
)
:
def
matthews_corrcoef
(
items
:
Iterable
[
tuple
[
int
,
int
]
|
tuple
[
str
,
str
]])
->
float
:
from
sklearn.metrics
import
matthews_corrcoef
unzipped_list
=
list
(
zip
(
*
items
))
...
...
@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@
register_aggregation
(
"bleu"
)
def
bleu
(
items
):
def
bleu
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where
...
...
@@ -117,7 +119,7 @@ def chrf(items):
@
register_aggregation
(
"ter"
)
def
ter
(
items
):
def
ter
(
items
:
Iterable
[
tuple
[
str
,
str
]]
):
"""Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one
of the references
...
...
@@ -135,7 +137,9 @@ def ter(items):
@
register_aggregation
(
"brier_score"
)
def
brier_score
(
items
):
# This is a passthrough function
def
brier_score
(
items
:
Iterable
[
tuple
[
str
,
float
]],
):
# This is a passthrough function
gold
,
predictions
=
list
(
zip
(
*
items
))
bs
,
num_class
=
np
.
array
(
predictions
).
shape
...
...
@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
word_perplexity_fn
(
items
)
:
# This is a passthrough function
def
word_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
...
...
@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"weighted_perplexity"
,
)
def
byte_perplexity_fn
(
items
)
:
# This is a passthrough function
def
byte_perplexity_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
...
...
@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type
=
"loglikelihood_rolling"
,
aggregation
=
"bits_per_byte"
,
)
def
bits_per_byte_fn
(
items
)
:
# This is a passthrough function
def
bits_per_byte_fn
(
items
:
T
)
->
T
:
# This is a passthrough function
return
items
...
...
@@ -295,7 +299,7 @@ def pop_stddev(arr):
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
len
(
arr
))
def
sample_stddev
(
arr
:
Sequence
[
T
])
->
float
:
def
sample_stddev
(
arr
:
Sequence
[
float
])
->
float
:
mu
=
mean
(
arr
)
return
math
.
sqrt
(
sum
([(
x
-
mu
)
**
2
for
x
in
arr
])
/
(
len
(
arr
)
-
1
))
...
...
@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
:
L
ist
[
tuple
[
float
,
float
]])
->
float
:
def
weighted_mean
(
items
:
l
ist
[
tuple
[
float
,
float
]])
->
float
:
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
...
...
@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def
_sacreformat
(
refs
,
preds
):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (
L
ist[str],
L
ist[
L
ist[str])
# Sacrebleu expects (
l
ist[str],
l
ist[
l
ist[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect
# We expect refs to be
L
ist[str] or
L
ist[
L
ist[str]], the outer list corresponding to preds
# Must become
L
ist[
L
ist[str]] with the inner list corresponding to preds
# We expect refs to be
l
ist[str] or
l
ist[
l
ist[str]], the outer list corresponding to preds
# Must become
l
ist[
l
ist[str]] with the inner list corresponding to preds
if
not
is_non_str_iterable
(
refs
):
refs
=
list
(
refs
)
if
not
is_non_str_iterable
(
refs
[
0
]):
...
...
@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs
=
list
(
zip
(
*
refs
))
# Note the number of refs in each ref list much match the number of preds
# We expect preds to be
L
ist[str] or
L
ist[
L
ist[str]]. Must become
L
ist[str]
# We expect preds to be
l
ist[str] or
l
ist[
l
ist[str]]. Must become
l
ist[str]
if
not
is_non_str_iterable
(
preds
):
preds
=
list
(
preds
)
if
is_non_str_iterable
(
preds
[
0
]):
...
...
@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff
class
_bootstrap_internal
:
class
_bootstrap_internal
(
Generic
[
T
])
:
"""
Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`.
...
...
@@ -539,7 +543,7 @@ def bootstrap_stderr(
def
stderr_for_metric
(
metric
:
Callable
[[
Sequence
[
T
]],
float
],
bootstrap_iters
:
int
)
->
Optional
[
Callable
[[
Sequence
[
T
]],
float
]
]
:
)
->
Callable
[[
Sequence
[
T
]],
float
]
|
None
:
"""
Return a function that estimates the standard error of `metric(xs)`.
...
...
@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr
=
{
mean
:
mean_stderr
,
acc_all
:
acc_all_stderr
}
return
stderr
.
get
(
metric
,
None
)
return
stderr
.
get
(
metric
)
def
pooled_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
]):
def
pooled_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
]):
# Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask.
#
...
...
@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return
np
.
sqrt
(
pooled_sample_var
/
sum
(
sizes
))
def
combined_sample_stderr
(
stderrs
:
L
ist
[
float
],
sizes
:
L
ist
[
int
],
metrics
=
None
):
def
combined_sample_stderr
(
stderrs
:
l
ist
[
float
],
sizes
:
l
ist
[
int
],
metrics
=
None
):
assert
metrics
is
not
None
,
(
"Need to pass a list of each subtask's metric for this stderr aggregation"
)
...
...
@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return
np
.
sqrt
(
variance
)
def
aggregate_subtask_metrics
(
metrics
,
sizes
,
weight_by_size
=
True
):
def
aggregate_subtask_metrics
(
metrics
:
list
[
float
],
sizes
:
list
[
float
],
weight_by_size
:
bool
=
True
):
# A helper function that is used to aggregate
# subtask scores cross-task.
# TODO: does not hold for non-mean aggregations
...
...
@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert
len
(
metrics
)
==
len
(
sizes
)
return
sum
(
[
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
)
]
)
/
sum
(
sizes
)
return
sum
(
metric
*
size
for
metric
,
size
in
zip
(
metrics
,
sizes
))
/
sum
(
sizes
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment