Commit 69d14fb3 authored by Baber's avatar Baber
Browse files

cleanup

parent 57b8c0b1
from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance
......@@ -20,7 +20,9 @@ class Filter(ABC):
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,9 +42,9 @@ class FilterEnsemble:
"""
name: str
filters: List[type[Filter]]
filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None:
def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs)
......
from __future__ import annotations
import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING:
......@@ -36,13 +38,14 @@ def register_model(*names):
return decorate
def get_model(model_name: str) -> type["LM"]:
def get_model(model_name: str) -> type[LM]:
try:
return MODEL_REGISTRY[model_name]
except KeyError:
raise ValueError(
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}"
)
except KeyError as err:
available_models = ", ".join(MODEL_REGISTRY.keys())
raise KeyError(
f"Model '{model_name}' not found. Available models: {available_models}"
) from err
TASK_REGISTRY = {}
......@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {}
AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {}
......@@ -125,7 +128,7 @@ def register_metric(**args):
return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]:
def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric:
if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name]
......@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return decorate
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
def get_aggregation(name: str) -> Callable[..., Any] | None:
try:
return AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]:
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name: str) -> Optional[bool]:
def is_higher_better(metric_name: str) -> bool | None:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......@@ -192,7 +195,7 @@ def register_filter(name: str):
return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable:
def get_filter(filter_name: str | Callable) -> Callable:
try:
return FILTER_REGISTRY[filter_name]
except KeyError as e:
......
This diff is collapsed.
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from functools import cached_property
from typing import Any, Callable, List, Optional
from typing import Any
@dataclass
......@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Optional[Callable] = None
kwargs: Optional[dict] = None
aggregation_fn: Optional[Callable] = None
fn: Callable | None = None
kwargs: Mapping[str, Any] | None = None
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
......@@ -20,7 +23,7 @@ class MetricConfig:
return self.name
@cached_property
def aggregation(self) -> Callable:
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
......@@ -28,7 +31,7 @@ class MetricConfig:
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool:
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
......@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any:
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values)
return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union
from typing import TYPE_CHECKING, Callable
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
......@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N"
kwargs: Optional[dict] = field(default_factory=dict)
metric_fn: str | Callable = "pass@N"
kwargs: dict | None = field(default_factory=dict)
@dataclass
......@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: Optional[str] = None
sampler: Union[str, Callable] = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None
fewshot_indices: Optional[list[int]] = None
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
......@@ -65,22 +68,20 @@ class FewshotConfig:
def _get_raw_docs(
self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]:
) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
if isinstance(self.samples, list) or callable(self.samples):
return self.samples
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Optional[Iterable[dict]]:
def get_docs(self, dataset) -> Iterable[dict] | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
......@@ -100,8 +101,8 @@ class FewshotConfig:
return self.sampler
def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None
) -> "ContextSampler":
self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> ContextSampler:
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
......@@ -120,49 +121,49 @@ class FewshotConfig:
@dataclass
class TaskConfig(dict):
# task naming/registry
task: Optional[str] = None
task_alias: Optional[str] = None
tag: Optional[Union[str, list]] = None
task: str | None = None
task_alias: str | None = None
tag: str | list | None = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Optional[Callable] = None
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = field(default_factory=dict)
training_split: Optional[str] = None
validation_split: Optional[str] = None
test_split: Optional[str] = None
fewshot_split: Optional[str] = (
custom_dataset: Callable | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
process_docs: Callable | None = None
doc_to_text: Callable | str | None = None
doc_to_target: Callable | str | None = None
doc_to_image: Callable | str | None = None
doc_to_audio: Callable | str | None = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
use_prompt: Optional[str] = None
doc_to_choice: Callable | str | dict | list | None = None
process_results: Callable | str | None = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None
fewshot_config: dict | None = None
# runtime configuration options
num_fewshot: Optional[int] = 0
generation_kwargs: Optional[dict] = None
num_fewshot: int | None = 0
generation_kwargs: dict | None = None
# scoring options
metric_list: Optional[list] = None
metric_list: list | None = None
output_type: OutputType = "generate_until"
repeats: int = 1
filter_list: Optional[list[dict]] = None
filter_list: list[dict] | None = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = field(
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
metadata: dict | None = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
......@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
def _get_metric(
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
......@@ -314,7 +313,7 @@ class TaskConfig(dict):
return metrics
@property
def get_filters(self) -> list["FilterConfig"]:
def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
......@@ -354,7 +353,7 @@ class TaskConfig(dict):
return x
@classmethod
def from_yaml(cls, data: dict) -> "TaskConfig":
def from_yaml(cls, data: dict) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
......
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union
from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING:
......@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template."""
template: str
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: Optional[str]
choice_delimiter: Optional[str]
choice_format: str | None
choice_delimiter: str | None
fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field(
metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
......@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice.
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: Optional[str] = "letters"
choice_delimiter: Optional[str] = "\n"
choice_format: str | None = "letters"
choice_delimiter: str | None = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"])
metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
@dataclass
......@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: Union[str, Callable[[dict], str]]
doc_to_choice: Union[str, list, Callable[[dict], list]]
doc_to_target: Union[int, Callable[[dict], int]]
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
......@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: Optional[str] = None
choice_delimiter: Optional[str] = None
choice_format: str | None = None
choice_delimiter: str | None = None
fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(
metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
from __future__ import annotations
from inspect import getsource
from typing import Any, Callable, Union
from typing import Any, Callable
def serialize_callable(
value: Union[Callable[..., Any], str], keep_callable=False
) -> Union[Callable[..., Any], str]:
value: Callable[..., Any] | str, keep_callable=False
) -> Callable[..., Any] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
......@@ -20,9 +22,7 @@ def serialize_callable(
return str(value)
def maybe_serialize(
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
"""Conditionally serializes a value if it is callable."""
return (
......
import re
import sys
import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self.group_select = group_select
self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst):
filtered_resp = []
for resp in inst:
......@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment