Commit 86039e85 authored by lintangsutawika's avatar lintangsutawika
Browse files

add configurable group

parent 3473e196
......@@ -53,10 +53,13 @@ eval_logger = logging.getLogger("lm-eval")
@dataclass
class GroupConfig(dict):
group: Optional[Union[str, list]] = None
group: Optional[str] = None
group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None
aggregate_metric: Optional[str] = False
aggregate_fn: Optional[str] = "mean"
weight_by_size: Optional[str] = False
metric_alias: Optional[str] = None
def __getitem__(self, item):
return getattr(self, item)
......@@ -100,14 +103,39 @@ class GroupConfig(dict):
return str(value)
class ConfigurableGroup(abc.ABC):
def __init__(
self,
config: Optional[dict] = None,
) -> None:
self._config = GroupConfig(**config)
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def config(self):
return self._config.to_dict()
def __repr__(self):
return (
f"ConfigurableGroup(group={self.group},"
f"group_alias={self.group_alias})"
)
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tags: Optional[Union[str, list]] = None
group: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None
group_config: Optional[dict] = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
......@@ -1418,7 +1446,6 @@ class ConfigurableTask(Task):
def __repr__(self):
return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment