Commit 04e74420 authored by Baber's avatar Baber
Browse files

cleanup

parent b0173d57
......@@ -176,14 +176,14 @@ class LM(abc.ABC):
return cls(**arg_dict, **additional_config)
@property
def rank(self):
def rank(self) -> int:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._rank
@property
def world_size(self):
def world_size(self) -> int:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
......@@ -233,7 +233,7 @@ class CacheHook:
class CachingLM:
def __init__(self, lm: LM, cache_db: str) -> None:
def __init__(self, lm: "LM", cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
......@@ -327,11 +327,11 @@ class TemplateLM(LM):
@property
@abc.abstractmethod
def eot_token_id(self):
def eot_token_id(self) -> int:
pass
@property
def prefix_token_id(self):
def prefix_token_id(self) -> int:
# it is used as prefix for loglikelihood
return self.eot_token_id
......@@ -351,6 +351,11 @@ class TemplateLM(LM):
def _encode_pair(
self, context: str, continuation: str
) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
import transformers
n_spaces = len(context) - len(context.rstrip())
......@@ -402,6 +407,7 @@ class TemplateLM(LM):
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
......@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
MODEL_REGISTRY = {}
DEFAULTS = {
"model": {"max_length": 2048},
"tasks": {"generate_until": {"max_length": 2048}},
}
def register_model(*names):
......@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> Optional[bool]:
def is_higher_better(metric_name: str) -> Optional[bool]:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
)
def register_filter(name):
def register_filter(name: str):
def decorate(cls):
if name in FILTER_REGISTRY:
eval_logger.info(
......
......@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from functools import cached_property
from inspect import getsource
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
......@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until",
]
if TYPE_CHECKING:
from lm_eval.api.filter import FilterEnsemble
eval_logger = logging.getLogger(__name__)
......@@ -81,7 +86,7 @@ class MetricConfig:
return is_higher_better(self.name)
return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any:
def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
......@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Optional[Callable] = None
metric_fn: Optional[str, Callable] = "pass@N"
kwargs: Optional[dict] = None
......@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[Union[str, list]] = None
filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list = None
_filter_list = None
_metric_list: list[MetricConfig] = None
_filter_list: list[FilterConfig] = None
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
......@@ -289,16 +294,13 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
if self.metric_list is not None:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
if self.metric_list and not all("metric" in cfg for cfg in self.metric_list):
raise ValueError("each entry in metric_list must include a 'metric' key")
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
......@@ -313,11 +315,8 @@ class TaskConfig(dict):
for metric_name in _metric_list
)
else:
# ---------- 2. How will the samples be evaluated ----------
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
......@@ -379,34 +378,30 @@ class TaskConfig(dict):
)
return metrics
def get_filters(self):
if self.filter_list is not None:
_filter_list = []
if isinstance(self.filter_list, dict):
for filter_config in self.filter_list:
_filter_list.append(
build_filter_ensemble(
filter_name=filter_config["name"],
components=[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
def get_filters(self) -> list["FilterEnsemble"]:
if not self.filter_list:
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
_filter_list = [build_filter_ensemble("none", [["take_first", None]])]
return [build_filter_ensemble("none", [["take_first", None]])]
else:
return _filter_list
def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item):
return getattr(self, item)
......@@ -415,31 +410,27 @@ class TaskConfig(dict):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif k == "metric_list":
for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def _maybe_serialize(val):
return (
self.serialize_function(val, keep_callable=keep_callable)
if callable(val)
else val
)
cfg = asdict(self)
return {
k: [{mk: _maybe_serialize(mv) for mk, mv in md.items()} for md in v]
if k == "metric_list"
else _maybe_serialize(v)
for k, v in cfg.items()
if v is not None
}
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
......@@ -627,7 +618,7 @@ class Task(abc.ABC):
return doc
@property
def instances(self) -> List[Instance]:
def instances(self) -> list[Instance]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
......@@ -639,27 +630,27 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
def doc_to_decontamination_query(self, doc: dict):
raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query."
)
@abc.abstractmethod
def doc_to_text(self, doc) -> str:
def doc_to_text(self, doc: dict) -> str:
pass
@abc.abstractmethod
def doc_to_target(self, doc) -> Union[str, int]:
def doc_to_target(self, doc: dict) -> Union[str, int]:
pass
# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
def doc_to_image(self, doc: dict):
raise NotImplementedError
def doc_to_audio(self, doc):
def doc_to_audio(self, doc: dict):
raise NotImplementedError
def doc_to_prefix(self, doc) -> str:
def doc_to_prefix(self, doc: dict) -> str:
return ""
def build_all_requests(
......@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs):
def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def process_results(self, doc, results):
def process_results(self, doc: dict, results: list):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
......@@ -1450,7 +1441,7 @@ class ConfigurableTask(Task):
"""
return doc
def doc_to_text(self, doc, doc_to_text=None):
def doc_to_text(self, doc: dict, doc_to_text: Optional[int, str, Callable] = None):
if self.prompt is not None:
doc_to_text = self.prompt
elif doc_to_text is not None:
......@@ -1486,7 +1477,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
elif doc_to_target is not None:
......@@ -1532,7 +1523,9 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict] = None
) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif doc_to_choice is not None:
......@@ -1558,7 +1551,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]:
def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1604,7 +1597,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc) -> Optional[str]:
def doc_to_prefix(self, doc: dict) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
......@@ -1713,7 +1706,7 @@ class ConfigurableTask(Task):
**kwargs,
)
def process_results(self, doc, results):
def process_results(self, doc: dict, results: list) -> dict:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
......
......@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination"
def __init__(self, path) -> None:
def __init__(self, path, **kwargs) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
super().__init__(**kwargs)
self._decontam_results = None
def apply(self, resps, docs) -> None:
......
......@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0,
fallback: str = "[invalid]",
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
......@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
......
......@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(**kwargs)
def apply(self, resps, docs):
......
......@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@register_filter("lowercase")
class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.lower() for resp in inst]
......@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@register_filter("uppercase")
class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.upper() for resp in inst]
......@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@register_filter("map")
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
super().__init__()
"""
Initializes the MapFilter with a given mapping dictionary and default value.
......@@ -60,9 +55,6 @@ class MapFilter(Filter):
@register_filter("format_span")
class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def format_ner_text(text):
label_dict = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment