Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
bbf79d44
"vscode:/vscode.git/clone" did not exist on "1a8706c8b94918915bcaa44ddbc9e29a0cfea3b2"
Commit
bbf79d44
authored
Jun 30, 2025
by
Baber
Browse files
update type hints
parent
7f7872c1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
47 additions
and
32 deletions
+47
-32
lm_eval/api/metrics.py
lm_eval/api/metrics.py
+7
-7
lm_eval/api/task.py
lm_eval/api/task.py
+40
-25
No files found.
lm_eval/api/metrics.py
View file @
bbf79d44
...
@@ -24,36 +24,36 @@ def bypass_agg(arr):
...
@@ -24,36 +24,36 @@ def bypass_agg(arr):
@
register_aggregation
(
"nanmean"
)
@
register_aggregation
(
"nanmean"
)
def
nanmean
(
arr
)
:
def
nanmean
(
arr
:
list
[
float
])
->
float
:
if
len
(
arr
)
==
0
or
all
(
np
.
isnan
(
arr
)):
if
len
(
arr
)
==
0
or
all
(
np
.
isnan
(
arr
)):
return
np
.
nan
return
np
.
nan
return
np
.
nanmean
(
arr
)
return
np
.
nanmean
(
arr
)
@
register_aggregation
(
"mean"
)
@
register_aggregation
(
"mean"
)
def
mean
(
arr
)
:
def
mean
(
arr
:
list
[
float
])
->
float
:
return
sum
(
arr
)
/
len
(
arr
)
return
sum
(
arr
)
/
len
(
arr
)
@
register_aggregation
(
"median"
)
@
register_aggregation
(
"median"
)
def
median
(
arr
)
:
def
median
(
arr
:
list
[
float
])
->
float
:
return
arr
[
len
(
arr
)
//
2
]
return
arr
[
len
(
arr
)
//
2
]
# Certain metrics must be calculated across all documents in a benchmark.
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@
register_aggregation
(
"perplexity"
)
@
register_aggregation
(
"perplexity"
)
def
perplexity
(
items
)
:
def
perplexity
(
items
:
list
[
float
])
->
float
:
return
math
.
exp
(
-
mean
(
items
))
return
math
.
exp
(
-
mean
(
items
))
@
register_aggregation
(
"weighted_perplexity"
)
@
register_aggregation
(
"weighted_perplexity"
)
def
weighted_perplexity
(
items
)
:
def
weighted_perplexity
(
items
:
list
[
tuple
[
float
,
float
]])
->
float
:
return
math
.
exp
(
-
weighted_mean
(
items
))
return
math
.
exp
(
-
weighted_mean
(
items
))
@
register_aggregation
(
"bits_per_byte"
)
@
register_aggregation
(
"bits_per_byte"
)
def
bits_per_byte
(
items
)
:
def
bits_per_byte
(
items
:
list
[
tuple
[
float
,
float
]])
->
float
:
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
return
-
weighted_mean
(
items
)
/
math
.
log
(
2
)
...
@@ -416,7 +416,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
...
@@ -416,7 +416,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return
max
(
scores_for_ground_truths
)
return
max
(
scores_for_ground_truths
)
def
weighted_mean
(
items
)
:
def
weighted_mean
(
items
:
List
[
tuple
[
float
,
float
]])
->
float
:
a
,
b
=
zip
(
*
items
)
a
,
b
=
zip
(
*
items
)
return
sum
(
a
)
/
sum
(
b
)
return
sum
(
a
)
/
sum
(
b
)
...
...
lm_eval/api/task.py
View file @
bbf79d44
...
@@ -63,11 +63,10 @@ class MetricConfig:
...
@@ -63,11 +63,10 @@ class MetricConfig:
aggregation_fn
:
Optional
[
Callable
]
=
None
aggregation_fn
:
Optional
[
Callable
]
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
sample_metric
:
bool
=
True
is_elementwise
:
bool
=
True
is_elementwise
:
bool
=
True
@
cached_property
@
cached_property
def
metric_name
s
(
self
)
->
str
:
def
metric_name
(
self
)
->
str
:
return
self
.
name
return
self
.
name
@
cached_property
@
cached_property
...
@@ -82,6 +81,12 @@ class MetricConfig:
...
@@ -82,6 +81,12 @@ class MetricConfig:
return
is_higher_better
(
self
.
name
)
return
is_higher_better
(
self
.
name
)
return
self
.
higher_is_better
return
self
.
higher_is_better
def
calculate_metric
(
self
,
*
args
,
**
kwargs
)
->
Any
:
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
return
self
.
fn
(
*
args
,
**
{
**
self
.
kwargs
,
**
kwargs
})
@
dataclass
@
dataclass
class
RepeatConfig
:
class
RepeatConfig
:
...
@@ -108,6 +113,16 @@ class FewshotConfig:
...
@@ -108,6 +113,16 @@ class FewshotConfig:
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Optional
[
Callable
]
=
None
@
dataclass
class
DatasetConfig
:
"""Encapsulates information about a dataset."""
dataset_path
:
Optional
[
str
]
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
None
custom_dataset
:
Optional
[
Callable
]
=
None
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
...
@@ -132,8 +147,8 @@ class TaskConfig(dict):
...
@@ -132,8 +147,8 @@ class TaskConfig(dict):
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Optional
[
Callable
]
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_image
:
Union
[
Callable
,
str
]
=
None
doc_to_image
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_audio
:
Union
[
Callable
,
str
]
=
None
doc_to_audio
:
Union
[
Callable
,
str
,
None
]
=
None
unsafe_code
:
bool
=
False
unsafe_code
:
bool
=
False
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
...
@@ -466,17 +481,17 @@ class Task(abc.ABC):
...
@@ -466,17 +481,17 @@ class Task(abc.ABC):
return
self
.
_config
return
self
.
_config
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_training_docs
(
self
):
def
has_training_docs
(
self
)
->
bool
:
"""Whether the task has a training set"""
"""Whether the task has a training set"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_validation_docs
(
self
):
def
has_validation_docs
(
self
)
->
bool
:
"""Whether the task has a validation set"""
"""Whether the task has a validation set"""
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
has_test_docs
(
self
):
def
has_test_docs
(
self
)
->
bool
:
"""Whether the task has a test set"""
"""Whether the task has a test set"""
pass
pass
...
@@ -536,7 +551,7 @@ class Task(abc.ABC):
...
@@ -536,7 +551,7 @@ class Task(abc.ABC):
"""
"""
return
self
.
_instances
return
self
.
_instances
def
fewshot_examples
(
self
,
k
,
rnd
):
def
fewshot_examples
(
self
,
k
,
rnd
)
->
Iterable
[
dict
]
:
if
self
.
_training_docs
is
None
:
if
self
.
_training_docs
is
None
:
self
.
_training_docs
=
list
(
self
.
training_docs
())
self
.
_training_docs
=
list
(
self
.
training_docs
())
...
@@ -548,11 +563,11 @@ class Task(abc.ABC):
...
@@ -548,11 +563,11 @@ class Task(abc.ABC):
)
)
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_text
(
self
,
doc
):
def
doc_to_text
(
self
,
doc
)
->
str
:
pass
pass
@
abc
.
abstractmethod
@
abc
.
abstractmethod
def
doc_to_target
(
self
,
doc
):
def
doc_to_target
(
self
,
doc
)
->
Union
[
str
,
int
]
:
pass
pass
# not an abstractmethod because not every language-only task has to implement this
# not an abstractmethod because not every language-only task has to implement this
...
@@ -562,7 +577,7 @@ class Task(abc.ABC):
...
@@ -562,7 +577,7 @@ class Task(abc.ABC):
def
doc_to_audio
(
self
,
doc
):
def
doc_to_audio
(
self
,
doc
):
raise
NotImplementedError
raise
NotImplementedError
def
doc_to_prefix
(
self
,
doc
):
def
doc_to_prefix
(
self
,
doc
)
->
str
:
return
""
return
""
def
build_all_requests
(
def
build_all_requests
(
...
@@ -734,12 +749,12 @@ class Task(abc.ABC):
...
@@ -734,12 +749,12 @@ class Task(abc.ABC):
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
@
classmethod
@
classmethod
def
count_bytes
(
cls
,
doc
):
def
count_bytes
(
cls
,
doc
)
->
int
:
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return
len
(
doc
.
encode
(
"utf-8"
))
return
len
(
doc
.
encode
(
"utf-8"
))
@
classmethod
@
classmethod
def
count_words
(
cls
,
doc
):
def
count_words
(
cls
,
doc
)
->
int
:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
return
len
(
re
.
split
(
r
"\s+"
,
doc
))
...
@@ -853,7 +868,7 @@ class Task(abc.ABC):
...
@@ -853,7 +868,7 @@ class Task(abc.ABC):
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
self
.
sampler
.
rnd
=
self
.
fewshot_rnd
@
property
@
property
def
eval_docs
(
self
)
->
Union
[
datasets
.
Dataset
,
List
[
dict
]]:
def
eval_docs
(
self
)
->
Union
[
datasets
.
Dataset
,
Iterable
[
dict
]]:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
return
self
.
test_docs
()
return
self
.
test_docs
()
elif
self
.
has_validation_docs
():
elif
self
.
has_validation_docs
():
...
@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
...
@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
dataset_name
is
not
None
:
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
_
config
.
get_metrics
()
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
config
.
get_metrics
()
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
self
.
_training_docs
=
None
...
@@ -1092,7 +1107,7 @@ class ConfigurableTask(Task):
...
@@ -1092,7 +1107,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
False
return
False
def
training_docs
(
self
)
->
datasets
.
Dataset
:
def
training_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
if
self
.
has_training_docs
():
if
self
.
has_training_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -1100,7 +1115,7 @@ class ConfigurableTask(Task):
...
@@ -1100,7 +1115,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
training_split
]
return
self
.
dataset
[
self
.
config
.
training_split
]
def
validation_docs
(
self
)
->
datasets
.
Dataset
:
def
validation_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
if
self
.
has_validation_docs
():
if
self
.
has_validation_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
return
self
.
config
.
process_docs
(
...
@@ -1108,7 +1123,7 @@ class ConfigurableTask(Task):
...
@@ -1108,7 +1123,7 @@ class ConfigurableTask(Task):
)
)
return
self
.
dataset
[
self
.
config
.
validation_split
]
return
self
.
dataset
[
self
.
config
.
validation_split
]
def
test_docs
(
self
)
->
datasets
.
Dataset
:
def
test_docs
(
self
)
->
Optional
[
datasets
.
Dataset
]
:
if
self
.
has_test_docs
():
if
self
.
has_test_docs
():
if
self
.
config
.
process_docs
is
not
None
:
if
self
.
config
.
process_docs
is
not
None
:
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
return
self
.
config
.
process_docs
(
self
.
dataset
[
self
.
config
.
test_split
])
...
@@ -1174,7 +1189,7 @@ class ConfigurableTask(Task):
...
@@ -1174,7 +1189,7 @@ class ConfigurableTask(Task):
fewshot_as_multiturn
:
bool
=
False
,
fewshot_as_multiturn
:
bool
=
False
,
chat_template
:
Optional
[
Callable
]
=
None
,
chat_template
:
Optional
[
Callable
]
=
None
,
gen_prefix
:
Optional
[
str
]
=
None
,
gen_prefix
:
Optional
[
str
]
=
None
,
)
->
Union
[
str
,
List
[
str
]]:
)
->
Union
[
str
,
List
[
str
]
,
None
]:
"""Returns a fewshot context string that is made up of a prepended description
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
...
@@ -1461,7 +1476,7 @@ class ConfigurableTask(Task):
...
@@ -1461,7 +1476,7 @@ class ConfigurableTask(Task):
else
:
else
:
raise
TypeError
raise
TypeError
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_image
(
self
,
doc
:
Any
,
doc_to_image
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_image
is
not
None
:
if
doc_to_image
is
not
None
:
doc_to_image
=
doc_to_image
doc_to_image
=
doc_to_image
elif
self
.
config
.
doc_to_image
is
not
None
:
elif
self
.
config
.
doc_to_image
is
not
None
:
...
@@ -1484,7 +1499,7 @@ class ConfigurableTask(Task):
...
@@ -1484,7 +1499,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
Union
[
int
,
str
,
list
]:
def
doc_to_audio
(
self
,
doc
:
Any
,
doc_to_audio
=
None
)
->
Union
[
int
,
str
,
list
,
None
]:
if
doc_to_audio
is
not
None
:
if
doc_to_audio
is
not
None
:
doc_to_audio
=
doc_to_audio
doc_to_audio
=
doc_to_audio
elif
self
.
config
.
doc_to_audio
is
not
None
:
elif
self
.
config
.
doc_to_audio
is
not
None
:
...
@@ -1507,7 +1522,7 @@ class ConfigurableTask(Task):
...
@@ -1507,7 +1522,7 @@ class ConfigurableTask(Task):
else
:
else
:
return
None
return
None
def
doc_to_prefix
(
self
,
doc
):
def
doc_to_prefix
(
self
,
doc
)
->
Optional
[
str
]
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
(
gen_prefix
:
=
self
.
config
.
gen_prefix
)
is
not
None
:
if
gen_prefix
in
self
.
features
:
if
gen_prefix
in
self
.
features
:
return
doc
[
gen_prefix
]
return
doc
[
gen_prefix
]
...
@@ -1554,7 +1569,7 @@ class ConfigurableTask(Task):
...
@@ -1554,7 +1569,7 @@ class ConfigurableTask(Task):
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
[
m
.
metric_name
s
for
m
in
self
.
metric_list
]:
if
"acc_mutual_info"
in
[
m
.
metric_name
for
m
in
self
.
metric_list
]:
# if we are calculating multiple choice accuracy
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...
@@ -1621,7 +1636,7 @@ class ConfigurableTask(Task):
...
@@ -1621,7 +1636,7 @@ class ConfigurableTask(Task):
return
self
.
config
.
process_results
(
doc
,
results
)
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
result_dict
=
{}
use_metric
=
list
(
m
.
metric_name
s
for
m
in
self
.
metric_list
)
use_metric
=
list
(
m
.
metric_name
for
m
in
self
.
metric_list
)
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
results
=
results
[
0
]
ll
,
is_greedy
=
results
ll
,
is_greedy
=
results
...
@@ -1819,7 +1834,7 @@ class ConfigurableTask(Task):
...
@@ -1819,7 +1834,7 @@ class ConfigurableTask(Task):
return
getattr
(
self
.
_config
,
key
,
None
)
return
getattr
(
self
.
_config
,
key
,
None
)
@
property
@
property
def
task_name
(
self
)
->
Any
:
def
task_name
(
self
)
->
Optional
[
str
]
:
return
getattr
(
self
.
config
,
"task"
,
None
)
return
getattr
(
self
.
config
,
"task"
,
None
)
def
__repr__
(
self
):
def
__repr__
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment