Commit 69d14fb3 authored by Baber's avatar Baber
Browse files

cleanup

parent 57b8c0b1
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections.abc import Iterable
from dataclasses import dataclass from dataclasses import dataclass
from typing import Iterable, List, Union
from lm_eval.api.instance import Instance from lm_eval.api.instance import Instance
...@@ -20,7 +20,9 @@ class Filter(ABC): ...@@ -20,7 +20,9 @@ class Filter(ABC):
""" """
@abstractmethod @abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
""" """
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects. Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g. Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
...@@ -40,9 +42,9 @@ class FilterEnsemble: ...@@ -40,9 +42,9 @@ class FilterEnsemble:
""" """
name: str name: str
filters: List[type[Filter]] filters: list[type[Filter]]
def apply(self, instances: List[Instance]) -> None: def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances)) resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
resps, docs = list(resps), list(docs) resps, docs = list(resps), list(docs)
......
from __future__ import annotations
import logging import logging
from typing import TYPE_CHECKING, Callable, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -36,13 +38,14 @@ def register_model(*names): ...@@ -36,13 +38,14 @@ def register_model(*names):
return decorate return decorate
def get_model(model_name: str) -> type["LM"]: def get_model(model_name: str) -> type[LM]:
try: try:
return MODEL_REGISTRY[model_name] return MODEL_REGISTRY[model_name]
except KeyError: except KeyError as err:
raise ValueError( available_models = ", ".join(MODEL_REGISTRY.keys())
f"Attempted to load model '{model_name}', but no model for this name found! Supported model names: {', '.join(MODEL_REGISTRY.keys())}" raise KeyError(
) f"Model '{model_name}' not found. Available models: {available_models}"
) from err
TASK_REGISTRY = {} TASK_REGISTRY = {}
...@@ -81,7 +84,7 @@ def register_group(name): ...@@ -81,7 +84,7 @@ def register_group(name):
OUTPUT_TYPE_REGISTRY = {} OUTPUT_TYPE_REGISTRY = {}
METRIC_REGISTRY = {} METRIC_REGISTRY = {}
METRIC_AGGREGATION_REGISTRY = {} METRIC_AGGREGATION_REGISTRY = {}
AGGREGATION_REGISTRY: Dict[str, Callable[[], Dict[str, Callable]]] = {} AGGREGATION_REGISTRY: dict[str, Callable[[], dict[str, Callable]]] = {}
HIGHER_IS_BETTER_REGISTRY = {} HIGHER_IS_BETTER_REGISTRY = {}
FILTER_REGISTRY = {} FILTER_REGISTRY = {}
...@@ -125,7 +128,7 @@ def register_metric(**args): ...@@ -125,7 +128,7 @@ def register_metric(**args):
return decorate return decorate
def get_metric(name: str, hf_evaluate_metric=False) -> Optional[Callable]: def get_metric(name: str, hf_evaluate_metric=False) -> Callable[..., Any] | None:
if not hf_evaluate_metric: if not hf_evaluate_metric:
if name in METRIC_REGISTRY: if name in METRIC_REGISTRY:
return METRIC_REGISTRY[name] return METRIC_REGISTRY[name]
...@@ -157,21 +160,21 @@ def register_aggregation(name: str): ...@@ -157,21 +160,21 @@ def register_aggregation(name: str):
return decorate return decorate
def get_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]: def get_aggregation(name: str) -> Callable[..., Any] | None:
try: try:
return AGGREGATION_REGISTRY[name] return AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} not a registered aggregation metric!") eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callable]]]: def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
try: try:
return METRIC_AGGREGATION_REGISTRY[name] return METRIC_AGGREGATION_REGISTRY[name]
except KeyError: except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name: str) -> Optional[bool]: def is_higher_better(metric_name: str) -> bool | None:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
...@@ -192,7 +195,7 @@ def register_filter(name: str): ...@@ -192,7 +195,7 @@ def register_filter(name: str):
return decorate return decorate
def get_filter(filter_name: Union[str, Callable]) -> Callable: def get_filter(filter_name: str | Callable) -> Callable:
try: try:
return FILTER_REGISTRY[filter_name] return FILTER_REGISTRY[filter_name]
except KeyError as e: except KeyError as e:
......
This diff is collapsed.
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass
from functools import cached_property from functools import cached_property
from typing import Any, Callable, List, Optional from typing import Any
@dataclass @dataclass
...@@ -8,9 +11,9 @@ class MetricConfig: ...@@ -8,9 +11,9 @@ class MetricConfig:
"""Encapsulates information about a single metric.""" """Encapsulates information about a single metric."""
name: str name: str
fn: Optional[Callable] = None fn: Callable | None = None
kwargs: Optional[dict] = None kwargs: Mapping[str, Any] | None = None
aggregation_fn: Optional[Callable] = None aggregation_fn: Callable | None = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
is_elementwise: bool = True is_elementwise: bool = True
...@@ -20,7 +23,7 @@ class MetricConfig: ...@@ -20,7 +23,7 @@ class MetricConfig:
return self.name return self.name
@cached_property @cached_property
def aggregation(self) -> Callable: def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None: if self.aggregation_fn is None:
...@@ -28,7 +31,7 @@ class MetricConfig: ...@@ -28,7 +31,7 @@ class MetricConfig:
return self.aggregation_fn return self.aggregation_fn
@cached_property @cached_property
def _higher_is_better(self) -> bool: def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None: if self.higher_is_better is None:
...@@ -39,10 +42,10 @@ class MetricConfig: ...@@ -39,10 +42,10 @@ class MetricConfig:
"""Calculates the metric using the provided function and arguments.""" """Calculates the metric using the provided function and arguments."""
if self.fn is None: if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs}) return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, values: List[Any]) -> Any: def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values.""" """Computes the aggregation of the metric values."""
if self.aggregation_fn is None: if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.") raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(values) return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable, Iterable, Optional, Union from typing import TYPE_CHECKING, Callable
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -20,8 +23,8 @@ class RepeatConfig: ...@@ -20,8 +23,8 @@ class RepeatConfig:
"""Encapsulates information about a single repeat.""" """Encapsulates information about a single repeat."""
repeats: int = 1 repeats: int = 1
metric_fn: Union[str, Callable] = "pass@N" metric_fn: str | Callable = "pass@N"
kwargs: Optional[dict] = field(default_factory=dict) kwargs: dict | None = field(default_factory=dict)
@dataclass @dataclass
...@@ -38,11 +41,11 @@ class FewshotConfig: ...@@ -38,11 +41,11 @@ class FewshotConfig:
# hack: this returns task.config.num_fewshot # hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified # to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int] num_fewshot: Callable[[], int]
split: Optional[str] = None split: str | None = None
sampler: Union[str, Callable] = "default" sampler: str | Callable = "default"
samples: Union[Callable[[], list[dict]], list[dict], None] = None samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Optional[Callable[[list[dict]], Iterable[dict]]] = None process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
fewshot_indices: Optional[list[int]] = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
def __post_init__(self) -> None: def __post_init__(self) -> None:
...@@ -65,22 +68,20 @@ class FewshotConfig: ...@@ -65,22 +68,20 @@ class FewshotConfig:
def _get_raw_docs( def _get_raw_docs(
self, dataset self, dataset
) -> Union[list[dict], Callable[[], Iterable[dict]], None]: ) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source.""" """Get raw documents from configured source."""
if self.split is not None: if self.split is not None:
return dataset[self.split] return dataset[self.split]
if self.samples is not None: if self.samples is not None:
if isinstance(self.samples, list): if isinstance(self.samples, list) or callable(self.samples):
return self.samples
elif callable(self.samples):
return self.samples return self.samples
else: else:
raise TypeError( raise TypeError(
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Optional[Iterable[dict]]: def get_docs(self, dataset) -> Iterable[dict] | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -100,8 +101,8 @@ class FewshotConfig: ...@@ -100,8 +101,8 @@ class FewshotConfig:
return self.sampler return self.sampler
def init_sampler( def init_sampler(
self, docs: list[dict], task: "Task", rnd=None, fewshot_indices=None self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> "ContextSampler": ) -> ContextSampler:
"""Initialize the sampler with the given documents and task.""" """Initialize the sampler with the given documents and task."""
if rnd is None: if rnd is None:
raise ValueError( raise ValueError(
...@@ -120,49 +121,49 @@ class FewshotConfig: ...@@ -120,49 +121,49 @@ class FewshotConfig:
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
task: Optional[str] = None task: str | None = None
task_alias: Optional[str] = None task_alias: str | None = None
tag: Optional[Union[str, list]] = None tag: str | list | None = None
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
custom_dataset: Optional[Callable] = None custom_dataset: Callable | None = None
dataset_path: Optional[str] = None dataset_path: str | None = None
dataset_name: Optional[str] = None dataset_name: str | None = None
dataset_kwargs: Optional[dict] = field(default_factory=dict) dataset_kwargs: dict | None = field(default_factory=dict)
training_split: Optional[str] = None training_split: str | None = None
validation_split: Optional[str] = None validation_split: str | None = None
test_split: Optional[str] = None test_split: str | None = None
fewshot_split: Optional[str] = ( fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?) None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
) )
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Optional[Callable] = None process_docs: Callable | None = None
doc_to_text: Optional[Union[Callable, str]] = None doc_to_text: Callable | str | None = None
doc_to_target: Optional[Union[Callable, str]] = None doc_to_target: Callable | str | None = None
doc_to_image: Union[Callable, str, None] = None doc_to_image: Callable | str | None = None
doc_to_audio: Union[Callable, str, None] = None doc_to_audio: Callable | str | None = None
unsafe_code: bool = False unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None doc_to_choice: Callable | str | dict | list | None = None
process_results: Optional[Union[Callable, str]] = None process_results: Callable | str | None = None
use_prompt: Optional[str] = None use_prompt: str | None = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: Optional[dict] = None fewshot_config: dict | None = None
# runtime configuration options # runtime configuration options
num_fewshot: Optional[int] = 0 num_fewshot: int | None = 0
generation_kwargs: Optional[dict] = None generation_kwargs: dict | None = None
# scoring options # scoring options
metric_list: Optional[list] = None metric_list: list | None = None
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
repeats: int = 1 repeats: int = 1
filter_list: Optional[list[dict]] = None filter_list: list[dict] | None = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: str | None = None
gen_prefix: Optional[str] = None gen_prefix: str | None = None
metadata: Optional[dict] = field( metadata: dict | None = field(
default_factory=dict default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks ) # by default, not used in the code. allows for users to pass arbitrary info to tasks
...@@ -215,9 +216,7 @@ class TaskConfig(dict): ...@@ -215,9 +216,7 @@ class TaskConfig(dict):
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None), fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
) )
def _get_metric( def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
self, metric_list: Optional[list[dict]] = None
) -> list["MetricConfig"]:
from lm_eval.api.registry import ( from lm_eval.api.registry import (
AGGREGATION_REGISTRY, AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY, DEFAULT_METRIC_REGISTRY,
...@@ -314,7 +313,7 @@ class TaskConfig(dict): ...@@ -314,7 +313,7 @@ class TaskConfig(dict):
return metrics return metrics
@property @property
def get_filters(self) -> list["FilterConfig"]: def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
if not self.filter_list: if not self.filter_list:
...@@ -354,7 +353,7 @@ class TaskConfig(dict): ...@@ -354,7 +353,7 @@ class TaskConfig(dict):
return x return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> "TaskConfig": def from_yaml(cls, data: dict) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary.""" """Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data) return cls(**data)
......
from __future__ import annotations
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable, Optional, Union from typing import TYPE_CHECKING, Callable
if TYPE_CHECKING: if TYPE_CHECKING:
...@@ -11,19 +13,19 @@ class TemplateConfig: ...@@ -11,19 +13,19 @@ class TemplateConfig:
"""Encapsulates information about a template.""" """Encapsulates information about a template."""
template: str template: str
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
description: str description: str
context_prefix: str context_prefix: str
prefix_delimiter: str prefix_delimiter: str
context_delimiter: str context_delimiter: str
answer_suffix: str answer_suffix: str
target_delimiter: str target_delimiter: str
choice_format: Optional[str] choice_format: str | None
choice_delimiter: Optional[str] choice_delimiter: str | None
fewshot_delimiter: str fewshot_delimiter: str
metric_list: Optional[Union[list[str], list["MetricConfig"]]] = field( metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
...@@ -40,19 +42,19 @@ class MCQTemplateConfig: ...@@ -40,19 +42,19 @@ class MCQTemplateConfig:
Answer:` doc_to_choice(doc)` for each choice. Answer:` doc_to_choice(doc)` for each choice.
""" """
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
template = "mcq" template = "mcq"
context_prefix: str = "Question:" context_prefix: str = "Question:"
prefix_delimiter: str = " " prefix_delimiter: str = " "
context_delimiter: str = "\n" context_delimiter: str = "\n"
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = "\n" target_delimiter: str = "\n"
choice_format: Optional[str] = "letters" choice_format: str | None = "letters"
choice_delimiter: Optional[str] = "\n" choice_delimiter: str | None = "\n"
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field(default_factory=lambda: ["acc"]) metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
@dataclass @dataclass
...@@ -63,9 +65,9 @@ class ClozeTemplateConfig: ...@@ -63,9 +65,9 @@ class ClozeTemplateConfig:
Answer:` <doc_to_target(doc)>` Answer:` <doc_to_target(doc)>`
""" """
doc_to_text: Union[str, Callable[[dict], str]] doc_to_text: str | Callable[[dict], str]
doc_to_choice: Union[str, list, Callable[[dict], list]] doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: Union[int, Callable[[dict], int]] doc_to_target: int | Callable[[dict], int]
template: str = "cloze" template: str = "cloze"
description: str = "" description: str = ""
context_prefix: str = "Question:" context_prefix: str = "Question:"
...@@ -73,9 +75,9 @@ class ClozeTemplateConfig: ...@@ -73,9 +75,9 @@ class ClozeTemplateConfig:
context_delimiter: str = "\n" context_delimiter: str = "\n"
answer_suffix: str = "Answer:" answer_suffix: str = "Answer:"
target_delimiter: str = " " target_delimiter: str = " "
choice_format: Optional[str] = None choice_format: str | None = None
choice_delimiter: Optional[str] = None choice_delimiter: str | None = None
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
metric_list: Optional[list["MetricConfig"]] = field( metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"] default_factory=lambda: ["acc", "acc_norm"]
) )
from __future__ import annotations
from inspect import getsource from inspect import getsource
from typing import Any, Callable, Union from typing import Any, Callable
def serialize_callable( def serialize_callable(
value: Union[Callable[..., Any], str], keep_callable=False value: Callable[..., Any] | str, keep_callable=False
) -> Union[Callable[..., Any], str]: ) -> Callable[..., Any] | str:
"""Serializes a given function or string. """Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned. If 'keep_callable' is True, the original callable is returned.
...@@ -20,9 +22,7 @@ def serialize_callable( ...@@ -20,9 +22,7 @@ def serialize_callable(
return str(value) return str(value)
def maybe_serialize( def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
val: Union[Callable, Any], keep_callable=False
) -> Union[Callable, Any]:
"""Conditionally serializes a value if it is callable.""" """Conditionally serializes a value if it is callable."""
return ( return (
......
import re import re
import sys import sys
import unicodedata import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter from lm_eval.api.registry import register_filter
...@@ -32,7 +33,9 @@ class RegexFilter(Filter): ...@@ -32,7 +33,9 @@ class RegexFilter(Filter):
self.group_select = group_select self.group_select = group_select
self.fallback = fallback self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
...@@ -59,59 +62,13 @@ class RegexFilter(Filter): ...@@ -59,59 +62,13 @@ class RegexFilter(Filter):
return filtered_resps return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace") @register_filter("remove_whitespace")
class WhitespaceFilter(Filter): class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses.""" """Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst): def filter_set(inst):
filtered_resp = [] filtered_resp = []
for resp in inst: for resp in inst:
...@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter): ...@@ -156,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]: def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is # here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair. # a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets) # so we process each of these (same input/target response sets)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment