Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
17223113
Commit
17223113
authored
Jul 21, 2025
by
Baber
Browse files
type hints
parent
24b7e2d6
Changes
5
Hide whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
79 additions
and
78 deletions
+79
-78
lm_eval/api/registry.py
lm_eval/api/registry.py
+1
-1
lm_eval/config/metric.py
lm_eval/config/metric.py
+5
-5
lm_eval/config/task.py
lm_eval/config/task.py
+46
-47
lm_eval/config/template.py
lm_eval/config/template.py
+21
-19
lm_eval/config/utils.py
lm_eval/config/utils.py
+6
-6
No files found.
lm_eval/api/registry.py
View file @
17223113
...
@@ -160,7 +160,7 @@ def register_aggregation(name: str):
...
@@ -160,7 +160,7 @@ def register_aggregation(name: str):
return
decorate
return
decorate
def
get_aggregation
(
name
:
str
)
->
Callable
[
[],
dict
[
str
,
Callable
]
]
|
None
:
def
get_aggregation
(
name
:
str
)
->
Callable
[
...,
Any
]
|
None
:
try
:
try
:
return
AGGREGATION_REGISTRY
[
name
]
return
AGGREGATION_REGISTRY
[
name
]
except
KeyError
:
except
KeyError
:
...
...
lm_eval/config/metric.py
View file @
17223113
from
__future__
import
annotations
from
__future__
import
annotations
from
collections.abc
import
Callable
from
collections.abc
import
Callable
,
Mapping
from
dataclasses
import
dataclass
from
dataclasses
import
dataclass
from
functools
import
cached_property
from
functools
import
cached_property
from
typing
import
Any
from
typing
import
Any
...
@@ -12,7 +12,7 @@ class MetricConfig:
...
@@ -12,7 +12,7 @@ class MetricConfig:
name
:
str
name
:
str
fn
:
Callable
|
None
=
None
fn
:
Callable
|
None
=
None
kwargs
:
dict
|
None
=
None
kwargs
:
Mapping
[
str
,
Any
]
|
None
=
None
aggregation_fn
:
Callable
|
None
=
None
aggregation_fn
:
Callable
|
None
=
None
higher_is_better
:
bool
=
True
higher_is_better
:
bool
=
True
hf_evaluate
:
bool
=
False
hf_evaluate
:
bool
=
False
...
@@ -23,7 +23,7 @@ class MetricConfig:
...
@@ -23,7 +23,7 @@ class MetricConfig:
return
self
.
name
return
self
.
name
@
cached_property
@
cached_property
def
aggregation
(
self
)
->
Callable
:
def
aggregation
(
self
)
->
Callable
[...,
Any
]
|
None
:
from
lm_eval.api.registry
import
get_aggregation
from
lm_eval.api.registry
import
get_aggregation
if
self
.
aggregation_fn
is
None
:
if
self
.
aggregation_fn
is
None
:
...
@@ -31,7 +31,7 @@ class MetricConfig:
...
@@ -31,7 +31,7 @@ class MetricConfig:
return
self
.
aggregation_fn
return
self
.
aggregation_fn
@
cached_property
@
cached_property
def
_higher_is_better
(
self
)
->
bool
:
def
_higher_is_better
(
self
)
->
bool
|
None
:
from
lm_eval.api.registry
import
is_higher_better
from
lm_eval.api.registry
import
is_higher_better
if
self
.
higher_is_better
is
None
:
if
self
.
higher_is_better
is
None
:
...
@@ -42,7 +42,7 @@ class MetricConfig:
...
@@ -42,7 +42,7 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
"""Calculates the metric using the provided function and arguments."""
if
self
.
fn
is
None
:
if
self
.
fn
is
None
:
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
raise
ValueError
(
f
"Metric function for
{
self
.
name
}
is not defined."
)
return
self
.
fn
(
*
args
,
**
{
**
self
.
kwargs
,
**
kwargs
})
return
self
.
fn
(
*
args
,
**
{
**
(
self
.
kwargs
or
{})
,
**
kwargs
})
def
compute_aggregation
(
self
,
values
:
list
[
Any
])
->
Any
:
def
compute_aggregation
(
self
,
values
:
list
[
Any
])
->
Any
:
"""Computes the aggregation of the metric values."""
"""Computes the aggregation of the metric values."""
...
...
lm_eval/config/task.py
View file @
17223113
from
__future__
import
annotations
import
logging
import
logging
from
collections.abc
import
Iterable
from
dataclasses
import
asdict
,
dataclass
,
field
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Iterable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.api.instance
import
OutputType
...
@@ -20,8 +23,8 @@ class RepeatConfig:
...
@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
"""Encapsulates information about a single repeat."""
repeats
:
int
=
1
repeats
:
int
=
1
metric_fn
:
Union
[
str
,
Callable
]
=
"pass@N"
metric_fn
:
str
|
Callable
=
"pass@N"
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
@
dataclass
@
dataclass
...
@@ -38,11 +41,11 @@ class FewshotConfig:
...
@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
# to keep in sync as it is runtime-modified
num_fewshot
:
Callable
[[],
int
]
num_fewshot
:
Callable
[[],
int
]
split
:
Optional
[
str
]
=
None
split
:
str
|
None
=
None
sampler
:
Union
[
str
,
Callable
]
=
"default"
sampler
:
str
|
Callable
=
"default"
samples
:
Union
[
Callable
[[],
list
[
dict
]]
,
list
[
dict
]
,
None
]
=
None
samples
:
Callable
[[],
list
[
dict
]]
|
list
[
dict
]
|
None
=
None
process_docs
:
Optional
[
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
]
=
None
process_docs
:
Callable
[[
list
[
dict
]],
Iterable
[
dict
]]
|
None
=
None
fewshot_indices
:
Optional
[
list
[
int
]
]
=
None
fewshot_indices
:
list
[
int
]
|
None
=
None
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
rnd
:
int
=
field
(
init
=
False
,
default
=
False
)
def
__post_init__
(
self
)
->
None
:
def
__post_init__
(
self
)
->
None
:
...
@@ -65,22 +68,20 @@ class FewshotConfig:
...
@@ -65,22 +68,20 @@ class FewshotConfig:
def
_get_raw_docs
(
def
_get_raw_docs
(
self
,
dataset
self
,
dataset
)
->
Union
[
list
[
dict
]
,
Callable
[[],
Iterable
[
dict
]]
,
None
]
:
)
->
list
[
dict
]
|
Callable
[[],
Iterable
[
dict
]]
|
None
:
"""Get raw documents from configured source."""
"""Get raw documents from configured source."""
if
self
.
split
is
not
None
:
if
self
.
split
is
not
None
:
return
dataset
[
self
.
split
]
return
dataset
[
self
.
split
]
if
self
.
samples
is
not
None
:
if
self
.
samples
is
not
None
:
if
isinstance
(
self
.
samples
,
list
):
if
isinstance
(
self
.
samples
,
list
)
or
callable
(
self
.
samples
):
return
self
.
samples
elif
callable
(
self
.
samples
):
return
self
.
samples
return
self
.
samples
else
:
else
:
raise
TypeError
(
raise
TypeError
(
"samples must be either a list of dicts or a callable returning a list"
"samples must be either a list of dicts or a callable returning a list"
)
)
def
get_docs
(
self
,
dataset
)
->
Optional
[
Iterable
[
dict
]
]
:
def
get_docs
(
self
,
dataset
)
->
Iterable
[
dict
]
|
None
:
"""Get processed documents from configured source."""
"""Get processed documents from configured source."""
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
raw_docs
=
self
.
_get_raw_docs
(
dataset
)
if
raw_docs
is
None
:
if
raw_docs
is
None
:
...
@@ -100,8 +101,8 @@ class FewshotConfig:
...
@@ -100,8 +101,8 @@ class FewshotConfig:
return
self
.
sampler
return
self
.
sampler
def
init_sampler
(
def
init_sampler
(
self
,
docs
:
list
[
dict
],
task
:
"
Task
"
,
rnd
=
None
,
fewshot_indices
=
None
self
,
docs
:
list
[
dict
],
task
:
Task
,
rnd
=
None
,
fewshot_indices
=
None
)
->
"
ContextSampler
"
:
)
->
ContextSampler
:
"""Initialize the sampler with the given documents and task."""
"""Initialize the sampler with the given documents and task."""
if
rnd
is
None
:
if
rnd
is
None
:
raise
ValueError
(
raise
ValueError
(
...
@@ -120,49 +121,49 @@ class FewshotConfig:
...
@@ -120,49 +121,49 @@ class FewshotConfig:
@
dataclass
@
dataclass
class
TaskConfig
(
dict
):
class
TaskConfig
(
dict
):
# task naming/registry
# task naming/registry
task
:
Optional
[
str
]
=
None
task
:
str
|
None
=
None
task_alias
:
Optional
[
str
]
=
None
task_alias
:
str
|
None
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
tag
:
str
|
list
|
None
=
None
# HF dataset options.
# HF dataset options.
# which dataset to use,
# which dataset to use,
# and what splits for what purpose
# and what splits for what purpose
custom_dataset
:
Optional
[
Callable
]
=
None
custom_dataset
:
Callable
|
None
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_path
:
str
|
None
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_name
:
str
|
None
=
None
dataset_kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
dataset_kwargs
:
dict
|
None
=
field
(
default_factory
=
dict
)
training_split
:
Optional
[
str
]
=
None
training_split
:
str
|
None
=
None
validation_split
:
Optional
[
str
]
=
None
validation_split
:
str
|
None
=
None
test_split
:
Optional
[
str
]
=
None
test_split
:
str
|
None
=
None
fewshot_split
:
Optional
[
str
]
=
(
fewshot_split
:
str
|
None
=
(
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
None
# TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
)
# formatting / prompting options.
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
# see docs/advanced_task_guide.md for more info
process_docs
:
Optional
[
Callable
]
=
None
process_docs
:
Callable
|
None
=
None
doc_to_text
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_text
:
Callable
|
str
|
None
=
None
doc_to_target
:
Optional
[
Union
[
Callable
,
str
]]
=
None
doc_to_target
:
Callable
|
str
|
None
=
None
doc_to_image
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_image
:
Callable
|
str
|
None
=
None
doc_to_audio
:
Union
[
Callable
,
str
,
None
]
=
None
doc_to_audio
:
Callable
|
str
|
None
=
None
unsafe_code
:
bool
=
False
unsafe_code
:
bool
=
False
doc_to_choice
:
Optional
[
Union
[
Callable
,
str
,
dict
,
list
]]
=
None
doc_to_choice
:
Callable
|
str
|
dict
|
list
|
None
=
None
process_results
:
Optional
[
Union
[
Callable
,
str
]]
=
None
process_results
:
Callable
|
str
|
None
=
None
use_prompt
:
Optional
[
str
]
=
None
use_prompt
:
str
|
None
=
None
description
:
str
=
""
description
:
str
=
""
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_config
:
Optional
[
dict
]
=
None
fewshot_config
:
dict
|
None
=
None
# runtime configuration options
# runtime configuration options
num_fewshot
:
Optional
[
int
]
=
0
num_fewshot
:
int
|
None
=
0
generation_kwargs
:
Optional
[
dict
]
=
None
generation_kwargs
:
dict
|
None
=
None
# scoring options
# scoring options
metric_list
:
Optional
[
list
]
=
None
metric_list
:
list
|
None
=
None
output_type
:
OutputType
=
"generate_until"
output_type
:
OutputType
=
"generate_until"
repeats
:
int
=
1
repeats
:
int
=
1
filter_list
:
Optional
[
list
[
dict
]
]
=
None
filter_list
:
list
[
dict
]
|
None
=
None
should_decontaminate
:
bool
=
False
should_decontaminate
:
bool
=
False
doc_to_decontamination_query
:
Optional
[
str
]
=
None
doc_to_decontamination_query
:
str
|
None
=
None
gen_prefix
:
Optional
[
str
]
=
None
gen_prefix
:
str
|
None
=
None
metadata
:
Optional
[
dict
]
=
field
(
metadata
:
dict
|
None
=
field
(
default_factory
=
dict
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
...
@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
)
)
def
_get_metric
(
def
_get_metric
(
self
,
metric_list
:
list
[
dict
]
|
None
=
None
)
->
list
[
MetricConfig
]:
self
,
metric_list
:
Optional
[
list
[
dict
]]
=
None
)
->
list
[
"MetricConfig"
]:
from
lm_eval.api.registry
import
(
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
...
@@ -314,7 +313,7 @@ class TaskConfig(dict):
return
metrics
return
metrics
@
property
@
property
def
get_filters
(
self
)
->
list
[
"
FilterConfig
"
]:
def
get_filters
(
self
)
->
list
[
FilterConfig
]:
from
lm_eval.filters
import
build_filter_ensemble
from
lm_eval.filters
import
build_filter_ensemble
if
not
self
.
filter_list
:
if
not
self
.
filter_list
:
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
...
@@ -354,7 +353,7 @@ class TaskConfig(dict):
return
x
return
x
@
classmethod
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
"
TaskConfig
"
:
def
from_yaml
(
cls
,
data
:
dict
)
->
TaskConfig
:
"""Create a TaskConfig instance from a YAML-like dictionary."""
"""Create a TaskConfig instance from a YAML-like dictionary."""
return
cls
(
**
data
)
return
cls
(
**
data
)
...
...
lm_eval/config/template.py
View file @
17223113
from
__future__
import
annotations
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Optional
,
Union
from
typing
import
TYPE_CHECKING
,
Callable
if
TYPE_CHECKING
:
if
TYPE_CHECKING
:
...
@@ -11,19 +13,19 @@ class TemplateConfig:
...
@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
"""Encapsulates information about a template."""
template
:
str
template
:
str
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
description
:
str
description
:
str
context_prefix
:
str
context_prefix
:
str
prefix_delimiter
:
str
prefix_delimiter
:
str
context_delimiter
:
str
context_delimiter
:
str
answer_suffix
:
str
answer_suffix
:
str
target_delimiter
:
str
target_delimiter
:
str
choice_format
:
Optional
[
str
]
choice_format
:
str
|
None
choice_delimiter
:
Optional
[
str
]
choice_delimiter
:
str
|
None
fewshot_delimiter
:
str
fewshot_delimiter
:
str
metric_list
:
Optional
[
Union
[
list
[
str
]
,
list
[
"
MetricConfig
"
]]]
=
field
(
metric_list
:
list
[
str
]
|
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
)
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
...
@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
Answer:` doc_to_choice(doc)` for each choice.
"""
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
=
"mcq"
template
=
"mcq"
context_prefix
:
str
=
"Question:"
context_prefix
:
str
=
"Question:"
prefix_delimiter
:
str
=
" "
prefix_delimiter
:
str
=
" "
context_delimiter
:
str
=
"
\n
"
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
"
\n
"
target_delimiter
:
str
=
"
\n
"
choice_format
:
Optional
[
str
]
=
"letters"
choice_format
:
str
|
None
=
"letters"
choice_delimiter
:
Optional
[
str
]
=
"
\n
"
choice_delimiter
:
str
|
None
=
"
\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
default_factory
=
lambda
:
[
"acc"
])
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
])
@
dataclass
@
dataclass
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
...
@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
Answer:` <doc_to_target(doc)>`
"""
"""
doc_to_text
:
Union
[
str
,
Callable
[[
dict
],
str
]
]
doc_to_text
:
str
|
Callable
[[
dict
],
str
]
doc_to_choice
:
Union
[
str
,
list
,
Callable
[[
dict
],
list
]
]
doc_to_choice
:
str
|
list
|
Callable
[[
dict
],
list
]
doc_to_target
:
Union
[
int
,
Callable
[[
dict
],
int
]
]
doc_to_target
:
int
|
Callable
[[
dict
],
int
]
template
:
str
=
"cloze"
template
:
str
=
"cloze"
description
:
str
=
""
description
:
str
=
""
context_prefix
:
str
=
"Question:"
context_prefix
:
str
=
"Question:"
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
...
@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter
:
str
=
"
\n
"
context_delimiter
:
str
=
"
\n
"
answer_suffix
:
str
=
"Answer:"
answer_suffix
:
str
=
"Answer:"
target_delimiter
:
str
=
" "
target_delimiter
:
str
=
" "
choice_format
:
Optional
[
str
]
=
None
choice_format
:
str
|
None
=
None
choice_delimiter
:
Optional
[
str
]
=
None
choice_delimiter
:
str
|
None
=
None
fewshot_delimiter
:
str
=
"
\n\n
"
fewshot_delimiter
:
str
=
"
\n\n
"
metric_list
:
Optional
[
list
[
"
MetricConfig
"
]]
=
field
(
metric_list
:
list
[
MetricConfig
]
|
None
=
field
(
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
default_factory
=
lambda
:
[
"acc"
,
"acc_norm"
]
)
)
lm_eval/config/utils.py
View file @
17223113
from
__future__
import
annotations
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
Union
from
typing
import
Any
,
Callable
def
serialize_callable
(
def
serialize_callable
(
value
:
Union
[
Callable
[...,
Any
]
,
str
]
,
keep_callable
=
False
value
:
Callable
[...,
Any
]
|
str
,
keep_callable
=
False
)
->
Union
[
Callable
[...,
Any
]
,
str
]
:
)
->
Callable
[...,
Any
]
|
str
:
"""Serializes a given function or string.
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
If 'keep_callable' is True, the original callable is returned.
...
@@ -20,9 +22,7 @@ def serialize_callable(
...
@@ -20,9 +22,7 @@ def serialize_callable(
return
str
(
value
)
return
str
(
value
)
def
maybe_serialize
(
def
maybe_serialize
(
val
:
Callable
|
Any
,
keep_callable
=
False
)
->
Callable
|
Any
:
val
:
Union
[
Callable
,
Any
],
keep_callable
=
False
)
->
Union
[
Callable
,
Any
]:
"""Conditionally serializes a value if it is callable."""
"""Conditionally serializes a value if it is callable."""
return
(
return
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment