"git@developer.sourcefind.cn:sugon_wxj/megatron-lm.git" did not exist on "48269d8d81005e2a3997a50fb021c82180ac43a3"
Commit 86039e85 authored by lintangsutawika's avatar lintangsutawika
Browse files

add configurable group

parent 3473e196
...@@ -53,10 +53,13 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -53,10 +53,13 @@ eval_logger = logging.getLogger("lm-eval")
@dataclass @dataclass
class GroupConfig(dict): class GroupConfig(dict):
group: Optional[Union[str, list]] = None group: Optional[str] = None
group_alias: Optional[str] = None
task: Optional[Union[str, list]] = None
aggregate_metric: Optional[str] = False aggregate_metric: Optional[str] = False
aggregate_fn: Optional[str] = "mean" aggregate_fn: Optional[str] = "mean"
weight_by_size: Optional[str] = False weight_by_size: Optional[str] = False
metric_alias: Optional[str] = None
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -100,14 +103,39 @@ class GroupConfig(dict): ...@@ -100,14 +103,39 @@ class GroupConfig(dict):
return str(value) return str(value)
class ConfigurableGroup(abc.ABC):
def __init__(
self,
config: Optional[dict] = None,
) -> None:
self._config = GroupConfig(**config)
@property
def group(self):
return self._config.group
@property
def group_alias(self):
return self._config.group_alias
@property
def config(self):
return self._config.to_dict()
def __repr__(self):
return (
f"ConfigurableGroup(group={self.group},"
f"group_alias={self.group_alias})"
)
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: Optional[str] = None
task_alias: Optional[str] = None task_alias: Optional[str] = None
tags: Optional[Union[str, list]] = None
group: Optional[Union[str, list]] = None group: Optional[Union[str, list]] = None
group_alias: Optional[Union[str, list]] = None group_alias: Optional[Union[str, list]] = None
group_config: Optional[dict] = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
...@@ -1418,7 +1446,6 @@ class ConfigurableTask(Task): ...@@ -1418,7 +1446,6 @@ class ConfigurableTask(Task):
def __repr__(self): def __repr__(self):
return ( return (
f"ConfigurableTask(task_name={getattr(self.config, 'task', None)}," f"ConfigurableTask(task_name={getattr(self.config, 'task', None)},"
f"group_name={getattr(self.config, 'group', None)},"
f"output_type={self.OUTPUT_TYPE}," f"output_type={self.OUTPUT_TYPE},"
f"num_fewshot={getattr(self.config, 'num_fewshot', None)}," f"num_fewshot={getattr(self.config, 'num_fewshot', None)},"
f"num_samples={len(self.eval_docs)})" f"num_samples={len(self.eval_docs)})"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment