Commit ec767666 authored by Baber's avatar Baber
Browse files

overload Task methods if callable in yaml dict

parent 2009ec4b
from abc import ABC, abstractmethod
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
class Filter(ABC): @runtime_checkable
class Filter(Protocol):
""" """
Filter classes operate on a per-task level. Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`) They take all model outputs (`instance.resps` for all `task.instances`)
...@@ -19,7 +20,6 @@ class Filter(ABC): ...@@ -19,7 +20,6 @@ class Filter(ABC):
Can define custom behavior here, if an individual instantiation of a Filter class should have state. Can define custom behavior here, if an individual instantiation of a Filter class should have state.
""" """
@abstractmethod
def apply( def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict] self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]: ) -> Iterable[list[str]]:
......
...@@ -7,6 +7,8 @@ import random ...@@ -7,6 +7,8 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from functools import cached_property
from types import MethodType
from typing import TYPE_CHECKING, Any, Literal, overload from typing import TYPE_CHECKING, Any, Literal, overload
import datasets import datasets
...@@ -143,14 +145,17 @@ class Task(abc.ABC): ...@@ -143,14 +145,17 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class.""" """Returns the TaskConfig associated with this class."""
return self._config return self._config
@property
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
"""Whether the task has a training set""" """Whether the task has a training set"""
raise NotImplementedError raise NotImplementedError
@property
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
"""Whether the task has a validation set""" """Whether the task has a validation set"""
raise NotImplementedError raise NotImplementedError
@property
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
"""Whether the task has a test set""" """Whether the task has a test set"""
raise NotImplementedError raise NotImplementedError
...@@ -181,9 +186,9 @@ class Task(abc.ABC): ...@@ -181,9 +186,9 @@ class Task(abc.ABC):
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
if self.has_training_docs(): if self.has_training_docs:
return self.training_docs() return self.training_docs()
elif self.has_validation_docs(): elif self.has_validation_docs:
return self.validation_docs() return self.validation_docs()
else: else:
if self.config.num_fewshot and self.config.num_fewshot > 0: if self.config.num_fewshot and self.config.num_fewshot > 0:
...@@ -211,7 +216,7 @@ class Task(abc.ABC): ...@@ -211,7 +216,7 @@ class Task(abc.ABC):
""" """
return self._instances return self._instances
def fewshot_examples(self, k, rnd) -> Iterable[dict]: def fewshot_examples(self, k: int, rnd) -> Iterable[dict]:
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.training_docs()) self._training_docs = list(self.training_docs())
...@@ -449,13 +454,13 @@ class Task(abc.ABC): ...@@ -449,13 +454,13 @@ class Task(abc.ABC):
labeled_examples = "" labeled_examples = ""
else: else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc* # for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs(): if self.has_training_docs:
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd) fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else: else:
if self._fewshot_docs is None: if self._fewshot_docs is None:
self._fewshot_docs = list( self._fewshot_docs = list(
self.validation_docs() self.validation_docs()
if self.has_validation_docs() if self.has_validation_docs
else self.test_docs() else self.test_docs()
) )
...@@ -528,9 +533,9 @@ class Task(abc.ABC): ...@@ -528,9 +533,9 @@ class Task(abc.ABC):
@property @property
def eval_docs(self) -> datasets.Dataset | Iterable[dict]: def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs(): if self.has_test_docs:
return self.test_docs() return self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs:
return self.validation_docs() return self.validation_docs()
else: else:
raise ValueError( raise ValueError(
...@@ -587,7 +592,7 @@ class ConfigurableTask(Task): ...@@ -587,7 +592,7 @@ class ConfigurableTask(Task):
# Use new configurations if there was no preconfiguration # Use new configurations if there was no preconfiguration
if self.config is None: if self.config is None:
self._config = TaskConfig(**config) self._config = TaskConfig.from_yaml(config)
# Overwrite configs # Overwrite configs
else: else:
if config is not None: if config is not None:
...@@ -730,17 +735,20 @@ class ConfigurableTask(Task): ...@@ -730,17 +735,20 @@ class ConfigurableTask(Task):
**self.config.dataset_kwargs, **self.config.dataset_kwargs,
) )
@cached_property
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
return self.config.training_split is not None return self.config.training_split is not None
@cached_property
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
return self.config.validation_split is not None return self.config.validation_split is not None
@cached_property
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
return self.config.test_split is not None return self.config.test_split is not None
def training_docs(self) -> DataSet | None: def training_docs(self) -> DataSet | None:
if self.has_training_docs(): if self.has_training_docs:
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
self.dataset[self.config.training_split] self.dataset[self.config.training_split]
...@@ -748,7 +756,7 @@ class ConfigurableTask(Task): ...@@ -748,7 +756,7 @@ class ConfigurableTask(Task):
return self.dataset[self.config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> DataSet | None: def validation_docs(self) -> DataSet | None:
if self.has_validation_docs(): if self.has_validation_docs:
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
self.dataset[self.config.validation_split] self.dataset[self.config.validation_split]
...@@ -756,7 +764,7 @@ class ConfigurableTask(Task): ...@@ -756,7 +764,7 @@ class ConfigurableTask(Task):
return self.dataset[self.config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> DataSet | None: def test_docs(self) -> DataSet | None:
if self.has_test_docs(): if self.has_test_docs:
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self.config.test_split] return self.dataset[self.config.test_split]
...@@ -1011,23 +1019,16 @@ class ConfigurableTask(Task): ...@@ -1011,23 +1019,16 @@ class ConfigurableTask(Task):
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text doc_to_text = doc_to_text or self.config.doc_to_text
if doc_to_text in doc:
if isinstance(doc_to_text, int): return doc[doc_to_text]
return doc_to_text
elif isinstance(doc_to_text, str): elif isinstance(doc_to_text, str):
if doc_to_text in self.features: text_string = utils.apply_template(doc_to_text, doc)
# if self.config.doc_to_choice is not None: if text_string.isdigit() and self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]] return ast.literal_eval(text_string)
# else:
return doc[doc_to_text]
else: else:
text_string = utils.apply_template(doc_to_text, doc) return text_string
if text_string.isdigit() and self.config.doc_to_choice is not None: elif isinstance(doc_to_text, int):
return ast.literal_eval(text_string) return doc_to_text
else:
return text_string
elif callable(doc_to_text):
return doc_to_text(doc)
# Used when applying a Promptsource template # Used when applying a Promptsource template
# elif hasattr(doc_to_text, "apply"): # elif hasattr(doc_to_text, "apply"):
# applied_prompt = doc_to_text.apply(doc) # applied_prompt = doc_to_text.apply(doc)
...@@ -1062,38 +1063,31 @@ class ConfigurableTask(Task): ...@@ -1062,38 +1063,31 @@ class ConfigurableTask(Task):
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]: def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
if doc_to_target is not None: doc_to_target = doc_to_target or self.config.doc_to_target
doc_to_target = doc_to_target if doc_to_target in doc:
else: return doc[doc_to_target]
doc_to_target = self.config.doc_to_target
if isinstance(doc_to_target, int):
return doc_to_target
elif isinstance(doc_to_target, str): elif isinstance(doc_to_target, str):
if doc_to_target in self.features: target_string = utils.apply_template(doc_to_target, doc)
# if self.config.doc_to_choice is not None: if target_string.isdigit() and self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]] return ast.literal_eval(target_string)
# else: # elif (
return doc[doc_to_target] # len(target_string) >= 2
# and (target_string[0] == "[")
# and (target_string[-1] == "]")
# ):
# try:
# return ast.literal_eval(target_string)
# except (SyntaxError, ValueError):
# return target_string
else: else:
target_string = utils.apply_template(doc_to_target, doc) return target_string
if target_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(target_string) elif isinstance(doc_to_target, (int, list)):
elif (
len(target_string) >= 2
and (target_string[0] == "[")
and (target_string[-1] == "]")
):
try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else:
return target_string
elif isinstance(doc_to_target, list):
return doc_to_target return doc_to_target
elif callable(doc_to_target): # elif isinstance(doc_to_target, list):
return doc_to_target(doc) # return doc_to_target
# elif callable(doc_to_target):
# return doc_to_target(doc)
# # Used when applying a Promptsource template # # Used when applying a Promptsource template
# elif hasattr(doc_to_target, "apply"): # elif hasattr(doc_to_target, "apply"):
# applied_prompt = doc_to_target.apply(doc) # applied_prompt = doc_to_target.apply(doc)
...@@ -1138,16 +1132,14 @@ class ConfigurableTask(Task): ...@@ -1138,16 +1132,14 @@ class ConfigurableTask(Task):
doc_to_choice = self.config.doc_to_choice doc_to_choice = self.config.doc_to_choice
if isinstance(doc_to_choice, str): if isinstance(doc_to_choice, str):
if doc_to_choice in self.features: if doc_to_choice in doc:
return doc[doc_to_choice] return doc[doc_to_choice]
else: else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc)) return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list): elif isinstance(doc_to_choice, list):
return doc_to_choice return doc_to_choice
elif isinstance(doc_to_choice, dict): # elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values()) # return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
# elif hasattr(doc_to_choice, "get_answer_choices_list"): # elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc) # return doc_to_choice.get_answer_choices_list(doc)
else: else:
...@@ -1225,7 +1217,7 @@ class ConfigurableTask(Task): ...@@ -1225,7 +1217,7 @@ class ConfigurableTask(Task):
def doc_to_prefix(self, doc: dict) -> str | None: def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: if gen_prefix in doc:
return doc[gen_prefix] return doc[gen_prefix]
else: else:
return utils.apply_template(gen_prefix, doc) return utils.apply_template(gen_prefix, doc)
...@@ -1333,9 +1325,6 @@ class ConfigurableTask(Task): ...@@ -1333,9 +1325,6 @@ class ConfigurableTask(Task):
) )
def process_results(self, doc: dict, results: list) -> dict[str, Any]: def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(m.metric_name for m in self.config._metric_list) use_metric = list(m.metric_name for m in self.config._metric_list)
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
......
...@@ -10,7 +10,7 @@ import datasets ...@@ -10,7 +10,7 @@ import datasets
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize from lm_eval.config.utils import doc_to_closure, maybe_serialize
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -179,6 +179,7 @@ class TaskConfig: ...@@ -179,6 +179,7 @@ class TaskConfig:
_filter_list: list[FilterConfig] = field(default_factory=list) _filter_list: list[FilterConfig] = field(default_factory=list)
# ds_cfg: DatasetConfig = field(init=False) # ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False) fewshot_cfg: FewshotConfig = field(init=False)
_fn: dict[str, Callable] = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
### ---setup generation kwargs--- ### ### ---setup generation kwargs--- ###
...@@ -363,7 +364,8 @@ class TaskConfig: ...@@ -363,7 +364,8 @@ class TaskConfig:
@classmethod @classmethod
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig: def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary.""" """Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data) fn = {k: doc_to_closure(v) for k, v in data.items() if callable(v)}
return cls(**data, _fn=fn)
@classmethod @classmethod
def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig: def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig:
......
from __future__ import annotations from __future__ import annotations
from functools import wraps
from inspect import getsource from inspect import getsource
from typing import Any, Callable from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable( def serialize_callable(
value: Callable[..., Any] | str, keep_callable=False value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., Any] | str: ) -> Callable[..., T] | str:
"""Serializes a given function or string. """Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned. If 'keep_callable' is True, the original callable is returned.
...@@ -22,7 +26,9 @@ def serialize_callable( ...@@ -22,7 +26,9 @@ def serialize_callable(
return str(value) return str(value)
def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any: def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable.""" """Conditionally serializes a value if it is callable."""
return ( return (
...@@ -41,3 +47,13 @@ def create_mc_choices(choices: list[str], choice_delimiter: str | None = "\n") - ...@@ -41,3 +47,13 @@ def create_mc_choices(choices: list[str], choice_delimiter: str | None = "\n") -
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)] formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices) return choice_delimiter.join(formatted_choices)
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment