Commit eb9c6a57 authored by lintangsutawika's avatar lintangsutawika
Browse files

add group config

parent cd0ad1b4
......@@ -51,6 +51,55 @@ ALL_OUTPUT_TYPES = [
eval_logger = logging.getLogger("lm-eval")
@dataclass
class GroupConfig(dict):
group: Optional[Union[str, list]] = None
aggregate_metric: Optional[str] = False
aggregate_fn: Optional[str] = "mean"
weight_by_size: Optional[str] = False
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
) -> Union[Callable, str]:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
@dataclass
class TaskConfig(dict):
# task naming/registry
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment