Commit ec767666 authored by Baber's avatar Baber
Browse files

overload Task methods if callable in yaml dict

parent 2009ec4b
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Protocol, runtime_checkable
from lm_eval.api.instance import Instance
class Filter(ABC):
@runtime_checkable
class Filter(Protocol):
"""
Filter classes operate on a per-task level.
They take all model outputs (`instance.resps` for all `task.instances`)
......@@ -19,7 +20,6 @@ class Filter(ABC):
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
@abstractmethod
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
......
......@@ -7,6 +7,8 @@ import random
import re
from collections.abc import Callable
from copy import deepcopy
from functools import cached_property
from types import MethodType
from typing import TYPE_CHECKING, Any, Literal, overload
import datasets
......@@ -143,14 +145,17 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
return self._config
@property
def has_training_docs(self) -> bool:
"""Whether the task has a training set"""
raise NotImplementedError
@property
def has_validation_docs(self) -> bool:
"""Whether the task has a validation set"""
raise NotImplementedError
@property
def has_test_docs(self) -> bool:
"""Whether the task has a test set"""
raise NotImplementedError
......@@ -181,9 +186,9 @@ class Task(abc.ABC):
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
if self.has_training_docs():
if self.has_training_docs:
return self.training_docs()
elif self.has_validation_docs():
elif self.has_validation_docs:
return self.validation_docs()
else:
if self.config.num_fewshot and self.config.num_fewshot > 0:
......@@ -211,7 +216,7 @@ class Task(abc.ABC):
"""
return self._instances
def fewshot_examples(self, k, rnd) -> Iterable[dict]:
def fewshot_examples(self, k: int, rnd) -> Iterable[dict]:
if self._training_docs is None:
self._training_docs = list(self.training_docs())
......@@ -449,13 +454,13 @@ class Task(abc.ABC):
labeled_examples = ""
else:
# for sets with no training docs, draw from other set *but ensure no overlap with current doc*
if self.has_training_docs():
if self.has_training_docs:
fewshotex = self.fewshot_examples(k=num_fewshot, rnd=rnd)
else:
if self._fewshot_docs is None:
self._fewshot_docs = list(
self.validation_docs()
if self.has_validation_docs()
if self.has_validation_docs
else self.test_docs()
)
......@@ -528,9 +533,9 @@ class Task(abc.ABC):
@property
def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs():
if self.has_test_docs:
return self.test_docs()
elif self.has_validation_docs():
elif self.has_validation_docs:
return self.validation_docs()
else:
raise ValueError(
......@@ -587,7 +592,7 @@ class ConfigurableTask(Task):
# Use new configurations if there was no preconfiguration
if self.config is None:
self._config = TaskConfig(**config)
self._config = TaskConfig.from_yaml(config)
# Overwrite configs
else:
if config is not None:
......@@ -730,17 +735,20 @@ class ConfigurableTask(Task):
**self.config.dataset_kwargs,
)
@cached_property
def has_training_docs(self) -> bool:
return self.config.training_split is not None
@cached_property
def has_validation_docs(self) -> bool:
return self.config.validation_split is not None
@cached_property
def has_test_docs(self) -> bool:
return self.config.test_split is not None
def training_docs(self) -> DataSet | None:
if self.has_training_docs():
if self.has_training_docs:
if self.config.process_docs is not None:
return self.config.process_docs(
self.dataset[self.config.training_split]
......@@ -748,7 +756,7 @@ class ConfigurableTask(Task):
return self.dataset[self.config.training_split]
def validation_docs(self) -> DataSet | None:
if self.has_validation_docs():
if self.has_validation_docs:
if self.config.process_docs is not None:
return self.config.process_docs(
self.dataset[self.config.validation_split]
......@@ -756,7 +764,7 @@ class ConfigurableTask(Task):
return self.dataset[self.config.validation_split]
def test_docs(self) -> DataSet | None:
if self.has_test_docs():
if self.has_test_docs:
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
return self.dataset[self.config.test_split]
......@@ -1011,23 +1019,16 @@ class ConfigurableTask(Task):
# if self.prompt is not None:
# doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text
if isinstance(doc_to_text, int):
return doc_to_text
if doc_to_text in doc:
return doc[doc_to_text]
elif isinstance(doc_to_text, str):
if doc_to_text in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_text]]
# else:
return doc[doc_to_text]
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
text_string = utils.apply_template(doc_to_text, doc)
if text_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(text_string)
else:
return text_string
elif callable(doc_to_text):
return doc_to_text(doc)
return text_string
elif isinstance(doc_to_text, int):
return doc_to_text
# Used when applying a Promptsource template
# elif hasattr(doc_to_text, "apply"):
# applied_prompt = doc_to_text.apply(doc)
......@@ -1062,38 +1063,31 @@ class ConfigurableTask(Task):
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if doc_to_target is not None:
doc_to_target = doc_to_target
else:
doc_to_target = self.config.doc_to_target
if isinstance(doc_to_target, int):
return doc_to_target
doc_to_target = doc_to_target or self.config.doc_to_target
if doc_to_target in doc:
return doc[doc_to_target]
elif isinstance(doc_to_target, str):
if doc_to_target in self.features:
# if self.config.doc_to_choice is not None:
# return self.doc_to_choice(doc)[doc[doc_to_target]]
# else:
return doc[doc_to_target]
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(target_string)
# elif (
# len(target_string) >= 2
# and (target_string[0] == "[")
# and (target_string[-1] == "]")
# ):
# try:
# return ast.literal_eval(target_string)
# except (SyntaxError, ValueError):
# return target_string
else:
target_string = utils.apply_template(doc_to_target, doc)
if target_string.isdigit() and self.config.doc_to_choice is not None:
return ast.literal_eval(target_string)
elif (
len(target_string) >= 2
and (target_string[0] == "[")
and (target_string[-1] == "]")
):
try:
return ast.literal_eval(target_string)
except (SyntaxError, ValueError):
return target_string
else:
return target_string
elif isinstance(doc_to_target, list):
return target_string
elif isinstance(doc_to_target, (int, list)):
return doc_to_target
elif callable(doc_to_target):
return doc_to_target(doc)
# elif isinstance(doc_to_target, list):
# return doc_to_target
# elif callable(doc_to_target):
# return doc_to_target(doc)
# # Used when applying a Promptsource template
# elif hasattr(doc_to_target, "apply"):
# applied_prompt = doc_to_target.apply(doc)
......@@ -1138,16 +1132,14 @@ class ConfigurableTask(Task):
doc_to_choice = self.config.doc_to_choice
if isinstance(doc_to_choice, str):
if doc_to_choice in self.features:
if doc_to_choice in doc:
return doc[doc_to_choice]
else:
return ast.literal_eval(utils.apply_template(doc_to_choice, doc))
elif isinstance(doc_to_choice, list):
return doc_to_choice
elif isinstance(doc_to_choice, dict):
return list(doc_to_choice.values())
elif callable(doc_to_choice):
return doc_to_choice(doc)
# elif isinstance(doc_to_choice, dict):
# return list(doc_to_choice.values())
# elif hasattr(doc_to_choice, "get_answer_choices_list"):
# return doc_to_choice.get_answer_choices_list(doc)
else:
......@@ -1225,7 +1217,7 @@ class ConfigurableTask(Task):
def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
if gen_prefix in doc:
return doc[gen_prefix]
else:
return utils.apply_template(gen_prefix, doc)
......@@ -1333,9 +1325,6 @@ class ConfigurableTask(Task):
)
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(m.metric_name for m in self.config._metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
......
......@@ -10,7 +10,7 @@ import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
from lm_eval.config.utils import doc_to_closure, maybe_serialize
if TYPE_CHECKING:
......@@ -179,6 +179,7 @@ class TaskConfig:
_filter_list: list[FilterConfig] = field(default_factory=list)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False)
_fn: dict[str, Callable] = field(default_factory=dict)
def __post_init__(self) -> None:
### ---setup generation kwargs--- ###
......@@ -363,7 +364,8 @@ class TaskConfig:
@classmethod
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
fn = {k: doc_to_closure(v) for k, v in data.items() if callable(v)}
return cls(**data, _fn=fn)
@classmethod
def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig:
......
from __future__ import annotations
from functools import wraps
from inspect import getsource
from typing import Any, Callable
from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable(
value: Callable[..., Any] | str, keep_callable=False
) -> Callable[..., Any] | str:
value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., T] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
......@@ -22,7 +26,9 @@ def serialize_callable(
return str(value)
def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable."""
return (
......@@ -41,3 +47,13 @@ def create_mc_choices(choices: list[str], choice_delimiter: str | None = "\n") -
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment