Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
227f1a74
Commit
227f1a74
authored
Jul 08, 2025
by
Baber
Browse files
refactor: improve dataset and metric handling in TaskConfig
parent
3b4d0af1
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
113 additions
and
55 deletions
+113
-55
lm_eval/api/group.py
lm_eval/api/group.py
+6
-0
lm_eval/api/task.py
lm_eval/api/task.py
+20
-15
lm_eval/config/task.py
lm_eval/config/task.py
+38
-38
tests/test_tasks.py
tests/test_tasks.py
+49
-2
No files found.
lm_eval/api/group.py
View file @
227f1a74
...
...
@@ -29,6 +29,7 @@ class GroupConfig(dict):
aggregate_metric_list
:
Optional
[
Union
[
List
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
version
:
Optional
[
str
]
=
None
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
)
...
...
@@ -48,6 +49,11 @@ class GroupConfig(dict):
AggMetricConfig
(
**
item
)
if
isinstance
(
item
,
dict
)
else
item
for
item
in
self
.
aggregate_metric_list
]
self
.
version
=
(
self
.
version
or
self
.
metadata
.
get
(
"version"
,
"1.0"
)
if
self
.
metadata
else
"1.0"
)
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
...
...
lm_eval/api/task.py
View file @
227f1a74
...
...
@@ -639,7 +639,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
config
.
get_metrics
#
self.metric_list: list[MetricConfig] = self.config.get_metrics
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
...
...
@@ -655,7 +655,10 @@ class ConfigurableTask(Task):
else
:
self
.
prompt
=
None
if
self
.
config
.
fewshot_cfg
.
num
()
>
0
and
self
.
fewshot_docs
()
is
not
None
:
if
(
self
.
config
.
fewshot_cfg
.
num_fewshot
()
>
0
and
self
.
fewshot_docs
()
is
not
None
):
self
.
fewshot_rnd
=
random
.
Random
()
self
.
sampler
=
self
.
config
.
fewshot_cfg
.
init_sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
self
.
fewshot_rnd
...
...
@@ -724,21 +727,23 @@ class ConfigurableTask(Task):
)
->
None
:
from
packaging.version
import
parse
as
vparse
self
.
config
.
dataset_kwargs
,
self
.
config
.
metadata
=
(
self
.
config
.
dataset_kwargs
or
{},
self
.
config
.
metadata
or
{},
)
if
dataset_kwargs
and
vparse
(
datasets
.
__version__
)
>=
vparse
(
"4.0.0"
):
dataset_kwargs
.
pop
(
"trust_remote_code"
,
None
)
if
isinstance
(
self
.
config
.
custom_dataset
,
Callable
):
if
isinstance
(
df
:
=
self
.
config
.
custom_dataset
,
Callable
):
eval_logger
.
warning
(
f
"
{
self
.
config
.
task
}
: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+
"
\n
For example --metadata='{
\"
max_seq_lengths
\"
:[4096, 8192]}'. For details see task Readme."
)
self
.
dataset
=
self
.
config
.
custom_dataset
(
**
(
self
.
config
.
metadata
or
{}),
**
(
self
.
config
.
dataset_kwargs
or
{})
)
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
else
:
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
DATASET_PATH
,
name
=
self
.
DATASET_NAME
,
**
dataset_kwargs
if
dataset_kwargs
is
not
None
else
{}
,
path
=
self
.
config
.
dataset_path
,
name
=
self
.
config
.
dataset_name
,
**
self
.
config
.
dataset_kwargs
,
)
def
has_training_docs
(
self
)
->
bool
:
...
...
@@ -975,7 +980,7 @@ class ConfigurableTask(Task):
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
for
f
in
self
.
_filters
:
f
.
apply
(
self
.
_instances
)
f
.
ensemble
.
apply
(
self
.
_instances
)
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
return
self
.
_instances
...
...
@@ -1214,7 +1219,7 @@ class ConfigurableTask(Task):
arguments
=
[(
ctx
,
f
"
{
target_delimiter
}{
cont
}
"
)
for
cont
in
choices
]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if
"acc_mutual_info"
in
[
m
.
metric_name
for
m
in
self
.
metric_list
]:
if
"acc_mutual_info"
in
[
m
.
metric_name
for
m
in
self
.
config
.
_
metric_list
]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...
...
@@ -1281,7 +1286,7 @@ class ConfigurableTask(Task):
return
self
.
config
.
process_results
(
doc
,
results
)
result_dict
=
{}
use_metric
=
list
(
m
.
metric_name
for
m
in
self
.
metric_list
)
use_metric
=
list
(
m
.
metric_name
for
m
in
self
.
config
.
_
metric_list
)
if
self
.
OUTPUT_TYPE
==
"loglikelihood"
:
results
=
results
[
0
]
ll
,
is_greedy
=
results
...
...
@@ -1407,7 +1412,7 @@ class ConfigurableTask(Task):
# cast gold to the same type as result
gold
=
type
(
result
)(
gold
)
for
metric
in
self
.
metric_list
:
for
metric
in
self
.
config
.
_
metric_list
:
if
self
.
multiple_target
:
# in the case where we have multiple targets,
# return true if any are true
...
...
@@ -1470,10 +1475,10 @@ class ConfigurableTask(Task):
return
result_dict
def
aggregation
(
self
)
->
dict
:
return
{
k
.
name
:
k
.
aggregation_fn
for
k
in
self
.
metric_list
}
return
{
k
.
name
:
k
.
aggregation_fn
for
k
in
self
.
config
.
_
metric_list
}
def
higher_is_better
(
self
)
->
dict
:
return
{
k
.
name
:
k
.
higher_is_better
for
k
in
self
.
metric_list
}
return
{
k
.
name
:
k
.
higher_is_better
for
k
in
self
.
config
.
_
metric_list
}
def
get_config
(
self
,
key
:
str
)
->
Any
:
return
getattr
(
self
.
_config
,
key
,
None
)
...
...
lm_eval/config/task.py
View file @
227f1a74
...
...
@@ -2,6 +2,7 @@ import logging
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Iterable
,
Optional
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.utils
import
maybe_serialize
...
...
@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize
if
TYPE_CHECKING
:
from
lm_eval.api.samplers
import
ContextSampler
from
lm_eval.api.task
import
Task
from
lm_eval.filters
import
FilterEnsemble
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -29,8 +29,8 @@ class FilterConfig:
"""Encapsulates information about a single filter."""
name
:
str
fn
:
Optional
[
Callable
]
=
Non
e
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
ensemble
:
FilterEnsembl
e
metric_list
:
list
[
MetricConfig
]
@
dataclass
...
...
@@ -117,21 +117,10 @@ class FewshotConfig:
)
@
dataclass
class
DatasetConfig
:
"""Encapsulates information about a dataset."""
path
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
custom
:
Optional
[
Callable
]
=
None
metadata
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
class
TaskConfig
(
dict
):
# task naming/registry
task
:
str
task
:
Optional
[
str
]
=
None
task_alias
:
Optional
[
str
]
=
None
tag
:
Optional
[
Union
[
str
,
list
]]
=
None
# HF dataset options.
...
...
@@ -140,7 +129,7 @@ class TaskConfig(dict):
custom_dataset
:
Optional
[
Callable
]
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
training_split
:
Optional
[
str
]
=
None
validation_split
:
Optional
[
str
]
=
None
test_split
:
Optional
[
str
]
=
None
...
...
@@ -177,9 +166,9 @@ class TaskConfig(dict):
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list
:
list
[
MetricConfig
]
=
None
_metric_list
:
list
[
MetricConfig
]
=
field
(
default_factory
=
list
)
_filter_list
:
list
[
FilterConfig
]
=
None
ds_cfg
:
DatasetConfig
=
field
(
init
=
False
)
#
ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg
:
FewshotConfig
=
field
(
init
=
False
)
def
__post_init__
(
self
)
->
None
:
...
...
@@ -215,18 +204,10 @@ class TaskConfig(dict):
eval_logger
.
warning
(
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
# ---setup dataset config--- #
self
.
ds_cfg
=
DatasetConfig
(
path
=
self
.
dataset_path
,
name
=
self
.
dataset_name
,
kwargs
=
self
.
dataset_kwargs
,
custom
=
self
.
custom_dataset
,
metadata
=
self
.
metadata
or
{},
)
# ---setup fewshot config--- #
_fewshot_cfg
=
self
.
fewshot_config
if
self
.
fewshot_config
is
not
None
else
{}
self
.
fewshot_cfg
=
FewshotConfig
(
num_fewshot
=
lambda
:
self
.
num_fewshot
or
_fewshot_cfg
[
"num_fewshot"
]
,
num_fewshot
=
lambda
:
self
.
num_fewshot
or
_fewshot_cfg
.
get
(
"num_fewshot"
,
0
)
,
split
=
self
.
fewshot_split
,
sampler
=
_fewshot_cfg
.
get
(
"sampler"
,
"default"
),
samples
=
_fewshot_cfg
.
get
(
"samples"
,
None
),
...
...
@@ -234,8 +215,9 @@ class TaskConfig(dict):
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
)
@
property
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
def
_get_metric
(
self
,
metric_list
:
Optional
[
list
[
dict
]]
=
None
)
->
list
[
"MetricConfig"
]:
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
...
@@ -245,8 +227,10 @@ class TaskConfig(dict):
is_higher_better
,
)
# if metric_list defined inside a filter, use that; otherwise use the task's metric_list
metric_list
=
metric_list
or
self
.
metric_list
metrics
=
[]
if
self
.
metric_list
is
None
:
if
not
metric_list
:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
...
...
@@ -263,7 +247,7 @@ class TaskConfig(dict):
)
else
:
# ---------- 2. Process user-defined metrics from config ----------
for
metric_config
in
self
.
metric_list
:
for
metric_config
in
metric_list
:
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
...
...
@@ -324,34 +308,50 @@ class TaskConfig(dict):
hf_evaluate
=
_hf_evaluate_metric
,
)
)
for
m
in
metrics
:
if
m
not
in
self
.
_metric_list
:
self
.
_metric_list
.
append
(
m
)
return
metrics
@
property
def
get_filters
(
self
)
->
list
[
"Filter
Ensemble
"
]:
def
get_filters
(
self
)
->
list
[
"Filter
Config
"
]:
from
lm_eval.filters
import
build_filter_ensemble
if
not
self
.
filter_list
:
eval_logger
.
debug
(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return
[
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)])]
return
[
FilterConfig
(
name
=
"none"
,
ensemble
=
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)]),
metric_list
=
self
.
_get_metric
(
metric_list
=
None
),
)
]
else
:
def
_strip_fn
(
d
:
dict
)
->
tuple
[
str
,
dict
]:
return
d
[
"function"
],
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
!=
"function"
}
return
d
[
"function"
],
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
not
in
[
"function"
,
"metric_list"
]
}
configs
=
(
self
.
filter_list
.
values
()
if
isinstance
(
self
.
filter_list
,
dict
)
else
self
.
filter_list
)
return
[
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]],
x
=
[
FilterConfig
(
name
=
cfg
[
"name"
],
ensemble
=
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]],
),
metric_list
=
self
.
_get_metric
(
metric_list
=
cfg
.
get
(
"metric_list"
)),
)
for
cfg
in
configs
]
return
x
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
"TaskConfig"
:
...
...
tests/test_tasks.py
View file @
227f1a74
...
...
@@ -46,7 +46,12 @@ def limit() -> int:
return
10
class
BaseTasks
:
@
pytest
.
mark
.
parametrize
(
"task_class"
,
task_class
(
get_new_tasks_else_default
()),
ids
=
lambda
x
:
f
"
{
x
.
config
.
task
}
"
,
)
class
TestBaseTasks
:
"""
Base class for testing tasks
"""
...
...
@@ -160,8 +165,50 @@ class BaseTasks:
task_class
(
get_new_tasks_else_default
()),
ids
=
lambda
x
:
f
"
{
x
.
config
.
task
}
"
,
)
class
TestNewTasksElseDefault
(
BaseTasks
):
class
TestNewTasksElseDefault
(
Test
BaseTasks
):
"""
Test class parameterized with a list of new/modified tasks
(or a set of default tasks if none have been modified)
"""
@
pytest
.
mark
.
parametrize
(
"task_class"
,
task_class
(
[
"arc_easy_unitxt"
],
tasks
.
TaskManager
(
include_path
=
"./tests/testconfigs"
)
),
ids
=
lambda
x
:
f
"
{
x
.
config
.
task
}
"
,
)
class
TestUnitxtTasks
(
TestBaseTasks
):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def
test_check_training_docs
(
self
,
task_class
:
ConfigurableTask
):
if
task_class
.
has_training_docs
():
assert
task_class
.
dataset
[
"train"
]
is
not
None
def
test_check_validation_docs
(
self
,
task_class
):
if
task_class
.
has_validation_docs
():
assert
task_class
.
dataset
[
"validation"
]
is
not
None
def
test_check_test_docs
(
self
,
task_class
):
task
=
task_class
if
task
.
has_test_docs
():
assert
task
.
dataset
[
"test"
]
is
not
None
def
test_doc_to_text
(
self
,
task_class
,
limit
:
int
):
task
=
task_class
arr
=
(
list
(
islice
(
task
.
test_docs
(),
limit
))
if
task
.
has_test_docs
()
else
list
(
islice
(
task
.
validation_docs
(),
limit
))
)
_array
=
[
task
.
doc_to_text
(
doc
)
for
doc
in
arr
]
if
not
task
.
multiple_input
:
for
x
in
_array
:
assert
isinstance
(
x
,
str
)
else
:
pass
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment