Commit 227f1a74 authored by Baber's avatar Baber
Browse files

refactor: improve dataset and metric handling in TaskConfig

parent 3b4d0af1
...@@ -29,6 +29,7 @@ class GroupConfig(dict): ...@@ -29,6 +29,7 @@ class GroupConfig(dict):
aggregate_metric_list: Optional[ aggregate_metric_list: Optional[
Union[List[AggMetricConfig], AggMetricConfig, dict] Union[List[AggMetricConfig], AggMetricConfig, dict]
] = None ] = None
version: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
...@@ -48,6 +49,11 @@ class GroupConfig(dict): ...@@ -48,6 +49,11 @@ class GroupConfig(dict):
AggMetricConfig(**item) if isinstance(item, dict) else item AggMetricConfig(**item) if isinstance(item, dict) else item
for item in self.aggregate_metric_list for item in self.aggregate_metric_list
] ]
self.version = (
self.version or self.metadata.get("version", "1.0")
if self.metadata
else "1.0"
)
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format. """dumps the current config as a dictionary object, as a printable format.
......
...@@ -639,7 +639,7 @@ class ConfigurableTask(Task): ...@@ -639,7 +639,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self.config.get_metrics # self.metric_list: list[MetricConfig] = self.config.get_metrics
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -655,7 +655,10 @@ class ConfigurableTask(Task): ...@@ -655,7 +655,10 @@ class ConfigurableTask(Task):
else: else:
self.prompt = None self.prompt = None
if self.config.fewshot_cfg.num() > 0 and self.fewshot_docs() is not None: if (
self.config.fewshot_cfg.num_fewshot() > 0
and self.fewshot_docs() is not None
):
self.fewshot_rnd = random.Random() self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler( self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
...@@ -724,21 +727,23 @@ class ConfigurableTask(Task): ...@@ -724,21 +727,23 @@ class ConfigurableTask(Task):
) -> None: ) -> None:
from packaging.version import parse as vparse from packaging.version import parse as vparse
self.config.dataset_kwargs, self.config.metadata = (
self.config.dataset_kwargs or {},
self.config.metadata or {},
)
if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"): if dataset_kwargs and vparse(datasets.__version__) >= vparse("4.0.0"):
dataset_kwargs.pop("trust_remote_code", None) dataset_kwargs.pop("trust_remote_code", None)
if isinstance(self.config.custom_dataset, Callable): if isinstance(df := self.config.custom_dataset, Callable):
eval_logger.warning( eval_logger.warning(
f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager." f"{self.config.task}: Custom kwargs can be passed to `--metadata` in console (as json string) or to the TaskManager."
+ "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme." + "\nFor example --metadata='{\"max_seq_lengths\":[4096, 8192]}'. For details see task Readme."
) )
self.dataset = self.config.custom_dataset( self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
**(self.config.metadata or {}), **(self.config.dataset_kwargs or {})
)
else: else:
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.config.dataset_path,
name=self.DATASET_NAME, name=self.config.dataset_name,
**dataset_kwargs if dataset_kwargs is not None else {}, **self.config.dataset_kwargs,
) )
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
...@@ -975,7 +980,7 @@ class ConfigurableTask(Task): ...@@ -975,7 +980,7 @@ class ConfigurableTask(Task):
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.ensemble.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning("No filter defined, passing through instances")
return self._instances return self._instances
...@@ -1214,7 +1219,7 @@ class ConfigurableTask(Task): ...@@ -1214,7 +1219,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_name for m in self.metric_list]: if "acc_mutual_info" in [m.metric_name for m in self.config._metric_list]:
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -1281,7 +1286,7 @@ class ConfigurableTask(Task): ...@@ -1281,7 +1286,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(m.metric_name for m in self.metric_list) use_metric = list(m.metric_name for m in self.config._metric_list)
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
...@@ -1407,7 +1412,7 @@ class ConfigurableTask(Task): ...@@ -1407,7 +1412,7 @@ class ConfigurableTask(Task):
# cast gold to the same type as result # cast gold to the same type as result
gold = type(result)(gold) gold = type(result)(gold)
for metric in self.metric_list: for metric in self.config._metric_list:
if self.multiple_target: if self.multiple_target:
# in the case where we have multiple targets, # in the case where we have multiple targets,
# return true if any are true # return true if any are true
...@@ -1470,10 +1475,10 @@ class ConfigurableTask(Task): ...@@ -1470,10 +1475,10 @@ class ConfigurableTask(Task):
return result_dict return result_dict
def aggregation(self) -> dict: def aggregation(self) -> dict:
return {k.name: k.aggregation_fn for k in self.metric_list} return {k.name: k.aggregation_fn for k in self.config._metric_list}
def higher_is_better(self) -> dict: def higher_is_better(self) -> dict:
return {k.name: k.higher_is_better for k in self.metric_list} return {k.name: k.higher_is_better for k in self.config._metric_list}
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
......
...@@ -2,6 +2,7 @@ import logging ...@@ -2,6 +2,7 @@ import logging
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize from lm_eval.config.utils import maybe_serialize
...@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize ...@@ -10,7 +11,6 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task from lm_eval.api.task import Task
from lm_eval.filters import FilterEnsemble
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -29,8 +29,8 @@ class FilterConfig: ...@@ -29,8 +29,8 @@ class FilterConfig:
"""Encapsulates information about a single filter.""" """Encapsulates information about a single filter."""
name: str name: str
fn: Optional[Callable] = None ensemble: FilterEnsemble
kwargs: Optional[dict] = field(default_factory=dict) metric_list: list[MetricConfig]
@dataclass @dataclass
...@@ -117,21 +117,10 @@ class FewshotConfig: ...@@ -117,21 +117,10 @@ class FewshotConfig:
) )
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
path: Optional[str] = None
name: Optional[str] = None
kwargs: Optional[dict] = field(default_factory=dict)
custom: Optional[Callable] = None
metadata: Optional[dict] = field(default_factory=dict)
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: str task: Optional[str] = None
task_alias: Optional[str] = None task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None tag: Optional[Union[str, list]] = None
# HF dataset options. # HF dataset options.
...@@ -140,7 +129,7 @@ class TaskConfig(dict): ...@@ -140,7 +129,7 @@ class TaskConfig(dict):
custom_dataset: Optional[Callable] = None custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None dataset_path: Optional[str] = None
dataset_name: Optional[str] = None dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None training_split: Optional[str] = None
validation_split: Optional[str] = None validation_split: Optional[str] = None
test_split: Optional[str] = None test_split: Optional[str] = None
...@@ -177,9 +166,9 @@ class TaskConfig(dict): ...@@ -177,9 +166,9 @@ class TaskConfig(dict):
default_factory=dict default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = None _metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = None _filter_list: list[FilterConfig] = None
ds_cfg: DatasetConfig = field(init=False) # ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False) fewshot_cfg: FewshotConfig = field(init=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -215,18 +204,10 @@ class TaskConfig(dict): ...@@ -215,18 +204,10 @@ class TaskConfig(dict):
eval_logger.warning( eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
) )
# ---setup dataset config--- #
self.ds_cfg = DatasetConfig(
path=self.dataset_path,
name=self.dataset_name,
kwargs=self.dataset_kwargs,
custom=self.custom_dataset,
metadata=self.metadata or {},
)
# ---setup fewshot config--- # # ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {} _fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig( self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg["num_fewshot"], num_fewshot=lambda: self.num_fewshot or _fewshot_cfg.get("num_fewshot", 0),
split=self.fewshot_split, split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"), sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None), samples=_fewshot_cfg.get("samples", None),
...@@ -234,8 +215,9 @@ class TaskConfig(dict): ...@@ -234,8 +215,9 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None), fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
) )
@property def _get_metric(
def get_metrics(self) -> list["MetricConfig"]: self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -245,8 +227,10 @@ class TaskConfig(dict): ...@@ -245,8 +227,10 @@ class TaskConfig(dict):
is_higher_better, is_higher_better,
) )
# if metric_list defined inside a filter, use that; otherwise use the task's metric_list
metric_list = metric_list or self.metric_list
metrics = [] metrics = []
if self.metric_list is None: if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ---------- # ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info( eval_logger.info(
...@@ -263,7 +247,7 @@ class TaskConfig(dict): ...@@ -263,7 +247,7 @@ class TaskConfig(dict):
) )
else: else:
# ---------- 2. Process user-defined metrics from config ---------- # ---------- 2. Process user-defined metrics from config ----------
for metric_config in self.metric_list: for metric_config in metric_list:
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
_metric_fn_kwargs = { _metric_fn_kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -324,34 +308,50 @@ class TaskConfig(dict): ...@@ -324,34 +308,50 @@ class TaskConfig(dict):
hf_evaluate=_hf_evaluate_metric, hf_evaluate=_hf_evaluate_metric,
) )
) )
for m in metrics:
if m not in self._metric_list:
self._metric_list.append(m)
return metrics return metrics
@property @property
def get_filters(self) -> list["FilterEnsemble"]: def get_filters(self) -> list["FilterConfig"]:
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
if not self.filter_list: if not self.filter_list:
eval_logger.debug( eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats." "No custom filters defined; falling back to 'take_first' for handling repeats."
) )
return [build_filter_ensemble("none", [("take_first", None)])] return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self._get_metric(metric_list=None),
)
]
else: else:
def _strip_fn(d: dict) -> tuple[str, dict]: def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {k: v for k, v in d.items() if k != "function"} return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = ( configs = (
self.filter_list.values() self.filter_list.values()
if isinstance(self.filter_list, dict) if isinstance(self.filter_list, dict)
else self.filter_list else self.filter_list
) )
return [ x = [
build_filter_ensemble( FilterConfig(
filter_name=cfg["name"], name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]], ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self._get_metric(metric_list=cfg.get("metric_list")),
) )
for cfg in configs for cfg in configs
] ]
return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> "TaskConfig": def from_yaml(cls, data: dict) -> "TaskConfig":
......
...@@ -46,7 +46,12 @@ def limit() -> int: ...@@ -46,7 +46,12 @@ def limit() -> int:
return 10 return 10
class BaseTasks: @pytest.mark.parametrize(
"task_class",
task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}",
)
class TestBaseTasks:
""" """
Base class for testing tasks Base class for testing tasks
""" """
...@@ -160,8 +165,50 @@ class BaseTasks: ...@@ -160,8 +165,50 @@ class BaseTasks:
task_class(get_new_tasks_else_default()), task_class(get_new_tasks_else_default()),
ids=lambda x: f"{x.config.task}", ids=lambda x: f"{x.config.task}",
) )
class TestNewTasksElseDefault(BaseTasks): class TestNewTasksElseDefault(TestBaseTasks):
""" """
Test class parameterized with a list of new/modified tasks Test class parameterized with a list of new/modified tasks
(or a set of default tasks if none have been modified) (or a set of default tasks if none have been modified)
""" """
@pytest.mark.parametrize(
"task_class",
task_class(
["arc_easy_unitxt"], tasks.TaskManager(include_path="./tests/testconfigs")
),
ids=lambda x: f"{x.config.task}",
)
class TestUnitxtTasks(TestBaseTasks):
"""
Test class for Unitxt tasks parameterized with a small custom
task as described here:
https://www.unitxt.ai/en/latest/docs/lm_eval.html
"""
def test_check_training_docs(self, task_class: ConfigurableTask):
if task_class.has_training_docs():
assert task_class.dataset["train"] is not None
def test_check_validation_docs(self, task_class):
if task_class.has_validation_docs():
assert task_class.dataset["validation"] is not None
def test_check_test_docs(self, task_class):
task = task_class
if task.has_test_docs():
assert task.dataset["test"] is not None
def test_doc_to_text(self, task_class, limit: int):
task = task_class
arr = (
list(islice(task.test_docs(), limit))
if task.has_test_docs()
else list(islice(task.validation_docs(), limit))
)
_array = [task.doc_to_text(doc) for doc in arr]
if not task.multiple_input:
for x in _array:
assert isinstance(x, str)
else:
pass
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment