Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
0aca6958
Commit
0aca6958
authored
Jul 13, 2025
by
Baber
Browse files
refactor: replace ConfigurableGroup with GroupConfig
parent
7fcfb4ac
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
46 additions
and
57 deletions
+46
-57
lm_eval/api/group.py
lm_eval/api/group.py
+21
-32
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+19
-19
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+6
-6
No files found.
lm_eval/api/group.py
View file @
0aca6958
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
inspect
import
getsource
from
typing
import
Any
,
Callable
,
List
,
Optional
,
Union
from
typing
import
Callable
,
List
,
Optional
,
Union
@
dataclass
@
dataclass
...
@@ -22,7 +22,7 @@ class AggMetricConfig(dict):
...
@@ -22,7 +22,7 @@ class AggMetricConfig(dict):
@
dataclass
@
dataclass
class
GroupConfig
(
dict
)
:
class
GroupConfig
:
group
:
Optional
[
str
]
=
None
group
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
...
@@ -39,6 +39,24 @@ class GroupConfig(dict):
...
@@ -39,6 +39,24 @@ class GroupConfig(dict):
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
return
setattr
(
self
,
item
,
value
)
def
__contains__
(
self
,
item
):
"""Support 'in' operator for dict-like behavior."""
return
hasattr
(
self
,
item
)
def
get
(
self
,
key
,
default
=
None
):
"""Dict-like get method."""
return
getattr
(
self
,
key
,
default
)
def
__hash__
(
self
):
"""Make GroupConfig hashable based on group name."""
return
hash
(
self
.
group
)
def
__eq__
(
self
,
other
):
"""Equality comparison based on group name."""
if
not
isinstance
(
other
,
GroupConfig
):
return
False
return
self
.
group
==
other
.
group
def
__post_init__
(
self
):
def
__post_init__
(
self
):
if
self
.
aggregate_metric_list
is
not
None
:
if
self
.
aggregate_metric_list
is
not
None
:
if
isinstance
(
self
.
aggregate_metric_list
,
dict
):
if
isinstance
(
self
.
aggregate_metric_list
,
dict
):
...
@@ -87,34 +105,5 @@ class GroupConfig(dict):
...
@@ -87,34 +105,5 @@ class GroupConfig(dict):
"""Returns the version of the group configuration."""
"""Returns the version of the group configuration."""
return
self
.
metadata
.
get
(
"version"
,
"1.0"
)
return
self
.
metadata
.
get
(
"version"
,
"1.0"
)
@
dataclass
class
ConfigurableGroup
:
def
__init__
(
self
,
config
:
Optional
[
dict
]
=
None
,
)
->
None
:
self
.
_config
=
GroupConfig
(
**
config
)
@
property
def
group
(
self
):
return
self
.
_config
.
group
@
property
def
group_alias
(
self
):
return
self
.
_config
.
group_alias
@
property
def
version
(
self
):
return
self
.
_config
.
version
@
property
def
config
(
self
):
return
self
.
_config
.
to_dict
()
@
property
def
group_name
(
self
)
->
Any
:
return
self
.
_config
.
group
def
__repr__
(
self
):
def
__repr__
(
self
):
return
f
"Config
urableGroup
(group=
{
self
.
group
}
,group_alias=
{
self
.
group_alias
}
)"
return
f
"
Group
Config(group=
{
self
.
group
}
,group_alias=
{
self
.
group_alias
}
)"
lm_eval/evaluator_utils.py
View file @
0aca6958
...
@@ -151,14 +151,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
...
@@ -151,14 +151,14 @@ def get_task_list(task_dict: dict) -> List[TaskOutput]:
def
get_subtask_list
(
task_dict
,
task_root
=
None
,
depth
=
0
):
def
get_subtask_list
(
task_dict
,
task_root
=
None
,
depth
=
0
):
from
lm_eval.api.group
import
Config
urableGroup
from
lm_eval.api.group
import
Group
Config
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
Task
subtask_list
=
{}
subtask_list
=
{}
for
group_obj
,
task_obj
in
task_dict
.
items
():
for
group_obj
,
task_obj
in
task_dict
.
items
():
if
isinstance
(
group_obj
,
Config
urableGroup
):
if
isinstance
(
group_obj
,
Group
Config
):
# group_name = group_obj.group
_name
# group_name = group_obj.group
group_name
=
group_obj
.
group
_name
group_name
=
group_obj
.
group
else
:
else
:
group_name
=
group_obj
group_name
=
group_obj
if
isinstance
(
task_obj
,
dict
):
if
isinstance
(
task_obj
,
dict
):
...
@@ -176,9 +176,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
...
@@ -176,9 +176,9 @@ def get_subtask_list(task_dict, task_root=None, depth=0):
subtask_list
=
{
**
subtask_list
,
**
_subtask_list
}
subtask_list
=
{
**
subtask_list
,
**
_subtask_list
}
else
:
else
:
if
isinstance
(
task_obj
,
Config
urableGroup
):
if
isinstance
(
task_obj
,
Group
Config
):
# group_or_task_name = task_obj.group
_name
# group_or_task_name = task_obj.group
group_or_task_name
=
task_obj
.
group
_name
group_or_task_name
=
task_obj
.
group
elif
isinstance
(
task_obj
,
Task
):
elif
isinstance
(
task_obj
,
Task
):
# group_or_task_name = task_obj.task_name
# group_or_task_name = task_obj.task_name
group_or_task_name
=
task_obj
.
task_name
group_or_task_name
=
task_obj
.
task_name
...
@@ -241,7 +241,7 @@ def prepare_print_tasks(
...
@@ -241,7 +241,7 @@ def prepare_print_tasks(
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
Prepares the task hierarchy and aggregates the results for each task and group recursively for printing.
"""
"""
from
lm_eval.api.group
import
Config
urableGroup
from
lm_eval.api.group
import
Group
Config
def
_sort_task_dict
(
task_dict
):
def
_sort_task_dict
(
task_dict
):
"""
"""
...
@@ -252,8 +252,8 @@ def prepare_print_tasks(
...
@@ -252,8 +252,8 @@ def prepare_print_tasks(
return
dict
(
return
dict
(
sorted
(
sorted
(
task_dict
.
items
(),
task_dict
.
items
(),
key
=
lambda
item
:
item
[
0
].
group
_name
key
=
lambda
item
:
item
[
0
].
group
if
isinstance
(
item
[
0
],
Config
urableGroup
)
if
isinstance
(
item
[
0
],
Group
Config
)
else
item
[
0
],
else
item
[
0
],
)
)
)
)
...
@@ -263,9 +263,9 @@ def prepare_print_tasks(
...
@@ -263,9 +263,9 @@ def prepare_print_tasks(
task_dict
=
_sort_task_dict
(
task_dict
)
task_dict
=
_sort_task_dict
(
task_dict
)
for
task_or_group_name
,
task_or_group_obj
in
task_dict
.
items
():
for
task_or_group_name
,
task_or_group_obj
in
task_dict
.
items
():
tab_string
=
" "
*
task_depth
+
"- "
if
task_depth
>
0
else
""
tab_string
=
" "
*
task_depth
+
"- "
if
task_depth
>
0
else
""
if
isinstance
(
task_or_group_name
,
Config
urableGroup
):
if
isinstance
(
task_or_group_name
,
Group
Config
):
# string_name = task_or_group_name.group
_name
# string_name = task_or_group_name.group
name
=
task_or_group_name
.
group
_name
name
=
task_or_group_name
.
group
from_configurable_group
=
True
from_configurable_group
=
True
task_or_group_obj
=
_sort_task_dict
(
task_or_group_obj
)
task_or_group_obj
=
_sort_task_dict
(
task_or_group_obj
)
elif
isinstance
(
task_or_group_name
,
str
):
elif
isinstance
(
task_or_group_name
,
str
):
...
@@ -399,7 +399,7 @@ def consolidate_group_results(
...
@@ -399,7 +399,7 @@ def consolidate_group_results(
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
The method then returns the updated results, versions, show_group_table, and task_aggregation_list as a tuple.
In the top-level invocation of this function, task_aggregation_list is ignored.
In the top-level invocation of this function, task_aggregation_list is ignored.
"""
"""
from
lm_eval.api.group
import
Config
urableGroup
from
lm_eval.api.group
import
Group
Config
from
lm_eval.api.task
import
Task
from
lm_eval.api.task
import
Task
if
task_root
is
None
:
if
task_root
is
None
:
...
@@ -410,9 +410,9 @@ def consolidate_group_results(
...
@@ -410,9 +410,9 @@ def consolidate_group_results(
for
group_or_task
,
group_or_task_info
in
task_dict
.
items
():
for
group_or_task
,
group_or_task_info
in
task_dict
.
items
():
# Convert to string
# Convert to string
if
isinstance
(
group_or_task
,
Config
urableGroup
):
if
isinstance
(
group_or_task
,
Group
Config
):
group_config
=
group_or_task
.
config
group_config
=
group_or_task
.
to_dict
()
group_or_task
=
group_or_task
.
group
_name
group_or_task
=
group_or_task
.
group
else
:
else
:
group_config
=
None
group_config
=
None
...
@@ -441,7 +441,7 @@ def consolidate_group_results(
...
@@ -441,7 +441,7 @@ def consolidate_group_results(
)
)
if
(
group_config
is
None
)
or
(
if
(
group_config
is
None
)
or
(
group_config
[
"aggregate_metric_list"
]
is
None
group_config
.
get
(
"aggregate_metric_list"
)
is
None
):
):
results
[
group_or_task
][
" "
]
=
" "
results
[
group_or_task
][
" "
]
=
" "
continue
continue
...
@@ -450,7 +450,7 @@ def consolidate_group_results(
...
@@ -450,7 +450,7 @@ def consolidate_group_results(
agg_metric_list
=
group_config
[
"aggregate_metric_list"
]
agg_metric_list
=
group_config
[
"aggregate_metric_list"
]
show_group_table
=
show_group_table
|
bool
(
show_group_table
=
show_group_table
|
bool
(
group_config
[
"aggregate_metric_list"
]
group_config
.
get
(
"aggregate_metric_list"
)
)
)
task_list
=
_task_aggregation_list
[
group_or_task
]
task_list
=
_task_aggregation_list
[
group_or_task
]
...
...
lm_eval/tasks/__init__.py
View file @
0aca6958
...
@@ -49,7 +49,7 @@ from typing import (
...
@@ -49,7 +49,7 @@ from typing import (
import
yaml
import
yaml
from
yaml
import
YAMLError
from
yaml
import
YAMLError
from
lm_eval.api.group
import
ConfigurableGroup
,
GroupConfig
from
lm_eval.api.group
import
GroupConfig
from
lm_eval.evaluator_utils
import
get_subtask_list
from
lm_eval.evaluator_utils
import
get_subtask_list
from
lm_eval.utils
import
pattern_match
,
setup_logging
from
lm_eval.utils
import
pattern_match
,
setup_logging
...
@@ -767,17 +767,17 @@ class TaskManager:
...
@@ -767,17 +767,17 @@ class TaskManager:
self
,
self
,
cfg
:
dict
,
cfg
:
dict
,
parent_name
:
str
|
None
=
None
,
parent_name
:
str
|
None
=
None
,
)
->
tuple
[
Config
urableGroup
,
list
[
Union
[
str
,
dict
]]]:
)
->
tuple
[
Group
Config
,
list
[
Union
[
str
,
dict
]]]:
"""
"""
Build Config
urableGroup
and return (group_obj, subtask_names).
Build
Group
Config and return (group_obj, subtask_names).
Resolves tag expansion.
Resolves tag expansion.
"""
"""
if
self
.
metadata
is
not
None
:
if
self
.
metadata
is
not
None
:
cfg
[
"metadata"
]
=
cfg
.
get
(
"metadata"
,
{})
|
self
.
metadata
cfg
[
"metadata"
]
=
cfg
.
get
(
"metadata"
,
{})
|
self
.
metadata
grp
=
Configurable
Group
(
c
onfig
=
cfg
)
grp
=
Group
C
onfig
(
**
cfg
)
subtasks
:
list
[
Union
[
str
,
dict
]]
=
[]
subtasks
:
list
[
Union
[
str
,
dict
]]
=
[]
for
t
in
grp
.
config
[
"
task
"
]
:
for
t
in
grp
.
task
:
if
isinstance
(
t
,
str
)
and
self
.
_name_is_tag
(
t
):
if
isinstance
(
t
,
str
)
and
self
.
_name_is_tag
(
t
):
subtasks
.
extend
(
self
.
_get_tasklist
(
t
))
subtasks
.
extend
(
self
.
_get_tasklist
(
t
))
else
:
else
:
...
@@ -787,7 +787,7 @@ class TaskManager:
...
@@ -787,7 +787,7 @@ class TaskManager:
def
_load_subtasks
(
def
_load_subtasks
(
self
,
self
,
subtasks
:
list
[
Union
[
str
,
dict
]],
subtasks
:
list
[
Union
[
str
,
dict
]],
parent_name
:
Union
[
str
,
Config
urableGroup
,
None
],
parent_name
:
Union
[
str
,
Group
Config
,
None
],
update_config
:
dict
|
None
,
update_config
:
dict
|
None
,
)
->
Mapping
:
)
->
Mapping
:
"""Return merged mapping of all subtasks, handling duplicates."""
"""Return merged mapping of all subtasks, handling duplicates."""
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment