Commit 3b4d0af1 authored by Baber's avatar Baber
Browse files

refactor: update type hints and improve filter ensemble construction

parent c81c03ee
......@@ -101,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests: list[Instance]) -> list[str]:
def generate_until(self, requests: list["Instance"]) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -432,7 +432,7 @@ class TemplateLM(LM):
@abc.abstractmethod
def generate_until(
self, requests, disable_tqdm: bool = False
self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[str]:
"""Generate until a stopping sequence.
......
......@@ -102,9 +102,9 @@ class Task(abc.ABC):
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig()
self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [["take_first", None]])]
self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: Optional[random.Random] = (
None # purposely induce errors in case of improper usage
)
......@@ -655,7 +655,7 @@ class ConfigurableTask(Task):
else:
self.prompt = None
if self.config.fewshot_cfg.num > 0 and self.fewshot_docs() is not None:
if self.config.fewshot_cfg.num() > 0 and self.fewshot_docs() is not None:
self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
......
......@@ -2,7 +2,6 @@ import logging
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
......@@ -10,7 +9,8 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task, eval_logger
from lm_eval.api.task import Task
from lm_eval.filters import FilterEnsemble
eval_logger = logging.getLogger(__name__)
......@@ -35,7 +35,9 @@ class FilterConfig:
@dataclass
class FewshotConfig:
num: int = 0
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: Optional[str] = None
sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None
......@@ -162,10 +164,10 @@ class TaskConfig(dict):
fewshot_config: Optional[dict] = None
# runtime configuration options
num_fewshot: Optional[int] = 0
generation_kwargs: Optional[dict] = None
# scoring options
metric_list: Optional[list] = None
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False
......@@ -224,6 +226,7 @@ class TaskConfig(dict):
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg["num_fewshot"],
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
......@@ -331,26 +334,30 @@ class TaskConfig(dict):
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [build_filter_ensemble("none", [["take_first", None]])]
return [build_filter_ensemble("none", [("take_first", None)])]
else:
def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
components=[_strip_fn(f) for f in cfg["filter"]],
)
for cfg in configs
]
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
def __getitem__(self, item):
return getattr(self, item)
......
from functools import partial
from typing import List, Union
from typing import Iterable, List, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
......@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str, components: list[Union[list[dict], list[str]]]
filter_name: str, components: list[tuple[str, Optional[dict]]]
) -> FilterEnsemble:
"""
Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment