"docs/en/FAQ.md" did not exist on "85e363584fcf17b2415e81220a1ed56ea8559cb5"
Commit db5dff9c authored by Baber's avatar Baber
Browse files

type hints

parent 023bfe0d
......@@ -3,7 +3,7 @@ import hashlib
import json
import logging
import os
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union
from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union
import transformers
from sqlitedict import SqliteDict
......@@ -12,6 +12,10 @@ from tqdm import tqdm
from lm_eval import utils
if TYPE_CHECKING:
from lm_eval.api.instance import Instance
eval_logger = logging.getLogger(__name__)
T = TypeVar("T", bound="LM")
......@@ -30,7 +34,7 @@ class LM(abc.ABC):
self.cache_hook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests) -> List[Tuple[float, bool]]:
def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible.
......@@ -55,7 +59,7 @@ class LM(abc.ABC):
pass
@abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[float]:
def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
......@@ -97,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests) -> List[str]:
def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -114,7 +118,7 @@ class LM(abc.ABC):
pass
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True
self, chat_history: list[dict[str, str]], add_generation_prompt=True
) -> str:
"""
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
......@@ -173,14 +177,14 @@ class LM(abc.ABC):
return cls(**arg_dict, **additional_config)
@property
def rank(self):
def rank(self) -> int:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
return self._rank
@property
def world_size(self):
def world_size(self) -> int:
# used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it.
......@@ -230,7 +234,7 @@ class CacheHook:
class CachingLM:
def __init__(self, lm, cache_db) -> None:
def __init__(self, lm: "LM", cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
......@@ -253,7 +257,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def fn(requests):
def fn(requests: list[Instance]) -> list[Instance]:
res = []
remaining_reqs = []
warned = False
......@@ -322,28 +326,35 @@ class TemplateLM(LM):
@property
@abc.abstractmethod
def eot_token_id(self):
def eot_token_id(self) -> int:
pass
@property
def prefix_token_id(self):
def prefix_token_id(self) -> int:
# it is used as prefix for loglikelihood
return self.eot_token_id
@abc.abstractmethod
def tok_encode(self, string: str, **kwargs) -> List[int]:
def tok_encode(self, string: str, **kwargs) -> list[int]:
"""
Tokenize a string using the model's tokenizer and return a list of token IDs.
"""
pass
@abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
def _loglikelihood_tokens(
self, requests: list[Instance], **kwargs
) -> list[tuple[float, bool]]:
pass
def _encode_pair(
self, context: str, continuation: str
) -> Tuple[List[int], List[int]]:
) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0:
continuation = context[-n_spaces:] + continuation
......@@ -364,8 +375,8 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(
self, requests, disable_tqdm: bool = False
) -> List[Tuple[float, bool]]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
new_reqs = []
for context, continuation in [req.args for req in requests]:
if context == "":
......@@ -384,15 +395,16 @@ class TemplateLM(LM):
@abc.abstractmethod
def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False
) -> List[float]:
) -> list[float]:
pass
@abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]:
def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
......@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
MODEL_REGISTRY = {}
DEFAULTS = {
"model": {"max_length": 2048},
"tasks": {"generate_until": {"max_length": 2048}},
}
def register_model(*names):
......@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> Optional[bool]:
def is_higher_better(metric_name: str) -> Optional[bool]:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
......@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
)
def register_filter(name):
def register_filter(name: str):
def decorate(cls):
if name in FILTER_REGISTRY:
eval_logger.info(
......
......@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from functools import cached_property
from inspect import getsource
from typing import (
TYPE_CHECKING,
Any,
Dict,
Iterable,
......@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until",
]
if TYPE_CHECKING:
from lm_eval.api.filter import FilterEnsemble
eval_logger = logging.getLogger(__name__)
......@@ -81,7 +86,7 @@ class MetricConfig:
return is_higher_better(self.name)
return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any:
def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
......@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: Optional[Callable] = None
metric_fn: Optional[Callable] = "pass@N"
kwargs: Optional[dict] = None
......@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None
repeats: int = 1
filter_list: Optional[Union[str, list]] = None
filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None
metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks
)
_metric_list = None
_filter_list = None
_metric_list: list[MetricConfig] = None
_filter_list: list[FilterConfig] = None
def __post_init__(self) -> None:
if self.generation_kwargs is not None:
......@@ -289,16 +294,13 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
if self.metric_list is not None:
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
if self.metric_list and not all("metric" in cfg for cfg in self.metric_list):
raise ValueError("each entry in metric_list must include a 'metric' key")
def get_metrics(self) -> list["MetricConfig"]:
metrics = []
if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
......@@ -313,11 +315,8 @@ class TaskConfig(dict):
for metric_name in _metric_list
)
else:
# ---------- 2. How will the samples be evaluated ----------
for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
......@@ -379,34 +378,30 @@ class TaskConfig(dict):
)
return metrics
def get_filters(self):
if self.filter_list is not None:
_filter_list = []
if isinstance(self.filter_list, dict):
for filter_config in self.filter_list:
_filter_list.append(
build_filter_ensemble(
filter_name=filter_config["name"],
components=[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
def get_filters(self) -> list["FilterEnsemble"]:
if not self.filter_list:
eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats."
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
_filter_list = [build_filter_ensemble("none", [["take_first", None]])]
return [build_filter_ensemble("none", [["take_first", None]])]
else:
return _filter_list
def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item):
return getattr(self, item)
......@@ -415,31 +410,27 @@ class TaskConfig(dict):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format.
null fields will not be printed.
Used for dumping results alongside full task configuration
"""Return a printable dict with Nones stripped and callables serialised.
:return: dict
A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
"""
cfg_dict = asdict(self)
# remove values that are `None`
for k, v in list(cfg_dict.items()):
if v is None:
cfg_dict.pop(k)
elif k == "metric_list":
for metric_dict in v:
for metric_key, metric_value in metric_dict.items():
if callable(metric_value):
metric_dict[metric_key] = self.serialize_function(
metric_value, keep_callable=keep_callable
)
cfg_dict[k] = v
elif callable(v):
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable)
return cfg_dict
def _maybe_serialize(val):
return (
self.serialize_function(val, keep_callable=keep_callable)
if callable(val)
else val
)
cfg = asdict(self)
return {
k: [{mk: _maybe_serialize(mv) for mk, mv in md.items()} for md in v]
if k == "metric_list"
else _maybe_serialize(v)
for k, v in cfg.items()
if v is not None
}
def serialize_function(
self, value: Union[Callable, str], keep_callable=False
......@@ -627,7 +618,7 @@ class Task(abc.ABC):
return doc
@property
def instances(self) -> List[Instance]:
def instances(self) -> list[Instance]:
"""After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated.
"""
......@@ -639,27 +630,27 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc):
def doc_to_decontamination_query(self, doc: dict):
raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query."
)
@abc.abstractmethod
def doc_to_text(self, doc) -> str:
def doc_to_text(self, doc: dict) -> str:
pass
@abc.abstractmethod
def doc_to_target(self, doc) -> Union[str, int]:
def doc_to_target(self, doc: dict) -> Union[str, int]:
pass
# not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc):
def doc_to_image(self, doc: dict):
raise NotImplementedError
def doc_to_audio(self, doc):
def doc_to_audio(self, doc: dict):
raise NotImplementedError
def doc_to_prefix(self, doc) -> str:
def doc_to_prefix(self, doc: dict) -> str:
return ""
def build_all_requests(
......@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs):
def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM.
......@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def process_results(self, doc, results):
def process_results(self, doc: dict, results: list):
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
......@@ -1446,7 +1437,7 @@ class ConfigurableTask(Task):
"""
return doc
def doc_to_text(self, doc, doc_to_text=None):
def doc_to_text(self, doc: dict, doc_to_text: Optional[int, str, Callable] = None):
if self.prompt is not None:
doc_to_text = self.prompt
elif doc_to_text is not None:
......@@ -1482,7 +1473,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]:
def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None:
doc_to_target = self.prompt
elif doc_to_target is not None:
......@@ -1528,7 +1519,9 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]:
def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict] = None
) -> List[str]:
if self.prompt is not None:
doc_to_choice = self.prompt
elif doc_to_choice is not None:
......@@ -1554,7 +1547,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]:
def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1600,7 +1593,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc) -> Optional[str]:
def doc_to_prefix(self, doc: dict) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
......@@ -1709,7 +1702,7 @@ class ConfigurableTask(Task):
**kwargs,
)
def process_results(self, doc, results):
def process_results(self, doc: dict, results: list) -> dict:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment