Commit fedaf262 authored by Baber's avatar Baber
Browse files

refactor: improve dataset and metric handling in TaskConfig

parent 863ff340
......@@ -639,7 +639,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self.config.get_metrics
# self.metric_list: list[MetricConfig] = self.config.get_metrics
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -655,7 +655,10 @@ class ConfigurableTask(Task):
else:
self.prompt = None
if self.config.fewshot_cfg.num() > 0 and self.fewshot_docs() is not None:
if (
self.config.fewshot_cfg.num_fewshot() > 0
and self.fewshot_docs() is not None
):
self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
......@@ -722,19 +725,21 @@ class ConfigurableTask(Task):
def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
) -> None:
if isinstance(df := self.config.ds_cfg.custom, Callable):
self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {},
self.config.metadata or {},
)
if isinstance(df := self.config.custom_dataset, Callable):
eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
)
self.dataset = df(
**(self.config.ds_cfg.kwargs | self.config.ds_cfg.metadata)
)
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else:
self.dataset = datasets.load_dataset(
path=self.config.ds_cfg.path,
name=self.config.ds_cfg.name,
**self.config.ds_cfg.kwargs if self.config.ds_cfg.kwargs else {},
path=self.config.dataset_path,
name=self.config.dataset_name,
**self.config.dataset_kwargs,
)
def has_training_docs(self) -> bool:
......@@ -971,7 +976,7 @@ class ConfigurableTask(Task):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
f.ensemble.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......
......@@ -2,6 +2,7 @@ import logging
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
......@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task
from lm_eval.filters import FilterEnsemble
eval_logger = logging.getLogger(__name__)
......@@ -29,8 +29,8 @@ class FilterConfig:
"""Encapsulates information about a single filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = field(default_factory=dict)
ensemble: FilterEnsemble
metric_list: list[MetricConfig]
@dataclass
......@@ -117,17 +117,6 @@ class FewshotConfig:
)
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
path: Optional[str] = None
name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None
metadata: Optional[dict] = field(default_factory=dict)
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -140,7 +129,7 @@ class TaskConfig(dict):
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
......@@ -177,9 +166,9 @@ class TaskConfig(dict):
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = None
_metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = field(init=False)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False)
def __post_init__(self) -> None:
......@@ -215,14 +204,6 @@ class TaskConfig(dict):
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup dataset config--- #
self.ds_cfg = DatasetConfig(
path=self.dataset_path,
name=self.dataset_name,
kwargs=self.dataset_kwargs,
custom=self.custom_dataset,
metadata=self.metadata or {},
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
......@@ -234,8 +215,9 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
@property
def get_metrics(self) -> list["MetricConfig"]:
def get_metric(
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -245,8 +227,9 @@ class TaskConfig(dict):
is_higher_better,
)
metric_list = metric_list or self.metric_list
metrics = []
if self.metric_list is None:
if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
......@@ -263,7 +246,7 @@ class TaskConfig(dict):
)
else:
# ---------- 2. Process user-defined metrics from config ----------
for metric_config in self.metric_list:
for metric_config in metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
......@@ -324,34 +307,50 @@ class TaskConfig(dict):
hf_evaluate=_hf_evaluate_metric,
)
)
for m in metrics:
if m not in self._metric_list:
self._metric_list.extend(m)
return metrics
@property
def get_filters(self) -> list["FilterEnsemble"]:
def get_filters(self) -> list["FilterConfig"]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [build_filter_ensemble("none", [("take_first", None)])]
return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self.get_metric(metric_list=None),
)
]
else:
def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {k: v for k, v in d.items() if k != "function"}
return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
x = [
FilterConfig(
name=cfg["name"],
ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self.get_metric(metric_list=cfg.get("metric_list")),
)
for cfg in configs
]
return x
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment