"examples/deprecated_api/imagenet/main.py" did not exist on "bc62f325384a62663901f48a09001ca9477646e9"
Unverified Commit 7fc43656 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

serialize callable functions in config (#1367)

parent 488759d2
......@@ -5,6 +5,7 @@ import random
import re
from collections.abc import Callable
from dataclasses import asdict, dataclass
from inspect import getsource
from typing import Any, List, Literal, Tuple, Union
import datasets
......@@ -37,7 +38,6 @@ ALL_OUTPUT_TYPES = [
"generate_until",
]
eval_logger = logging.getLogger("lm-eval")
......@@ -110,15 +110,13 @@ class TaskConfig(dict):
"do_sample": False,
}
# TODO: how to make TaskConfigs be de- and re-serializable, even when using the !function constructor?
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable=False):
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
......@@ -133,14 +131,34 @@ class TaskConfig(dict):
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif isinstance(v, Callable):
if keep_callable:
cfg_dict[k] = v
else:
# TODO: this should handle Promptsource template objects as a separate case?
cfg_dict[k] = str(v)
elif k == "metric_list":
for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment