Commit 227f1a74 authored by Baber's avatar Baber
Browse files

refactor: improve dataset and metric handling in TaskConfig

parent 3b4d0af1
......@@ -29,6 +29,7 @@ class GroupConfig(dict):
aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict]
] = None
version: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
......@@ -48,6 +49,11 @@ class GroupConfig(dict):
AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list
]
self.version = (
self.version or self.metadata.get("version", "1.0")
if self.metadata
else "1.0"
)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
......
......@@ -639,7 +639,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self.config.get_metrics
# self.metric_list: list[MetricConfig] = self.config.get_metrics
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -655,7 +655,10 @@ class ConfigurableTask(Task):
else:
self.prompt = None
if self.config.fewshot_cfg.num() > 0 and self.fewshot_docs() is not None:
if (
self.config.fewshot_cfg.num_fewshot() > 0
and self.fewshot_docs() is not None
):
self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
......@@ -724,21 +727,23 @@ class ConfigurableTask(Task):
) -> None:
from packaging.version import parse as vparse
self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {},
self.config.metadata or {},
)
if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"):
dataset_kwargs.pop("trust_remote_code", None)
if isinstance(self.config.custom_dataset, Callable):
if isinstance(df := self.config.custom_dataset, Callable):
eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
)
self.dataset = self.config.custom_dataset(
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
)
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else:
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
**dataset_kwargs if dataset_kwargs is not None else {},
path=self.config.dataset_path,
name=self.config.dataset_name,
**self.config.dataset_kwargs,
)
def has_training_docs(self) -> bool:
......@@ -975,7 +980,7 @@ class ConfigurableTask(Task):
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
f.apply(self._instances)
f.ensemble.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
return self._instances
......@@ -1214,7 +1219,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_name for m in self.metric_list]:
if "acc_mutual_info" in [m.metric_name for m in self.config._metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1281,7 +1286,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(m.metric_name for m in self.metric_list)
use_metric = list(m.metric_name for m in self.config._metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
......@@ -1407,7 +1412,7 @@ class ConfigurableTask(Task):
# cast gold to the same type as result
gold = type(result)(gold)
for metric in self.metric_list:
for metric in self.config._metric_list:
if self.multiple_target:
# in the case where we have multiple targets,
# return true if any are true
......@@ -1470,10 +1475,10 @@ class ConfigurableTask(Task):
return result_dict
def aggregation(self) -> dict:
return {k.name: k.aggregation_fn for k in self.metric_list}
return {k.name: k.aggregation_fn for k in self.config._metric_list}
def higher_is_better(self) -> dict:
return {k.name: k.higher_is_better for k in self.metric_list}
return {k.name: k.higher_is_better for k in self.config._metric_list}
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
......
......@@ -2,6 +2,7 @@ import logging
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
......@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task
from lm_eval.filters import FilterEnsemble
eval_logger = logging.getLogger(__name__)
......@@ -29,8 +29,8 @@ class FilterConfig:
"""Encapsulates information about a single filter."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = field(default_factory=dict)
ensemble: FilterEnsemble
metric_list: list[MetricConfig]
@dataclass
......@@ -117,21 +117,10 @@ class FewshotConfig:
)
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
path: Optional[str] = None
name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None
metadata: Optional[dict] = field(default_factory=dict)
@dataclass
class TaskConfig(dict):
# task naming/registry
task: str
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
# HF dataset options.
......@@ -140,7 +129,7 @@ class TaskConfig(dict):
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
......@@ -177,9 +166,9 @@ class TaskConfig(dict):
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = None
_metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = field(init=False)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False)
def __post_init__(self) -> None:
......@@ -215,18 +204,10 @@ class TaskConfig(dict):
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup dataset config--- #
self.ds_cfg = DatasetConfig(
path=self.dataset_path,
name=self.dataset_name,
kwargs=self.dataset_kwargs,
custom=self.custom_dataset,
metadata=self.metadata or {},
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg["num_fewshot"],
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg.get("num_fewshot", 0),
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
......@@ -234,8 +215,9 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
@property
def get_metrics(self) -> list["MetricConfig"]:
def _get_metric(
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -245,8 +227,10 @@ class TaskConfig(dict):
is_higher_better,
)
# if metric_list defined inside a filter, use that; otherwise use the task's metric_list
metric_list = metric_list or self.metric_list
metrics = []
if self.metric_list is None:
if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
......@@ -263,7 +247,7 @@ class TaskConfig(dict):
)
else:
# ---------- 2. Process user-defined metrics from config ----------
for metric_config in self.metric_list:
for metric_config in metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
......@@ -324,34 +308,50 @@ class TaskConfig(dict):
hf_evaluate=_hf_evaluate_metric,
)
)
for m in metrics:
if m not in self._metric_list:
self._metric_list.append(m)
return metrics
@property
def get_filters(self) -> list["FilterEnsemble"]:
def get_filters(self) -> list["FilterConfig"]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [build_filter_ensemble("none", [("take_first", None)])]
return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self._get_metric(metric_list=None),
)
]
else:
def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {k: v for k, v in d.items() if k != "function"}
return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
x = [
FilterConfig(
name=cfg["name"],
ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self._get_metric(metric_list=cfg.get("metric_list")),
)
for cfg in configs
]
return x
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
......
......@@ -46,7 +46,12 @@ def limit() -> int:
return 10
class BaseTasks:
@pytest.mark.parametrize(
"task_class",
task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}",
)
class TestBaseTasks:
"""
Base class for testing tasks
"""
......@@ -160,8 +165,50 @@ class BaseTasks:
task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}",
)
class TestNewTasksElseDefault(BaseTasks):
class TestNewTasksElseDefault(TestBaseTasks):
"""
Test class parameterized with a list of new/modified tasks
(or a set of default tasks if none have been modified)
"""
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(TestBaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment