Commit 69d14fb3 authored by Baber's avatar Baber
Browse files

cleanup

parent 57b8c0b1
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance
......@@ -20,7 +20,9 @@ class Filter(ABC):
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
name: str
filters: List[type[Filter]]
filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None:
def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
......
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
......@@ -36,13 +38,14 @@ def register_model(*names):
return decorate
def get_model(model_name: str) -> type["LM"]:
def get_model(model_name: str) -> type[LM]:
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
)
except KeyError as err:
available_models = ", ".join(MODEL_REGISTRY.keys())
raise KeyError(
f"Model '{model_name}' not found. Available models: {available_models}"
) from err
TASK_REGISTRY = {}
......@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
......@@ -125,7 +128,7 @@ def register_metric(**args):
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
def get_aggregation(name: str) -> Callable[..., Any] | None:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name: str) -> Optional[bool]:
def is_higher_better(metric_name: str) -> bool | None:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......@@ -192,7 +195,7 @@ def register_filter(name: str):
return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable:
def get_filter(filter_name: str | Callable) -> Callable:
try:
return FILTER_REGISTRY[filter_name]
except KeyError as e:
......
from __future__ import annotations
import abc
import ast
import logging
......@@ -8,15 +10,7 @@ from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
Iterator,
List,
Literal,
Mapping,
Optional,
Tuple,
Union,
)
import datasets
......@@ -57,23 +51,23 @@ class Task(abc.ABC):
{"question": ..., question, answer)
"""
VERSION: Optional[Union[int, str]] = None
VERSION: int | str | None = None
# The name of the `Task` benchmark as denoted in the HuggingFace datasets Hub
# or a path to a custom `datasets` loading script.
DATASET_PATH: Optional[str] = None
DATASET_PATH: str | None = None
# The name of a subset within `DATASET_PATH`.
DATASET_NAME: Optional[str] = None
DATASET_NAME: str | None = None
OUTPUT_TYPE: Optional[OutputType] = None
OUTPUT_TYPE: OutputType | None = None
def __init__(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
download_mode: Optional[datasets.DownloadMode] = None,
config: Optional[Mapping] = None, # Union[dict, TaskConfig]
data_dir: str | None = None,
cache_dir: str | None = None,
download_mode: datasets.DownloadMode | None = None,
config: Mapping | None = None, # Union[dict, TaskConfig]
) -> None:
"""
:param data_dir: str
......@@ -97,21 +91,21 @@ class Task(abc.ABC):
Fresh download and fresh dataset.
"""
self.download(data_dir, cache_dir, download_mode)
self._training_docs: Optional[list] = None
self._fewshot_docs: Optional[list] = None
self._instances: Optional[List[Instance]] = None
self._training_docs: list | None = None
self._fewshot_docs: list | None = None
self._instances: list[Instance] | None = None
self._config: TaskConfig = TaskConfig.from_yaml({**config})
self._filters = [build_filter_ensemble("none", [("take_first", None)])]
self.fewshot_rnd: Optional[random.Random] = (
self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage
)
def download(
self,
data_dir: Optional[str] = None,
cache_dir: Optional[str] = None,
data_dir: str | None = None,
cache_dir: str | None = None,
download_mode=None,
) -> None:
"""Downloads and returns the task dataset.
......@@ -238,7 +232,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def doc_to_target(self, doc: dict) -> Union[str, int]:
def doc_to_target(self, doc: dict) -> str | int:
pass
# not an abstractmethod because not every language-only task has to implement this
......@@ -254,16 +248,16 @@ class Task(abc.ABC):
def build_all_requests(
self,
*,
limit: Union[int, None] = None,
samples: Optional[List[int]] = None,
limit: int | None = None,
samples: list[int] | None = None,
rank: int = 0,
world_size: int = 1,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
system_instruction: Optional[str] = None,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
chat_template: Callable | None = None,
tokenizer_name: str = "",
) -> None:
"""Build a set of Instances for a task, and store them in task.instances"""
......@@ -365,7 +359,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod
def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs):
def construct_requests(self, doc: dict, ctx: list[dict] | str, **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -405,7 +399,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
functions that aggregate a list of metric scores
"""
pass
return True
@deprecated("not used anymore")
def higher_is_better(self):
......@@ -414,7 +408,7 @@ class Task(abc.ABC):
A dictionary where keys are the names of submetrics and values are
whether a higher value of the submetric is better
"""
pass
return True
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
......@@ -488,13 +482,15 @@ class Task(abc.ABC):
example = self.doc_to_text(doc)
return description + labeled_examples + example
def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
if hasattr(self, "_filters") and self._instances:
for f in self._filters:
f.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
eval_logger.warning(
"No filter defined or no instances, passing through instances"
)
return self._instances
def dump_config(self) -> dict:
......@@ -505,9 +501,6 @@ class Task(abc.ABC):
def set_config(self, key: str, value: Any, update: bool = False) -> None:
"""Set or update the configuration for a given key."""
if key is None:
raise ValueError("Key must be provided.")
if update:
current_value = getattr(self._config, key, {})
if not isinstance(current_value, dict):
......@@ -533,13 +526,13 @@ class Task(abc.ABC):
setattr(self._config, "metric_list", [MetricConfig(name=metric_name)])
setattr(self._config, "process_results", lambda *args: {"bypass": 0})
def set_fewshot_seed(self, seed: Optional[int] = None) -> None:
def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd
@property
def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]:
def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
......@@ -553,13 +546,13 @@ class Task(abc.ABC):
self,
*,
rank: int = 0,
limit: Union[int, None] = None,
limit: int | None = None,
world_size: int = 1,
samples: Optional[List[int]] = None,
) -> Iterator[Tuple[int, Any]]:
samples: list[int] | None = None,
) -> Iterator[tuple[int, Any]]:
if samples:
n = len(self.eval_docs)
assert all([e < n for e in samples]), (
assert all(e < n for e in samples), (
f"Elements of --samples should be in the interval [0,k-1] where k is the number of total examples. In this case, k={n}."
)
eval_logger.info(
......@@ -592,7 +585,7 @@ class ConfigurableTask(Task):
data_dir=None,
cache_dir=None,
download_mode=None,
config: Optional[dict] = None,
config: dict | None = None,
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -610,9 +603,8 @@ class ConfigurableTask(Task):
"Must pass a config to ConfigurableTask, either in cls.CONFIG or `config` kwarg"
)
if isinstance(self.config.metadata, dict):
if "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if isinstance(self.config.metadata, dict) and "version" in self.config.metadata:
self.VERSION = self.config.metadata["version"]
if self.config.output_type is not None:
if self.config.output_type not in ALL_OUTPUT_TYPES:
......@@ -698,18 +690,13 @@ class ConfigurableTask(Task):
else:
test_target = str(test_target)
if test_choice is not None:
check_choices = test_choice
else:
check_choices = [test_target]
check_choices = test_choice if test_choice is not None else [test_target]
if self.config.doc_to_choice is not None:
for choice in check_choices:
choice_has_whitespace = True if choice[0].isspace() else False
choice_has_whitespace = choice[0].isspace()
delimiter_has_whitespace = (
True
if self.config.target_delimiter.rstrip()
self.config.target_delimiter.rstrip()
!= self.config.target_delimiter
else False
)
if delimiter_has_whitespace and choice_has_whitespace:
......@@ -722,7 +709,7 @@ class ConfigurableTask(Task):
)
def download(
self, dataset_kwargs: Optional[Dict[str, Any]] = None, **kwargs
self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
) -> None:
from packaging.version import parse as vparse
......@@ -746,24 +733,15 @@ class ConfigurableTask(Task):
)
def has_training_docs(self) -> bool:
if self.config.training_split is not None:
return True
else:
return False
return self.config.training_split is not None
def has_validation_docs(self) -> bool:
if self.config.validation_split is not None:
return True
else:
return False
return self.config.validation_split is not None
def has_test_docs(self) -> bool:
if self.config.test_split is not None:
return True
else:
return False
return self.config.test_split is not None
def training_docs(self) -> Optional[datasets.Dataset]:
def training_docs(self) -> datasets.Dataset | None:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -771,7 +749,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> Optional[datasets.Dataset]:
def validation_docs(self) -> datasets.Dataset | None:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -779,7 +757,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> Optional[datasets.Dataset]:
def test_docs(self) -> datasets.Dataset | None:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
......@@ -792,22 +770,25 @@ class ConfigurableTask(Task):
return docs
# Fallback to parent implementation
if _num_fewshot := getattr(self.config, "num_fewshot"):
if isinstance(_num_fewshot, int) and _num_fewshot > 0:
eval_logger.warning(
f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
if (
(_num_fewshot := self.config.num_fewshot)
and isinstance(_num_fewshot, int)
and _num_fewshot > 0
):
eval_logger.warning(
f"[Task: {self.config.task}] "
"num_fewshot > 0 but no fewshot source configured. "
"Using preconfigured rule."
)
return super().fewshot_docs()
@staticmethod
def append_target_question(
labeled_examples: List[Dict[str, str]],
labeled_examples: list[dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None,
gen_prefix: str | None = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
......@@ -831,12 +812,12 @@ class ConfigurableTask(Task):
self,
doc: dict,
num_fewshot: int,
system_instruction: Optional[str] = None,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str], None]:
chat_template: Callable | None = None,
gen_prefix: str | None = None,
) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -857,10 +838,7 @@ class ConfigurableTask(Task):
:returns: str
The fewshot context.
"""
if apply_chat_template:
labeled_examples = []
else:
labeled_examples = ""
labeled_examples = [] if apply_chat_template else ""
# get task description
if description := self.config.description:
......@@ -930,7 +908,7 @@ class ConfigurableTask(Task):
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=False if gen_prefix else True,
add_generation_prompt=not gen_prefix,
)
)
return labeled_examples_list
......@@ -954,7 +932,7 @@ class ConfigurableTask(Task):
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=False if gen_prefix else True,
add_generation_prompt=not gen_prefix,
)
else:
prefix = (
......@@ -975,7 +953,7 @@ class ConfigurableTask(Task):
else:
return labeled_examples + str(example) + prefix
def apply_filters(self) -> Optional[List[Instance]]:
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
for f in self._filters:
......@@ -1015,9 +993,7 @@ class ConfigurableTask(Task):
"""
return doc
def doc_to_text(
self, doc: dict, doc_to_text: Union[int, str, Callable, None] = None
):
def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None):
# if self.prompt is not None:
# doc_to_text = self.prompt
if doc_to_text is not None:
......@@ -1053,9 +1029,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(
self, doc: dict, doc_to_target=None
) -> Union[int, str, list[int]]:
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None:
# doc_to_target = self.prompt
if doc_to_target is not None:
......@@ -1104,8 +1078,8 @@ class ConfigurableTask(Task):
def doc_to_choice(
self,
doc: dict,
doc_to_choice: Union[str, list, dict, Callable[..., list[str]], None] = None,
) -> List[str]:
doc_to_choice: str | list | dict | Callable[..., list[str]] | None = None,
) -> list[str]:
# if self.prompt is not None:
# doc_to_choice = self.prompt
if doc_to_choice is not None:
......@@ -1132,7 +1106,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]:
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1155,7 +1129,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]:
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None:
......@@ -1178,7 +1152,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc: dict) -> Optional[str]:
def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
......@@ -1188,7 +1162,7 @@ class ConfigurableTask(Task):
def construct_requests(
self, doc: dict, ctx: str, **kwargs
) -> Union[List[Instance], Instance]:
) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
......@@ -1324,7 +1298,7 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "multiple_choice":
lls, is_greedy = zip(*results)
# retrieve choices in List[str] form, to compute choice lengths, etc.
# retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc)
completion_len = np.array([float(len(i)) for i in choices])
......@@ -1371,7 +1345,7 @@ class ConfigurableTask(Task):
if self.multiple_target:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any([is_greedy[i] if i != -100 else 0 for i in gold]))
exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
else:
acc = 1.0 if pred == gold else 0.0
acc_norm = 1.0 if pred_norm == gold else 0.0
......@@ -1413,7 +1387,7 @@ class ConfigurableTask(Task):
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
for metric in self._metric_fn_list.keys():
for metric in self._metric_fn_list:
try:
result_score = self._metric_fn_list[metric](
references=[gold] if not isinstance(gold, list) else gold,
......@@ -1447,7 +1421,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None)
@property
def task_name(self) -> Optional[str]:
def task_name(self) -> str | None:
return getattr(self.config, "task", None)
def __repr__(self):
......@@ -1465,7 +1439,7 @@ class MultipleChoiceTask(Task):
def doc_to_target(self, doc: dict) -> str:
return " " + doc["choices"][doc["gold"]]
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> List[Instance]:
def construct_requests(self, doc: dict, ctx: str, **kwargs) -> list[Instance]:
# TODO: add mutual info here?
return [
Instance(
......@@ -1478,7 +1452,7 @@ class MultipleChoiceTask(Task):
for i, choice in enumerate(doc["choices"])
]
def process_results(self, doc: dict, results: Iterable[Tuple[float, bool]]) -> dict:
def process_results(self, doc: dict, results: Iterable[tuple[float, bool]]) -> dict:
results = [
res[0] for res in results
] # only retain loglikelihoods, discard is_greedy TODO: do we need is_greedy anywhere?
......@@ -1512,7 +1486,7 @@ class PerplexityTask(Task):
def has_training_docs(self) -> bool:
return False
def fewshot_examples(self, k: int, rnd) -> List:
def fewshot_examples(self, k: int, rnd) -> list:
if k != 0:
raise ValueError(
"The number of fewshot examples must be 0 for perplexity tasks."
......@@ -1543,7 +1517,7 @@ class PerplexityTask(Task):
def doc_to_target(self, doc):
return doc
def construct_requests(self, doc: dict, ctx: Optional[str], **kwargs):
def construct_requests(self, doc: dict, ctx: str | None, **kwargs):
if bool(ctx):
raise ValueError
......@@ -1555,7 +1529,7 @@ class PerplexityTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: Tuple[float]) -> dict:
def process_results(self, doc: dict, results: tuple[float]) -> dict:
(loglikelihood,) = results
words = self.count_words(self.doc_to_target(doc))
bytes_ = self.count_bytes(self.doc_to_target(doc))
......
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, List, Optional
from typing import Any
@dataclass
......@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
fn: Callable | None = None
kwargs: Mapping[str, Any] | None = None
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
......@@ -20,7 +23,7 @@ class MetricConfig:
return self.name
@cached_property
def aggregation(self) -> Callable:
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
......@@ -28,7 +31,7 @@ class MetricConfig:
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
......@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from typing import TYPE_CHECKING, Callable
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
......@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N"
kwargs: Optional[dict] = field(default_factory=dict)
metric_fn: str | Callable = "pass@N"
kwargs: dict | None = field(default_factory=dict)
@dataclass
......@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: Optional[str] = None
sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None
fewshot_indices: Optional[list[int]] = None
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
......@@ -65,22 +68,20 @@ class FewshotConfig:
def _get_raw_docs(
self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]:
) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
if isinstance(self.samples, list) or callable(self.samples):
return self.samples
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Optional[Iterable[dict]]:
def get_docs(self, dataset) -> Iterable[dict] | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
......@@ -100,8 +101,8 @@ class FewshotConfig:
return self.sampler
def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None
) -> "ContextSampler":
self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> ContextSampler:
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
......@@ -120,49 +121,49 @@ class FewshotConfig:
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
task: str | None = None
task_alias: str | None = None
tag: str | list | None = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
custom_dataset: Callable | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
process_docs: Callable | None = None
doc_to_text: Callable | str | None = None
doc_to_target: Callable | str | None = None
doc_to_image: Callable | str | None = None
doc_to_audio: Callable | str | None = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
doc_to_choice: Callable | str | dict | list | None = None
process_results: Callable | str | None = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
fewshot_config: dict | None = None
# runtime configuration options
num_fewshot: Optional[int] = 0
generation_kwargs: Optional[dict] = None
num_fewshot: int | None = 0
generation_kwargs: dict | None = None
# scoring options
metric_list: Optional[list] = None
metric_list: list | None = None
output_type: OutputType = "generate_until"
repeats: int = 1
filter_list: Optional[list[dict]] = None
filter_list: list[dict] | None = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = field(
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
metadata: dict | None = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
......@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
def _get_metric(
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -314,7 +313,7 @@ class TaskConfig(dict):
return metrics
@property
def get_filters(self) -> list["FilterConfig"]:
def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
......@@ -354,7 +353,7 @@ class TaskConfig(dict):
return x
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
def from_yaml(cls, data: dict) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
......
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING:
......@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
choice_format: str | None
choice_delimiter: str | None
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field(
metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
......@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
choice_format: str | None = "letters"
choice_delimiter: str | None = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"])
metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
@dataclass
......@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
......@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
choice_format: str | None = None
choice_delimiter: str | None = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(
metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
from __future__ import annotations
from inspect import getsource
from typing import Any, Callable, Union
from typing import Any, Callable
def serialize_callable(
value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable[..., Any], str]:
value: Callable[..., Any] | str, keep_callable=False
) -> Callable[..., Any] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
......@@ -20,9 +22,7 @@ def serialize_callable(
return str(value)
def maybe_serialize(
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
"""Conditionally serializes a value if it is callable."""
return (
......
import re
import sys
import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self.group_select = group_select
self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst):
filtered_resp = []
for resp in inst:
......@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment