Unverified Commit 7fc43656 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

serialize callable functions in config (#1367)

parent 488759d2
...@@ -5,6 +5,7 @@ import random ...@@ -5,6 +5,7 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, List, Literal, Tuple, Union from typing import Any, List, Literal, Tuple, Union
import datasets import datasets
...@@ -37,7 +38,6 @@ ALL_OUTPUT_TYPES = [ ...@@ -37,7 +38,6 @@ ALL_OUTPUT_TYPES = [
"generate_until", "generate_until",
] ]
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
...@@ -110,15 +110,13 @@ class TaskConfig(dict): ...@@ -110,15 +110,13 @@ class TaskConfig(dict):
"do_sample": False, "do_sample": False,
} }
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def __setitem__(self, item, value): def __setitem__(self, item, value):
return setattr(self, item, value) return setattr(self, item, value)
def to_dict(self, keep_callable=False): def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
null fields will not be printed. null fields will not be printed.
Used for dumping results alongside full task configuration Used for dumping results alongside full task configuration
...@@ -133,14 +131,34 @@ class TaskConfig(dict): ...@@ -133,14 +131,34 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()): for k, v in list(cfg_dict.items()):
if v is None: if v is None:
cfg_dict.pop(k) cfg_dict.pop(k)
elif isinstance(v, Callable): elif k == "metric_list":
if keep_callable: for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v cfg_dict[k] = v
else: elif callable(v):
# TODO: this should handle Promptsource template objects as a separate case? cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
cfg_dict[k] = str(v)
return cfg_dict return cfg_dict
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment