Commit 3b4d0af1 authored by Baber's avatar Baber
Browse files

refactor: update type hints and improve filter ensemble construction

parent c81c03ee
...@@ -101,7 +101,7 @@ class LM(abc.ABC): ...@@ -101,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests: list[Instance]) -> list[str]: def generate_until(self, requests: list["Instance"]) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -432,7 +432,7 @@ class TemplateLM(LM): ...@@ -432,7 +432,7 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def generate_until( def generate_until(
self, requests, disable_tqdm: bool = False self, requests: list["Instance"], disable_tqdm: bool = False
) -> list[str]: ) -> list[str]:
"""Generate until a stopping sequence. """Generate until a stopping sequence.
......
...@@ -102,9 +102,9 @@ class Task(abc.ABC): ...@@ -102,9 +102,9 @@ class Task(abc.ABC):
self._fewshot_docs: Optional[list] = None self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None self._instances: Optional[List[Instance]] = None
self._config: TaskConfig = TaskConfig({**config}) if config else TaskConfig() self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [["take_first", None]])] self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: Optional[random.Random] = ( self.fewshot_rnd: Optional[random.Random] = (
None # purposely induce errors in case of improper usage None # purposely induce errors in case of improper usage
) )
...@@ -655,7 +655,7 @@ class ConfigurableTask(Task): ...@@ -655,7 +655,7 @@ class ConfigurableTask(Task):
else: else:
self.prompt = None self.prompt = None
if self.config.fewshot_cfg.num > 0 and self.fewshot_docs() is not None: if self.config.fewshot_cfg.num() > 0 and self.fewshot_docs() is not None:
self.fewshot_rnd = random.Random() self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler( self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
......
...@@ -2,7 +2,6 @@ import logging ...@@ -2,7 +2,6 @@ import logging
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize from lm_eval.config.utils import maybe_serialize
...@@ -10,7 +9,8 @@ from lm_eval.config.utils import maybe_serialize ...@@ -10,7 +9,8 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING: if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task, eval_logger from lm_eval.api.task import Task
from lm_eval.filters import FilterEnsemble
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -35,7 +35,9 @@ class FilterConfig: ...@@ -35,7 +35,9 @@ class FilterConfig:
@dataclass @dataclass
class FewshotConfig: class FewshotConfig:
num: int = 0 # hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: Optional[str] = None split: Optional[str] = None
sampler: Union[str, Callable] = "default" sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None samples: Union[Callable[[], list[dict]], list[dict], None] = None
...@@ -162,10 +164,10 @@ class TaskConfig(dict): ...@@ -162,10 +164,10 @@ class TaskConfig(dict):
fewshot_config: Optional[dict] = None fewshot_config: Optional[dict] = None
# runtime configuration options # runtime configuration options
num_fewshot: Optional[int] = 0 num_fewshot: Optional[int] = 0
generation_kwargs: Optional[dict] = None
# scoring options # scoring options
metric_list: Optional[list] = None metric_list: Optional[list] = None
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1 repeats: int = 1
filter_list: Optional[list[dict]] = None filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
...@@ -224,6 +226,7 @@ class TaskConfig(dict): ...@@ -224,6 +226,7 @@ class TaskConfig(dict):
# ---setup fewshot config--- # # ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {} _fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig( self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg["num_fewshot"],
split=self.fewshot_split, split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"), sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None), samples=_fewshot_cfg.get("samples", None),
...@@ -331,26 +334,30 @@ class TaskConfig(dict): ...@@ -331,26 +334,30 @@ class TaskConfig(dict):
eval_logger.debug( eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats." "No custom filters defined; falling back to 'take_first' for handling repeats."
) )
return [build_filter_ensemble("none", [["take_first", None]])] return [build_filter_ensemble("none", [("take_first", None)])]
else: else:
def _strip_fn(d: dict) -> dict: def _strip_fn(d: dict) -> tuple[str, dict]:
return {k: v for k, v in d.items() if k != "function"} return d["function"], {k: v for k, v in d.items() if k != "function"}
configs = ( configs = (
self.filter_list.values() self.filter_list.values()
if isinstance(self.filter_list, dict) if isinstance(self.filter_list, dict)
else self.filter_list else self.filter_list
) )
return [ return [
build_filter_ensemble( build_filter_ensemble(
filter_name=cfg["name"], filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]], components=[_strip_fn(f) for f in cfg["filter"]],
) )
for cfg in configs for cfg in configs
] ]
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
......
from functools import partial from functools import partial
from typing import List, Union from typing import Iterable, List, Optional, Union
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter from lm_eval.api.registry import get_filter
...@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation ...@@ -8,7 +8,7 @@ from . import custom, extraction, selection, transformation
def build_filter_ensemble( def build_filter_ensemble(
filter_name: str, components: list[Union[list[dict], list[str]]] filter_name: str, components: list[tuple[str, Optional[dict]]]
) -> FilterEnsemble: ) -> FilterEnsemble:
""" """
Create a filtering pipeline. Create a filtering pipeline.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment