Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
94673d40
Commit
94673d40
authored
Jul 03, 2024
by
haileyschoelkopf
Browse files
move group api to separate file
parent
c6839d72
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
5 additions
and
116 deletions
+5
-116
lm_eval/api/task.py
lm_eval/api/task.py
+0
-113
lm_eval/evaluator_utils.py
lm_eval/evaluator_utils.py
+2
-1
lm_eval/tasks/__init__.py
lm_eval/tasks/__init__.py
+3
-2
No files found.
lm_eval/api/task.py
View file @
94673d40
...
...
@@ -51,119 +51,6 @@ ALL_OUTPUT_TYPES = [
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
@
dataclass
class
AggMetricConfig
(
dict
):
metric
:
Optional
[
str
]
=
None
aggregation
:
Optional
[
str
]
=
"mean"
weight_by_size
:
Optional
[
str
]
=
False
# list of filter names which should be incorporated into the aggregated metric.
filter_list
:
Optional
[
Union
[
str
,
list
]]
=
"none"
def
__post_init__
(
self
):
if
self
.
aggregation
!=
"mean"
:
raise
ValueError
(
f
"Currently, only 'mean' is supported for automatically aggregating scores across groups' subtasks. Got '
{
self
.
aggregation
}
'."
)
if
isinstance
(
self
.
filter_list
,
str
):
self
.
filter_list
=
[
self
.
filter_list
]
@
dataclass
class
GroupConfig
(
dict
):
group
:
Optional
[
str
]
=
None
group_alias
:
Optional
[
str
]
=
None
task
:
Optional
[
Union
[
str
,
list
]]
=
None
aggregate_metric_list
:
Optional
[
Union
[
List
[
AggMetricConfig
],
AggMetricConfig
,
dict
]
]
=
None
metadata
:
Optional
[
dict
]
=
None
# by default, not used in the code. allows for users to pass arbitrary info to tasks
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
def
__post_init__
(
self
):
if
self
.
aggregate_metric_list
is
not
None
:
if
isinstance
(
self
.
aggregate_metric_list
,
dict
):
self
.
aggregate_metric_list
=
[
self
.
aggregate_metric_list
]
self
.
aggregate_metric_list
=
[
AggMetricConfig
(
**
item
)
if
isinstance
(
item
,
dict
)
else
item
for
item
in
self
.
aggregate_metric_list
]
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict
=
asdict
(
self
)
# remove values that are `None`
for
k
,
v
in
list
(
cfg_dict
.
items
()):
if
callable
(
v
):
cfg_dict
[
k
]
=
self
.
serialize_function
(
v
,
keep_callable
=
keep_callable
)
return
cfg_dict
def
serialize_function
(
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
)
->
Union
[
Callable
,
str
]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if
keep_callable
:
return
value
else
:
try
:
return
getsource
(
value
)
except
(
TypeError
,
OSError
):
return
str
(
value
)
class
ConfigurableGroup
(
abc
.
ABC
):
def
__init__
(
self
,
config
:
Optional
[
dict
]
=
None
,
)
->
None
:
self
.
_config
=
GroupConfig
(
**
config
)
@
property
def
group
(
self
):
return
self
.
_config
.
group
@
property
def
group_alias
(
self
):
return
self
.
_config
.
group_alias
@
property
def
version
(
self
):
return
self
.
_config
.
version
@
property
def
config
(
self
):
return
self
.
_config
.
to_dict
()
@
property
def
group_name
(
self
)
->
Any
:
return
self
.
_config
.
group
def
__repr__
(
self
):
return
(
f
"ConfigurableGroup(group=
{
self
.
group
}
,"
f
"group_alias=
{
self
.
group_alias
}
)"
)
@
dataclass
class
TaskConfig
(
dict
):
# task naming/registry
...
...
lm_eval/evaluator_utils.py
View file @
94673d40
...
...
@@ -4,12 +4,13 @@ import pathlib
import
sys
from
typing
import
List
,
Optional
,
Tuple
,
Union
from
lm_eval.api.group
import
ConfigurableGroup
from
lm_eval.api.metrics
import
(
aggregate_subtask_metrics
,
pooled_sample_stderr
,
stderr_for_metric
,
)
from
lm_eval.api.task
import
ConfigurableGroup
,
Task
from
lm_eval.api.task
import
Task
from
lm_eval.utils
import
eval_logger
,
positional_deprecated
...
...
lm_eval/tasks/__init__.py
View file @
94673d40
...
...
@@ -5,7 +5,8 @@ from functools import partial
from
typing
import
Dict
,
List
,
Mapping
,
Optional
,
Union
from
lm_eval
import
utils
from
lm_eval.api.task
import
ConfigurableGroup
,
ConfigurableTask
,
GroupConfig
,
Task
from
lm_eval.api.group
import
ConfigurableGroup
,
GroupConfig
from
lm_eval.api.task
import
ConfigurableTask
,
Task
from
lm_eval.evaluator_utils
import
get_subtask_list
...
...
@@ -153,7 +154,7 @@ class TaskManager:
if
self
.
_config_is_python_task
(
config
):
task_object
=
(
config
[
"class"
](
config
=
config
)
if
is
instance
(
config
[
"class"
],
ConfigurableTask
)
if
is
subclass
(
config
[
"class"
],
ConfigurableTask
)
else
config
[
"class"
]()
)
# very scuffed: set task name here. TODO: fixme?
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment