Commit 69d14fb3 authored by Baber's avatar Baber
Browse files

cleanup

parent 57b8c0b1
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -20,7 +20,9 @@ class Filter(ABC): ...@@ -20,7 +20,9 @@ class Filter(ABC):
""" """
@abstractmethod @abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,9 +42,9 @@ class FilterEnsemble: ...@@ -40,9 +42,9 @@ class FilterEnsemble:
""" """
name: str name: str
filters: List[type[Filter]] filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs) resps, docs = list(resps), list(docs)
......
from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -36,13 +38,14 @@ def register_model(*names): ...@@ -36,13 +38,14 @@ def register_model(*names):
return decorate return decorate
def get_model(model_name: str) -> type["LM"]: def get_model(model_name: str) -> type[LM]:
try: try:
return MODEL_REGISTRY[model_name] return MODEL_REGISTRY[model_name]
except KeyError: except KeyError as err:
raise ValueError( available_models = ", ".join(MODEL_REGISTRY.keys())
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" raise KeyError(
) f"Model '{model_name}' not found. Available models: {available_models}"
) from err
TASK_REGISTRY = {} TASK_REGISTRY = {}
...@@ -81,7 +84,7 @@ def register_group(name): ...@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {} OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {} METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {} METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {} FILTER_REGISTRY = {}
...@@ -125,7 +128,7 @@ def register_metric(**args): ...@@ -125,7 +128,7 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]: def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -157,21 +160,21 @@ def register_aggregation(name: str): ...@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return decorate return decorate
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]: def get_aggregation(name: str) -> Callable[..., Any] | None:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]: def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name: str) -> Optional[bool]: def is_higher_better(metric_name: str) -> bool | None:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
...@@ -192,7 +195,7 @@ def register_filter(name: str): ...@@ -192,7 +195,7 @@ def register_filter(name: str):
return decorate return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable: def get_filter(filter_name: str | Callable) -> Callable:
try: try:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
except KeyError as e: except KeyError as e:
......
from __future__ import annotations
import abc import abc
import ast import ast
import logging import logging
...@@ -8,15 +10,7 @@ from copy import deepcopy ...@@ -8,15 +10,7 @@ from copy import deepcopy
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
Dict,
Iterable,
Iterator,
List,
Literal, Literal,
Mapping,
Optional,
Tuple,
Union,
) )
import datasets import datasets
...@@ -57,23 +51,23 @@ class Task(abc.ABC): ...@@ -57,23 +51,23 @@ class Task(abc.ABC):
{"question": ..., question, answer) {"question": ..., question, answer)
""" """
VERSION: Optional[Union[int, str]] = None VERSION: int | str | None = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub # The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script. # or a path to a custom `datasets` loading script.
DATASET_PATH: Optional[str] = None DATASET_PATH: str | None = None
# The name of a subset within `DATASET_PATH`. # The name of a subset within `DATASET_PATH`.
DATASET_NAME: Optional[str] = None DATASET_NAME: str | None = None
OUTPUT_TYPE: Optional[OutputType] = None OUTPUT_TYPE: OutputType | None = None
def __init__( def __init__(
self, self,
data_dir: Optional[str] = None, data_dir: str | None = None,
cache_dir: Optional[str] = None, cache_dir: str | None = None,
download_mode: Optional[datasets.DownloadMode] = None, download_mode: datasets.DownloadMode | None = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig] config: Mapping | None = None, # Union[dict, TaskConfig]
) -> None: ) -> None:
""" """
:param data_dir: str :param data_dir: str
...@@ -97,21 +91,21 @@ class Task(abc.ABC): ...@@ -97,21 +91,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
self.download(data_dir, cache_dir, download_mode) self.download(data_dir, cache_dir, download_mode)
self._training_docs: Optional[list] = None self._training_docs: list | None = None
self._fewshot_docs: Optional[list] = None self._fewshot_docs: list | None = None
self._instances: Optional[List[Instance]] = None self._instances: list[Instance] | None = None
self._config: TaskConfig = TaskConfig.from_yaml({**config}) self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [("take_first", None)])] self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: Optional[random.Random] = ( self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage None # purposely induce errors in case of improper usage
) )
def download( def download(
self, self,
data_dir: Optional[str] = None, data_dir: str | None = None,
cache_dir: Optional[str] = None, cache_dir: str | None = None,
download_mode=None, download_mode=None,
) -> None: ) -> None:
"""Downloads and returns the task dataset. """Downloads and returns the task dataset.
...@@ -238,7 +232,7 @@ class Task(abc.ABC): ...@@ -238,7 +232,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def doc_to_target(self, doc: dict) -> Union[str, int]: def doc_to_target(self, doc: dict) -> str | int:
pass pass
# not an abstractmethod because not every language-only task has to implement this # not an abstractmethod because not every language-only task has to implement this
...@@ -254,16 +248,16 @@ class Task(abc.ABC): ...@@ -254,16 +248,16 @@ class Task(abc.ABC):
def build_all_requests( def build_all_requests(
self, self,
*, *,
limit: Union[int, None] = None, limit: int | None = None,
samples: Optional[List[int]] = None, samples: list[int] | None = None,
rank: int = 0, rank: int = 0,
world_size: int = 1, world_size: int = 1,
cache_requests: bool = False, cache_requests: bool = False,
rewrite_requests_cache: bool = False, rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None, system_instruction: str | None = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Callable | None = None,
tokenizer_name: str = "", tokenizer_name: str = "",
) -> None: ) -> None:
"""Build a set of Instances for a task, and store them in task.instances""" """Build a set of Instances for a task, and store them in task.instances"""
...@@ -365,7 +359,7 @@ class Task(abc.ABC): ...@@ -365,7 +359,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances) save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod @abc.abstractmethod
def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs): def construct_requests(self, doc: dict, ctx: list[dict] | str, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -405,7 +399,7 @@ class Task(abc.ABC): ...@@ -405,7 +399,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores functions that aggregate a list of metric scores
""" """
pass return True
@deprecated("not used anymore") @deprecated("not used anymore")
def higher_is_better(self): def higher_is_better(self):
...@@ -414,7 +408,7 @@ class Task(abc.ABC): ...@@ -414,7 +408,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better whether a higher value of the submetric is better
""" """
pass return True
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
...@@ -488,13 +482,15 @@ class Task(abc.ABC): ...@@ -488,13 +482,15 @@ class Task(abc.ABC):
example = self.doc_to_text(doc) example = self.doc_to_text(doc)
return description + labeled_examples + example return description + labeled_examples + example
def apply_filters(self) -> Optional[List[Instance]]: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters") and self._instances:
for f in self._filters: for f in self._filters:
f.apply(self._instances) f.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning(
"No filter defined or no instances, passing through instances"
)
return self._instances return self._instances
def dump_config(self) -> dict: def dump_config(self) -> dict:
...@@ -505,9 +501,6 @@ class Task(abc.ABC): ...@@ -505,9 +501,6 @@ class Task(abc.ABC):
def set_config(self, key: str, value: Any, update: bool = False) -> None: def set_config(self, key: str, value: Any, update: bool = False) -> None:
"""Set or update the configuration for a given key.""" """Set or update the configuration for a given key."""
if key is None:
raise ValueError("Key must be provided.")
if update: if update:
current_value = getattr(self._config, key, {}) current_value = getattr(self._config, key, {})
if not isinstance(current_value, dict): if not isinstance(current_value, dict):
...@@ -533,13 +526,13 @@ class Task(abc.ABC): ...@@ -533,13 +526,13 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [MetricConfig(name=metric_name)]) setattr(self._config, "metric_list", [MetricConfig(name=metric_name)])
setattr(self._config, "process_results", lambda *args: {"bypass": 0}) setattr(self._config, "process_results", lambda *args: {"bypass": 0})
def set_fewshot_seed(self, seed: Optional[int] = None) -> None: def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed) self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"): if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd self.sampler.rnd = self.fewshot_rnd
@property @property
def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]: def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs(): if self.has_test_docs():
return self.test_docs() return self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
...@@ -553,13 +546,13 @@ class Task(abc.ABC): ...@@ -553,13 +546,13 @@ class Task(abc.ABC):
self, self,
*, *,
rank: int = 0, rank: int = 0,
limit: Union[int, None] = None, limit: int | None = None,
world_size: int = 1, world_size: int = 1,
samples: Optional[List[int]] = None, samples: list[int] | None = None,
) -> Iterator[Tuple[int, Any]]: ) -> Iterator[tuple[int, Any]]:
if samples: if samples:
n = len(self.eval_docs) n = len(self.eval_docs)
assert all([e < n for e in samples]), ( assert all(e < n for e in samples), (
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}." f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
) )
eval_logger.info( eval_logger.info(
...@@ -592,7 +585,7 @@ class ConfigurableTask(Task): ...@@ -592,7 +585,7 @@ class ConfigurableTask(Task):
data_dir=None, data_dir=None,
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config: Optional[dict] = None, config: dict | None = None,
) -> None: ) -> None:
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -610,9 +603,8 @@ class ConfigurableTask(Task): ...@@ -610,9 +603,8 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg" "Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
) )
if isinstance(self.config.metadata, dict): if isinstance(self.config.metadata, dict) and "version" in self.config.metadata:
if "version" in self.config.metadata: self.VERSION = self.config.metadata["version"]
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None: if self.config.output_type is not None:
if self.config.output_type not in ALL_OUTPUT_TYPES: if self.config.output_type not in ALL_OUTPUT_TYPES:
...@@ -698,18 +690,13 @@ class ConfigurableTask(Task): ...@@ -698,18 +690,13 @@ class ConfigurableTask(Task):
else: else:
test_target = str(test_target) test_target = str(test_target)
if test_choice is not None: check_choices = test_choice if test_choice is not None else [test_target]
check_choices = test_choice
else:
check_choices = [test_target]
if self.config.doc_to_choice is not None: if self.config.doc_to_choice is not None:
for choice in check_choices: for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = ( delimiter_has_whitespace = (
True self.config.target_delimiter.rstrip()
if self.config.target_delimiter.rstrip()
!= self.config.target_delimiter != self.config.target_delimiter
else False
) )
if delimiter_has_whitespace and choice_has_whitespace: if delimiter_has_whitespace and choice_has_whitespace:
...@@ -722,7 +709,7 @@ class ConfigurableTask(Task): ...@@ -722,7 +709,7 @@ class ConfigurableTask(Task):
) )
def download( def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
) -> None: ) -> None:
from packaging.version import parse as vparse from packaging.version import parse as vparse
...@@ -746,24 +733,15 @@ class ConfigurableTask(Task): ...@@ -746,24 +733,15 @@ class ConfigurableTask(Task):
) )
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
if self.config.training_split is not None: return self.config.training_split is not None
return True
else:
return False
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
if self.config.validation_split is not None: return self.config.validation_split is not None
return True
else:
return False
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
if self.config.test_split is not None: return self.config.test_split is not None
return True
else:
return False
def training_docs(self) -> Optional[datasets.Dataset]: def training_docs(self) -> datasets.Dataset | None:
if self.has_training_docs(): if self.has_training_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -771,7 +749,7 @@ class ConfigurableTask(Task): ...@@ -771,7 +749,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> Optional[datasets.Dataset]: def validation_docs(self) -> datasets.Dataset | None:
if self.has_validation_docs(): if self.has_validation_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -779,7 +757,7 @@ class ConfigurableTask(Task): ...@@ -779,7 +757,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> Optional[datasets.Dataset]: def test_docs(self) -> datasets.Dataset | None:
if self.has_test_docs(): if self.has_test_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
...@@ -792,22 +770,25 @@ class ConfigurableTask(Task): ...@@ -792,22 +770,25 @@ class ConfigurableTask(Task):
return docs return docs
# Fallback to parent implementation # Fallback to parent implementation
if _num_fewshot := getattr(self.config, "num_fewshot"): if (
if isinstance(_num_fewshot, int) and _num_fewshot > 0: (_num_fewshot := self.config.num_fewshot)
eval_logger.warning( and isinstance(_num_fewshot, int)
f"[Task: {self.config.task}] " and _num_fewshot > 0
"num_fewshot > 0 but no fewshot source configured. " ):
"Using preconfigured rule." eval_logger.warning(
) f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
return super().fewshot_docs() return super().fewshot_docs()
@staticmethod @staticmethod
def append_target_question( def append_target_question(
labeled_examples: List[Dict[str, str]], labeled_examples: list[dict[str, str]],
question: str, question: str,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None, gen_prefix: str | None = None,
) -> None: ) -> None:
"""Adds a target question to the labeled examples list. """Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry. If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
...@@ -831,12 +812,12 @@ class ConfigurableTask(Task): ...@@ -831,12 +812,12 @@ class ConfigurableTask(Task):
self, self,
doc: dict, doc: dict,
num_fewshot: int, num_fewshot: int,
system_instruction: Optional[str] = None, system_instruction: str | None = None,
apply_chat_template: bool = False, apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Callable | None = None,
gen_prefix: Optional[str] = None, gen_prefix: str | None = None,
) -> Union[str, List[str], None]: ) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -857,10 +838,7 @@ class ConfigurableTask(Task): ...@@ -857,10 +838,7 @@ class ConfigurableTask(Task):
:returns: str :returns: str
The fewshot context. The fewshot context.
""" """
if apply_chat_template: labeled_examples = [] if apply_chat_template else ""
labeled_examples = []
else:
labeled_examples = ""
# get task description # get task description
if description := self.config.description: if description := self.config.description:
...@@ -930,7 +908,7 @@ class ConfigurableTask(Task): ...@@ -930,7 +908,7 @@ class ConfigurableTask(Task):
labeled_examples_list.append( labeled_examples_list.append(
chat_template( chat_template(
chat, chat,
add_generation_prompt=False if gen_prefix else True, add_generation_prompt=not gen_prefix,
) )
) )
return labeled_examples_list return labeled_examples_list
...@@ -954,7 +932,7 @@ class ConfigurableTask(Task): ...@@ -954,7 +932,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples) # return lm.apply_chat_template(labeled_examples)
return chat_template( return chat_template(
labeled_examples, labeled_examples,
add_generation_prompt=False if gen_prefix else True, add_generation_prompt=not gen_prefix,
) )
else: else:
prefix = ( prefix = (
...@@ -975,7 +953,7 @@ class ConfigurableTask(Task): ...@@ -975,7 +953,7 @@ class ConfigurableTask(Task):
else: else:
return labeled_examples + str(example) + prefix return labeled_examples + str(example) + prefix
def apply_filters(self) -> Optional[List[Instance]]: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters"):
for f in self._filters: for f in self._filters:
...@@ -1015,9 +993,7 @@ class ConfigurableTask(Task): ...@@ -1015,9 +993,7 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text( def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None):
self, doc: dict, doc_to_text: Union[int, str, Callable, None] = None
):
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
if doc_to_text is not None: if doc_to_text is not None:
...@@ -1053,9 +1029,7 @@ class ConfigurableTask(Task): ...@@ -1053,9 +1029,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target( def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
if doc_to_target is not None: if doc_to_target is not None:
...@@ -1104,8 +1078,8 @@ class ConfigurableTask(Task): ...@@ -1104,8 +1078,8 @@ class ConfigurableTask(Task):
def doc_to_choice( def doc_to_choice(
self, self,
doc: dict, doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None, doc_to_choice: str | list | dict | Callable[..., list[str]] | None = None,
) -> List[str]: ) -> list[str]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_choice = self.prompt # doc_to_choice = self.prompt
if doc_to_choice is not None: if doc_to_choice is not None:
...@@ -1132,7 +1106,7 @@ class ConfigurableTask(Task): ...@@ -1132,7 +1106,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]: def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
elif self.config.doc_to_image is not None: elif self.config.doc_to_image is not None:
...@@ -1155,7 +1129,7 @@ class ConfigurableTask(Task): ...@@ -1155,7 +1129,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]: def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None: if doc_to_audio is not None:
doc_to_audio = doc_to_audio doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None: elif self.config.doc_to_audio is not None:
...@@ -1178,7 +1152,7 @@ class ConfigurableTask(Task): ...@@ -1178,7 +1152,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc: dict) -> Optional[str]: def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
...@@ -1188,7 +1162,7 @@ class ConfigurableTask(Task): ...@@ -1188,7 +1162,7 @@ class ConfigurableTask(Task):
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]: ) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False) apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None) chat_template: Callable | None = kwargs.pop("chat_template", None)
...@@ -1324,7 +1298,7 @@ class ConfigurableTask(Task): ...@@ -1324,7 +1298,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc. # retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
...@@ -1371,7 +1345,7 @@ class ConfigurableTask(Task): ...@@ -1371,7 +1345,7 @@ class ConfigurableTask(Task):
if self.multiple_target: if self.multiple_target:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold])) exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
else: else:
acc = 1.0 if pred == gold else 0.0 acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0 acc_norm = 1.0 if pred_norm == gold else 0.0
...@@ -1413,7 +1387,7 @@ class ConfigurableTask(Task): ...@@ -1413,7 +1387,7 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number. # it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
gold = choices[gold] gold = choices[gold]
for metric in self._metric_fn_list.keys(): for metric in self._metric_fn_list:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
references=[gold] if not isinstance(gold, list) else gold, references=[gold] if not isinstance(gold, list) else gold,
...@@ -1447,7 +1421,7 @@ class ConfigurableTask(Task): ...@@ -1447,7 +1421,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None) return getattr(self._config, key, None)
@property @property
def task_name(self) -> Optional[str]: def task_name(self) -> str | None:
return getattr(self.config, "task", None) return getattr(self.config, "task", None)
def __repr__(self): def __repr__(self):
...@@ -1465,7 +1439,7 @@ class MultipleChoiceTask(Task): ...@@ -1465,7 +1439,7 @@ class MultipleChoiceTask(Task):
def doc_to_target(self, doc: dict) -> str: def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]] return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]: def construct_requests(self, doc: dict, ctx: str, **kwargs) -> list[Instance]:
# TODO: add mutual info here? # TODO: add mutual info here?
return [ return [
Instance( Instance(
...@@ -1478,7 +1452,7 @@ class MultipleChoiceTask(Task): ...@@ -1478,7 +1452,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"]) for i, choice in enumerate(doc["choices"])
] ]
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict: def process_results(self, doc: dict, results: Iterable[tuple[float, bool]]) -> dict:
results = [ results = [
res[0] for res in results res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere? ] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
...@@ -1512,7 +1486,7 @@ class PerplexityTask(Task): ...@@ -1512,7 +1486,7 @@ class PerplexityTask(Task):
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
return False return False
def fewshot_examples(self, k: int, rnd) -> List: def fewshot_examples(self, k: int, rnd) -> list:
if k != 0: if k != 0:
raise ValueError( raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks." "The number of fewshot examples must be 0 for perplexity tasks."
...@@ -1543,7 +1517,7 @@ class PerplexityTask(Task): ...@@ -1543,7 +1517,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc): def doc_to_target(self, doc):
return doc return doc
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs): def construct_requests(self, doc: dict, ctx: str | None, **kwargs):
if bool(ctx): if bool(ctx):
raise ValueError raise ValueError
...@@ -1555,7 +1529,7 @@ class PerplexityTask(Task): ...@@ -1555,7 +1529,7 @@ class PerplexityTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc: dict, results: Tuple[float]) -> dict: def process_results(self, doc: dict, results: tuple[float]) -> dict:
(loglikelihood,) = results (loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc)) words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc)) bytes_ = self.count_bytes(self.doc_to_target(doc))
......
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any, Callable, List, Optional from typing import Any
@dataclass @dataclass
...@@ -8,9 +11,9 @@ class MetricConfig: ...@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric.""" """Encapsulates information about a single metric."""
name: str name: str
fn: Optional[Callable] = None fn: Callable | None = None
kwargs: Optional[dict] = None kwargs: Mapping[str, Any] | None = None
aggregation_fn: Optional[Callable] = None aggregation_fn: Callable | None = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
is_elementwise: bool = True is_elementwise: bool = True
...@@ -20,7 +23,7 @@ class MetricConfig: ...@@ -20,7 +23,7 @@ class MetricConfig:
return self.name return self.name
@cached_property @cached_property
def aggregation(self) -> Callable: def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None: if self.aggregation_fn is None:
...@@ -28,7 +31,7 @@ class MetricConfig: ...@@ -28,7 +31,7 @@ class MetricConfig:
return self.aggregation_fn return self.aggregation_fn
@cached_property @cached_property
def _higher_is_better(self) -> bool: def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None: if self.higher_is_better is None:
...@@ -39,10 +42,10 @@ class MetricConfig: ...@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments.""" """Calculates the metric using the provided function and arguments."""
if self.fn is None: if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs}) return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any: def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values.""" """Computes the aggregation of the metric values."""
if self.aggregation_fn is None: if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.") raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values) return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -20,8 +23,8 @@ class RepeatConfig: ...@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat.""" """Encapsulates information about a single repeat."""
repeats: int = 1 repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N" metric_fn: str | Callable = "pass@N"
kwargs: Optional[dict] = field(default_factory=dict) kwargs: dict | None = field(default_factory=dict)
@dataclass @dataclass
...@@ -38,11 +41,11 @@ class FewshotConfig: ...@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot # hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified # to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int] num_fewshot: Callable[[], int]
split: Optional[str] = None split: str | None = None
sampler: Union[str, Callable] = "default" sampler: str | Callable = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
fewshot_indices: Optional[list[int]] = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -65,22 +68,20 @@ class FewshotConfig: ...@@ -65,22 +68,20 @@ class FewshotConfig:
def _get_raw_docs( def _get_raw_docs(
self, dataset self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]: ) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source.""" """Get raw documents from configured source."""
if self.split is not None: if self.split is not None:
return dataset[self.split] return dataset[self.split]
if self.samples is not None: if self.samples is not None:
if isinstance(self.samples, list): if isinstance(self.samples, list) or callable(self.samples):
return self.samples
elif callable(self.samples):
return self.samples return self.samples
else: else:
raise TypeError( raise TypeError(
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Optional[Iterable[dict]]: def get_docs(self, dataset) -> Iterable[dict] | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -100,8 +101,8 @@ class FewshotConfig: ...@@ -100,8 +101,8 @@ class FewshotConfig:
return self.sampler return self.sampler
def init_sampler( def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> "ContextSampler": ) -> ContextSampler:
"""Initialize the sampler with the given documents and task.""" """Initialize the sampler with the given documents and task."""
if rnd is None: if rnd is None:
raise ValueError( raise ValueError(
...@@ -120,49 +121,49 @@ class FewshotConfig: ...@@ -120,49 +121,49 @@ class FewshotConfig:
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: str | None = None
task_alias: Optional[str] = None task_alias: str | None = None
tag: Optional[Union[str, list]] = None tag: str | list | None = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
custom_dataset: Optional[Callable] = None custom_dataset: Callable | None = None
dataset_path: Optional[str] = None dataset_path: str | None = None
dataset_name: Optional[str] = None dataset_name: str | None = None
dataset_kwargs: Optional[dict] = field(default_factory=dict) dataset_kwargs: dict | None = field(default_factory=dict)
training_split: Optional[str] = None training_split: str | None = None
validation_split: Optional[str] = None validation_split: str | None = None
test_split: Optional[str] = None test_split: str | None = None
fewshot_split: Optional[str] = ( fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
) )
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None process_docs: Callable | None = None
doc_to_text: Optional[Union[Callable, str]] = None doc_to_text: Callable | str | None = None
doc_to_target: Optional[Union[Callable, str]] = None doc_to_target: Callable | str | None = None
doc_to_image: Union[Callable, str, None] = None doc_to_image: Callable | str | None = None
doc_to_audio: Union[Callable, str, None] = None doc_to_audio: Callable | str | None = None
unsafe_code: bool = False unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None doc_to_choice: Callable | str | dict | list | None = None
process_results: Optional[Union[Callable, str]] = None process_results: Callable | str | None = None
use_prompt: Optional[str] = None use_prompt: str | None = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None fewshot_config: dict | None = None
# runtime configuration options # runtime configuration options
num_fewshot: Optional[int] = 0 num_fewshot: int | None = 0
generation_kwargs: Optional[dict] = None generation_kwargs: dict | None = None
# scoring options # scoring options
metric_list: Optional[list] = None metric_list: list | None = None
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
repeats: int = 1 repeats: int = 1
filter_list: Optional[list[dict]] = None filter_list: list[dict] | None = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: str | None = None
gen_prefix: Optional[str] = None gen_prefix: str | None = None
metadata: Optional[dict] = field( metadata: dict | None = field(
default_factory=dict default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
...@@ -215,9 +216,7 @@ class TaskConfig(dict): ...@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None), fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
) )
def _get_metric( def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -314,7 +313,7 @@ class TaskConfig(dict): ...@@ -314,7 +313,7 @@ class TaskConfig(dict):
return metrics return metrics
@property @property
def get_filters(self) -> list["FilterConfig"]: def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
if not self.filter_list: if not self.filter_list:
...@@ -354,7 +353,7 @@ class TaskConfig(dict): ...@@ -354,7 +353,7 @@ class TaskConfig(dict):
return x return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> "TaskConfig": def from_yaml(cls, data: dict) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary.""" """Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data) return cls(**data)
......
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -11,19 +13,19 @@ class TemplateConfig: ...@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template.""" """Encapsulates information about a template."""
template: str template: str
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
description: str description: str
context_prefix: str context_prefix: str
prefix_delimiter: str prefix_delimiter: str
context_delimiter: str context_delimiter: str
answer_suffix: str answer_suffix: str
target_delimiter: str target_delimiter: str
choice_format: Optional[str] choice_format: str | None
choice_delimiter: Optional[str] choice_delimiter: str | None
fewshot_delimiter: str fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field( metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
...@@ -40,19 +42,19 @@ class MCQTemplateConfig: ...@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice. Answer:` doc_to_choice(doc)` for each choice.
""" """
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
template = "mcq" template = "mcq"
context_prefix: str = "Question:" context_prefix: str = "Question:"
prefix_delimiter: str = " " prefix_delimiter: str = " "
context_delimiter: str = "\n" context_delimiter: str = "\n"
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = "\n" target_delimiter: str = "\n"
choice_format: Optional[str] = "letters" choice_format: str | None = "letters"
choice_delimiter: Optional[str] = "\n" choice_delimiter: str | None = "\n"
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"]) metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
@dataclass @dataclass
...@@ -63,9 +65,9 @@ class ClozeTemplateConfig: ...@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>` Answer:` <doc_to_target(doc)>`
""" """
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
template: str = "cloze" template: str = "cloze"
description: str = "" description: str = ""
context_prefix: str = "Question:" context_prefix: str = "Question:"
...@@ -73,9 +75,9 @@ class ClozeTemplateConfig: ...@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter: str = "\n" context_delimiter: str = "\n"
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = " " target_delimiter: str = " "
choice_format: Optional[str] = None choice_format: str | None = None
choice_delimiter: Optional[str] = None choice_delimiter: str | None = None
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field( metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
from __future__ import annotations
from inspect import getsource from inspect import getsource
from typing import Any, Callable, Union from typing import Any, Callable
def serialize_callable( def serialize_callable(
value: Union[Callable[..., Any], str], keep_callable=False value: Callable[..., Any] | str, keep_callable=False
) -> Union[Callable[..., Any], str]: ) -> Callable[..., Any] | str:
"""Serializes a given function or string. """Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned. If 'keep_callable' is True, the original callable is returned.
...@@ -20,9 +22,7 @@ def serialize_callable( ...@@ -20,9 +22,7 @@ def serialize_callable(
return str(value) return str(value)
def maybe_serialize( def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
"""Conditionally serializes a value if it is callable.""" """Conditionally serializes a value if it is callable."""
return ( return (
......
import re import re
import sys import sys
import unicodedata import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter from lm_eval.api.registry import register_filter
...@@ -32,7 +33,9 @@ class RegexFilter(Filter): ...@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -59,59 +62,13 @@ class RegexFilter(Filter): ...@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses.""" """Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment