Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
85f61d85
Commit
85f61d85
authored
Jul 11, 2025
by
Baber
Browse files
refactor: add type hints
parent
5454e95d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
39 deletions
+53
-39
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+45
-34
lm_eval/utils.py
lm_eval/utils.py
+8
-5
No files found.
lm_eval/tasks/__init__.py
View file @
85f61d85
...
...
@@ -88,31 +88,34 @@ class TaskManager:
return
task_index
@
property
def
all_tasks
(
self
):
def
all_tasks
(
self
)
->
List
[
str
]
:
return
self
.
_all_tasks
@
property
def
all_groups
(
self
):
def
all_groups
(
self
)
->
List
[
str
]
:
return
self
.
_all_groups
@
property
def
all_subtasks
(
self
):
def
all_subtasks
(
self
)
->
List
[
str
]
:
return
self
.
_all_subtasks
@
property
def
all_tags
(
self
):
def
all_tags
(
self
)
->
List
[
str
]
:
return
self
.
_all_tags
@
property
def
task_index
(
self
):
def
task_index
(
self
)
->
Dict
[
str
,
Dict
[
str
,
Union
[
str
,
int
,
List
[
str
]]]]
:
return
self
.
_task_index
def
list_all_tasks
(
self
,
list_groups
=
True
,
list_tags
=
True
,
list_subtasks
=
True
self
,
list_groups
:
bool
=
True
,
list_tags
:
bool
=
True
,
list_subtasks
:
bool
=
True
,
)
->
str
:
from
pytablewriter
import
MarkdownTableWriter
def
sanitize_path
(
path
)
:
def
sanitize_path
(
path
:
str
)
->
str
:
# don't print full path if we are within the lm_eval/tasks dir !
# if we aren't though, provide the full path.
if
"lm_eval/tasks/"
in
path
:
...
...
@@ -210,12 +213,12 @@ class TaskManager:
def
_config_is_task_list
(
self
,
config
:
dict
)
->
bool
:
return
"task_list"
in
config
and
isinstance
(
config
[
"task_list"
],
list
)
def
_get_yaml_path
(
self
,
name
:
str
):
def
_get_yaml_path
(
self
,
name
:
str
)
->
Union
[
str
,
int
]
:
if
name
not
in
self
.
task_index
:
raise
ValueError
return
self
.
task_index
[
name
][
"yaml_path"
]
def
_get_config
(
self
,
name
)
:
def
_get_config
(
self
,
name
:
str
)
->
Dict
:
if
name
not
in
self
.
task_index
:
raise
ValueError
yaml_path
=
self
.
_get_yaml_path
(
name
)
...
...
@@ -224,7 +227,7 @@ class TaskManager:
else
:
return
utils
.
load_yaml_config
(
yaml_path
,
mode
=
"full"
)
def
_get_tasklist
(
self
,
name
)
:
def
_get_tasklist
(
self
,
name
:
str
)
->
Union
[
List
[
str
],
int
]
:
if
self
.
_name_is_task
(
name
):
raise
ValueError
return
self
.
task_index
[
name
][
"task"
]
...
...
@@ -234,10 +237,10 @@ class TaskManager:
task_name
:
str
,
task_type
:
str
,
yaml_path
:
str
,
tasks_and_groups
:
d
ict
,
config
:
d
ict
=
None
,
populate_tags_fn
=
None
,
):
tasks_and_groups
:
Dict
[
str
,
D
ict
]
,
config
:
Optional
[
D
ict
]
=
None
,
populate_tags_fn
:
Optional
[
callable
]
=
None
,
)
->
None
:
"""Helper method to register a task in the tasks_and_groups dict"""
tasks_and_groups
[
task_name
]
=
{
"type"
:
task_type
,
...
...
@@ -248,8 +251,8 @@ class TaskManager:
populate_tags_fn
(
config
,
task_name
,
tasks_and_groups
)
def
_merge_task_configs
(
self
,
base_config
:
d
ict
,
task_specific_config
:
d
ict
,
task_name
:
str
)
->
d
ict
:
self
,
base_config
:
D
ict
,
task_specific_config
:
D
ict
,
task_name
:
str
)
->
D
ict
:
"""Merge base config with task-specific overrides for task_list configs"""
if
task_specific_config
:
task_specific_config
=
task_specific_config
.
copy
()
...
...
@@ -257,7 +260,9 @@ class TaskManager:
return
{
**
base_config
,
**
task_specific_config
,
"task"
:
task_name
}
return
{
**
base_config
,
"task"
:
task_name
}
def
_process_tag_subtasks
(
self
,
tag_name
:
str
,
update_config
:
dict
=
None
):
def
_process_tag_subtasks
(
self
,
tag_name
:
str
,
update_config
:
Optional
[
Dict
]
=
None
)
->
Dict
:
"""Process subtasks for a tag and return loaded tasks"""
subtask_list
=
self
.
_get_tasklist
(
tag_name
)
fn
=
partial
(
...
...
@@ -266,7 +271,7 @@ class TaskManager:
)
return
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
reversed
(
subtask_list
))))
def
_process_alias
(
self
,
config
,
group
=
None
)
:
def
_process_alias
(
self
,
config
:
Dict
,
group
:
Optional
[
str
]
=
None
)
->
Dict
:
# If the group is not the same as the original
# group which the group alias was intended for,
# Set the group_alias to None instead.
...
...
@@ -275,7 +280,7 @@ class TaskManager:
config
[
"group_alias"
]
=
None
return
config
def
_class_has_config_in_constructor
(
self
,
cls
):
def
_class_has_config_in_constructor
(
self
,
cls
)
->
bool
:
constructor
=
getattr
(
cls
,
"__init__"
,
None
)
return
(
"config"
in
inspect
.
signature
(
constructor
).
parameters
...
...
@@ -285,11 +290,13 @@ class TaskManager:
def
_load_individual_task_or_group
(
self
,
name_or_config
:
Optional
[
Union
[
str
,
d
ict
]]
=
None
,
name_or_config
:
Optional
[
Union
[
str
,
D
ict
]]
=
None
,
parent_name
:
Optional
[
str
]
=
None
,
update_config
:
Optional
[
d
ict
]
=
None
,
update_config
:
Optional
[
D
ict
]
=
None
,
)
->
Mapping
:
def
_load_task
(
config
,
task
,
yaml_path
=
None
):
def
_load_task
(
config
:
Dict
,
task
:
str
,
yaml_path
:
Optional
[
str
]
=
None
)
->
Dict
[
str
,
Union
[
ConfigurableTask
,
Task
]]:
if
"include"
in
config
:
# Store the task name to preserve it after include processing
original_task_name
=
config
.
get
(
"task"
,
task
)
...
...
@@ -325,8 +332,8 @@ class TaskManager:
return
{
task
:
task_object
}
def
_get_group_and_subtask_from_config
(
config
:
d
ict
,
)
->
tuple
[
ConfigurableGroup
,
l
ist
[
str
]]:
config
:
D
ict
,
)
->
tuple
[
ConfigurableGroup
,
L
ist
[
str
]]:
if
self
.
metadata
is
not
None
:
config
[
"metadata"
]
=
config
.
get
(
"metadata"
,
{})
|
self
.
metadata
group_name
=
ConfigurableGroup
(
config
=
config
)
...
...
@@ -339,8 +346,8 @@ class TaskManager:
return
group_name
,
subtask_list
def
_process_group_config
(
config
:
d
ict
,
update_config
:
d
ict
=
None
)
->
tuple
[
d
ict
,
d
ict
]:
config
:
D
ict
,
update_config
:
Optional
[
D
ict
]
=
None
)
->
tuple
[
D
ict
,
Optional
[
D
ict
]
]
:
if
update_config
is
not
None
:
config
=
{
**
config
,
**
update_config
}
_update_config
=
{
...
...
@@ -472,7 +479,9 @@ class TaskManager:
group_name
:
dict
(
collections
.
ChainMap
(
*
map
(
fn
,
reversed
(
subtask_list
))))
}
def
load_task_or_group
(
self
,
task_list
:
Optional
[
Union
[
str
,
list
]]
=
None
)
->
dict
:
def
load_task_or_group
(
self
,
task_list
:
Optional
[
Union
[
str
,
List
[
str
]]]
=
None
)
->
Dict
:
"""Loads a dictionary of task objects from a list
:param task_list: Union[str, list] = None
...
...
@@ -494,10 +503,10 @@ class TaskManager:
)
return
all_loaded_tasks
def
load_config
(
self
,
config
:
Dict
):
def
load_config
(
self
,
config
:
Dict
)
->
Mapping
:
return
self
.
_load_individual_task_or_group
(
config
)
def
_get_task_and_group
(
self
,
task_dir
:
Union
[
str
,
Path
]):
def
_get_task_and_group
(
self
,
task_dir
:
Union
[
str
,
Path
])
->
Dict
[
str
,
Dict
]
:
"""Creates a dictionary of tasks index with the following metadata,
- `type`, that can be either `task`, `python_task`, `group` or `tags`.
`task` refer to regular task configs, `python_task` are special
...
...
@@ -520,7 +529,9 @@ class TaskManager:
Dictionary of task names as key and task metadata
"""
def
_populate_tags_and_groups
(
config
,
task
,
tasks_and_groups
):
def
_populate_tags_and_groups
(
config
:
Dict
,
task
:
str
,
tasks_and_groups
:
Dict
[
str
,
Dict
]
)
->
None
:
# TODO: remove group in next release
if
"tag"
in
config
:
attr_list
=
config
[
"tag"
]
...
...
@@ -557,7 +568,7 @@ class TaskManager:
for
f
in
file_list
:
if
f
.
endswith
(
".yaml"
):
yaml_path
=
root_path
/
f
config
=
utils
.
load_yaml_config
(
str
(
yaml_path
)
,
mode
=
"simple"
)
config
=
utils
.
load_yaml_config
(
yaml_path
,
mode
=
"simple"
)
if
self
.
_config_is_python_task
(
config
):
# This is a python class config
task
=
config
[
"task"
]
...
...
@@ -629,7 +640,7 @@ def get_task_name_from_config(task_config: Dict[str, str]) -> str:
return
"{dataset_path}"
.
format
(
**
task_config
)
def
get_task_name_from_object
(
task_object
)
:
def
get_task_name_from_object
(
task_object
:
Union
[
ConfigurableTask
,
Task
])
->
str
:
if
hasattr
(
task_object
,
"config"
):
return
task_object
.
_config
[
"task"
]
...
...
@@ -642,7 +653,7 @@ def get_task_name_from_object(task_object):
)
def
_check_duplicates
(
task_dict
:
d
ict
)
->
None
:
def
_check_duplicates
(
task_dict
:
D
ict
[
str
,
List
[
str
]]
)
->
None
:
"""helper function solely used in validating get_task_dict output.
Takes the output of lm_eval.evaluator_utils.get_subtask_list and
returns a list of all leaf subtasks contained within, and errors if any such leaf subtasks are
...
...
@@ -672,7 +683,7 @@ def _check_duplicates(task_dict: dict) -> None:
def
get_task_dict
(
task_name_list
:
Union
[
str
,
List
[
Union
[
str
,
Dict
,
Task
]]],
task_manager
:
Optional
[
TaskManager
]
=
None
,
):
)
->
Dict
[
str
,
Union
[
ConfigurableTask
,
Task
]]
:
"""Creates a dictionary of task objects from either a name of task, config, or prepared Task object.
:param task_name_list: List[Union[str, Dict, Task]]
...
...
lm_eval/utils.py
View file @
85f61d85
...
...
@@ -11,7 +11,7 @@ import re
from
dataclasses
import
asdict
,
is_dataclass
from
itertools
import
islice
from
pathlib
import
Path
from
typing
import
Any
,
Callable
,
Generator
,
List
,
Optional
,
Tuple
from
typing
import
Any
,
Callable
,
Dict
,
Generator
,
List
,
Optional
,
Tuple
,
Union
import
numpy
as
np
import
yaml
...
...
@@ -441,11 +441,11 @@ def positional_deprecated(fn):
return
_wrapper
def
ignore_constructor
(
loader
,
node
)
:
def
ignore_constructor
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
)
->
yaml
.
Node
:
return
node
def
import_function
(
loader
:
yaml
.
Loader
,
node
,
yaml_path
:
Path
):
def
import_function
(
loader
:
yaml
.
Loader
,
node
:
yaml
.
Node
,
yaml_path
:
Path
)
->
Callable
:
function_name
=
loader
.
construct_scalar
(
node
)
*
module_name
,
function_name
=
function_name
.
split
(
"."
)
...
...
@@ -468,8 +468,11 @@ def import_function(loader: yaml.Loader, node, yaml_path: Path):
def
load_yaml_config
(
yaml_path
=
None
,
yaml_config
=
None
,
yaml_dir
=
None
,
mode
=
"full"
)
->
dict
:
yaml_path
:
Optional
[
Union
[
str
,
Path
]]
=
None
,
yaml_config
:
Optional
[
Dict
]
=
None
,
yaml_dir
:
Optional
[
Union
[
str
,
Path
]]
=
None
,
mode
:
str
=
"full"
,
)
->
Dict
:
# Convert yaml_path to Path object if it's a string
if
yaml_path
is
not
None
:
yaml_path
=
Path
(
yaml_path
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment