Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
4254c7bd
Commit
4254c7bd
authored
Jul 26, 2025
by
Baber
Browse files
add task factory
parent
eec9de3e
Changes
10
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
10 changed files
with
601 additions
and
1288 deletions
+601
-1288
.pre-commit-config.yaml
.pre-commit-config.yaml
+1
-1
lm_eval/api/group.py
lm_eval/api/group.py
+5
-3
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+2
-1218
lm_eval/tasks/_config_loader.py
lm_eval/tasks/_config_loader.py
+71
-13
lm_eval/tasks/factory.py
lm_eval/tasks/factory.py
+126
-0
lm_eval/tasks/index.py
lm_eval/tasks/index.py
+171
-0
lm_eval/tasks/manager.py
lm_eval/tasks/manager.py
+79
-0
pyproject.toml
pyproject.toml
+2
-2
tests/test_config_loader.py
tests/test_config_loader.py
+109
-18
tests/test_task_index.py
tests/test_task_index.py
+35
-33
No files found.
.pre-commit-config.yaml
View file @
4254c7bd
...
@@ -29,7 +29,7 @@ repos:
...
@@ -29,7 +29,7 @@ repos:
-
id
:
mixed-line-ending
-
id
:
mixed-line-ending
args
:
[
--fix=lf
]
args
:
[
--fix=lf
]
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
-
repo
:
https://github.com/astral-sh/ruff-pre-commit
rev
:
v0.12.
2
rev
:
v0.12.
5
hooks
:
hooks
:
# Run the linter.
# Run the linter.
-
id
:
ruff
-
id
:
ruff
...
...
lm_eval/api/group.py
View file @
4254c7bd
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
Optional
,
Union
from
datasets.features.pdf
import
field
@
dataclass
@
dataclass
...
@@ -25,9 +27,9 @@ class AggMetricConfig(dict):
...
@@ -25,9 +27,9 @@ class AggMetricConfig(dict):
class
GroupConfig
:
class
GroupConfig
:
group
:
Optional
[
str
]
=
None
group
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]
]
=
None
task
:
Union
[
str
,
list
]
=
field
(
default_factory
=
list
)
aggregate_metric_list
:
Optional
[
aggregate_metric_list
:
Optional
[
Union
[
L
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
Union
[
l
ist
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
]
=
None
metadata
:
Optional
[
dict
]
=
(
metadata
:
Optional
[
dict
]
=
(
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
...
...
lm_eval/tasks/__init__.py
View file @
4254c7bd
This diff is collapsed.
Click to expand it.
lm_eval/tasks/_config_loader.py
View file @
4254c7bd
...
@@ -3,29 +3,31 @@ from __future__ import annotations
...
@@ -3,29 +3,31 @@ from __future__ import annotations
import
functools
import
functools
import
importlib.util
import
importlib.util
import
sys
import
sys
from
collections.abc
import
Callable
from
pathlib
import
Path
from
pathlib
import
Path
from
typing
import
Any
from
typing
import
Any
import
yaml
import
yaml
_Base
=
yaml
.
CLoader
if
getattr
(
yaml
,
"__with_libyaml__"
,
False
)
else
yaml
.
FullLoader
_Base
=
(
yaml
.
CSafeLoader
if
getattr
(
yaml
,
"__with_libyaml__"
,
False
)
else
yaml
.
FullLoader
)
_IGNORE_DIRS
=
{
"__pycache__"
,
".ipynb_checkpoints"
}
_IGNORE_DIRS
=
{
"__pycache__"
,
".ipynb_checkpoints"
}
# --------------------------------------------------------------------------- helpers
# --------------------------------------------------------------------------- helpers
@
functools
.
lru_cache
(
128
)
def
_mk_function_ctor
(
base_dir
:
Path
,
resolve
:
bool
):
def
_mk_function_ctor
(
base_dir
:
Path
,
resolve
:
bool
):
def
ctor
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
):
def
ctor
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
):
spec
=
loader
.
construct_scalar
(
node
)
# type: ignore[arg-type]
spec
=
loader
.
construct_scalar
(
node
)
# type: ignore[arg-type]
if
not
resolve
:
if
not
resolve
:
return
lambda
*
_
,
**
__
:
None
return
str
(
base_dir
.
expanduser
()
/
spec
)
return
_import_func
tion
(
spec
,
base_dir
)
return
_import_func
_in_yml
(
spec
,
base_dir
)
return
ctor
return
ctor
@
functools
.
lru_cache
(
maxsize
=
1024
)
@
functools
.
lru_cache
(
maxsize
=
512
)
def
_make_loader
(
base_dir
:
Path
,
*
,
resolve_funcs
:
bool
)
->
type
[
yaml
.
Loader
]:
def
_make_loader
(
base_dir
:
Path
,
*
,
resolve_funcs
:
bool
)
->
type
[
yaml
.
Loader
]:
class
Loader
(
_Base
):
...
# type: ignore[no-redef]
class
Loader
(
_Base
):
...
# type: ignore[no-redef]
...
@@ -37,8 +39,14 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
...
@@ -37,8 +39,14 @@ def _make_loader(base_dir: Path, *, resolve_funcs: bool) -> type[yaml.Loader]:
return
Loader
return
Loader
@
functools
.
lru_cache
(
maxsize
=
4096
)
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_import_function
(
qual
:
str
,
base_dir
:
Path
):
def
_import_func_in_yml
(
qual
:
str
,
base_dir
:
Path
):
"""Import function from qual: utils.process_doc, checking local files first then standard imports.
Args:
qual: Qualified function name (e.g., 'utils.process_doc')
base_dir: Directory to search for local modules
"""
mod_path
,
_
,
fn_name
=
qual
.
rpartition
(
"."
)
mod_path
,
_
,
fn_name
=
qual
.
rpartition
(
"."
)
# 1) relative “utils.py” next to YAML
# 1) relative “utils.py” next to YAML
rel
=
(
base_dir
/
f
"
{
mod_path
.
replace
(
'.'
,
'/'
)
}
.py"
).
resolve
()
rel
=
(
base_dir
/
f
"
{
mod_path
.
replace
(
'.'
,
'/'
)
}
.py"
).
resolve
()
...
@@ -47,26 +55,74 @@ def _import_function(qual: str, base_dir: Path):
...
@@ -47,26 +55,74 @@ def _import_function(qual: str, base_dir: Path):
key
=
f
"
{
rel
}
:
{
mtime
}
"
# one module per mtime
key
=
f
"
{
rel
}
:
{
mtime
}
"
# one module per mtime
if
key
not
in
sys
.
modules
:
if
key
not
in
sys
.
modules
:
spec
=
importlib
.
util
.
spec_from_file_location
(
key
,
rel
)
spec
=
importlib
.
util
.
spec_from_file_location
(
key
,
rel
)
if
spec
is
None
or
spec
.
loader
is
None
:
raise
ImportError
(
f
"Cannot load module from
{
rel
}
"
)
from
None
mod
=
importlib
.
util
.
module_from_spec
(
spec
)
mod
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
mod
)
# type: ignore[arg-type]
spec
.
loader
.
exec_module
(
mod
)
# type: ignore[arg-type]
sys
.
modules
[
key
]
=
mod
sys
.
modules
[
key
]
=
mod
return
getattr
(
sys
.
modules
[
key
],
fn_name
)
return
getattr
(
sys
.
modules
[
key
],
fn_name
)
# 2) already
‑
importable module
# 2) already
-
importable module
module
=
__import__
(
mod_path
,
fromlist
=
[
fn_name
])
module
=
__import__
(
mod_path
,
fromlist
=
[
fn_name
])
return
getattr
(
module
,
fn_name
)
return
getattr
(
module
,
fn_name
)
# --------------------------------------------------------------------- public API
@
functools
.
lru_cache
(
maxsize
=
128
)
def
_import_fun_from_str
(
path_str
:
str
)
->
Any
:
"""Import a function from a string in the form '/absolute/path/to/module.function_name'."""
try
:
# Split off the function name from the rightmost dot
module_path_str
,
function_name
=
path_str
.
rsplit
(
"."
,
1
)
except
ValueError
as
e
:
raise
ValueError
(
f
"Invalid path format:
{
path_str
}
. Expected format: /path/to/module.function_name"
)
from
e
# Convert to Path and handle .py extension
module_path
=
Path
(
module_path_str
)
if
not
module_path
.
suffix
:
module_path
=
module_path
.
with_suffix
(
".py"
)
elif
module_path
.
suffix
!=
".py"
:
# If it has a non-.py suffix, the user might have included .py in the path
# e.g., "/path/to/module.py.function_name"
base_path
=
module_path
.
with_suffix
(
""
)
if
base_path
.
with_suffix
(
".py"
).
exists
():
module_path
=
base_path
.
with_suffix
(
".py"
)
if
not
module_path
.
exists
():
raise
ImportError
(
f
"Module file not found:
{
module_path
}
"
)
# Use similar approach to _import_func_in_yml for consistency
mtime
=
module_path
.
stat
().
st_mtime_ns
cache_key
=
f
"
{
module_path
}
:
{
mtime
}
"
if
cache_key
not
in
sys
.
modules
:
spec
=
importlib
.
util
.
spec_from_file_location
(
cache_key
,
module_path
)
if
spec
is
None
or
spec
.
loader
is
None
:
raise
ImportError
(
f
"Cannot load module from
{
module_path
}
"
)
from
None
module
=
importlib
.
util
.
module_from_spec
(
spec
)
spec
.
loader
.
exec_module
(
module
)
sys
.
modules
[
cache_key
]
=
module
module
=
sys
.
modules
[
cache_key
]
if
not
hasattr
(
module
,
function_name
):
raise
AttributeError
(
f
"Function '
{
function_name
}
' not found in module
{
module_path
}
"
)
return
getattr
(
module
,
function_name
)
def
load_yaml
(
def
load_yaml
(
path
:
str
|
Path
,
path
:
str
|
Path
,
*
,
*
,
resolve_functions
:
bool
=
True
,
resolve_functions
:
bool
=
True
,
resolve_includes
:
bool
=
True
,
resolve_includes
:
bool
=
True
,
_seen
:
set
[
Path
]
|
None
=
None
,
_seen
:
set
[
Path
]
|
None
=
None
,
)
->
dict
[
str
,
str
|
Callable
[...,
Any
]
]
:
)
->
dict
[
str
,
Any
]:
"""Pure data
‑
loading helper.
"""Pure data
-
loading helper.
Returns a dict ready for higher
‑
level interpretation.
Returns a dict ready for higher
-
level interpretation.
•No task/group/tag semantics here.
•No task/group/tag semantics here.
"""
"""
path
=
Path
(
path
).
expanduser
().
resolve
()
path
=
Path
(
path
).
expanduser
().
resolve
()
...
@@ -82,9 +138,11 @@ def load_yaml(
...
@@ -82,9 +138,11 @@ def load_yaml(
if
not
resolve_includes
or
"include"
not
in
cfg
:
if
not
resolve_includes
or
"include"
not
in
cfg
:
return
cfg
return
cfg
else
:
includes
=
cfg
.
pop
(
"include"
)
merged
=
{}
merged
=
{}
for
inc
in
cfg
.
pop
(
"include"
)
:
for
inc
in
includes
if
isinstance
(
includes
,
list
)
else
[
includes
]
:
inc_path
=
(
path
.
parent
/
inc
)
if
not
Path
(
inc
).
is_absolute
()
else
Path
(
inc
)
inc_path
=
(
path
.
parent
/
inc
)
if
not
Path
(
inc
).
is_absolute
()
else
Path
(
inc
)
merged
.
update
(
merged
.
update
(
load_yaml
(
load_yaml
(
...
...
lm_eval/tasks/factory.py
0 → 100644
View file @
4254c7bd
from
__future__
import
annotations
import
inspect
from
collections.abc
import
Mapping
from
copy
import
deepcopy
from
functools
import
lru_cache
from
typing
import
Any
from
lm_eval.api.group
import
GroupConfig
from
lm_eval.api.task
import
ConfigurableTask
,
Task
# noqa: F401 (typing)
from
lm_eval.tasks._config_loader
import
load_yaml
as
load_cfg
from
lm_eval.tasks.index
import
Entry
,
Kind
load_cfg_cached
=
lru_cache
(
maxsize
=
512
)(
load_cfg
)
# type: ignore[no-redef]
class
TaskFactory
:
"""
Turns a *Entry* (plus optional overrides) into a
*Task* | *ConfigurableTask* | *GroupConfig* hierarchy.
"""
def
__init__
(
self
,
*
,
meta
:
dict
[
str
,
Any
]
|
None
=
None
):
self
.
_meta
=
meta
or
{}
# ---------------------------------------------------------------- public API
def
build
(
self
,
entry
:
Entry
,
*
,
overrides
:
dict
[
str
,
Any
]
|
None
=
None
,
registry
:
Mapping
[
str
,
Entry
],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if
entry
.
kind
is
Kind
.
TAG
:
return
self
.
_build_tag
(
entry
,
overrides
,
registry
)
if
entry
.
kind
is
Kind
.
GROUP
:
return
self
.
_build_group
(
entry
,
overrides
,
registry
)
return
self
.
_build_task
(
entry
,
overrides
)
def
_build_task
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
):
cfg
=
self
.
_load_full_config
(
entry
,
overrides
)
if
"class"
in
cfg
:
# PY_TASK route
cls
=
cfg
[
"class"
]
obj
=
cls
(
config
=
cfg
)
if
_ctor_accepts_config
(
cls
)
else
cls
()
if
isinstance
(
obj
,
ConfigurableTask
):
obj
.
config
.
task
=
entry
.
name
return
obj
# YAML task
return
ConfigurableTask
(
config
=
cfg
)
# type: ignore[arg-type]
def
_build_group
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
,
registry
:
Mapping
[
str
,
Entry
],
):
raw_cfg
=
self
.
_load_full_config
(
entry
,
None
)
grp_cfg
=
{
k
:
v
for
k
,
v
in
raw_cfg
.
items
()
if
k
in
GroupConfig
.
__annotations__
}
grp_cfg
[
"metadata"
]
=
grp_cfg
.
get
(
"metadata"
,
{})
|
self
.
_meta
group_obj
=
GroupConfig
(
**
grp_cfg
)
children
:
dict
[
str
,
Any
]
=
{}
for
item
in
group_obj
.
task
:
if
isinstance
(
item
,
str
):
# task: hellaswag
child
=
self
.
build
(
registry
[
item
],
overrides
=
overrides
,
# group-level overrides propagate
registry
=
registry
,
)
elif
isinstance
(
item
,
dict
):
# task: {task: hellaswag, num_fewshot: 5}
base_name
=
item
[
"task"
]
child
=
self
.
build
(
registry
[
base_name
],
overrides
=
item
,
# per-item override
registry
=
registry
,
)
else
:
raise
TypeError
(
f
"Unsupported sub-entry
{
item
!
r
}
in group '
{
entry
.
name
}
'"
)
# `child` itself is a mapping (task-name -> obj) or {GroupConfig: ...}
children
.
update
(
child
)
return
{
group_obj
:
children
}
def
_build_tag
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
,
registry
:
Mapping
[
str
,
Entry
],
):
return
{
name
:
self
.
_build_task
(
registry
[
name
],
overrides
)
for
name
in
entry
.
tags
}
def
_load_full_config
(
self
,
entry
:
Entry
,
overrides
:
dict
[
str
,
Any
]
|
None
)
->
dict
[
str
,
Any
]:
if
entry
.
yaml_path
:
cfg
=
deepcopy
(
load_cfg_cached
(
entry
.
yaml_path
,
resolve_functions
=
True
))
print
(
f
"Loaded task config from
{
load_cfg_cached
.
cache_info
()
}
"
)
else
:
cfg
=
{
"metadata"
:
{
"config"
:
"unknown"
}}
# python task without YAML
if
overrides
:
cfg
=
{
**
cfg
,
**
overrides
}
cfg
[
"metadata"
]
=
(
m
if
isinstance
(
m
:
=
cfg
.
get
(
"metadata"
,
{}),
dict
)
else
{
"_metadata"
:
m
}
)
|
self
.
_meta
cfg
.
setdefault
(
"task"
,
entry
.
name
)
return
cfg
def
_ctor_accepts_config
(
cls
)
->
bool
:
init
=
getattr
(
cls
,
"__init__"
,
None
)
return
init
and
"config"
in
inspect
.
signature
(
init
).
parameters
lm_eval/tasks/
_task_
index.py
→
lm_eval/tasks/index.py
View file @
4254c7bd
# lm_eval/task_index.py (continued)
from
__future__
import
annotations
from
__future__
import
annotations
import
logging
import
logging
from
dataclasses
import
dataclass
,
field
from
dataclasses
import
dataclass
,
field
from
enum
import
Enum
,
auto
from
enum
import
Enum
,
auto
from
typing
import
TYPE_CHECKING
from
typing
import
TYPE_CHECKING
,
Any
from
lm_eval.tasks._config_loader
import
load_yaml
as
load_cfg
from
lm_eval.tasks._config_loader
import
load_yaml
as
load_cfg
...
@@ -14,137 +13,159 @@ if TYPE_CHECKING:
...
@@ -14,137 +13,159 @@ if TYPE_CHECKING:
from
pathlib
import
Path
from
pathlib
import
Path
class
Task
Kind
(
Enum
):
class
Kind
(
Enum
):
TASK
=
auto
()
# YAML task, or task_list entry
TASK
=
auto
()
# YAML task, or task_list entry
PY_TASK
=
auto
()
# Python
‑
defined, via "class"
PY_TASK
=
auto
()
# Python
-
defined, via "class"
GROUP
=
auto
()
GROUP
=
auto
()
TAG
=
auto
()
TAG
=
auto
()
TASK_LIST
=
auto
()
TASK_LIST
=
auto
()
@
dataclass
@
dataclass
class
Task
Entry
:
class
Entry
:
name
:
str
name
:
str
kind
:
TaskKind
kind
:
Kind
yaml_path
:
Path
|
None
# None for generated / py‑only entries
yaml_path
:
Path
|
None
# None for generated / py-only entries
cfg
:
dict
[
str
,
str
]
|
None
=
None
tags
:
set
[
str
]
=
field
(
default_factory
=
set
)
tags
:
set
[
str
]
=
field
(
default_factory
=
set
)
task_list_path
:
Path
|
None
=
None
# only for GROUP / TAG when lazy‑loaded
task_list_path
:
Path
|
None
=
None
log
=
logging
.
getLogger
(
__name__
)
log
=
logging
.
getLogger
(
__name__
)
_IGNORE_DIRS
=
{
"__pycache__"
,
".ipynb_checkpoints"
}
_IGNORE_DIRS
=
{
"__pycache__"
,
".ipynb_checkpoints"
}
class
TaskIndex
Builder
:
class
TaskIndex
:
"""Walks one or more directories, parses YAML quickly (functions unresolved),
"""Walks one or more directories, parses YAML quickly (functions unresolved),
and produces a mapping {task_name:
Task
Entry}.
and produces a mapping {task_name: Entry}.
"""
"""
def
__init__
(
self
,
*
,
meta
data
:
dict
|
None
=
None
)
->
None
:
def
__init__
(
self
,
*
,
meta
:
dict
[
str
,
str
]
|
None
=
None
)
->
None
:
self
.
_metadata
=
meta
data
or
{}
self
.
_metadata
=
meta
or
{}
# ------------- public API --------------------------------------------------
def
build
(
def
build
(
self
,
self
,
paths
:
Iterable
[
Path
],
paths
:
Iterable
[
Path
],
# include_defaults: bool = True,
*
,
)
->
dict
[
str
,
TaskEntry
]:
resolve_includes
=
False
,
index
:
dict
[
str
,
TaskEntry
]
=
{}
)
->
dict
[
str
,
Entry
]:
index
:
dict
[
str
,
Entry
]
=
{}
log
.
debug
(
"Building task index from %s"
,
paths
)
for
root
in
paths
:
for
root
in
paths
:
for
yaml_path
in
self
.
_iter_yaml_files
(
root
):
for
yaml_path
in
self
.
_iter_yaml_files
(
root
):
try
:
try
:
cfg
=
load_cfg
(
cfg
=
load_cfg
(
yaml_path
,
yaml_path
,
resolve_functions
=
False
,
resolve_functions
=
False
,
resolve_includes
=
False
,
resolve_includes
=
resolve_includes
,
)
)
self
.
process_cfg
(
cfg
,
yaml_path
,
index
)
except
Exception
as
err
:
except
Exception
as
err
:
log
.
debug
(
"Skip %s (%s)"
,
yaml_path
,
err
)
log
.
debug
(
"Skip %s (%s)"
,
yaml_path
,
err
)
continue
continue
self
.
_process_cfg
(
cfg
,
yaml_path
,
index
)
# self._process_cfg(cfg, yaml_path, index)
log
.
debug
(
"Built task index with %d entries"
,
len
(
index
))
return
index
return
index
# ------------- helpers -----------------------------------------------------
@
staticmethod
def
_iter_yaml_files
(
self
,
root
:
Path
):
def
_iter_yaml_files
(
root
:
Path
):
yield
from
(
yield
from
(
p
p
for
p
in
root
.
glob
(
"**/*.yaml"
)
for
p
in
root
.
glob
(
"**/*.yaml"
)
if
not
any
(
part
in
_IGNORE_DIRS
for
part
in
p
.
parts
)
if
not
any
(
part
in
_IGNORE_DIRS
for
part
in
p
.
parts
)
)
)
# ---------------------------------------------------------------------------
@
staticmethod
def
_process_cfg
(
def
process_cfg
(
self
,
cfg
:
dict
[
str
,
Any
],
cfg
:
dict
,
path
:
Path
,
path
:
Path
,
index
:
dict
[
str
,
Task
Entry
],
index
:
dict
[
str
,
Entry
],
)
->
None
:
)
->
None
:
kind
=
self
.
_kind_of
(
cfg
)
kind
=
TaskIndex
.
_kind_of
(
cfg
)
if
kind
is
Task
Kind
.
GROUP
:
if
kind
is
Kind
.
GROUP
:
grp_name
=
cfg
[
"group"
]
grp_name
=
cfg
[
"group"
]
index
[
grp_name
]
=
Task
Entry
(
index
[
grp_name
]
=
Entry
(
name
=
grp_name
,
name
=
grp_name
,
kind
=
Task
Kind
.
GROUP
,
kind
=
Kind
.
GROUP
,
yaml_path
=
path
,
yaml_path
=
path
,
tags
=
set
(
cfg
.
get
(
"tag"
,
[])),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
)
return
return
if
kind
is
Task
Kind
.
PY_TASK
:
if
kind
is
Kind
.
PY_TASK
:
name
=
cfg
[
"task"
]
name
=
cfg
[
"task"
]
index
[
name
]
=
Task
Entry
(
index
[
name
]
=
Entry
(
name
=
name
,
name
=
name
,
kind
=
Task
Kind
.
PY_TASK
,
kind
=
Kind
.
PY_TASK
,
yaml_path
=
None
,
yaml_path
=
None
,
tags
=
set
(
cfg
.
get
(
"tag"
,
[])),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
)
self
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
,
[]
),
index
)
TaskIndex
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
),
index
)
return
return
if
kind
is
Task
Kind
.
TASK
:
if
kind
is
Kind
.
TASK
:
name
=
cfg
[
"task"
]
name
=
cfg
[
"task"
]
index
[
name
]
=
Task
Entry
(
index
[
name
]
=
Entry
(
name
=
name
,
name
=
name
,
kind
=
Task
Kind
.
TASK
,
kind
=
Kind
.
TASK
,
yaml_path
=
path
,
yaml_path
=
path
,
tags
=
set
(
cfg
.
get
(
"tag"
,
[])),
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
cfg
=
cfg
,
)
)
self
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
,
[]
),
index
)
TaskIndex
.
_register_tags
(
name
,
cfg
.
get
(
"tag"
),
index
)
return
return
if
kind
is
Task
Kind
.
TASK_LIST
:
if
kind
is
Kind
.
TASK_LIST
:
for
entry
in
cfg
[
"task_list"
]:
for
entry
in
cfg
[
"task_list"
]:
task_name
=
entry
[
"task"
]
if
isinstance
(
entry
,
dict
)
else
entry
task_name
=
entry
[
"task"
]
if
isinstance
(
entry
,
dict
)
else
entry
index
[
task_name
]
=
Task
Entry
(
index
[
task_name
]
=
Entry
(
name
=
task_name
,
name
=
task_name
,
kind
=
Task
Kind
.
TASK
,
kind
=
Kind
.
TASK
,
yaml_path
=
path
,
yaml_path
=
path
,
tags
=
set
(
entry
.
get
(
"tag"
,
[]))
tags
=
TaskIndex
.
_str_to_set
(
cfg
.
get
(
"tag"
)),
if
isinstance
(
entry
,
dict
)
cfg
=
cfg
,
else
set
(),
)
)
self
.
_register_tags
(
task_name
,
entry
.
get
(
"tag"
,
[]
),
index
)
TaskIndex
.
_register_tags
(
task_name
,
entry
.
get
(
"tag"
),
index
)
return
return
# ---------------------------------------------------------------------------
@
staticmethod
def
_register_tags
(
self
,
task
:
str
,
tags
,
index
)
->
None
:
def
_register_tags
(
task
:
str
,
tags
:
str
|
list
[
str
]
|
None
,
index
:
dict
[
str
,
Entry
],
)
->
None
:
if
not
tags
:
return
for
tag
in
tags
if
isinstance
(
tags
,
list
)
else
[
tags
]:
for
tag
in
tags
if
isinstance
(
tags
,
list
)
else
[
tags
]:
if
not
tag
:
continue
entry
=
index
.
setdefault
(
entry
=
index
.
setdefault
(
tag
,
tag
,
Task
Entry
(
name
=
tag
,
kind
=
Task
Kind
.
TAG
,
yaml_path
=
None
,
tags
=
set
()),
Entry
(
name
=
tag
,
kind
=
Kind
.
TAG
,
yaml_path
=
None
,
tags
=
set
()),
)
)
entry
.
tags
.
add
(
task
)
# mutate ok; dataclass not frozen for TAG
entry
.
tags
.
add
(
task
)
@
staticmethod
@
staticmethod
def
_kind_of
(
cfg
:
dict
)
->
Task
Kind
:
def
_kind_of
(
cfg
:
dict
)
->
Kind
:
if
"class"
in
cfg
:
if
"class"
in
cfg
:
return
TaskKind
.
PY_TASK
return
Kind
.
PY_TASK
if
"group"
in
cfg
:
return
Kind
.
GROUP
if
"task_list"
in
cfg
:
if
"task_list"
in
cfg
:
return
Task
Kind
.
TASK_LIST
return
Kind
.
TASK_LIST
if
"task"
in
cfg
:
if
"task"
in
cfg
:
return
Task
Kind
.
GROUP
if
isinstance
(
cfg
[
"task"
],
list
)
else
Task
Kind
.
TASK
return
Kind
.
GROUP
if
isinstance
(
cfg
[
"task"
],
list
)
else
Kind
.
TASK
msg
=
"Unknown config shape"
msg
=
"Unknown config shape"
raise
ValueError
(
msg
)
raise
ValueError
(
msg
)
from
None
@
staticmethod
def
_str_to_set
(
tags
:
str
|
list
[
str
]
|
None
=
None
)
->
set
[
str
]:
"""Convert a string or list of strings to a set of strings."""
return
(
set
(
tags
)
if
isinstance
(
tags
,
list
)
else
{
tags
}
if
isinstance
(
tags
,
str
)
else
set
()
)
lm_eval/tasks/manager.py
0 → 100644
View file @
4254c7bd
from
__future__
import
annotations
from
collections
import
defaultdict
from
itertools
import
chain
from
pathlib
import
Path
from
typing
import
Any
from
lm_eval.tasks.factory
import
TaskFactory
from
lm_eval.tasks.index
import
Entry
,
Kind
,
TaskIndex
from
lm_eval.utils
import
setup_logging
class
TaskManager
:
def
__init__
(
self
,
verbosity
:
str
|
None
=
None
,
include_path
:
str
|
Path
|
list
[
str
|
Path
]
|
None
=
None
,
include_defaults
:
bool
=
True
,
metadata
:
dict
[
str
,
dict
[
str
,
Any
]]
|
None
=
None
,
)
->
None
:
if
verbosity
:
setup_logging
(
verbosity
)
index
=
TaskIndex
()
self
.
_factory
=
TaskFactory
(
meta
=
metadata
)
all_paths
:
list
[
Path
]
=
[]
if
include_defaults
:
all_paths
.
append
(
Path
(
__file__
).
parent
)
if
include_path
:
all_paths
+=
[
Path
(
p
)
for
p
in
(
include_path
if
isinstance
(
include_path
,
(
list
,
tuple
))
else
[
include_path
]
)
]
self
.
_index
=
index
.
build
(
all_paths
)
buckets
=
defaultdict
(
list
)
for
k
,
e
in
self
.
_index
.
items
():
buckets
[
e
.
kind
].
append
(
k
)
self
.
_all_tasks
=
sorted
(
chain
.
from_iterable
(
buckets
[
k
]
for
k
in
{
Kind
.
TASK
,
Kind
.
PY_TASK
})
)
self
.
_all_groups
=
sorted
(
buckets
[
Kind
.
GROUP
])
self
.
_all_tags
=
sorted
(
buckets
[
Kind
.
TAG
])
def
_entry
(
self
,
name
:
str
)
->
Entry
:
if
name
not
in
self
.
_index
:
raise
KeyError
(
f
"Unknown task/group/tag:
{
name
}
"
)
return
self
.
_index
[
name
]
def
load_spec
(
self
,
spec
:
str
|
dict
[
str
,
Any
]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if
isinstance
(
spec
,
str
):
entry
=
self
.
_entry
(
spec
)
return
self
.
_factory
.
build
(
entry
,
overrides
=
None
,
registry
=
self
.
_index
)
if
isinstance
(
spec
,
dict
):
# inline dict => find base entry, then pass overrides
name
=
spec
[
"task"
]
entry
=
self
.
_entry
(
name
)
return
self
.
_factory
.
build
(
entry
,
overrides
=
spec
,
registry
=
self
.
_index
)
raise
TypeError
(
"spec must be str or dict"
)
def
load_task_or_group
(
self
,
task_list
:
str
|
list
[
str
]):
return
(
[
self
.
load_spec
(
s
)
for
s
in
task_list
]
if
isinstance
(
task_list
,
list
)
else
[
self
.
load_spec
(
task_list
)]
)
pyproject.toml
View file @
4254c7bd
...
@@ -103,7 +103,8 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
...
@@ -103,7 +103,8 @@ plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled
=
false
# no-bare-urls
plugins.md034.enabled
=
false
# no-bare-urls
[tool.ruff.lint]
[tool.ruff.lint]
extend-select
=
["I"]
select
=
[
"ASYNC"
,
"B"
,
"C4"
,
"E"
,
"F"
,
"I"
,
"LOG"
,
"PIE"
,
"PTH"
,
"SIM"
,
"UP"
,
"PERF"
,
"ISC001"
,
"ISC002"
,
"ICN001"
,
"C901"
,
"FURB"
,
"RUF"
]
ignore
=
[
"E501"
,
"E111"
,
"E114"
,
"E117"
,
"E501"
,
"PERF203"
,
"B011"
]
[tool.ruff.lint.isort]
[tool.ruff.lint.isort]
lines-after-imports
=
2
lines-after-imports
=
2
...
@@ -111,7 +112,6 @@ known-first-party = ["lm_eval"]
...
@@ -111,7 +112,6 @@ known-first-party = ["lm_eval"]
[tool.ruff.lint.extend-per-file-ignores]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py"
=
["F401","F402","F403"]
"__init__.py"
=
["F401","F402","F403"]
"utils.py"
=
["F401"]
[dependency-groups]
[dependency-groups]
dev
=
[
dev
=
[
...
...
tests/test_config_loader.py
View file @
4254c7bd
...
@@ -20,7 +20,7 @@ Test coverage:
...
@@ -20,7 +20,7 @@ Test coverage:
- load():
- load():
- test_load_simple_yaml: basic YAML parsing
- test_load_simple_yaml: basic YAML parsing
- test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_resolved: !function tags resolved to callables
- test_load_with_function_not_resolved: !function tags become
no-op lambda
s
- test_load_with_function_not_resolved: !function tags become
string
s
- test_load_with_includes: include files merged, main values win
- test_load_with_includes: include files merged, main values win
- test_load_with_absolute_include: absolute path includes
- test_load_with_absolute_include: absolute path includes
- test_load_without_includes_resolution: includes preserved when disabled
- test_load_without_includes_resolution: includes preserved when disabled
...
@@ -38,9 +38,10 @@ import pytest
...
@@ -38,9 +38,10 @@ import pytest
from
lm_eval.tasks._config_loader
import
(
from
lm_eval.tasks._config_loader
import
(
_Base
,
_Base
,
_import_func
tion
,
_import_func
_in_yml
,
_make_loader
,
_make_loader
,
_mk_function_ctor
,
_mk_function_ctor
,
import_fun_from_str
,
load_yaml
,
load_yaml
,
)
)
...
@@ -75,7 +76,7 @@ class TestMkFunctionCtor:
...
@@ -75,7 +76,7 @@ class TestMkFunctionCtor:
"""Tests for the YAML !function constructor factory."""
"""Tests for the YAML !function constructor factory."""
def
test_mk_function_ctor_with_resolve_false
(
self
,
temp_dir
):
def
test_mk_function_ctor_with_resolve_false
(
self
,
temp_dir
):
"""When resolve=False, should return a
no-op lambda
."""
"""When resolve=False, should return a
string
."""
ctor
=
_mk_function_ctor
(
temp_dir
,
resolve
=
False
)
ctor
=
_mk_function_ctor
(
temp_dir
,
resolve
=
False
)
loader
=
MagicMock
()
loader
=
MagicMock
()
...
@@ -84,8 +85,7 @@ class TestMkFunctionCtor:
...
@@ -84,8 +85,7 @@ class TestMkFunctionCtor:
result
=
ctor
(
loader
,
node
)
result
=
ctor
(
loader
,
node
)
assert
callable
(
result
)
assert
isinstance
(
result
,
str
)
assert
result
(
"arg1"
,
kwarg
=
"value"
)
is
None
def
test_mk_function_ctor_with_resolve_true
(
self
,
temp_dir
,
python_module
):
def
test_mk_function_ctor_with_resolve_true
(
self
,
temp_dir
,
python_module
):
"""When resolve=True, should import and return the actual function."""
"""When resolve=True, should import and return the actual function."""
...
@@ -136,7 +136,7 @@ class TestImportFunction:
...
@@ -136,7 +136,7 @@ class TestImportFunction:
# Create a local module
# Create a local module
python_module
(
"def local_func(x, y):
\n
return x + y
\n
"
)
python_module
(
"def local_func(x, y):
\n
return x + y
\n
"
)
func
=
_import_func
tion
(
"utils.local_func"
,
temp_dir
)
func
=
_import_func
_in_yml
(
"utils.local_func"
,
temp_dir
)
assert
callable
(
func
)
assert
callable
(
func
)
assert
func
(
2
,
3
)
==
5
assert
func
(
2
,
3
)
==
5
...
@@ -149,7 +149,7 @@ class TestImportFunction:
...
@@ -149,7 +149,7 @@ class TestImportFunction:
"def nested_func():
\n
return 'nested'
\n
"
"def nested_func():
\n
return 'nested'
\n
"
)
)
func
=
_import_func
tion
(
"sub.module.nested_func"
,
temp_dir
)
func
=
_import_func
_in_yml
(
"sub.module.nested_func"
,
temp_dir
)
assert
callable
(
func
)
assert
callable
(
func
)
assert
func
()
==
"nested"
assert
func
()
==
"nested"
...
@@ -157,19 +157,19 @@ class TestImportFunction:
...
@@ -157,19 +157,19 @@ class TestImportFunction:
def
test_import_standard_module
(
self
,
temp_dir
):
def
test_import_standard_module
(
self
,
temp_dir
):
"""Falls back to standard import for non-local modules."""
"""Falls back to standard import for non-local modules."""
# Import from standard library
# Import from standard library
func
=
_import_func
tion
(
"os.path.join"
,
temp_dir
)
func
=
_import_func
_in_yml
(
"os.path.join"
,
temp_dir
)
assert
callable
(
func
)
assert
callable
(
func
)
assert
func
(
"a"
,
"b"
)
in
(
"a/b"
,
"a
\\
b"
)
# Unix or Windows
assert
func
(
"a"
,
"b"
)
in
(
"a/b"
,
"a
\\
b"
)
# Unix or Windows
def
test_import_caching
(
self
,
temp_dir
,
python_module
):
def
test_import_caching
(
self
,
temp_dir
,
python_module
):
# Clear cache first
# Clear cache first
_import_func
tion
.
cache_clear
()
_import_func
_in_yml
.
cache_clear
()
python_module
(
"def cached_func():
\n
return 42
\n
"
)
python_module
(
"def cached_func():
\n
return 42
\n
"
)
func1
=
_import_func
tion
(
"utils.cached_func"
,
temp_dir
)
func1
=
_import_func
_in_yml
(
"utils.cached_func"
,
temp_dir
)
func2
=
_import_func
tion
(
"utils.cached_func"
,
temp_dir
)
func2
=
_import_func
_in_yml
(
"utils.cached_func"
,
temp_dir
)
assert
func1
is
func2
# Cached
assert
func1
is
func2
# Cached
...
@@ -177,7 +177,7 @@ class TestImportFunction:
...
@@ -177,7 +177,7 @@ class TestImportFunction:
"""Verifies LRU cache behavior - file changes require cache clear."""
"""Verifies LRU cache behavior - file changes require cache clear."""
# Clear the LRU cache
# Clear the LRU cache
_import_func
tion
.
cache_clear
()
_import_func
_in_yml
.
cache_clear
()
# Create a module
# Create a module
module_path
=
temp_dir
/
"test_mtime.py"
module_path
=
temp_dir
/
"test_mtime.py"
...
@@ -185,17 +185,102 @@ class TestImportFunction:
...
@@ -185,17 +185,102 @@ class TestImportFunction:
# Import it
# Import it
import_key
=
"test_mtime.value"
import_key
=
"test_mtime.value"
value1
=
_import_func
tion
(
import_key
,
temp_dir
)
value1
=
_import_func
_in_yml
(
import_key
,
temp_dir
)
assert
value1
==
1
assert
value1
==
1
value2
=
_import_func
tion
(
import_key
,
temp_dir
)
value2
=
_import_func
_in_yml
(
import_key
,
temp_dir
)
assert
value2
==
1
# From cache
assert
value2
==
1
# From cache
_import_func
tion
.
cache_clear
()
_import_func
_in_yml
.
cache_clear
()
value3
=
_import_func
tion
(
import_key
,
temp_dir
)
value3
=
_import_func
_in_yml
(
import_key
,
temp_dir
)
assert
value3
==
1
# Re-imported
assert
value3
==
1
# Re-imported
class
TestImportFunFromStr
:
"""Tests for import_fun_from_str function."""
def
test_import_from_absolute_path
(
self
,
temp_dir
):
"""Test importing function from absolute path."""
# Create a test module
module_path
=
temp_dir
/
"test_module.py"
module_path
.
write_text
(
"def test_func(x):
\n
return x * 2
\n
"
)
# Import using absolute path
func
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.test_func"
)
assert
callable
(
func
)
assert
func
(
5
)
==
10
def
test_import_with_py_extension
(
self
,
temp_dir
):
"""Test importing when .py is included in the path."""
# Create a test module
module_path
=
temp_dir
/
"test_module.py"
module_path
.
write_text
(
"def test_func(x):
\n
return x + 10
\n
"
)
# Import with .py in the path
func
=
import_fun_from_str
(
f
"
{
module_path
}
.test_func"
)
assert
callable
(
func
)
assert
func
(
5
)
==
15
def
test_import_nested_function
(
self
,
temp_dir
):
"""Test importing from nested module structure."""
# Create nested directory
(
temp_dir
/
"subdir"
).
mkdir
()
module_path
=
temp_dir
/
"subdir"
/
"nested.py"
module_path
.
write_text
(
"def nested_func():
\n
return 'nested'
\n
"
)
# Import from nested path
func
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.nested_func"
)
assert
callable
(
func
)
assert
func
()
==
"nested"
def
test_import_missing_module
(
self
,
temp_dir
):
"""Test error when module doesn't exist."""
with
pytest
.
raises
(
ImportError
,
match
=
"Module file not found"
):
import_fun_from_str
(
f
"
{
temp_dir
}
/nonexistent.test_func"
)
def
test_import_missing_function
(
self
,
temp_dir
):
"""Test error when function doesn't exist in module."""
module_path
=
temp_dir
/
"test_module.py"
module_path
.
write_text
(
"def other_func():
\n
pass
\n
"
)
with
pytest
.
raises
(
AttributeError
,
match
=
"Function 'missing_func' not found"
):
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.missing_func"
)
def
test_import_invalid_format
(
self
):
"""Test error with invalid path format."""
with
pytest
.
raises
(
ValueError
,
match
=
"Invalid path format"
):
import_fun_from_str
(
"/path/without/function"
)
def
test_import_caching
(
self
,
temp_dir
):
"""Test that modules are cached by mtime."""
# Clear any existing cache
import
sys
keys_to_remove
=
[
k
for
k
in
sys
.
modules
if
str
(
temp_dir
)
in
k
]
for
k
in
keys_to_remove
:
del
sys
.
modules
[
k
]
module_path
=
temp_dir
/
"cached_module.py"
module_path
.
write_text
(
"call_count = 0
\n
def func():
\n
global call_count
\n
call_count += 1
\n
return call_count
\n
"
)
# First import
func1
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.func"
)
_result1
=
func1
()
# Second import should use cached module
func2
=
import_fun_from_str
(
f
"
{
module_path
.
with_suffix
(
''
)
}
.func"
)
result2
=
func2
()
# Both should refer to the same module instance
assert
func1
is
func2
assert
result2
==
2
# call_count incremented
class
TestLoad
:
class
TestLoad
:
"""Tests for the main YAML loading function with includes and function resolution."""
"""Tests for the main YAML loading function with includes and function resolution."""
...
@@ -237,8 +322,10 @@ doc_to_text: !function utils.process_doc
...
@@ -237,8 +322,10 @@ doc_to_text: !function utils.process_doc
result
=
load_yaml
(
file_path
,
resolve_functions
=
False
)
result
=
load_yaml
(
file_path
,
resolve_functions
=
False
)
assert
callable
(
result
[
"doc_to_text"
])
assert
isinstance
(
result
[
"doc_to_text"
],
str
)
assert
result
[
"doc_to_text"
](
"hello"
)
is
None
# No-op lambda
# When resolve_functions=False, it returns the full path + function spec
assert
result
[
"doc_to_text"
].
endswith
(
"utils.process_doc"
)
assert
result
[
"doc_to_text"
]
==
str
(
file_path
.
parent
/
"utils.process_doc"
)
def
test_load_with_includes
(
self
,
temp_dir
,
yaml_file
):
def
test_load_with_includes
(
self
,
temp_dir
,
yaml_file
):
"""Include files are merged with local values taking precedence."""
"""Include files are merged with local values taking precedence."""
...
@@ -388,3 +475,7 @@ shared_key: from_main
...
@@ -388,3 +475,7 @@ shared_key: from_main
mock_expand
.
assert_called_once
()
mock_expand
.
assert_called_once
()
assert
result
[
"test"
]
==
"value"
assert
result
[
"test"
]
==
"value"
if
__name__
==
"__main__"
:
pytest
.
main
([
__file__
,
"-v"
,
"--tb=short"
])
tests/test_task_index.py
View file @
4254c7bd
"""
"""Tests for the task index builder that discovers YAML task configurations.
Tests for the task index builder that discovers YAML task configurations.
Test coverage:
Test coverage:
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
- TaskIndexBuilder._kind_of: identifies task/group/tag/task_list/py_task
...
@@ -14,7 +13,7 @@ from pathlib import Path
...
@@ -14,7 +13,7 @@ from pathlib import Path
import
pytest
import
pytest
from
lm_eval.tasks._task_index
import
TaskIndex
Builder
,
TaskKind
from
lm_eval.tasks._task_index
import
TaskIndex
,
TaskKind
@
pytest
.
fixture
@
pytest
.
fixture
...
@@ -40,28 +39,28 @@ class TestTaskKindOf:
...
@@ -40,28 +39,28 @@ class TestTaskKindOf:
def
test_kind_of_task
(
self
):
def
test_kind_of_task
(
self
):
"""Single task with string name."""
"""Single task with string name."""
cfg
=
{
"task"
:
"my_task"
,
"dataset_path"
:
"data"
}
cfg
=
{
"task"
:
"my_task"
,
"dataset_path"
:
"data"
}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK
def
test_kind_of_group
(
self
):
def
test_kind_of_group
(
self
):
"""Group has task as list."""
"""Group has task as list."""
cfg
=
{
"task"
:
[
"task1"
,
"task2"
],
"group"
:
"my_group"
}
cfg
=
{
"task"
:
[
"task1"
,
"task2"
],
"group"
:
"my_group"
}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
GROUP
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
GROUP
def
test_kind_of_py_task
(
self
):
def
test_kind_of_py_task
(
self
):
"""Python task has class field."""
"""Python task has class field."""
cfg
=
{
"task"
:
"my_task"
,
"class"
:
"tasks.MyTask"
}
cfg
=
{
"task"
:
"my_task"
,
"class"
:
"tasks.MyTask"
}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
PY_TASK
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
PY_TASK
def
test_kind_of_task_list
(
self
):
def
test_kind_of_task_list
(
self
):
"""Task list has task_list field."""
"""Task list has task_list field."""
cfg
=
{
"task_list"
:
[
"task1"
,
"task2"
]}
cfg
=
{
"task_list"
:
[
"task1"
,
"task2"
]}
assert
TaskIndex
Builder
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK_LIST
assert
TaskIndex
.
_kind_of
(
cfg
)
==
TaskKind
.
TASK_LIST
def
test_kind_of_unknown
(
self
):
def
test_kind_of_unknown
(
self
):
"""Unknown config raises ValueError."""
"""Unknown config raises ValueError."""
cfg
=
{
"unknown"
:
"field"
}
cfg
=
{
"unknown"
:
"field"
}
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown config shape"
):
with
pytest
.
raises
(
ValueError
,
match
=
"Unknown config shape"
):
TaskIndex
Builder
.
_kind_of
(
cfg
)
TaskIndex
.
_kind_of
(
cfg
)
class
TestIterYamlFiles
:
class
TestIterYamlFiles
:
...
@@ -75,8 +74,8 @@ class TestIterYamlFiles:
...
@@ -75,8 +74,8 @@ class TestIterYamlFiles:
(
temp_dir
/
"subdir"
/
"task2.yaml"
).
touch
()
(
temp_dir
/
"subdir"
/
"task2.yaml"
).
touch
()
(
temp_dir
/
"other.txt"
).
touch
()
(
temp_dir
/
"other.txt"
).
touch
()
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
yaml_files
=
list
(
builder
.
_iter_yaml_files
(
temp_dir
))
yaml_files
=
list
(
builder
.
_iter_yaml_files
())
assert
len
(
yaml_files
)
==
2
assert
len
(
yaml_files
)
==
2
names
=
{
f
.
name
for
f
in
yaml_files
}
names
=
{
f
.
name
for
f
in
yaml_files
}
...
@@ -90,8 +89,8 @@ class TestIterYamlFiles:
...
@@ -90,8 +89,8 @@ class TestIterYamlFiles:
(
temp_dir
/
".ipynb_checkpoints"
).
mkdir
()
(
temp_dir
/
".ipynb_checkpoints"
).
mkdir
()
(
temp_dir
/
".ipynb_checkpoints"
/
"also_ignored.yaml"
).
touch
()
(
temp_dir
/
".ipynb_checkpoints"
/
"also_ignored.yaml"
).
touch
()
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
yaml_files
=
list
(
builder
.
_iter_yaml_files
(
temp_dir
))
yaml_files
=
list
(
builder
.
_iter_yaml_files
())
assert
len
(
yaml_files
)
==
1
assert
len
(
yaml_files
)
==
1
assert
yaml_files
[
0
].
name
==
"task.yaml"
assert
yaml_files
[
0
].
name
==
"task.yaml"
...
@@ -106,8 +105,8 @@ class TestProcessCfg:
...
@@ -106,8 +105,8 @@ class TestProcessCfg:
path
=
temp_dir
/
"task.yaml"
path
=
temp_dir
/
"task.yaml"
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
.
process_cfg
(
cfg
,
path
,
index
)
assert
"my_task"
in
index
assert
"my_task"
in
index
entry
=
index
[
"my_task"
]
entry
=
index
[
"my_task"
]
...
@@ -122,8 +121,8 @@ class TestProcessCfg:
...
@@ -122,8 +121,8 @@ class TestProcessCfg:
path
=
temp_dir
/
"group.yaml"
path
=
temp_dir
/
"group.yaml"
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
.
process_cfg
(
cfg
,
path
,
index
)
assert
"my_group"
in
index
assert
"my_group"
in
index
entry
=
index
[
"my_group"
]
entry
=
index
[
"my_group"
]
...
@@ -138,8 +137,8 @@ class TestProcessCfg:
...
@@ -138,8 +137,8 @@ class TestProcessCfg:
path
=
temp_dir
/
"py_task.yaml"
path
=
temp_dir
/
"py_task.yaml"
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
.
process_cfg
(
cfg
,
path
,
index
)
assert
"py_task"
in
index
assert
"py_task"
in
index
entry
=
index
[
"py_task"
]
entry
=
index
[
"py_task"
]
...
@@ -154,27 +153,30 @@ class TestProcessCfg:
...
@@ -154,27 +153,30 @@ class TestProcessCfg:
"task_list"
:
[
"task_list"
:
[
"simple_task"
,
"simple_task"
,
{
"task"
:
"complex_task"
,
"tag"
:
[
"tag1"
,
"tag2"
]},
{
"task"
:
"complex_task"
,
"tag"
:
[
"tag1"
,
"tag2"
]},
]
]
,
}
}
path
=
temp_dir
/
"list.yaml"
path
=
temp_dir
/
"list.yaml"
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
# The implementation has a bug - it calls entry.get() on string entries
# The implementation has a bug - it calls entry.get() on string entries
# This test documents the current behavior which will fail
# This test documents the current behavior which will fail
with
pytest
.
raises
(
AttributeError
,
match
=
"'str' object has no attribute 'get'"
):
with
pytest
.
raises
(
AttributeError
,
match
=
"'str' object has no attribute 'get'"
):
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
.
process_cfg
(
cfg
,
path
,
index
)
def
test_process_task_list_dict_entries
(
self
,
temp_dir
):
def
test_process_task_list_dict_entries
(
self
,
temp_dir
):
"""Task list with only dict entries works."""
"""Task list with only dict entries works."""
cfg
=
{
cfg
=
{
"task_list"
:
[{
"task"
:
"task1"
},
{
"task"
:
"task2"
,
"tag"
:
[
"tag1"
,
"tag2"
]}]
"task_list"
:
[
{
"task"
:
"task1"
},
{
"task"
:
"task2"
,
"tag"
:
[
"tag1"
,
"tag2"
]},
],
}
}
path
=
temp_dir
/
"list.yaml"
path
=
temp_dir
/
"list.yaml"
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_
process_cfg
(
cfg
,
path
,
index
)
builder
.
process_cfg
(
cfg
,
path
,
index
)
# Task without tags
# Task without tags
assert
"task1"
in
index
assert
"task1"
in
index
...
@@ -197,7 +199,7 @@ class TestRegisterTags:
...
@@ -197,7 +199,7 @@ class TestRegisterTags:
def
test_register_single_tag
(
self
):
def
test_register_single_tag
(
self
):
"""Single tag creates TAG entry."""
"""Single tag creates TAG entry."""
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_register_tags
(
"task1"
,
"my_tag"
,
index
)
builder
.
_register_tags
(
"task1"
,
"my_tag"
,
index
)
...
@@ -210,7 +212,7 @@ class TestRegisterTags:
...
@@ -210,7 +212,7 @@ class TestRegisterTags:
def
test_register_multiple_tags
(
self
):
def
test_register_multiple_tags
(
self
):
"""Multiple tags create multiple TAG entries."""
"""Multiple tags create multiple TAG entries."""
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_register_tags
(
"task1"
,
[
"tag1"
,
"tag2"
],
index
)
builder
.
_register_tags
(
"task1"
,
[
"tag1"
,
"tag2"
],
index
)
...
@@ -222,7 +224,7 @@ class TestRegisterTags:
...
@@ -222,7 +224,7 @@ class TestRegisterTags:
def
test_register_tags_accumulates
(
self
):
def
test_register_tags_accumulates
(
self
):
"""Multiple tasks can have same tag."""
"""Multiple tasks can have same tag."""
index
=
{}
index
=
{}
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
builder
.
_register_tags
(
"task1"
,
"shared_tag"
,
index
)
builder
.
_register_tags
(
"task1"
,
"shared_tag"
,
index
)
builder
.
_register_tags
(
"task2"
,
"shared_tag"
,
index
)
builder
.
_register_tags
(
"task2"
,
"shared_tag"
,
index
)
...
@@ -237,7 +239,7 @@ class TestBuild:
...
@@ -237,7 +239,7 @@ class TestBuild:
def
test_build_empty_directory
(
self
,
temp_dir
):
def
test_build_empty_directory
(
self
,
temp_dir
):
"""Empty directory returns empty index."""
"""Empty directory returns empty index."""
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
index
=
builder
.
build
([
temp_dir
])
assert
index
==
{}
assert
index
==
{}
...
@@ -245,7 +247,7 @@ class TestBuild:
...
@@ -245,7 +247,7 @@ class TestBuild:
"""Single task file is discovered."""
"""Single task file is discovered."""
yaml_file
(
"task: my_task
\n
dataset_path: data
\n
"
)
yaml_file
(
"task: my_task
\n
dataset_path: data
\n
"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
index
=
builder
.
build
([
temp_dir
])
assert
len
(
index
)
==
1
assert
len
(
index
)
==
1
...
@@ -269,7 +271,7 @@ class TestBuild:
...
@@ -269,7 +271,7 @@ class TestBuild:
# Python task
# Python task
yaml_file
(
"task: py_task
\n
class: MyClass
\n
"
,
"python.yaml"
)
yaml_file
(
"task: py_task
\n
class: MyClass
\n
"
,
"python.yaml"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
index
=
builder
.
build
([
temp_dir
])
# Check all entries exist
# Check all entries exist
...
@@ -297,7 +299,7 @@ class TestBuild:
...
@@ -297,7 +299,7 @@ class TestBuild:
yaml_file
(
"task: sub_task
\n
"
,
"subdir/sub.yaml"
)
yaml_file
(
"task: sub_task
\n
"
,
"subdir/sub.yaml"
)
yaml_file
(
"task: deep_task
\n
"
,
"subdir/deeper/deep.yaml"
)
yaml_file
(
"task: deep_task
\n
"
,
"subdir/deeper/deep.yaml"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
index
=
builder
.
build
([
temp_dir
])
assert
len
(
index
)
==
3
assert
len
(
index
)
==
3
...
@@ -308,7 +310,7 @@ class TestBuild:
...
@@ -308,7 +310,7 @@ class TestBuild:
yaml_file
(
"task: valid_task
\n
"
,
"valid.yaml"
)
yaml_file
(
"task: valid_task
\n
"
,
"valid.yaml"
)
yaml_file
(
"invalid: [
\n
"
,
"invalid.yaml"
)
# Invalid YAML
yaml_file
(
"invalid: [
\n
"
,
"invalid.yaml"
)
# Invalid YAML
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
temp_dir
])
index
=
builder
.
build
([
temp_dir
])
assert
len
(
index
)
==
1
assert
len
(
index
)
==
1
...
@@ -325,7 +327,7 @@ class TestBuild:
...
@@ -325,7 +327,7 @@ class TestBuild:
(
dir1
/
"task1.yaml"
).
write_text
(
"task: task1
\n
"
)
(
dir1
/
"task1.yaml"
).
write_text
(
"task: task1
\n
"
)
(
dir2
/
"task2.yaml"
).
write_text
(
"task: task2
\n
"
)
(
dir2
/
"task2.yaml"
).
write_text
(
"task: task2
\n
"
)
builder
=
TaskIndex
Builder
()
builder
=
TaskIndex
()
index
=
builder
.
build
([
dir1
,
dir2
])
index
=
builder
.
build
([
dir1
,
dir2
])
assert
len
(
index
)
==
2
assert
len
(
index
)
==
2
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment