Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7fcfb4ac
"git@developer.sourcefind.cn:gaoqiong/migraphx.git" did not exist on "d18d38c4a4a2885fd43e9d70cea9da7c0b4605fd"
Commit
7fcfb4ac
authored
Jul 12, 2025
by
Baber
Browse files
refactor: simplify docstrings and improve task name matching logic
parent
5e632643
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
215 additions
and
337 deletions
+215
-337
lm_eval/api/group.py
lm_eval/api/group.py
+7
-2
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+208
-335
No files found.
lm_eval/api/group.py
View file @
7fcfb4ac
import
abc
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
...
@@ -83,8 +82,14 @@ class GroupConfig(dict):
...
@@ -83,8 +82,14 @@ class GroupConfig(dict):
except
(
TypeError
,
OSError
):
except
(
TypeError
,
OSError
):
return
str
(
value
)
return
str
(
value
)
@
property
def
version
(
self
)
->
str
:
"""Returns the version of the group configuration."""
return
self
.
metadata
.
get
(
"version"
,
"1.0"
)
class
ConfigurableGroup
(
abc
.
ABC
):
@
dataclass
class
ConfigurableGroup
:
def
__init__
(
def
__init__
(
self
,
self
,
config
:
Optional
[
dict
]
=
None
,
config
:
Optional
[
dict
]
=
None
,
...
...
lm_eval/tasks/__init__.py
View file @
7fcfb4ac
...
@@ -73,19 +73,7 @@ _IGNORE_DIRS = (
...
@@ -73,19 +73,7 @@ _IGNORE_DIRS = (
def
ignore_constructor
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
)
->
None
:
def
ignore_constructor
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
)
->
None
:
"""
"""YAML constructor that ignores !function tags during simple parsing."""
YAML constructor that ignores !function tags during simple parsing.
This is used when mode="simple" to skip function resolution for
faster indexing operations.
Args:
loader: YAML loader instance
node: YAML node being processed
Returns:
None
"""
return
None
return
None
...
@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
...
@@ -129,8 +117,7 @@ def _import_function(qualname: str, *, base_path: Path) -> Callable:
Dynamically import a function from a Python module relative to base_path.
Dynamically import a function from a Python module relative to base_path.
This function enables YAML files to reference Python functions using
This function enables YAML files to reference Python functions using
the !function tag. It supports dot notation for nested modules and
the !function tag. Supports dot notation for nested modules.
caches imported modules for performance.
Args:
Args:
qualname: Qualified function name like "my_module.my_function"
qualname: Qualified function name like "my_module.my_function"
...
@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict:
...
@@ -180,7 +167,7 @@ def _parse_yaml_file(path: Path, mode: str) -> dict:
@
functools
.
lru_cache
(
maxsize
=
4096
)
@
functools
.
lru_cache
(
maxsize
=
4096
)
def
_get_cached_config
(
yaml_path
:
Path
,
mode
:
str
)
->
dict
:
def
_get_cached_config
(
yaml_path
:
Path
,
mode
:
str
)
->
dict
:
"""Load and cache resolved YAML configs
with LRU eviction.
"""
"""Load and cache resolved YAML configs"""
# Parse the YAML file
# Parse the YAML file
yaml_config
=
_parse_yaml_file
(
yaml_path
,
mode
)
yaml_config
=
_parse_yaml_file
(
yaml_path
,
mode
)
yaml_dir
=
yaml_path
.
parent
yaml_dir
=
yaml_path
.
parent
...
@@ -288,7 +275,7 @@ def load_yaml_config(
...
@@ -288,7 +275,7 @@ def load_yaml_config(
return
final_cfg
return
final_cfg
def
iter_yaml_files
(
root
:
Path
)
->
Generator
[
Path
,
Any
,
None
]:
def
iter_yaml_files
(
root
:
Path
,
ignore
=
_IGNORE_DIRS
)
->
Generator
[
Path
,
Any
,
None
]:
"""
"""
Recursively iterate over all YAML files in a directory tree.
Recursively iterate over all YAML files in a directory tree.
...
@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
...
@@ -306,7 +293,7 @@ def iter_yaml_files(root: Path) -> Generator[Path, Any, None]:
"""
"""
for
p
in
iglob
(
"**/*.yaml"
,
root_dir
=
root
,
recursive
=
True
):
for
p
in
iglob
(
"**/*.yaml"
,
root_dir
=
root
,
recursive
=
True
):
# ignore check
# ignore check
if
Path
(
p
).
parts
[
0
]
in
_IGNORE_DIRS
:
if
Path
(
p
).
parts
[
0
]
in
ignore
:
continue
continue
yield
root
/
p
yield
root
/
p
...
@@ -352,7 +339,7 @@ class TaskManager:
...
@@ -352,7 +339,7 @@ class TaskManager:
verbosity
:
Optional
[
str
]
=
None
,
verbosity
:
Optional
[
str
]
=
None
,
include_path
:
Optional
[
Union
[
str
,
Path
,
list
[
Union
[
str
,
Path
]]]]
=
None
,
include_path
:
Optional
[
Union
[
str
,
Path
,
list
[
Union
[
str
,
Path
]]]]
=
None
,
include_defaults
:
bool
=
True
,
include_defaults
:
bool
=
True
,
metadata
:
Optional
[
dict
]
=
None
,
metadata
:
Optional
[
dict
[
str
,
dict
[
str
,
Any
]]
]
=
None
,
)
->
None
:
)
->
None
:
"""
"""
Initialize the TaskManager.
Initialize the TaskManager.
...
@@ -548,21 +535,7 @@ class TaskManager:
...
@@ -548,21 +535,7 @@ class TaskManager:
return
""
.
join
(
parts
)
return
""
.
join
(
parts
)
def
match_tasks
(
self
,
task_list
:
list
[
str
])
->
list
[
str
]:
def
match_tasks
(
self
,
task_list
:
list
[
str
])
->
list
[
str
]:
"""
"""Match task names using glob-style pattern matching."""
Match task names using pattern matching.
Supports glob-style patterns and returns all matching task names.
Args:
task_list: List of task name patterns to match
Returns:
List of matching task names
Example:
>>> tm.match_tasks(["hella*", "arc_*"])
['hellaswag', 'arc_easy', 'arc_challenge']
"""
return
pattern_match
(
task_list
,
self
.
all_tasks
)
return
pattern_match
(
task_list
,
self
.
all_tasks
)
def
_name_is_registered
(
self
,
name
:
str
)
->
bool
:
def
_name_is_registered
(
self
,
name
:
str
)
->
bool
:
...
@@ -738,276 +711,195 @@ class TaskManager:
...
@@ -738,276 +711,195 @@ class TaskManager:
else
False
else
False
)
)
def
_load_individual_task_or_group
(
###############################################################################
# NEW: Refactored _load_individual_task_or_group and helper methods #
###############################################################################
def
_create_task_object
(
self
,
self
,
name_or_config
:
Optional
[
Union
[
str
,
dict
]]
=
None
,
cfg
:
dict
,
parent
_name
:
Optional
[
str
]
=
None
,
task
_name
:
str
,
update_config
:
Optional
[
dict
]
=
None
,
yaml_path
:
str
|
None
,
)
->
Mapping
:
)
->
dict
:
"""
"""
Load a single task or group with all its configurations and dependencies.
Instantiate a single task (ConfigurableTask **or** python-task) from *cfg*.
Returns {task_name: task_object}.
This is the core method for instantiating task objects from either task names
or configuration dictionaries. It handles complex scenarios including:
- Individual tasks and Python class-based tasks
- Groups and their constituent subtasks
- Tags and their associated tasks
- Configuration merging and inheritance
- Duplicate detection and name resolution
- Include processing and YAML inheritance
Args:
name_or_config: Either a task name (str) or configuration dict.
If str, looks up the task in the index.
If dict, processes as inline configuration.
parent_name: Name of parent group (for duplicate detection)
update_config: Additional configuration to merge into task configs
Returns:
Mapping of task/group names to instantiated task objects.
For individual tasks: {task_name: ConfigurableTask}
For groups: {group_name: {subtask1: Task1, subtask2: Task2, ...}}
Example:
Load individual task::
task_dict = tm._load_individual_task_or_group("hellaswag")
# Returns: {"hellaswag": ConfigurableTask(...)}
Load with config override::
task_dict = tm._load_individual_task_or_group(
{"task": "hellaswag", "num_fewshot": 5}
)
Load a group::
group_dict = tm._load_individual_task_or_group("arc_group")
# Returns: {"arc_group": {"arc_easy": Task1, "arc_challenge": Task2}}
"""
"""
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.api.task
import
ConfigurableTask
,
Task
# local import avoids cycle
def
_load_task
(
# ---- include handling ---------------------------------------------------
config
:
dict
,
task
:
str
,
yaml_path
:
Optional
[
str
]
=
None
if
"include"
in
cfg
:
)
->
dict
[
str
,
Union
[
"ConfigurableTask"
,
"Task"
]]:
# keep original name so include merging cannot clobber it
"""
orig_name
=
cfg
.
get
(
"task"
,
task_name
)
Create a single task object from configuration.
cfg
=
{
**
load_yaml_config
(
# recurse once, cached
Handles include processing, Python class instantiation, and metadata injection.
yaml_path
=
Path
(
yaml_path
)
if
yaml_path
else
None
,
yaml_config
=
{
"include"
:
cfg
.
pop
(
"include"
)},
Args:
mode
=
"full"
if
yaml_path
else
"simple"
,
config: Task configuration dictionary
),
task: Task name
**
cfg
,
yaml_path: Path to source YAML file (for include resolution)
"task"
:
orig_name
,
}
Returns:
Dictionary mapping task name to instantiated task object
"""
if
"include"
in
config
:
# Store the task name to preserve it after include processing
original_task_name
=
config
.
get
(
"task"
,
task
)
config
=
{
**
load_yaml_config
(
yaml_path
=
Path
(
yaml_path
),
yaml_config
=
{
"include"
:
config
.
pop
(
"include"
)},
mode
=
"full"
if
yaml_path
else
"simple"
,
),
**
config
,
"task"
:
original_task_name
,
}
# Ensure the task name from the group config is preserved
# This prevents tasks with the same include from being treated as duplicates
if
self
.
_config_is_python_task
(
config
):
# ---- metadata merge -----------------------------------------------------
if
self
.
_class_has_config_in_constructor
(
config
[
"class"
]):
if
self
.
metadata
is
not
None
:
task_object
=
config
[
"class"
](
config
=
config
)
cfg
[
"metadata"
]
=
cfg
.
get
(
"metadata"
,
{})
|
self
.
metadata
else
:
else
:
task_object
=
config
[
"class"
]()
cfg
[
"metadata"
]
=
cfg
.
get
(
"metadata"
,
{})
if
isinstance
(
task_object
,
ConfigurableTask
):
# very scuffed: set task name here. TODO: fixme?
# ---- python-task vs YAML-task -------------------------------------------
task_object
.
config
.
task
=
task
if
self
.
_config_is_python_task
(
cfg
):
cls
=
cfg
[
"class"
]
task_obj
:
Task
if
self
.
_class_has_config_in_constructor
(
cls
):
task_obj
=
cls
(
config
=
cfg
)
else
:
else
:
if
self
.
metadata
is
not
None
:
task_obj
=
cls
()
config
[
"metadata"
]
=
config
.
get
(
"metadata"
,
{})
|
self
.
metadata
# make sure name propagates when the class inherits ConfigurableTask
else
:
if
isinstance
(
task_obj
,
ConfigurableTask
):
# type: ignore
config
[
"metadata"
]
=
config
.
get
(
"metadata"
,
{})
task_obj
.
config
.
task
=
task_name
task_object
=
ConfigurableTask
(
config
=
config
)
else
:
task_obj
=
ConfigurableTask
(
config
=
cfg
)
# type: ignore
return
{
task
:
task_object
}
def
_get_group_and_subtask_from_config
(
config
:
dict
,
)
->
tuple
[
ConfigurableGroup
,
list
[
str
]]:
"""
Extract group object and subtask list from group configuration.
Expands any tags in the task list to their constituent tasks.
return
{
task_name
:
task_obj
}
Args:
def
_create_group_object
(
config: Group configuration dictionary
self
,
cfg
:
dict
,
Returns:
parent_name
:
str
|
None
=
None
,
Tuple of (ConfigurableGroup, list of subtask names)
)
->
tuple
[
ConfigurableGroup
,
list
[
Union
[
str
,
dict
]]]:
"""
"""
if
self
.
metadata
is
not
None
:
Build ConfigurableGroup and return (group_obj, subtask_names).
config
[
"metadata"
]
=
config
.
get
(
"metadata"
,
{})
|
self
.
metadata
Resolves tag expansion.
group_name
=
ConfigurableGroup
(
config
=
config
)
"""
subtask_list
=
[]
if
self
.
metadata
is
not
None
:
for
task
in
group_name
.
config
[
"task"
]:
cfg
[
"metadata"
]
=
cfg
.
get
(
"metadata"
,
{})
|
self
.
metadata
if
isinstance
(
task
,
str
)
and
self
.
_name_is_tag
(
task
):
subtask_list
.
extend
(
self
.
_get_tasklist
(
task
))
grp
=
ConfigurableGroup
(
config
=
cfg
)
else
:
subtasks
:
list
[
Union
[
str
,
dict
]]
=
[]
subtask_list
.
append
(
task
)
for
t
in
grp
.
config
[
"task"
]:
return
group_name
,
subtask_list
if
isinstance
(
t
,
str
)
and
self
.
_name_is_tag
(
t
):
subtasks
.
extend
(
self
.
_get_tasklist
(
t
))
def
_process_group_config
(
else
:
config
:
dict
,
update_config
:
Optional
[
dict
]
=
None
subtasks
.
append
(
t
)
)
->
tuple
[
dict
,
Optional
[
dict
]]:
return
grp
,
subtasks
"""
Separate group-specific config from task-level config overrides.
Group-only keys (like 'group', 'aggregate') stay with the group,
def
_load_subtasks
(
while other keys become config overrides for constituent tasks.
self
,
subtasks
:
list
[
Union
[
str
,
dict
]],
parent_name
:
Union
[
str
,
ConfigurableGroup
,
None
],
update_config
:
dict
|
None
,
)
->
Mapping
:
"""Return merged mapping of all subtasks, handling duplicates."""
fn
=
functools
.
partial
(
self
.
_load_individual_task_or_group
,
parent_name
=
parent_name
,
update_config
=
update_config
,
)
return
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
reversed
(
subtasks
))))
Args:
def
_load_individual_task_or_group
(
config: Full configuration dictionary
self
,
update_config: Additional config to merge
payload
:
str
|
dict
,
*
,
parent_name
:
str
|
None
=
None
,
update_config
:
dict
|
None
=
None
,
)
->
Mapping
:
"""
Public helper that turns *payload* (str task/group/tag **or** dict config)
into a nested Mapping of {name_or_group_obj: task_obj | sub_mapping}.
"""
Returns:
# ------------------------------------------------------------------ STRING
Tuple of (group_config, task_update_config)
if
isinstance
(
payload
,
str
):
"""
# If caller supplied extra overrides, treat as dict immediately
if
update_config
is
not
None
:
if
update_config
:
config
=
{
**
config
,
**
update_config
}
return
self
.
_load_individual_task_or_group
(
_update_config
=
{
{
"task"
:
payload
,
**
update_config
},
k
:
v
for
k
,
v
in
config
.
items
()
if
k
not
in
GROUP_ONLY_KEYS
parent_name
=
parent_name
,
}
)
if
not
bool
(
_update_config
):
_update_config
=
None
group_config
=
{
k
:
v
for
k
,
v
in
config
.
items
()
if
k
in
GROUP_ONLY_KEYS
}
return
group_config
,
_update_config
if
isinstance
(
name_or_config
,
str
):
if
update_config
is
not
None
:
# Process name_or_config as a dict instead
name_or_config
=
{
"task"
:
name_or_config
,
**
update_config
}
elif
self
.
_name_is_task
(
name_or_config
)
or
self
.
_name_is_python_task
(
name_or_config
):
# Get the yaml_path for this task
yaml_path
=
self
.
_get_yaml_path
(
name_or_config
)
task_config
=
self
.
_get_config
(
name_or_config
)
# Handle task_list configs
if
"task_list"
in
task_config
:
# Find the specific task entry
task_specific_config
=
None
for
task_entry
in
task_config
[
"task_list"
]:
if
(
isinstance
(
task_entry
,
dict
)
and
task_entry
.
get
(
"task"
)
==
name_or_config
):
task_specific_config
=
task_entry
break
if
task_specific_config
:
# Create base config without task_list
base_config
=
{
k
:
v
for
k
,
v
in
task_config
.
items
()
if
k
!=
"task_list"
}
# Merge using helper method
task_config
=
self
.
_merge_task_configs
(
base_config
,
task_specific_config
,
name_or_config
)
else
:
# Task not found in task_list, shouldn't happen if indexing worked correctly
eval_logger
.
warning
(
f
"Task
{
name_or_config
}
not found in task_list"
)
task_config
=
{
"task"
:
name_or_config
}
return
_load_task
(
task_config
,
task
=
name_or_config
,
yaml_path
=
yaml_path
)
# ------------ registered TASK (YAML or python) -----------------
else
:
if
self
.
_name_is_task
(
payload
)
or
self
.
_name_is_python_task
(
payload
):
subtask_list
=
self
.
_get_tasklist
(
name_or_config
)
yaml_path
=
self
.
_get_yaml_path
(
payload
)
if
subtask_list
==
-
1
:
cfg
=
self
.
_get_config
(
payload
)
group_config
=
self
.
_get_config
(
name_or_config
)
group_config
,
update_config
=
_process_group_config
(
group_config
)
# task_list configs: extract the per-task override ------------
group_name
,
subtask_list
=
_get_group_and_subtask_from_config
(
if
"task_list"
in
cfg
:
group_config
override
=
next
(
(
entry
for
entry
in
cfg
[
"task_list"
]
if
isinstance
(
entry
,
dict
)
and
entry
.
get
(
"task"
)
==
payload
),
None
,
)
)
else
:
base
=
{
k
:
v
for
k
,
v
in
cfg
.
items
()
if
k
!=
"task_list"
}
if
self
.
_name_is_tag
(
name_or_config
):
if
override
:
return
self
.
_process_tag_subtasks
(
cfg
=
{
**
base
,
**
override
,
"task"
:
payload
}
name_or_config
,
return
self
.
_create_task_object
(
cfg
,
payload
,
yaml_path
)
name_or_config
if
isinstance
(
name_or_config
,
dict
)
# ------------ registered GROUP ----------------------------------
else
None
,
if
self
.
_name_is_group
(
payload
):
)
group_cfg
=
self
.
_get_config
(
payload
)
else
:
grp_only
=
{
k
:
v
for
k
,
v
in
group_cfg
.
items
()
if
k
in
GROUP_ONLY_KEYS
}
group_name
=
ConfigurableGroup
(
grp_obj
,
subtasks
=
self
.
_create_group_object
(
grp_only
,
parent_name
)
config
=
{
"group"
:
name_or_config
,
"task"
:
subtask_list
}
return
{
)
grp_obj
:
self
.
_load_subtasks
(
subtasks
,
grp_obj
,
update_config
=
None
)
}
if
isinstance
(
name_or_config
,
dict
):
# ------------ registered TAG ------------------------------------
if
self
.
_config_is_task
(
name_or_config
):
if
self
.
_name_is_tag
(
payload
):
name
=
name_or_config
.
pop
(
"task"
)
return
self
.
_process_tag_subtasks
(
payload
,
update_config
=
None
)
if
update_config
is
not
None
:
name_or_config
=
{
**
name_or_config
,
**
update_config
}
raise
ValueError
(
f
"Unknown task / group / tag name:
{
payload
!
r
}
"
)
# If the name is registered as a group
if
self
.
_name_is_group
(
name
):
# ------------------------------------------------------------------- DICT
group_config
=
self
.
_get_config
(
name
)
if
isinstance
(
payload
,
dict
):
# ------------------ simple 'task: name' dict --------------------
group_config
,
update_config
=
_process_group_config
(
if
self
.
_config_is_task
(
payload
):
group_config
,
name_or_config
name
=
payload
[
"task"
]
)
# override existing registered YAML if exists
group_name
,
subtask_list
=
_get_group_and_subtask_from_config
(
if
self
.
_name_is_registered
(
name
):
group_config
base_cfg
=
self
.
_get_config
(
name
)
)
yaml_path
=
self
.
_get_yaml_path
(
name
)
elif
self
.
_name_is_tag
(
name
):
merged
=
{
**
base_cfg
,
**
payload
}
return
self
.
_process_tag_subtasks
(
name
,
name_or_config
)
else
:
else
:
merged
=
payload
yaml_path
=
None
yaml_path
=
None
if
self
.
_name_is_registered
(
name
):
yaml_path
=
self
.
_get_yaml_path
(
name
)
base_task_config
=
self
.
_get_config
(
name
)
# Check if this is a duplicate.
if
parent_name
is
not
None
:
num_duplicate
=
len
(
list
(
filter
(
lambda
x
:
x
.
startswith
(
name
),
self
.
task_group_map
[
parent_name
],
)
)
)
if
num_duplicate
>
0
:
name
=
f
"
{
name
}
-
{
num_duplicate
}
"
self
.
task_group_map
[
parent_name
].
append
(
name
)
task_config
=
{
**
base_task_config
,
**
name_or_config
,
}
else
:
task_config
=
name_or_config
return
_load_task
(
task_config
,
task
=
name
,
yaml_path
=
yaml_path
)
else
:
group_config
,
update_config
=
_process_group_config
(
name_or_config
)
group_name
,
subtask_list
=
_get_group_and_subtask_from_config
(
group_config
)
fn
=
partial
(
# duplicate-naming guard when inside a group
self
.
_load_individual_task_or_group
,
if
parent_name
is
not
None
:
parent_name
=
group_name
,
count
=
len
(
update_config
=
update_config
,
[
n
for
n
in
self
.
task_group_map
[
parent_name
]
if
n
.
startswith
(
name
)
]
)
if
count
:
name
=
f
"
{
name
}
-
{
count
}
"
self
.
task_group_map
[
parent_name
].
append
(
name
)
return
self
.
_create_task_object
(
merged
,
name
,
yaml_path
)
# ----------------- literal group dict (task: [...]) -------------
if
self
.
_config_is_group
(
payload
):
grp_cfg
=
{
k
:
v
for
k
,
v
in
payload
.
items
()
if
k
in
GROUP_ONLY_KEYS
}
sub_override
=
{
k
:
v
for
k
,
v
in
payload
.
items
()
if
k
not
in
GROUP_ONLY_KEYS
}
or
None
grp_obj
,
subtasks
=
self
.
_create_group_object
(
grp_cfg
,
parent_name
)
return
{
grp_obj
:
self
.
_load_subtasks
(
subtasks
,
grp_obj
,
sub_override
)}
# ----------------- python-task dict ('class': …) ----------------
if
self
.
_config_is_python_task
(
payload
):
name
=
payload
[
"task"
]
return
self
.
_create_task_object
(
payload
,
name
,
yaml_path
=
None
)
raise
TypeError
(
f
"_load_individual_task_or_group expected str | dict, got
{
type
(
payload
)
}
"
)
)
return
{
group_name
:
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
reversed
(
subtask_list
))))
}
def
load_task_or_group
(
def
load_task_or_group
(
self
,
task_list
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
self
,
task_list
:
Optional
[
Union
[
str
,
list
[
str
]]]
=
None
...
@@ -1363,64 +1255,45 @@ def get_task_dict(
...
@@ -1363,64 +1255,45 @@ def get_task_dict(
tm = TaskManager(include_path="/custom/tasks")
tm = TaskManager(include_path="/custom/tasks")
tasks = get_task_dict(["custom_task"], task_manager=tm)
tasks = get_task_dict(["custom_task"], task_manager=tm)
"""
"""
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.api.task
import
Task
task_name_from_string_dict
=
{}
task_name_from_config_dict
=
{}
task_name_from_object_dict
=
{}
# Normalize input to list
if
isinstance
(
task_name_list
,
str
):
if
isinstance
(
task_name_list
,
str
):
task_name_list
=
[
task_name_list
]
task_name_list
=
[
task_name_list
]
elif
isinstance
(
task_name_list
,
list
):
elif
not
isinstance
(
task_name_list
,
list
):
if
not
all
([
isinstance
(
task
,
(
str
,
dict
,
Task
))
for
task
in
task_name_list
]):
raise
TypeError
(
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
)
else
:
raise
TypeError
(
raise
TypeError
(
f
"Expected a 'str' or 'list' but received
{
type
(
task_name_list
)
}
."
f
"Expected a 'str' or 'list' but received
{
type
(
task_name_list
)
}
."
)
)
string_task_name_list
=
[
task
for
task
in
task_name_list
if
isinstance
(
task
,
str
)]
# Validate list items
others_task_name_list
=
[
if
not
all
(
isinstance
(
task
,
(
str
,
dict
,
Task
))
for
task
in
task_name_list
):
task
for
task
in
task_name_list
if
not
isinstance
(
task
,
str
)
raise
TypeError
(
]
"Expected all list items to be of types 'str', 'dict', or 'Task', but at least one entry did not match."
if
len
(
string_task_name_list
)
>
0
:
if
task_manager
is
None
:
task_manager
=
TaskManager
()
task_name_from_string_dict
=
task_manager
.
load_task_or_group
(
string_task_name_list
)
)
for
task_element
in
others_task_name_list
:
# Ensure we have a task manager
if
isinstance
(
task_element
,
dict
):
if
task_manager
is
None
:
task_name_from_config_dict
=
{
task_manager
=
TaskManager
()
**
task_name_from_config_dict
,
**
task_manager
.
load_config
(
config
=
task_element
),
}
elif
isinstance
(
task_element
,
Task
):
task_name_from_object_dict
=
{
**
task_name_from_object_dict
,
get_task_name_from_object
(
task_element
):
task_element
,
}
if
not
set
(
task_name_from_string_dict
.
keys
()).
isdisjoint
(
set
(
task_name_from_object_dict
.
keys
())
):
raise
ValueError
final_task_dict
=
{
**
task_name_from_string_dict
,
**
task_name_from_config_dict
,
**
task_name_from_object_dict
,
}
# behavior can get odd if one tries to invoke several groups that "compete" for the same task.
# Process all items
# (notably, because one could request several num_fewshot values at once in GroupConfig overrides for the subtask
final_task_dict
=
{}
# and we'd be unsure which to use and report.)
for
task_spec
in
task_name_list
:
# we explicitly check and error in this case.
if
isinstance
(
task_spec
,
Task
):
# Pre-instantiated task object
task_name
=
get_task_name_from_object
(
task_spec
)
if
task_name
in
final_task_dict
:
raise
ValueError
(
f
"Duplicate task name:
{
task_name
}
"
)
final_task_dict
[
task_name
]
=
task_spec
else
:
# String or dict - use load_task_or_group
result
=
task_manager
.
load_task_or_group
(
task_spec
)
# Check for duplicate names
for
name
in
result
:
if
name
in
final_task_dict
:
raise
ValueError
(
f
"Duplicate task name:
{
name
}
"
)
final_task_dict
.
update
(
result
)
# Check for conflicting group memberships
_check_duplicates
(
get_subtask_list
(
final_task_dict
))
_check_duplicates
(
get_subtask_list
(
final_task_dict
))
return
final_task_dict
return
final_task_dict
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment