Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
lm-evaluation-harness
Commits
7fc43656
Unverified
Commit
7fc43656
authored
Jan 29, 2024
by
Baber Abbasi
Committed by
GitHub
Jan 29, 2024
Browse files
serialize callable functions in config (#1367)
parent
488759d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
28 additions
and
10 deletions
+28
-10
lm_eval/api/task.py
lm_eval/api/task.py
+28
-10
No files found.
lm_eval/api/task.py
View file @
7fc43656
...
@@ -5,6 +5,7 @@ import random
...
@@ -5,6 +5,7 @@ import random
import
re
import
re
from
collections.abc
import
Callable
from
collections.abc
import
Callable
from
dataclasses
import
asdict
,
dataclass
from
dataclasses
import
asdict
,
dataclass
from
inspect
import
getsource
from
typing
import
Any
,
List
,
Literal
,
Tuple
,
Union
from
typing
import
Any
,
List
,
Literal
,
Tuple
,
Union
import
datasets
import
datasets
...
@@ -37,7 +38,6 @@ ALL_OUTPUT_TYPES = [
...
@@ -37,7 +38,6 @@ ALL_OUTPUT_TYPES = [
"generate_until"
,
"generate_until"
,
]
]
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
eval_logger
=
logging
.
getLogger
(
"lm-eval"
)
...
@@ -110,15 +110,13 @@ class TaskConfig(dict):
...
@@ -110,15 +110,13 @@ class TaskConfig(dict):
"do_sample"
:
False
,
"do_sample"
:
False
,
}
}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def
__getitem__
(
self
,
item
):
def
__getitem__
(
self
,
item
):
return
getattr
(
self
,
item
)
return
getattr
(
self
,
item
)
def
__setitem__
(
self
,
item
,
value
):
def
__setitem__
(
self
,
item
,
value
):
return
setattr
(
self
,
item
,
value
)
return
setattr
(
self
,
item
,
value
)
def
to_dict
(
self
,
keep_callable
=
False
)
:
def
to_dict
(
self
,
keep_callable
:
bool
=
False
)
->
dict
:
"""dumps the current config as a dictionary object, as a printable format.
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
null fields will not be printed.
Used for dumping results alongside full task configuration
Used for dumping results alongside full task configuration
...
@@ -133,14 +131,34 @@ class TaskConfig(dict):
...
@@ -133,14 +131,34 @@ class TaskConfig(dict):
for
k
,
v
in
list
(
cfg_dict
.
items
()):
for
k
,
v
in
list
(
cfg_dict
.
items
()):
if
v
is
None
:
if
v
is
None
:
cfg_dict
.
pop
(
k
)
cfg_dict
.
pop
(
k
)
elif
isinstance
(
v
,
Callable
):
elif
k
==
"metric_list"
:
if
keep_callable
:
for
metric_dict
in
v
:
cfg_dict
[
k
]
=
v
for
metric_key
,
metric_value
in
metric_dict
.
items
():
else
:
if
callable
(
metric_value
):
# TODO: this should handle Promptsource template objects as a separate case?
metric_dict
[
metric_key
]
=
self
.
serialize_function
(
cfg_dict
[
k
]
=
str
(
v
)
metric_value
,
keep_callable
=
keep_callable
)
cfg_dict
[
k
]
=
v
elif
callable
(
v
):
cfg_dict
[
k
]
=
self
.
serialize_function
(
v
,
keep_callable
=
keep_callable
)
return
cfg_dict
return
cfg_dict
def
serialize_function
(
self
,
value
:
Union
[
Callable
,
str
],
keep_callable
=
False
)
->
Union
[
Callable
,
str
]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if
keep_callable
:
return
value
else
:
try
:
return
getsource
(
value
)
except
(
TypeError
,
OSError
):
return
str
(
value
)
class
Task
(
abc
.
ABC
):
class
Task
(
abc
.
ABC
):
"""A task represents an entire benchmark including its dataset, problems,
"""A task represents an entire benchmark including its dataset, problems,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment