Commit db5dff9c authored by Baber's avatar Baber
Browse files

type hints

parent 023bfe0d
...@@ -3,7 +3,7 @@ import hashlib ...@@ -3,7 +3,7 @@ import hashlib
import json import json
import logging import logging
import os import os
from typing import Dict, List, Optional, Tuple, Type, TypeVar, Union from typing import TYPE_CHECKING, Optional, Type, TypeVar, Union
import transformers import transformers
from sqlitedict import SqliteDict from sqlitedict import SqliteDict
...@@ -12,6 +12,10 @@ from tqdm import tqdm ...@@ -12,6 +12,10 @@ from tqdm import tqdm
from lm_eval import utils from lm_eval import utils
if TYPE_CHECKING:
from lm_eval.api.instance import Instance
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
T = TypeVar("T", bound="LM") T = TypeVar("T", bound="LM")
...@@ -30,7 +34,7 @@ class LM(abc.ABC): ...@@ -30,7 +34,7 @@ class LM(abc.ABC):
self.cache_hook = CacheHook(None) self.cache_hook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests) -> List[Tuple[float, bool]]: def loglikelihood(self, requests) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
Downstream tasks should attempt to use loglikelihood instead of other Downstream tasks should attempt to use loglikelihood instead of other
LM calls whenever possible. LM calls whenever possible.
...@@ -55,7 +59,7 @@ class LM(abc.ABC): ...@@ -55,7 +59,7 @@ class LM(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling(self, requests) -> List[float]: def loglikelihood_rolling(self, requests) -> list[float]:
"""Compute full log-likelihood of a string, with no truncation, for perplexity computation """Compute full log-likelihood of a string, with no truncation, for perplexity computation
- We will use the full max context length of the model. - We will use the full max context length of the model.
- For inputs that exceed the max context length, we divide the tokenized string into chunks of up to - For inputs that exceed the max context length, we divide the tokenized string into chunks of up to
...@@ -97,7 +101,7 @@ class LM(abc.ABC): ...@@ -97,7 +101,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests) -> List[str]: def generate_until(self, requests) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -114,7 +118,7 @@ class LM(abc.ABC): ...@@ -114,7 +118,7 @@ class LM(abc.ABC):
pass pass
def apply_chat_template( def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt=True self, chat_history: list[dict[str, str]], add_generation_prompt=True
) -> str: ) -> str:
""" """
Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM. Defines how to transform few-shot examples provided as chat history into a format that can be used as input to the LM.
...@@ -173,14 +177,14 @@ class LM(abc.ABC): ...@@ -173,14 +177,14 @@ class LM(abc.ABC):
return cls(**arg_dict, **additional_config) return cls(**arg_dict, **additional_config)
@property @property
def rank(self): def rank(self) -> int:
# used in the case of parallelism. Hardcoded to # used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return self._rank return self._rank
@property @property
def world_size(self): def world_size(self) -> int:
# used in the case of parallelism. Hardcoded to # used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
...@@ -230,7 +234,7 @@ class CacheHook: ...@@ -230,7 +234,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm, cache_db) -> None: def __init__(self, lm: "LM", cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
...@@ -253,7 +257,7 @@ class CachingLM: ...@@ -253,7 +257,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr return lm_attr
def fn(requests): def fn(requests: list[Instance]) -> list[Instance]:
res = [] res = []
remaining_reqs = [] remaining_reqs = []
warned = False warned = False
...@@ -322,28 +326,35 @@ class TemplateLM(LM): ...@@ -322,28 +326,35 @@ class TemplateLM(LM):
@property @property
@abc.abstractmethod @abc.abstractmethod
def eot_token_id(self): def eot_token_id(self) -> int:
pass pass
@property @property
def prefix_token_id(self): def prefix_token_id(self) -> int:
# it is used as prefix for loglikelihood # it is used as prefix for loglikelihood
return self.eot_token_id return self.eot_token_id
@abc.abstractmethod @abc.abstractmethod
def tok_encode(self, string: str, **kwargs) -> List[int]: def tok_encode(self, string: str, **kwargs) -> list[int]:
""" """
Tokenize a string using the model's tokenizer and return a list of token IDs. Tokenize a string using the model's tokenizer and return a list of token IDs.
""" """
pass pass
@abc.abstractmethod @abc.abstractmethod
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]: def _loglikelihood_tokens(
self, requests: list[Instance], **kwargs
) -> list[tuple[float, bool]]:
pass pass
def _encode_pair( def _encode_pair(
self, context: str, continuation: str self, context: str, continuation: str
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
if n_spaces > 0: if n_spaces > 0:
continuation = context[-n_spaces:] + continuation continuation = context[-n_spaces:] + continuation
...@@ -364,8 +375,8 @@ class TemplateLM(LM): ...@@ -364,8 +375,8 @@ class TemplateLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood( def loglikelihood(
self, requests, disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> List[Tuple[float, bool]]: ) -> list[tuple[float, bool]]:
new_reqs = [] new_reqs = []
for context, continuation in [req.args for req in requests]: for context, continuation in [req.args for req in requests]:
if context == "": if context == "":
...@@ -384,15 +395,16 @@ class TemplateLM(LM): ...@@ -384,15 +395,16 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def loglikelihood_rolling( def loglikelihood_rolling(
self, requests, disable_tqdm: bool = False self, requests, disable_tqdm: bool = False
) -> List[float]: ) -> list[float]:
pass pass
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests, disable_tqdm: bool = False) -> List[str]: def generate_until(self, requests, disable_tqdm: bool = False) -> list[str]:
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
""" """
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model. Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility. This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
...@@ -8,6 +8,10 @@ if TYPE_CHECKING: ...@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
DEFAULTS = {
"model": {"max_length": 2048},
"tasks": {"generate_until": {"max_length": 2048}},
}
def register_model(*names): def register_model(*names):
...@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl ...@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> Optional[bool]: def is_higher_better(metric_name: str) -> Optional[bool]:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
...@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]: ...@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
) )
def register_filter(name): def register_filter(name: str):
def decorate(cls): def decorate(cls):
if name in FILTER_REGISTRY: if name in FILTER_REGISTRY:
eval_logger.info( eval_logger.info(
......
...@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field ...@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from functools import cached_property from functools import cached_property
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Dict, Dict,
Iterable, Iterable,
...@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [ ...@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until", "generate_until",
] ]
if TYPE_CHECKING:
from lm_eval.api.filter import FilterEnsemble
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -81,7 +86,7 @@ class MetricConfig: ...@@ -81,7 +86,7 @@ class MetricConfig:
return is_higher_better(self.name) return is_higher_better(self.name)
return self.higher_is_better return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any: def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments.""" """Calculates the metric using the provided function and arguments."""
if self.fn is None: if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
...@@ -99,7 +104,7 @@ class RepeatConfig: ...@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat.""" """Encapsulates information about a single repeat."""
repeats: int = 1 repeats: int = 1
metric_fn: Optional[Callable] = None metric_fn: Optional[Callable] = "pass@N"
kwargs: Optional[dict] = None kwargs: Optional[dict] = None
...@@ -246,15 +251,15 @@ class TaskConfig(dict): ...@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None generation_kwargs: Optional[dict] = None
repeats: int = 1 repeats: int = 1
filter_list: Optional[Union[str, list]] = None filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None gen_prefix: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
_metric_list = None _metric_list: list[MetricConfig] = None
_filter_list = None _filter_list: list[FilterConfig] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
...@@ -289,16 +294,13 @@ class TaskConfig(dict): ...@@ -289,16 +294,13 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
) )
if self.metric_list is not None: if self.metric_list and not all("metric" in cfg for cfg in self.metric_list):
for metric_config in self.metric_list: raise ValueError("each entry in metric_list must include a 'metric' key")
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def get_metrics(self) -> list["MetricConfig"]: def get_metrics(self) -> list["MetricConfig"]:
metrics = [] metrics = []
if self.metric_list is None: if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info( eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}" f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
...@@ -313,11 +315,8 @@ class TaskConfig(dict): ...@@ -313,11 +315,8 @@ class TaskConfig(dict):
for metric_name in _metric_list for metric_name in _metric_list
) )
else: else:
# ---------- 2. How will the samples be evaluated ----------
for metric_config in self.metric_list: for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
_metric_fn_kwargs = { _metric_fn_kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -379,34 +378,30 @@ class TaskConfig(dict): ...@@ -379,34 +378,30 @@ class TaskConfig(dict):
) )
return metrics return metrics
def get_filters(self): def get_filters(self) -> list["FilterEnsemble"]:
if self.filter_list is not None: if not self.filter_list:
_filter_list = [] eval_logger.debug(
if isinstance(self.filter_list, dict): "No custom filters defined; falling back to 'take_first' for handling repeats."
for filter_config in self.filter_list:
_filter_list.append(
build_filter_ensemble(
filter_name=filter_config["name"],
components=[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
) )
return [build_filter_ensemble("none", [["take_first", None]])]
else: else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug( def _strip_fn(d: dict) -> dict:
"No custom filters defined. Using default 'take_first' filter for handling repeats." return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
) )
_filter_list = [build_filter_ensemble("none", [["take_first", None]])]
return _filter_list return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -415,31 +410,27 @@ class TaskConfig(dict): ...@@ -415,31 +410,27 @@ class TaskConfig(dict):
return setattr(self, item, value) return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format. """Return a printable dict with Nones stripped and callables serialised.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict :return: dict
A printable dictionary version of the TaskConfig object. A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
""" """
cfg_dict = asdict(self)
# remove values that are `None` def _maybe_serialize(val):
for k, v in list(cfg_dict.items()): return (
if v is None: self.serialize_function(val, keep_callable=keep_callable)
cfg_dict.pop(k) if callable(val)
elif k == "metric_list": else val
for metric_dict in v: )
for metric_key, metric_value in metric_dict.items():
if callable(metric_value): cfg = asdict(self)
metric_dict[metric_key] = self.serialize_function( return {
metric_value, keep_callable=keep_callable k: [{mk: _maybe_serialize(mv) for mk, mv in md.items()} for md in v]
) if k == "metric_list"
cfg_dict[k] = v else _maybe_serialize(v)
elif callable(v): for k, v in cfg.items()
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) if v is not None
return cfg_dict }
def serialize_function( def serialize_function(
self, value: Union[Callable, str], keep_callable=False self, value: Union[Callable, str], keep_callable=False
...@@ -627,7 +618,7 @@ class Task(abc.ABC): ...@@ -627,7 +618,7 @@ class Task(abc.ABC):
return doc return doc
@property @property
def instances(self) -> List[Instance]: def instances(self) -> list[Instance]:
"""After calling `task.build_all_requests()`, tasks """After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated. maintain a list of the dataset instances which will be evaluated.
""" """
...@@ -639,27 +630,27 @@ class Task(abc.ABC): ...@@ -639,27 +630,27 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc: dict):
raise NotImplementedError( raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query." "Override doc_to_decontamination_query with document specific decontamination query."
) )
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc) -> str: def doc_to_text(self, doc: dict) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def doc_to_target(self, doc) -> Union[str, int]: def doc_to_target(self, doc: dict) -> Union[str, int]:
pass pass
# not an abstractmethod because not every language-only task has to implement this # not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc): def doc_to_image(self, doc: dict):
raise NotImplementedError raise NotImplementedError
def doc_to_audio(self, doc): def doc_to_audio(self, doc: dict):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc) -> str: def doc_to_prefix(self, doc: dict) -> str:
return "" return ""
def build_all_requests( def build_all_requests(
...@@ -776,7 +767,7 @@ class Task(abc.ABC): ...@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances) save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod @abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -797,7 +788,7 @@ class Task(abc.ABC): ...@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc: dict, results: list):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
...@@ -1446,7 +1437,7 @@ class ConfigurableTask(Task): ...@@ -1446,7 +1437,7 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc, doc_to_text=None): def doc_to_text(self, doc: dict, doc_to_text: Optional[int, str, Callable] = None):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
elif doc_to_text is not None: elif doc_to_text is not None:
...@@ -1482,7 +1473,7 @@ class ConfigurableTask(Task): ...@@ -1482,7 +1473,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
elif doc_to_target is not None: elif doc_to_target is not None:
...@@ -1528,7 +1519,9 @@ class ConfigurableTask(Task): ...@@ -1528,7 +1519,9 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict] = None
) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif doc_to_choice is not None: elif doc_to_choice is not None:
...@@ -1554,7 +1547,7 @@ class ConfigurableTask(Task): ...@@ -1554,7 +1547,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]: def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
elif self.config.doc_to_image is not None: elif self.config.doc_to_image is not None:
...@@ -1600,7 +1593,7 @@ class ConfigurableTask(Task): ...@@ -1600,7 +1593,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc) -> Optional[str]: def doc_to_prefix(self, doc: dict) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
...@@ -1709,7 +1702,7 @@ class ConfigurableTask(Task): ...@@ -1709,7 +1702,7 @@ class ConfigurableTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc, results): def process_results(self, doc: dict, results: list) -> dict:
if callable(self.config.process_results): if callable(self.config.process_results):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment