Commit 312374bc authored by Baber's avatar Baber
Browse files

type hints

parent 90cf3b89
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Callable, Iterable, List, Union
from typing import Callable, Iterable, TypeVar
from lm_eval.api.instance import Instance
T = TypeVar("T")
class Filter(ABC):
"""
Filter classes operate on a per-task level.
......@@ -20,7 +23,7 @@ class Filter(ABC):
"""
@abstractmethod
def apply(self, resps: Union[List, Iterable], docs: List[dict]) -> Iterable:
def apply(self, resps: Iterable[list[T]], docs: list[dict]) -> Iterable[list[T]]:
"""
Defines the operation to perform on a list of the `inst.resps` properties of `Instance` objects.
Should return the list of (filtered) response lists *in the same order as they were input*, e.g.
......@@ -40,9 +43,9 @@ class FilterEnsemble:
"""
name: str
filters: List[Callable[[], Filter]]
filters: list[Callable[[], Filter]]
def apply(self, instances: List[Instance]) -> None:
def apply(self, instances: list[Instance]) -> None:
resps, docs = zip(*((inst.resps, inst.doc) for inst in instances))
# TODO: add backward
# unwrap responses from GenerateOutput as the filters expect strings
......
import abc
import ast
import itertools
import logging
import random
import re
......@@ -109,24 +110,19 @@ class TaskConfig(dict):
)
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if self.output_type == "generate_until":
if self.generation_kwargs is not None:
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
......@@ -140,6 +136,11 @@ class TaskConfig(dict):
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
else:
if self.generation_kwargs is not None:
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
def __getitem__(self, item):
return getattr(self, item)
......@@ -1558,7 +1559,7 @@ class ConfigurableTask(Task):
**kwargs,
)
def process_results(self, doc, results):
def process_results(self, doc, results) -> dict:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
......@@ -1779,11 +1780,11 @@ class ConfigurableTask(Task):
def compute_sample_metrics(
self,
requests: list[Instance] = None,
filter_keys: list[str] = None,
indices: list[int] = None,
requests: Optional[list[Instance]] = None,
filter_keys: Optional[list[str]] = None,
indices: Optional[list[int]] = None,
rank: int = 1,
limit: int = None,
limit: Optional[int] = None,
world_size: int = 1,
log_samples: bool = False,
) -> tuple[
......@@ -1807,6 +1808,9 @@ class ConfigurableTask(Task):
else:
requests = requests if requests else self.instances
all_metrics = defaultdict(list)
samples = [] if log_samples else None
### Collect values of metrics on all datapoints ###
# Pre-process task.instances to group by doc_id
instances_by_doc_id = defaultdict(list)
......@@ -1815,8 +1819,6 @@ class ConfigurableTask(Task):
# Sort instances within each group
for instances in instances_by_doc_id.values():
instances.sort(key=lambda x: x.idx)
_all_metrics = defaultdict(list)
_samples = [] if log_samples else None
if filter_keys is None:
filter_keys = (
......@@ -1840,9 +1842,16 @@ class ConfigurableTask(Task):
requests = instances_by_doc_id[_doc_id_true]
if self.OUTPUT_TYPE != "generate_until":
# if one doc has multiple instances then calculate metric together
metrics = self.process_results(
doc, [req.filtered_resps[filter_key] for req in requests]
)
metrics = [
self.process_results(
doc,
list(
itertools.chain.from_iterable(
[req.filtered_resps[filter_key] for req in requests]
)
),
)
]
else:
metrics = [
self.process_results(doc, response)
......@@ -1857,20 +1866,21 @@ class ConfigurableTask(Task):
for k, v in metric.items():
_sample_metric[k].append(v)
if log_samples:
_samples.append(
samples.append(
create_sample_log(
doc=doc,
doc_id=_doc_id_true,
target=self.doc_to_target(doc),
requests=requests,
metric_names=metrics,
requests=tuple(requests),
metric_names=tuple(str(x) for x in metrics[0]),
filter_key=filter_key,
metrics=tuple(metrics),
)
)
for metric_name, _score in _sample_metric.items():
_all_metrics[(metric_name, filter_key)].append(_score)
self.metric_results = _all_metrics
return _all_metrics, _samples
all_metrics[(metric_name, filter_key)].append(_score)
self.metric_results = all_metrics
return all_metrics, samples
def compute_agg_metrics(
self,
......
......@@ -352,8 +352,6 @@ def simple_evaluate(
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
)
if verbosity is not None:
setup_logging(verbosity=verbosity)
if lm.rank == 0:
if isinstance(model, str):
......@@ -588,14 +586,13 @@ def evaluate(
### Collect values of metrics on all datapoints ###
# # unpack results and sort back in order and return control to Task
# TODO: make it possible to use a different metric per filter
_metrics, samples = task.compute_sample_metrics(
task_output.sample_metrics, samples = task.compute_sample_metrics(
indices=samples,
rank=RANK,
limit=limit,
world_size=WORLD_SIZE,
log_samples=log_samples,
)
task_output.sample_metrics = _metrics
if log_samples:
task_output.logged_samples = samples
......@@ -606,6 +603,7 @@ def evaluate(
if log_samples:
# for task_name, task_samples in list(samples.items()):
full_samples = [None] * WORLD_SIZE if RANK == 0 else None
eval_logger.info(task_output.logged_samples)
torch.distributed.gather_object(
obj=task_output.logged_samples,
object_gather_list=full_samples,
......@@ -620,6 +618,7 @@ def evaluate(
# then collect metrics across all ranks
for metrics in task_output.sample_metrics:
metric_list = [None] * WORLD_SIZE if RANK == 0 else None
eval_logger.info(task_output.sample_metrics[metrics])
torch.distributed.gather_object(
obj=task_output.sample_metrics[metrics],
object_gather_list=metric_list,
......
......@@ -10,7 +10,6 @@ class CustomFilter(Filter):
def __init__(self, **kwargs) -> None:
self.filter_fn = kwargs.pop("filter_fn")
super().__init__(**kwargs)
def apply(self, resps, docs):
......
......@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0,
fallback: str = "[invalid]",
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
......@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
......
from collections import Counter
from typing import Iterable, TypeVar
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
T = TypeVar("T")
# TODO: implement "arg_max" filter. either it should take in an arbitrary "scoring"/reward function
# that takes an input and returns a scalar and then should select the max reward,
# or should implement different filters for different ways of handling a reward model's inference.
......@@ -11,26 +13,20 @@ from lm_eval.api.registry import register_filter
@register_filter("take_first")
class TakeFirstFilter(Filter):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps, docs):
def apply(self, resps: Iterable[list[T]], docs: list[dict]) -> Iterable[list[T]]:
"""
Assuming each entry of `resps` is a list of model responses, we discard all but the first response.
"""
return map(lambda r: r, resps)
return map(lambda r: [r[0]], resps)
@register_filter("take_first_k")
class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(**kwargs)
def apply(self, resps, docs):
def apply(self, resps: Iterable[list[T]], docs: list[dict]) -> Iterable[list[T]]:
# need resp to be subscriptable to check below
resps = list(resps)
# check we have at least k responses per doc, else we can't take the first k
......@@ -42,12 +38,7 @@ class TakeKFilter(Filter):
@register_filter("majority_vote")
class MajorityVoteFilter(Filter):
def __init__(self) -> None:
"""
Can define custom behavior here, if an individual instantiation of a Filter class should have state.
"""
def apply(self, resps, docs):
def apply(self, resps: Iterable[list[T]], docs: list[dict]) -> Iterable[list[T]]:
"""
Each entry of `resps` is a list of model responses.
We select the response that occurs most frequently in each entry of `resps`.
......
......@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@register_filter("lowercase")
class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.lower() for resp in inst]
......@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@register_filter("uppercase")
class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.upper() for resp in inst]
......@@ -30,7 +24,7 @@ class UppercaseFilter(Filter):
@register_filter("map")
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
def __init__(self, mapping_dict: dict = None, default_value=None, **kwargs) -> None:
"""
Initializes the MapFilter with a given mapping dictionary and default value.
......@@ -43,6 +37,7 @@ class MapFilter(Filter):
Example:
mapper = MapFilter({'A': 1, 'B': 2}, default_value=0)
"""
super().__init__(**kwargs)
if mapping_dict is None:
mapping_dict = {}
assert isinstance(mapping_dict, dict), (
......@@ -60,9 +55,6 @@ class MapFilter(Filter):
@register_filter("format_span")
class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def format_ner_text(text):
label_dict = {
......
......@@ -562,9 +562,10 @@ def create_sample_log(
doc: dict,
doc_id: int,
target: Any,
requests: list[Instance],
metric_names: [dict],
requests: tuple[Instance],
metric_names: tuple[str, ...],
filter_key: str,
metrics: tuple[dict, ...],
) -> dict:
return {
"doc_id": doc_id,
......@@ -574,7 +575,8 @@ def create_sample_log(
"resps": [req.resps for req in requests],
"filtered_resps": [req.filtered_resps[filter_key] for req in requests],
"filter": filter_key,
"metrics": metric_names,
"metric_names": metric_names,
"metrics": metrics,
"doc_hash": hash_string(
json.dumps(
requests[0].doc,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment