Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
fedaf262
Commit
fedaf262
authored
Jul 08, 2025
by
Baber
Browse files
refactor: improve dataset and metric handling in TaskConfig
parent
863ff340
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
50 additions
and
46 deletions
+50
-46
lm_eval/api/task.py
lm_eval/api/task.py
+15
-10
lm_eval/config/task.py
lm_eval/config/task.py
+35
-36
No files found.
lm_eval/api/task.py
View file @
fedaf262
...
...
@@ -639,7 +639,7 @@ class ConfigurableTask(Task):
if
self
.
config
.
dataset_name
is
not
None
:
self
.
DATASET_NAME
=
self
.
config
.
dataset_name
self
.
metric_list
:
list
[
MetricConfig
]
=
self
.
config
.
get_metrics
#
self.metric_list: list[MetricConfig] = self.config.get_metrics
self
.
download
(
self
.
config
.
dataset_kwargs
)
self
.
_training_docs
=
None
...
...
@@ -655,7 +655,10 @@ class ConfigurableTask(Task):
else
:
self
.
prompt
=
None
if
self
.
config
.
fewshot_cfg
.
num
()
>
0
and
self
.
fewshot_docs
()
is
not
None
:
if
(
self
.
config
.
fewshot_cfg
.
num_fewshot
()
>
0
and
self
.
fewshot_docs
()
is
not
None
):
self
.
fewshot_rnd
=
random
.
Random
()
self
.
sampler
=
self
.
config
.
fewshot_cfg
.
init_sampler
(
list
(
self
.
fewshot_docs
()),
self
,
rnd
=
self
.
fewshot_rnd
...
...
@@ -722,19 +725,21 @@ class ConfigurableTask(Task):
def
download
(
self
,
dataset_kwargs
:
Optional
[
Dict
[
str
,
Any
]]
=
None
,
**
kwargs
)
->
None
:
if
isinstance
(
df
:
=
self
.
config
.
ds_cfg
.
custom
,
Callable
):
self
.
config
.
dataset_kwargs
,
self
.
config
.
metadata
=
(
self
.
config
.
dataset_kwargs
or
{},
self
.
config
.
metadata
or
{},
)
if
isinstance
(
df
:
=
self
.
config
.
custom_dataset
,
Callable
):
eval_logger
.
warning
(
f
"
{
self
.
config
.
task
}
: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+
"
\n
For example --metadata='{
\"
max_seq_lengths
\"
:[4096, 8192]}'. For details see task Readme."
)
self
.
dataset
=
df
(
**
(
self
.
config
.
ds_cfg
.
kwargs
|
self
.
config
.
ds_cfg
.
metadata
)
)
self
.
dataset
=
df
(
**
(
self
.
config
.
dataset_kwargs
|
self
.
config
.
metadata
))
else
:
self
.
dataset
=
datasets
.
load_dataset
(
path
=
self
.
config
.
d
s_cfg
.
path
,
name
=
self
.
config
.
d
s_cfg
.
name
,
**
self
.
config
.
d
s_cfg
.
kwargs
if
self
.
config
.
ds_cfg
.
kwargs
else
{}
,
path
=
self
.
config
.
d
ataset_
path
,
name
=
self
.
config
.
d
ataset_
name
,
**
self
.
config
.
d
ataset_kwargs
,
)
def
has_training_docs
(
self
)
->
bool
:
...
...
@@ -971,7 +976,7 @@ class ConfigurableTask(Task):
"""Iterates over FilterEnsembles and applies them to instances"""
if
hasattr
(
self
,
"_filters"
):
for
f
in
self
.
_filters
:
f
.
apply
(
self
.
_instances
)
f
.
ensemble
.
apply
(
self
.
_instances
)
else
:
eval_logger
.
warning
(
"No filter defined, passing through instances"
)
return
self
.
_instances
...
...
lm_eval/config/task.py
View file @
fedaf262
...
...
@@ -2,6 +2,7 @@ import logging
from
dataclasses
import
asdict
,
dataclass
,
field
from
typing
import
TYPE_CHECKING
,
Callable
,
Iterable
,
Optional
,
Union
from
lm_eval.api.filter
import
FilterEnsemble
from
lm_eval.api.instance
import
OutputType
from
lm_eval.config.metric
import
MetricConfig
from
lm_eval.config.utils
import
maybe_serialize
...
...
@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize
if
TYPE_CHECKING
:
from
lm_eval.api.samplers
import
ContextSampler
from
lm_eval.api.task
import
Task
from
lm_eval.filters
import
FilterEnsemble
eval_logger
=
logging
.
getLogger
(
__name__
)
...
...
@@ -29,8 +29,8 @@ class FilterConfig:
"""Encapsulates information about a single filter."""
name
:
str
fn
:
Optional
[
Callable
]
=
Non
e
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
ensemble
:
FilterEnsembl
e
metric_list
:
list
[
MetricConfig
]
@
dataclass
...
...
@@ -117,17 +117,6 @@ class FewshotConfig:
)
@
dataclass
class
DatasetConfig
:
"""Encapsulates information about a dataset."""
path
:
Optional
[
str
]
=
None
name
:
Optional
[
str
]
=
None
kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
custom
:
Optional
[
Callable
]
=
None
metadata
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
@
dataclass
class
TaskConfig
(
dict
):
# task naming/registry
...
...
@@ -140,7 +129,7 @@ class TaskConfig(dict):
custom_dataset
:
Optional
[
Callable
]
=
None
dataset_path
:
Optional
[
str
]
=
None
dataset_name
:
Optional
[
str
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
None
dataset_kwargs
:
Optional
[
dict
]
=
field
(
default_factory
=
dict
)
training_split
:
Optional
[
str
]
=
None
validation_split
:
Optional
[
str
]
=
None
test_split
:
Optional
[
str
]
=
None
...
...
@@ -177,9 +166,9 @@ class TaskConfig(dict):
default_factory
=
dict
)
# by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list
:
list
[
MetricConfig
]
=
None
_metric_list
:
list
[
MetricConfig
]
=
field
(
default_factory
=
list
)
_filter_list
:
list
[
FilterConfig
]
=
None
ds_cfg
:
DatasetConfig
=
field
(
init
=
False
)
#
ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg
:
FewshotConfig
=
field
(
init
=
False
)
def
__post_init__
(
self
)
->
None
:
...
...
@@ -215,14 +204,6 @@ class TaskConfig(dict):
eval_logger
.
warning
(
f
"
{
self
.
task
}
: No `generation_kwargs` specified in task config, defaulting to
{
self
.
generation_kwargs
}
"
)
# ---setup dataset config--- #
self
.
ds_cfg
=
DatasetConfig
(
path
=
self
.
dataset_path
,
name
=
self
.
dataset_name
,
kwargs
=
self
.
dataset_kwargs
,
custom
=
self
.
custom_dataset
,
metadata
=
self
.
metadata
or
{},
)
# ---setup fewshot config--- #
_fewshot_cfg
=
self
.
fewshot_config
if
self
.
fewshot_config
is
not
None
else
{}
self
.
fewshot_cfg
=
FewshotConfig
(
...
...
@@ -234,8 +215,9 @@ class TaskConfig(dict):
fewshot_indices
=
_fewshot_cfg
.
get
(
"fewshot_indices"
,
None
),
)
@
property
def
get_metrics
(
self
)
->
list
[
"MetricConfig"
]:
def
get_metric
(
self
,
metric_list
:
Optional
[
list
[
dict
]]
=
None
)
->
list
[
"MetricConfig"
]:
from
lm_eval.api.registry
import
(
AGGREGATION_REGISTRY
,
DEFAULT_METRIC_REGISTRY
,
...
...
@@ -245,8 +227,9 @@ class TaskConfig(dict):
is_higher_better
,
)
metric_list
=
metric_list
or
self
.
metric_list
metrics
=
[]
if
self
.
metric_list
is
None
:
if
not
metric_list
:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list
=
DEFAULT_METRIC_REGISTRY
[
self
.
output_type
]
eval_logger
.
info
(
...
...
@@ -263,7 +246,7 @@ class TaskConfig(dict):
)
else
:
# ---------- 2. Process user-defined metrics from config ----------
for
metric_config
in
self
.
metric_list
:
for
metric_config
in
metric_list
:
metric_name
=
metric_config
[
"metric"
]
_metric_fn_kwargs
=
{
key
:
metric_config
[
key
]
...
...
@@ -324,34 +307,50 @@ class TaskConfig(dict):
hf_evaluate
=
_hf_evaluate_metric
,
)
)
for
m
in
metrics
:
if
m
not
in
self
.
_metric_list
:
self
.
_metric_list
.
extend
(
m
)
return
metrics
@
property
def
get_filters
(
self
)
->
list
[
"Filter
Ensemble
"
]:
def
get_filters
(
self
)
->
list
[
"Filter
Config
"
]:
from
lm_eval.filters
import
build_filter_ensemble
if
not
self
.
filter_list
:
eval_logger
.
debug
(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return
[
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)])]
return
[
FilterConfig
(
name
=
"none"
,
ensemble
=
build_filter_ensemble
(
"none"
,
[(
"take_first"
,
None
)]),
metric_list
=
self
.
get_metric
(
metric_list
=
None
),
)
]
else
:
def
_strip_fn
(
d
:
dict
)
->
tuple
[
str
,
dict
]:
return
d
[
"function"
],
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
!=
"function"
}
return
d
[
"function"
],
{
k
:
v
for
k
,
v
in
d
.
items
()
if
k
not
in
[
"function"
,
"metric_list"
]
}
configs
=
(
self
.
filter_list
.
values
()
if
isinstance
(
self
.
filter_list
,
dict
)
else
self
.
filter_list
)
return
[
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]],
x
=
[
FilterConfig
(
name
=
cfg
[
"name"
],
ensemble
=
build_filter_ensemble
(
filter_name
=
cfg
[
"name"
],
components
=
[
_strip_fn
(
f
)
for
f
in
cfg
[
"filter"
]],
),
metric_list
=
self
.
get_metric
(
metric_list
=
cfg
.
get
(
"metric_list"
)),
)
for
cfg
in
configs
]
return
x
@
classmethod
def
from_yaml
(
cls
,
data
:
dict
)
->
"TaskConfig"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment