Commit fedaf262 authored by Baber's avatar Baber
Browse files

refactor: improve dataset and metric handling in TaskConfig

parent 863ff340
...@@ -639,7 +639,7 @@ class ConfigurableTask(Task): ...@@ -639,7 +639,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self.config.get_metrics # self.metric_list: list[MetricConfig] = self.config.get_metrics
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -655,7 +655,10 @@ class ConfigurableTask(Task): ...@@ -655,7 +655,10 @@ class ConfigurableTask(Task):
else: else:
self.prompt = None self.prompt = None
if self.config.fewshot_cfg.num() > 0 and self.fewshot_docs() is not None: if (
self.config.fewshot_cfg.num_fewshot() > 0
and self.fewshot_docs() is not None
):
self.fewshot_rnd = random.Random() self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler( self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
...@@ -722,19 +725,21 @@ class ConfigurableTask(Task): ...@@ -722,19 +725,21 @@ class ConfigurableTask(Task):
def download( def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None: ) -> None:
if isinstance(df := self.config.ds_cfg.custom, Callable): self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {},
self.config.metadata or {},
)
if isinstance(df := self.config.custom_dataset, Callable):
eval_logger.warning( eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme." + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
) )
self.dataset = df( self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
**(self.config.ds_cfg.kwargs | self.config.ds_cfg.metadata)
)
else: else:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.config.ds_cfg.path, path=self.config.dataset_path,
name=self.config.ds_cfg.name, name=self.config.dataset_name,
**self.config.ds_cfg.kwargs if self.config.ds_cfg.kwargs else {}, **self.config.dataset_kwargs,
) )
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
...@@ -971,7 +976,7 @@ class ConfigurableTask(Task): ...@@ -971,7 +976,7 @@ class ConfigurableTask(Task):
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.ensemble.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning("No filter defined, passing through instances")
return self._instances return self._instances
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize from lm_eval.config.utils import maybe_serialize
...@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize ...@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.filters import FilterEnsemble
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -29,8 +29,8 @@ class FilterConfig: ...@@ -29,8 +29,8 @@ class FilterConfig:
"""Encapsulates information about a single filter.""" """Encapsulates information about a single filter."""
name: str name: str
fn: Optional[Callable] = None ensemble: FilterEnsemble
kwargs: Optional[dict] = field(default_factory=dict) metric_list: list[MetricConfig]
@dataclass @dataclass
...@@ -117,17 +117,6 @@ class FewshotConfig: ...@@ -117,17 +117,6 @@ class FewshotConfig:
) )
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
path: Optional[str] = None
name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None
metadata: Optional[dict] = field(default_factory=dict)
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
...@@ -140,7 +129,7 @@ class TaskConfig(dict): ...@@ -140,7 +129,7 @@ class TaskConfig(dict):
custom_dataset: Optional[Callable] = None custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None dataset_path: Optional[str] = None
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None training_split: Optional[str] = None
validation_split: Optional[str] = None validation_split: Optional[str] = None
test_split: Optional[str] = None test_split: Optional[str] = None
...@@ -177,9 +166,9 @@ class TaskConfig(dict): ...@@ -177,9 +166,9 @@ class TaskConfig(dict):
default_factory=dict default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = None _metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = None _filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = field(init=False) # ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False) fewshot_cfg: FewshotConfig = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -215,14 +204,6 @@ class TaskConfig(dict): ...@@ -215,14 +204,6 @@ class TaskConfig(dict):
eval_logger.warning( eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
) )
# ---setup dataset config--- #
self.ds_cfg = DatasetConfig(
path=self.dataset_path,
name=self.dataset_name,
kwargs=self.dataset_kwargs,
custom=self.custom_dataset,
metadata=self.metadata or {},
)
# ---setup fewshot config--- # # ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {} _fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig( self.fewshot_cfg = FewshotConfig(
...@@ -234,8 +215,9 @@ class TaskConfig(dict): ...@@ -234,8 +215,9 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None), fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
) )
@property def get_metric(
def get_metrics(self) -> list["MetricConfig"]: self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -245,8 +227,9 @@ class TaskConfig(dict): ...@@ -245,8 +227,9 @@ class TaskConfig(dict):
is_higher_better, is_higher_better,
) )
metric_list = metric_list or self.metric_list
metrics = [] metrics = []
if self.metric_list is None: if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ---------- # ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info( eval_logger.info(
...@@ -263,7 +246,7 @@ class TaskConfig(dict): ...@@ -263,7 +246,7 @@ class TaskConfig(dict):
) )
else: else:
# ---------- 2. Process user-defined metrics from config ---------- # ---------- 2. Process user-defined metrics from config ----------
for metric_config in self.metric_list: for metric_config in metric_list:
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
_metric_fn_kwargs = { _metric_fn_kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -324,34 +307,50 @@ class TaskConfig(dict): ...@@ -324,34 +307,50 @@ class TaskConfig(dict):
hf_evaluate=_hf_evaluate_metric, hf_evaluate=_hf_evaluate_metric,
) )
) )
for m in metrics:
if m not in self._metric_list:
self._metric_list.extend(m)
return metrics return metrics
@property @property
def get_filters(self) -> list["FilterEnsemble"]: def get_filters(self) -> list["FilterConfig"]:
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
if not self.filter_list: if not self.filter_list:
eval_logger.debug( eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats." "No custom filters defined; falling back to 'take_first' for handling repeats."
) )
return [build_filter_ensemble("none", [("take_first", None)])] return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self.get_metric(metric_list=None),
)
]
else: else:
def _strip_fn(d: dict) -> tuple[str, dict]: def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {k: v for k, v in d.items() if k != "function"} return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = ( configs = (
self.filter_list.values() self.filter_list.values()
if isinstance(self.filter_list, dict) if isinstance(self.filter_list, dict)
else self.filter_list else self.filter_list
) )
return [ x = [
build_filter_ensemble( FilterConfig(
filter_name=cfg["name"], name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]], ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self.get_metric(metric_list=cfg.get("metric_list")),
) )
for cfg in configs for cfg in configs
] ]
return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> "TaskConfig": def from_yaml(cls, data: dict) -> "TaskConfig":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment