Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4254c7bd
Commit
4254c7bd
authored
Jul 26, 2025
by
Baber
Browse files
add task factory
parent
eec9de3e
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
601 additions
and
1288 deletions
+601
-1288
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
lm_eval/api/group.py
lm_eval/api/group.py
+5
-3
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-1218
lm_eval/tasks/_config_loader.py
lm_eval/tasks/_config_loader.py
+71
-13
lm_eval/tasks/factory.py
lm_eval/tasks/factory.py
+126
-0
lm_eval/tasks/index.py
lm_eval/tasks/index.py
+171
-0
lm_eval/tasks/manager.py
lm_eval/tasks/manager.py
+79
-0
pyproject.toml
pyproject.toml
+2
-2
tests/test_config_loader.py
tests/test_config_loader.py
+109
-18
tests/test_task_index.py
tests/test_task_index.py
+35
-33
No files found.
.pre-commit-config.yaml
View file @
4254c7bd
...
...
@@ -29,7 +29,7 @@ repos:
-
id
:
mixed-line-ending
args
:
[
--fix=lf
]
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.12.
2
rev
:
v0.12.
5
hooks
:
# Run the linter.
-
id
:
ruff
...
...
lm_eval/api/group.py
View file @
4254c7bd
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
from
datasets.features.pdf
import
field
@
dataclass
...
...
@@ -25,9 +27,9 @@ class AggMetricConfig(dict):
class
GroupConfig
:
group
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]
]
=
None
task
:
Union
[
str
,
list
]
=
field
(
default_factory
=
list
)
aggregate_metric_list
:
Optional
[
Union
[
L
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
Union
[
l
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
...
lm_eval/tasks/__init__.py
View file @
4254c7bd
This diff is collapsed.
Click to expand it.
lm_eval/tasks/_config_loader.py
View file @
4254c7bd
...
...
@@ -3,29 +3,31 @@ from __future__ import annotations
import
functools
import
importlib.util
import
sys
from
collections.abc
import
Callable
from
pathlib
import
Path
from
typing
import
Any
import
yaml
_Base
=
yaml
.
CLoader
if
getattr
(
yaml
,
"__with_libyaml__"
,
False
)
else
yaml
.
FullLoader
_Base
=
(
yaml
.
CSafeLoader
if
getattr
(
yaml
,
"__with_libyaml__"
,
False
)
else
yaml
.
FullLoader
)
_IGNORE_DIRS
=
{
"__pycache__"
,
".ipynb_checkpoints"
}
# --------------------------------------------------------------------------- helpers
@
functools
.
lru_cache
(
128
)
def
_mk_function_ctor
(
base_dir
:
Path
,
resolve
:
bool
):
def
ctor
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
):
spec
=
loader
.
construct_scalar
(
node
)
# type: ignore[arg-type]
if
not
resolve
:
return
lambda
*
_
,
**
__
:
None
return
_import_func
tion
(
spec
,
base_dir
)
return
str
(
base_dir
.
expanduser
()
/
spec
)
return
_import_func
_in_yml
(
spec
,
base_dir
)
return
ctor
@
functools
.
lru_cache
(
maxsize
=
1024
)
@
functools
.
lru_cache
(
maxsize
=
512
)
def
_make_loader
(
base_dir
:
Path
,
*
,
resolve_funcs
:
bool
)
->
type
[
yaml
.
Loader
]:
class
Loader
(
_Base
):
...
# type: ignore[no-redef]
...
...
@@ -37,8 +39,14 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
return
Loader
@
functools
.
lru_cache
(
maxsize
=
4096
)
def
_import_function
(
qual
:
str
,
base_dir
:
Path
):
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_import_func_in_yml
(
qual
:
str
,
base_dir
:
Path
):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
Args:
qual: Qualified function name (e.g., 'utils.process_doc')
base_dir: Directory to search for local modules
"""
mod_path
,
_
,
fn_name
=
qual
.
rpartition
(
"."
)
# 1) relative “utils.py” next to YAML
rel
=
(
base_dir
/
f
"
{
mod_path
.
replace
(
'.'
,
'/'
)
}
.py"
).
resolve
()
...
...
@@ -47,26 +55,74 @@ def _import_function(qual: str, base_dir: Path):
key
=
f
"
{
rel
}
:
{
mtime
}
"
# one module per mtime
if
key
not
in
sys
.
modules
:
spec
=
importlib
.
util
.
spec_from_file_location
(
key
,
rel
)
if
spec
is
None
or
spec
.
loader
is
None
:
raise
ImportError
(
f
"Cannot load module from
{
rel
}
"
)
from
None
mod
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
mod
)
# type: ignore[arg-type]
sys
.
modules
[
key
]
=
mod
return
getattr
(
sys
.
modules
[
key
],
fn_name
)
# 2) already
‑
importable module
# 2) already
-
importable module
module
=
__import__
(
mod_path
,
fromlist
=
[
fn_name
])
return
getattr
(
module
,
fn_name
)
# --------------------------------------------------------------------- public API
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_import_fun_from_str
(
path_str
:
str
)
->
Any
:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try
:
# Split off the function name from the rightmost dot
module_path_str
,
function_name
=
path_str
.
rsplit
(
"."
,
1
)
except
ValueError
as
e
:
raise
ValueError
(
f
"Invalid path format:
{
path_str
}
. Expected format: /path/to/module.function_name"
)
from
e
# Convert to Path and handle .py extension
module_path
=
Path
(
module_path_str
)
if
not
module_path
.
suffix
:
module_path
=
module_path
.
with_suffix
(
".py"
)
elif
module_path
.
suffix
!=
".py"
:
# If it has a non-.py suffix, the user might have included .py in the path
# e.g., "/path/to/module.py.function_name"
base_path
=
module_path
.
with_suffix
(
""
)
if
base_path
.
with_suffix
(
".py"
).
exists
():
module_path
=
base_path
.
with_suffix
(
".py"
)
if
not
module_path
.
exists
():
raise
ImportError
(
f
"Module file not found:
{
module_path
}
"
)
# Use similar approach to _import_func_in_yml for consistency
mtime
=
module_path
.
stat
().
st_mtime_ns
cache_key
=
f
"
{
module_path
}
:
{
mtime
}
"
if
cache_key
not
in
sys
.
modules
:
spec
=
importlib
.
util
.
spec_from_file_location
(
cache_key
,
module_path
)
if
spec
is
None
or
spec
.
loader
is
None
:
raise
ImportError
(
f
"Cannot load module from
{
module_path
}
"
)
from
None
module
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
module
)
sys
.
modules
[
cache_key
]
=
module
module
=
sys
.
modules
[
cache_key
]
if
not
hasattr
(
module
,
function_name
):
raise
AttributeError
(
f
"Function '
{
function_name
}
' not found in module
{
module_path
}
"
)
return
getattr
(
module
,
function_name
)
def
load_yaml
(
path
:
str
|
Path
,
*
,
resolve_functions
:
bool
=
True
,
resolve_includes
:
bool
=
True
,
_seen
:
set
[
Path
]
|
None
=
None
,
)
->
dict
[
str
,
str
|
Callable
[...,
Any
]
]
:
"""Pure data
‑
loading helper.
Returns a dict ready for higher
‑
level interpretation.
)
->
dict
[
str
,
Any
]:
"""Pure data
-
loading helper.
Returns a dict ready for higher
-
level interpretation.
•No task/group/tag semantics here.
"""
path
=
Path
(
path
).
expanduser
().
resolve
()
...
...
@@ -82,9 +138,11 @@ def load_yaml(
if
not
resolve_includes
or
"include"
not
in
cfg
:
return
cfg
else
:
includes
=
cfg
.
pop
(
"include"
)
merged
=
{}
for
inc
in
cfg
.
pop
(
"include"
)
:
for
inc
in
includes
if
isinstance
(
includes
,
list
)
else
[
includes
]
:
inc_path
=
(
path
.
parent
/
inc
)
if
not
Path
(
inc
).
is_absolute
()
else
Path
(
inc
)
merged
.
update
(
load_yaml
(
...
...
lm_eval/tasks/factory.py
0 → 100644
View file @
4254c7bd
from
__future__
import
annotations
import
inspect
from
collections.abc
import
Mapping
from
copy
import
deepcopy
from
functools
import
lru_cache
from
typing
import
Any
from
lm_eval.api.group
import
GroupConfig
from
lm_eval.api.task
import
ConfigurableTask
,
Task
# noqa: F401 (typing)
from
lm_eval.tasks._config_loader
import
load_yaml
as
load_cfg
from
lm_eval.tasks.index
import
Entry
,
Kind
load_cfg_cached
=
lru_cache
(
maxsize
=
512
)(
load_cfg
)
# type: ignore[no-redef]
class
TaskFactory
:
"""
Turns a *Entry* (plus optional overrides) into a
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
"""
def
__init__
(
self
,
*
,
meta
:
dict
[
str
,
Any
]
|
None
=
None
):
self
.
_meta
=
meta
or
{}
# ---------------------------------------------------------------- public API
def
build
(
self
,
entry
:
Entry
,
*
,
overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
registry
:
Mapping
[
str
,
Entry
],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if
entry
.
kind
is
Kind
.
TAG
:
return
self
.
_build_tag
(
entry
,
overrides
,
registry
)
if
entry
.
kind
is
Kind
.
GROUP
:
return
self
.
_build_group
(
entry
,
overrides
,
registry
)
return
self
.
_build_task
(
entry
,
overrides
)
def
_build_task
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
):
cfg
=
self
.
_load_full_config
(
entry
,
overrides
)
if
"class"
in
cfg
:
# PY_TASK route
cls
=
cfg
[
"class"
]
obj
=
cls
(
config
=
cfg
)
if
_ctor_accepts_config
(
cls
)
else
cls
()
if
isinstance
(
obj
,
ConfigurableTask
):
obj
.
config
.
task
=
entry
.
name
return
obj
# YAML task
return
ConfigurableTask
(
config
=
cfg
)
# type: ignore[arg-type]
def
_build_group
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
,
registry
:
Mapping
[
str
,
Entry
],
):
raw_cfg
=
self
.
_load_full_config
(
entry
,
None
)
grp_cfg
=
{
k
:
v
for
k
,
v
in
raw_cfg
.
items
()
if
k
in
GroupConfig
.
__annotations__
}
grp_cfg
[
"metadata"
]
=
grp_cfg
.
get
(
"metadata"
,
{})
|
self
.
_meta
group_obj
=
GroupConfig
(
**
grp_cfg
)
children
:
dict
[
str
,
Any
]
=
{}
for
item
in
group_obj
.
task
:
if
isinstance
(
item
,
str
):
# task: hellaswag
child
=
self
.
build
(
registry
[
item
],
overrides
=
overrides
,
# group-level overrides propagate
registry
=
registry
,
)
elif
isinstance
(
item
,
dict
):
# task: {task: hellaswag, num_fewshot: 5}
base_name
=
item
[
"task"
]
child
=
self
.
build
(
registry
[
base_name
],
overrides
=
item
,
# per-item override
registry
=
registry
,
)
else
:
raise
TypeError
(
f
"Unsupported sub-entry
{
item
!
r
}
in group '
{
entry
.
name
}
'"
)
# `child` itself is a mapping (task-name -> obj) or {GroupConfig: ...}
children
.
update
(
child
)
return
{
group_obj
:
children
}
def
_build_tag
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
,
registry
:
Mapping
[
str
,
Entry
],
):
return
{
name
:
self
.
_build_task
(
registry
[
name
],
overrides
)
for
name
in
entry
.
tags
}
def
_load_full_config
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
)
->
dict
[
str
,
Any
]:
if
entry
.
yaml_path
:
cfg
=
deepcopy
(
load_cfg_cached
(
entry
.
yaml_path
,
resolve_functions
=
True
))
print
(
f
"Loaded task config from
{
load_cfg_cached
.
cache_info
()
}
"
)
else
:
cfg
=
{
"metadata"
:
{
"config"
:
"unknown"
}}
# python task without YAML
if
overrides
:
cfg
=
{
**
cfg
,
**
overrides
}
cfg
[
"metadata"
]
=
(
m
if
isinstance
(
m
:
=
cfg
.
get
(
"metadata"
,
{}),
dict
)
else
{
"_metadata"
:
m
}
)
|
self
.
_meta
cfg
.
setdefault
(
"task"
,
entry
.
name
)
return
cfg
def
_ctor_accepts_config
(
cls
)
->
bool
:
init
=
getattr
(
cls
,
"__init__"
,
None
)
return
init
and
"config"
in
inspect
.
signature
(
init
).
parameters
lm_eval/tasks/
_task_
index.py
→
lm_eval/tasks/index.py
View file @
4254c7bd
# lm_eval/task_index.py (continued)
from
__future__
import
annotations
import
logging
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Any
from
lm_eval.tasks._config_loader
import
load_yaml
as
load_cfg
...
...
@@ -14,137 +13,159 @@ if TYPE_CHECKING:
from
pathlib
import
Path
class
Task
Kind
(
Enum
):
class
Kind
(
Enum
):
TASK
=
auto
()
# YAML task, or task_list entry
PY_TASK
=
auto
()
# Python
‑
defined, via "class"
PY_TASK
=
auto
()
# Python
-
defined, via "class"
GROUP
=
auto
()
TAG
=
auto
()
TASK_LIST
=
auto
()
@
dataclass
class
Task
Entry
:
class
Entry
:
name
:
str
kind
:
TaskKind
yaml_path
:
Path
|
None
# None for generated / py‑only entries
kind
:
Kind
yaml_path
:
Path
|
None
# None for generated / py-only entries
cfg
:
dict
[
str
,
str
]
|
None
=
None
tags
:
set
[
str
]
=
field
(
default_factory
=
set
)
task_list_path
:
Path
|
None
=
None
# only for GROUP / TAG when lazy‑loaded
task_list_path
:
Path
|
None
=
None
log
=
logging
.
getLogger
(
__name__
)
_IGNORE_DIRS
=
{
"__pycache__"
,
".ipynb_checkpoints"
}
class
TaskIndex
Builder
:
class
TaskIndex
:
"""Walks one or more directories, parses YAML quickly (functions unresolved),
and produces a mapping {task_name:
Task
Entry}.
and produces a mapping {task_name: Entry}.
"""
def
__init__
(
self
,
*
,
meta
data
:
dict
|
None
=
None
)
->
None
:
self
.
_metadata
=
meta
data
or
{}
def
__init__
(
self
,
*
,
meta
:
dict
[
str
,
str
]
|
None
=
None
)
->
None
:
self
.
_metadata
=
meta
or
{}
# ------------- public API --------------------------------------------------
def
build
(
self
,
paths
:
Iterable
[
Path
],
# include_defaults: bool = True,
)
->
dict
[
str
,
TaskEntry
]:
index
:
dict
[
str
,
TaskEntry
]
=
{}
*
,
resolve_includes
=
False
,
)
->
dict
[
str
,
Entry
]:
index
:
dict
[
str
,
Entry
]
=
{}
log
.
debug
(
"Building task index from %s"
,
paths
)
for
root
in
paths
:
for
yaml_path
in
self
.
_iter_yaml_files
(
root
):
try
:
cfg
=
load_cfg
(
yaml_path
,
resolve_functions
=
False
,
resolve_includes
=
False
,
resolve_includes
=
resolve_includes
,
)
self
.
process_cfg
(
cfg
,
yaml_path
,
index
)
except
Exception
as
err
:
log
.
debug
(
"Skip %s (%s)"
,
yaml_path
,
err
)
continue
self
.
_process_cfg
(
cfg
,
yaml_path
,
index
)
# self._process_cfg(cfg, yaml_path, index)
log
.
debug
(
"Built task index with %d entries"
,
len
(
index
))
return
index
# ------------- helpers -----------------------------------------------------
def
_iter_yaml_files
(
self
,
root
:
Path
):
@
staticmethod
def
_iter_yaml_files
(
root
:
Path
):
yield
from
(
p
for
p
in
root
.
glob
(
"**/*.yaml"
)
if
not
any
(
part
in
_IGNORE_DIRS
for
part
in
p
.
parts
)
)
# ---------------------------------------------------------------------------
def
_process_cfg
(
self
,
cfg
:
dict
,
@
staticmethod
def
process_cfg
(
cfg
:
dict
[
str
,
Any
],
path
:
Path
,
index
:
dict
[
str
,
Task
Entry
],
index
:
dict
[
str
,
Entry
],
)
->
None
:
kind
=
self
.
_kind_of
(
cfg
)
if
kind
is
Task
Kind
.
GROUP
:
kind
=
TaskIndex
.
_kind_of
(
cfg
)
if
kind
is
Kind
.
GROUP
:
grp_name
=
cfg
[
"group"
]
index
[
grp_name
]
=
Task
Entry
(
index
[
grp_name
]
=
Entry
(
name
=
grp_name
,
kind
=
Task
Kind
.
GROUP
,
kind
=
Kind
.
GROUP
,
yaml_path
=
path
,
tags
=
set
(
cfg
.
get
(
"tag"
,
[])),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
return
if
kind
is
Task
Kind
.
PY_TASK
:
if
kind
is
Kind
.
PY_TASK
:
name
=
cfg
[
"task"
]
index
[
name
]
=
Task
Entry
(
index
[
name
]
=
Entry
(
name
=
name
,
kind
=
Task
Kind
.
PY_TASK
,
kind
=
Kind
.
PY_TASK
,
yaml_path
=
None
,
tags
=
set
(
cfg
.
get
(
"tag"
,
[])),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
self
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
,
[]
),
index
)
TaskIndex
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
),
index
)
return
if
kind
is
Task
Kind
.
TASK
:
if
kind
is
Kind
.
TASK
:
name
=
cfg
[
"task"
]
index
[
name
]
=
Task
Entry
(
index
[
name
]
=
Entry
(
name
=
name
,
kind
=
Task
Kind
.
TASK
,
kind
=
Kind
.
TASK
,
yaml_path
=
path
,
tags
=
set
(
cfg
.
get
(
"tag"
,
[])),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
self
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
,
[]
),
index
)
TaskIndex
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
),
index
)
return
if
kind
is
Task
Kind
.
TASK_LIST
:
if
kind
is
Kind
.
TASK_LIST
:
for
entry
in
cfg
[
"task_list"
]:
task_name
=
entry
[
"task"
]
if
isinstance
(
entry
,
dict
)
else
entry
index
[
task_name
]
=
Task
Entry
(
index
[
task_name
]
=
Entry
(
name
=
task_name
,
kind
=
Task
Kind
.
TASK
,
kind
=
Kind
.
TASK
,
yaml_path
=
path
,
tags
=
set
(
entry
.
get
(
"tag"
,
[]))
if
isinstance
(
entry
,
dict
)
else
set
(),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
self
.
_register_tags
(
task_name
,
entry
.
get
(
"tag"
,
[]
),
index
)
TaskIndex
.
_register_tags
(
task_name
,
entry
.
get
(
"tag"
),
index
)
return
# ---------------------------------------------------------------------------
def
_register_tags
(
self
,
task
:
str
,
tags
,
index
)
->
None
:
@
staticmethod
def
_register_tags
(
task
:
str
,
tags
:
str
|
list
[
str
]
|
None
,
index
:
dict
[
str
,
Entry
],
)
->
None
:
if
not
tags
:
return
for
tag
in
tags
if
isinstance
(
tags
,
list
)
else
[
tags
]:
if
not
tag
:
continue
entry
=
index
.
setdefault
(
tag
,
Task
Entry
(
name
=
tag
,
kind
=
Task
Kind
.
TAG
,
yaml_path
=
None
,
tags
=
set
()),
Entry
(
name
=
tag
,
kind
=
Kind
.
TAG
,
yaml_path
=
None
,
tags
=
set
()),
)
entry
.
tags
.
add
(
task
)
# mutate ok; dataclass not frozen for TAG
entry
.
tags
.
add
(
task
)
@
staticmethod
def
_kind_of
(
cfg
:
dict
)
->
Task
Kind
:
def
_kind_of
(
cfg
:
dict
)
->
Kind
:
if
"class"
in
cfg
:
return
TaskKind
.
PY_TASK
return
Kind
.
PY_TASK
if
"group"
in
cfg
:
return
Kind
.
GROUP
if
"task_list"
in
cfg
:
return
Task
Kind
.
TASK_LIST
return
Kind
.
TASK_LIST
if
"task"
in
cfg
:
return
Task
Kind
.
GROUP
if
isinstance
(
cfg
[
"task"
],
list
)
else
Task
Kind
.
TASK
return
Kind
.
GROUP
if
isinstance
(
cfg
[
"task"
],
list
)
else
Kind
.
TASK
msg
=
"Unknown config shape"
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
from
None
@
staticmethod
def
_str_to_set
(
tags
:
str
|
list
[
str
]
|
None
=
None
)
->
set
[
str
]:
"""Convert a string or list of strings to a set of strings."""
return
(
set
(
tags
)
if
isinstance
(
tags
,
list
)
else
{
tags
}
if
isinstance
(
tags
,
str
)
else
set
()
)
lm_eval/tasks/manager.py
0 → 100644
View file @
4254c7bd
from
__future__
import
annotations
from
collections
import
defaultdict
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Any
from
lm_eval.tasks.factory
import
TaskFactory
from
lm_eval.tasks.index
import
Entry
,
Kind
,
TaskIndex
from
lm_eval.utils
import
setup_logging
class
TaskManager
:
def
__init__
(
self
,
verbosity
:
str
|
None
=
None
,
include_path
:
str
|
Path
|
list
[
str
|
Path
]
|
None
=
None
,
include_defaults
:
bool
=
True
,
metadata
:
dict
[
str
,
dict
[
str
,
Any
]]
|
None
=
None
,
)
->
None
:
if
verbosity
:
setup_logging
(
verbosity
)
index
=
TaskIndex
()
self
.
_factory
=
TaskFactory
(
meta
=
metadata
)
all_paths
:
list
[
Path
]
=
[]
if
include_defaults
:
all_paths
.
append
(
Path
(
__file__
).
parent
)
if
include_path
:
all_paths
+=
[
Path
(
p
)
for
p
in
(
include_path
if
isinstance
(
include_path
,
(
list
,
tuple
))
else
[
include_path
]
)
]
self
.
_index
=
index
.
build
(
all_paths
)
buckets
=
defaultdict
(
list
)
for
k
,
e
in
self
.
_index
.
items
():
buckets
[
e
.
kind
].
append
(
k
)
self
.
_all_tasks
=
sorted
(
chain
.
from_iterable
(
buckets
[
k
]
for
k
in
{
Kind
.
TASK
,
Kind
.
PY_TASK
})
)
self
.
_all_groups
=
sorted
(
buckets
[
Kind
.
GROUP
])
self
.
_all_tags
=
sorted
(
buckets
[
Kind
.
TAG
])
def
_entry
(
self
,
name
:
str
)
->
Entry
:
if
name
not
in
self
.
_index
:
raise
KeyError
(
f
"Unknown task/group/tag:
{
name
}
"
)
return
self
.
_index
[
name
]
def
load_spec
(
self
,
spec
:
str
|
dict
[
str
,
Any
]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if
isinstance
(
spec
,
str
):
entry
=
self
.
_entry
(
spec
)
return
self
.
_factory
.
build
(
entry
,
overrides
=
None
,
registry
=
self
.
_index
)
if
isinstance
(
spec
,
dict
):
# inline dict => find base entry, then pass overrides
name
=
spec
[
"task"
]
entry
=
self
.
_entry
(
name
)
return
self
.
_factory
.
build
(
entry
,
overrides
=
spec
,
registry
=
self
.
_index
)
raise
TypeError
(
"spec must be str or dict"
)
def
load_task_or_group
(
self
,
task_list
:
str
|
list
[
str
]):
return
(
[
self
.
load_spec
(
s
)
for
s
in
task_list
]
if
isinstance
(
task_list
,
list
)
else
[
self
.
load_spec
(
task_list
)]
)
pyproject.toml
View file @
4254c7bd
...
...
@@ -103,7 +103,8 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff.lint]
extend-select
=
["I"]
select
=
[
"ASYNC"
,
"B"
,
"C4"
,
"E"
,
"F"
,
"I"
,
"LOG"
,
"PIE"
,
"PTH"
,
"SIM"
,
"UP"
,
"PERF"
,
"ISC001"
,
"ISC002"
,
"ICN001"
,
"C901"
,
"FURB"
,
"RUF"
]
ignore
=
[
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E501"
,
"PERF203"
,
"B011"
]
[tool.ruff.lint.isort]
lines-after-imports
=
2
...
...
@@ -111,7 +112,6 @@ known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
["F401","F402","F403"]
"utils.py"
=
["F401"]
[dependency-groups]
dev
=
[
...
...
tests/test_config_loader.py
View file @
4254c7bd
...
...
@@ -20,7 +20,7 @@ Test coverage:
- load():
- test_load_simple_yaml: basic YAML parsing
- test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_not_resolved: !function tags become
no-op lambda
s
- test_load_with_function_not_resolved: !function tags become
string
s
- test_load_with_includes: include files merged, main values win
- test_load_with_absolute_include: absolute path includes
- test_load_without_includes_resolution: includes preserved when disabled
...
...
@@ -38,9 +38,10 @@ import pytest
from
lm_eval.tasks._config_loader
import
(
_Base
,
_import_func
tion
,
_import_func
_in_yml
,
_make_loader
,
_mk_function_ctor
,
import_fun_from_str
,
load_yaml
,
)
...
...
@@ -75,7 +76,7 @@ class TestMkFunctionCtor:
"""Tests for the YAML !function constructor factory."""
def
test_mk_function_ctor_with_resolve_false
(
self
,
temp_dir
):
"""When resolve=False, should return a
no-op lambda
."""
"""When resolve=False, should return a
string
."""
ctor
=
_mk_function_ctor
(
temp_dir
,
resolve
=
False
)
loader
=
MagicMock
()
...
...
@@ -84,8 +85,7 @@ class TestMkFunctionCtor:
result
=
ctor
(
loader
,
node
)
assert
callable
(
result
)
assert
result
(
"arg1"
,
kwarg
=
"value"
)
is
None
assert
isinstance
(
result
,
str
)
def
test_mk_function_ctor_with_resolve_true
(
self
,
temp_dir
,
python_module
):
"""When resolve=True, should import and return the actual function."""
...
...
@@ -136,7 +136,7 @@ class TestImportFunction:
# Create a local module
python_module
(
"def local_func(x, y):
\n
return x + y
\n
"
)
func
=
_import_func
tion
(
"utils.local_func"
,
temp_dir
)
func
=
_import_func
_in_yml
(
"utils.local_func"
,
temp_dir
)
assert
callable
(
func
)
assert
func
(
2
,
3
)
==
5
...
...
@@ -149,7 +149,7 @@ class TestImportFunction:
"def nested_func():
\n
return 'nested'
\n
"
)
func
=
_import_func
tion
(
"sub.module.nested_func"
,
temp_dir
)
func
=
_import_func
_in_yml
(
"sub.module.nested_func"
,
temp_dir
)
assert
callable
(
func
)
assert
func
()
==
"nested"
...
...
@@ -157,19 +157,19 @@ class TestImportFunction:
def
test_import_standard_module
(
self
,
temp_dir
):
"""Falls back to standard import for non-local modules."""
# Import from standard library
func
=
_import_func
tion
(
"os.path.join"
,
temp_dir
)
func
=
_import_func
_in_yml
(
"os.path.join"
,
temp_dir
)
assert
callable
(
func
)
assert
func
(
"a"
,
"b"
)
in
(
"a/b"
,
"a
\\
b"
)
# Unix or Windows
def
test_import_caching
(
self
,
temp_dir
,
python_module
):
# Clear cache first
_import_func
tion
.
cache_clear
()
_import_func
_in_yml
.
cache_clear
()
python_module
(
"def cached_func():
\n
return 42
\n
"
)
func1
=
_import_func
tion
(
"utils.cached_func"
,
temp_dir
)
func2
=
_import_func
tion
(
"utils.cached_func"
,
temp_dir
)
func1
=
_import_func
_in_yml
(
"utils.cached_func"
,
temp_dir
)
func2
=
_import_func
_in_yml
(
"utils.cached_func"
,
temp_dir
)
assert
func1
is
func2
# Cached
...
...
@@ -177,7 +177,7 @@ class TestImportFunction:
"""Verifies LRU cache behavior - file changes require cache clear."""
# Clear the LRU cache
_import_func
tion
.
cache_clear
()
_import_func
_in_yml
.
cache_clear
()
# Create a module
module_path
=
temp_dir
/
"test_mtime.py"
...
...
@@ -185,17 +185,102 @@ class TestImportFunction:
# Import it
import_key
=
"test_mtime.value"
value1
=
_import_func
tion
(
import_key
,
temp_dir
)
value1
=
_import_func
_in_yml
(
import_key
,
temp_dir
)
assert
value1
==
1
value2
=
_import_func
tion
(
import_key
,
temp_dir
)
value2
=
_import_func
_in_yml
(
import_key
,
temp_dir
)
assert
value2
==
1
# From cache
_import_func
tion
.
cache_clear
()
value3
=
_import_func
tion
(
import_key
,
temp_dir
)
_import_func
_in_yml
.
cache_clear
()
value3
=
_import_func
_in_yml
(
import_key
,
temp_dir
)
assert
value3
==
1
# Re-imported
class
TestImportFunFromStr
:
"""Tests for import_fun_from_str function."""
def
test_import_from_absolute_path
(
self
,
temp_dir
):
"""Test importing function from absolute path."""
# Create a test module
module_path
=
temp_dir
/
"test_module.py"
module_path
.
write_text
(
"def test_func(x):
\n
return x * 2
\n
"
)
# Import using absolute path
func
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.test_func"
)
assert
callable
(
func
)
assert
func
(
5
)
==
10
def
test_import_with_py_extension
(
self
,
temp_dir
):
"""Test importing when .py is included in the path."""
# Create a test module
module_path
=
temp_dir
/
"test_module.py"
module_path
.
write_text
(
"def test_func(x):
\n
return x + 10
\n
"
)
# Import with .py in the path
func
=
import_fun_from_str
(
f
"
{
module_path
}
.test_func"
)
assert
callable
(
func
)
assert
func
(
5
)
==
15
def
test_import_nested_function
(
self
,
temp_dir
):
"""Test importing from nested module structure."""
# Create nested directory
(
temp_dir
/
"subdir"
).
mkdir
()
module_path
=
temp_dir
/
"subdir"
/
"nested.py"
module_path
.
write_text
(
"def nested_func():
\n
return 'nested'
\n
"
)
# Import from nested path
func
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.nested_func"
)
assert
callable
(
func
)
assert
func
()
==
"nested"
def
test_import_missing_module
(
self
,
temp_dir
):
"""Test error when module doesn't exist."""
with
pytest
.
raises
(
ImportError
,
match
=
"Module file not found"
):
import_fun_from_str
(
f
"
{
temp_dir
}
/nonexistent.test_func"
)
def
test_import_missing_function
(
self
,
temp_dir
):
"""Test error when function doesn't exist in module."""
module_path
=
temp_dir
/
"test_module.py"
module_path
.
write_text
(
"def other_func():
\n
pass
\n
"
)
with
pytest
.
raises
(
AttributeError
,
match
=
"Function 'missing_func' not found"
):
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.missing_func"
)
def
test_import_invalid_format
(
self
):
"""Test error with invalid path format."""
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid path format"
):
import_fun_from_str
(
"/path/without/function"
)
def
test_import_caching
(
self
,
temp_dir
):
"""Test that modules are cached by mtime."""
# Clear any existing cache
import
sys
keys_to_remove
=
[
k
for
k
in
sys
.
modules
if
str
(
temp_dir
)
in
k
]
for
k
in
keys_to_remove
:
del
sys
.
modules
[
k
]
module_path
=
temp_dir
/
"cached_module.py"
module_path
.
write_text
(
"call_count = 0
\n
def func():
\n
global call_count
\n
call_count += 1
\n
return call_count
\n
"
)
# First import
func1
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.func"
)
_result1
=
func1
()
# Second import should use cached module
func2
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.func"
)
result2
=
func2
()
# Both should refer to the same module instance
assert
func1
is
func2
assert
result2
==
2
# call_count incremented
class
TestLoad
:
"""Tests for the main YAML loading function with includes and function resolution."""
...
...
@@ -237,8 +322,10 @@ doc_to_text: !function utils.process_doc
result
=
load_yaml
(
file_path
,
resolve_functions
=
False
)
assert
callable
(
result
[
"doc_to_text"
])
assert
result
[
"doc_to_text"
](
"hello"
)
is
None
# No-op lambda
assert
isinstance
(
result
[
"doc_to_text"
],
str
)
# When resolve_functions=False, it returns the full path + function spec
assert
result
[
"doc_to_text"
].
endswith
(
"utils.process_doc"
)
assert
result
[
"doc_to_text"
]
==
str
(
file_path
.
parent
/
"utils.process_doc"
)
def
test_load_with_includes
(
self
,
temp_dir
,
yaml_file
):
"""Include files are merged with local values taking precedence."""
...
...
@@ -388,3 +475,7 @@ shared_key: from_main
mock_expand
.
assert_called_once
()
assert
result
[
"test"
]
==
"value"
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
,
"--tb=short"
])
tests/test_task_index.py
View file @
4254c7bd
"""
Tests for the task index builder that discovers YAML task configurations.
"""Tests for the task index builder that discovers YAML task configurations.
Test coverage:
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
...
...
@@ -14,7 +13,7 @@ from pathlib import Path
import
pytest
from
lm_eval.tasks._task_index
import
TaskIndex
Builder
,
TaskKind
from
lm_eval.tasks._task_index
import
TaskIndex
,
TaskKind
@
pytest
.
fixture
...
...
@@ -40,28 +39,28 @@ class TestTaskKindOf:
def
test_kind_of_task
(
self
):
"""Single task with string name."""
cfg
=
{
"task"
:
"my_task"
,
"dataset_path"
:
"data"
}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK
def
test_kind_of_group
(
self
):
"""Group has task as list."""
cfg
=
{
"task"
:
[
"task1"
,
"task2"
],
"group"
:
"my_group"
}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
GROUP
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
GROUP
def
test_kind_of_py_task
(
self
):
"""Python task has class field."""
cfg
=
{
"task"
:
"my_task"
,
"class"
:
"tasks.MyTask"
}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
PY_TASK
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
PY_TASK
def
test_kind_of_task_list
(
self
):
"""Task list has task_list field."""
cfg
=
{
"task_list"
:
[
"task1"
,
"task2"
]}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK_LIST
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK_LIST
def
test_kind_of_unknown
(
self
):
"""Unknown config raises ValueError."""
cfg
=
{
"unknown"
:
"field"
}
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown config shape"
):
TaskIndex
Builder
.
_kind_of
(
cfg
)
TaskIndex
.
_kind_of
(
cfg
)
class
TestIterYamlFiles
:
...
...
@@ -75,8 +74,8 @@ class TestIterYamlFiles:
(
temp_dir
/
"subdir"
/
"task2.yaml"
).
touch
()
(
temp_dir
/
"other.txt"
).
touch
()
builder
=
TaskIndex
Builder
()
yaml_files
=
list
(
builder
.
_iter_yaml_files
(
temp_dir
))
builder
=
TaskIndex
()
yaml_files
=
list
(
builder
.
_iter_yaml_files
())
assert
len
(
yaml_files
)
==
2
names
=
{
f
.
name
for
f
in
yaml_files
}
...
...
@@ -90,8 +89,8 @@ class TestIterYamlFiles:
(
temp_dir
/
".ipynb_checkpoints"
).
mkdir
()
(
temp_dir
/
".ipynb_checkpoints"
/
"also_ignored.yaml"
).
touch
()
builder
=
TaskIndex
Builder
()
yaml_files
=
list
(
builder
.
_iter_yaml_files
(
temp_dir
))
builder
=
TaskIndex
()
yaml_files
=
list
(
builder
.
_iter_yaml_files
())
assert
len
(
yaml_files
)
==
1
assert
yaml_files
[
0
].
name
==
"task.yaml"
...
...
@@ -106,8 +105,8 @@ class TestProcessCfg:
path
=
temp_dir
/
"task.yaml"
index
=
{}
builder
=
TaskIndex
Builder
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
=
TaskIndex
()
builder
.
process_cfg
(
cfg
,
path
,
index
)
assert
"my_task"
in
index
entry
=
index
[
"my_task"
]
...
...
@@ -122,8 +121,8 @@ class TestProcessCfg:
path
=
temp_dir
/
"group.yaml"
index
=
{}
builder
=
TaskIndex
Builder
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
=
TaskIndex
()
builder
.
process_cfg
(
cfg
,
path
,
index
)
assert
"my_group"
in
index
entry
=
index
[
"my_group"
]
...
...
@@ -138,8 +137,8 @@ class TestProcessCfg:
path
=
temp_dir
/
"py_task.yaml"
index
=
{}
builder
=
TaskIndex
Builder
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
=
TaskIndex
()
builder
.
process_cfg
(
cfg
,
path
,
index
)
assert
"py_task"
in
index
entry
=
index
[
"py_task"
]
...
...
@@ -154,27 +153,30 @@ class TestProcessCfg:
"task_list"
:
[
"simple_task"
,
{
"task"
:
"complex_task"
,
"tag"
:
[
"tag1"
,
"tag2"
]},
]
]
,
}
path
=
temp_dir
/
"list.yaml"
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
# The implementation has a bug - it calls entry.get() on string entries
# This test documents the current behavior which will fail
with
pytest
.
raises
(
AttributeError
,
match
=
"'str' object has no attribute 'get'"
):
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
.
process_cfg
(
cfg
,
path
,
index
)
def
test_process_task_list_dict_entries
(
self
,
temp_dir
):
"""Task list with only dict entries works."""
cfg
=
{
"task_list"
:
[{
"task"
:
"task1"
},
{
"task"
:
"task2"
,
"tag"
:
[
"tag1"
,
"tag2"
]}]
"task_list"
:
[
{
"task"
:
"task1"
},
{
"task"
:
"task2"
,
"tag"
:
[
"tag1"
,
"tag2"
]},
],
}
path
=
temp_dir
/
"list.yaml"
index
=
{}
builder
=
TaskIndex
Builder
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
=
TaskIndex
()
builder
.
process_cfg
(
cfg
,
path
,
index
)
# Task without tags
assert
"task1"
in
index
...
...
@@ -197,7 +199,7 @@ class TestRegisterTags:
def
test_register_single_tag
(
self
):
"""Single tag creates TAG entry."""
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_register_tags
(
"task1"
,
"my_tag"
,
index
)
...
...
@@ -210,7 +212,7 @@ class TestRegisterTags:
def
test_register_multiple_tags
(
self
):
"""Multiple tags create multiple TAG entries."""
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_register_tags
(
"task1"
,
[
"tag1"
,
"tag2"
],
index
)
...
...
@@ -222,7 +224,7 @@ class TestRegisterTags:
def
test_register_tags_accumulates
(
self
):
"""Multiple tasks can have same tag."""
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_register_tags
(
"task1"
,
"shared_tag"
,
index
)
builder
.
_register_tags
(
"task2"
,
"shared_tag"
,
index
)
...
...
@@ -237,7 +239,7 @@ class TestBuild:
def
test_build_empty_directory
(
self
,
temp_dir
):
"""Empty directory returns empty index."""
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
assert
index
==
{}
...
...
@@ -245,7 +247,7 @@ class TestBuild:
"""Single task file is discovered."""
yaml_file
(
"task: my_task
\n
dataset_path: data
\n
"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
assert
len
(
index
)
==
1
...
...
@@ -269,7 +271,7 @@ class TestBuild:
# Python task
yaml_file
(
"task: py_task
\n
class: MyClass
\n
"
,
"python.yaml"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
# Check all entries exist
...
...
@@ -297,7 +299,7 @@ class TestBuild:
yaml_file
(
"task: sub_task
\n
"
,
"subdir/sub.yaml"
)
yaml_file
(
"task: deep_task
\n
"
,
"subdir/deeper/deep.yaml"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
assert
len
(
index
)
==
3
...
...
@@ -308,7 +310,7 @@ class TestBuild:
yaml_file
(
"task: valid_task
\n
"
,
"valid.yaml"
)
yaml_file
(
"invalid: [
\n
"
,
"invalid.yaml"
)
# Invalid YAML
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
assert
len
(
index
)
==
1
...
...
@@ -325,7 +327,7 @@ class TestBuild:
(
dir1
/
"task1.yaml"
).
write_text
(
"task: task1
\n
"
)
(
dir2
/
"task2.yaml"
).
write_text
(
"task: task2
\n
"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
dir1
,
dir2
])
assert
len
(
index
)
==
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment