Commit 04e74420 authored by Baber's avatar Baber
Browse files

cleanup

parent b0173d57
...@@ -176,14 +176,14 @@ class LM(abc.ABC): ...@@ -176,14 +176,14 @@ class LM(abc.ABC):
return cls(**arg_dict, **additional_config) return cls(**arg_dict, **additional_config)
@property @property
def rank(self): def rank(self) -> int:
# used in the case of parallelism. Hardcoded to # used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
return self._rank return self._rank
@property @property
def world_size(self): def world_size(self) -> int:
# used in the case of parallelism. Hardcoded to # used in the case of parallelism. Hardcoded to
# ensure no errors arise using API models which do # ensure no errors arise using API models which do
# not support multi-device parallelism nor expect it. # not support multi-device parallelism nor expect it.
...@@ -233,7 +233,7 @@ class CacheHook: ...@@ -233,7 +233,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm: LM, cache_db: str) -> None: def __init__(self, lm: "LM", cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
...@@ -327,11 +327,11 @@ class TemplateLM(LM): ...@@ -327,11 +327,11 @@ class TemplateLM(LM):
@property @property
@abc.abstractmethod @abc.abstractmethod
def eot_token_id(self): def eot_token_id(self) -> int:
pass pass
@property @property
def prefix_token_id(self): def prefix_token_id(self) -> int:
# it is used as prefix for loglikelihood # it is used as prefix for loglikelihood
return self.eot_token_id return self.eot_token_id
...@@ -351,6 +351,11 @@ class TemplateLM(LM): ...@@ -351,6 +351,11 @@ class TemplateLM(LM):
def _encode_pair( def _encode_pair(
self, context: str, continuation: str self, context: str, continuation: str
) -> tuple[list[int], list[int]]: ) -> tuple[list[int], list[int]]:
"""Encodes a pair of context and continuation strings into token IDs.
Ensures that encode(context + continuation) == encode(context) + encode(continuation)
"""
import transformers import transformers
n_spaces = len(context) - len(context.rstrip()) n_spaces = len(context) - len(context.rstrip())
...@@ -402,6 +407,7 @@ class TemplateLM(LM): ...@@ -402,6 +407,7 @@ class TemplateLM(LM):
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
""" """
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model. Set and get the appropriate chat template for the model.
This method sets the tokenizer's chat_template and returns the template string for reproducibility. This method sets the tokenizer's chat_template and returns the template string for reproducibility.
......
...@@ -8,6 +8,10 @@ if TYPE_CHECKING: ...@@ -8,6 +8,10 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
MODEL_REGISTRY = {} MODEL_REGISTRY = {}
DEFAULTS = {
"model": {"max_length": 2048},
"tasks": {"generate_until": {"max_length": 2048}},
}
def register_model(*names): def register_model(*names):
...@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl ...@@ -167,7 +171,7 @@ def get_metric_aggregation(name: str) -> Optional[Callable[[], Dict[str, Callabl
eval_logger.warning(f"{name} metric is not assigned a default aggregation!") eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
def is_higher_better(metric_name) -> Optional[bool]: def is_higher_better(metric_name: str) -> Optional[bool]:
try: try:
return HIGHER_IS_BETTER_REGISTRY[metric_name] return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError: except KeyError:
...@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]: ...@@ -176,7 +180,7 @@ def is_higher_better(metric_name) -> Optional[bool]:
) )
def register_filter(name): def register_filter(name: str):
def decorate(cls): def decorate(cls):
if name in FILTER_REGISTRY: if name in FILTER_REGISTRY:
eval_logger.info( eval_logger.info(
......
...@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field ...@@ -9,6 +9,7 @@ from dataclasses import asdict, dataclass, field
from functools import cached_property from functools import cached_property
from inspect import getsource from inspect import getsource
from typing import ( from typing import (
TYPE_CHECKING,
Any, Any,
Dict, Dict,
Iterable, Iterable,
...@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [ ...@@ -50,6 +51,10 @@ ALL_OUTPUT_TYPES = [
"generate_until", "generate_until",
] ]
if TYPE_CHECKING:
from lm_eval.api.filter import FilterEnsemble
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
...@@ -81,7 +86,7 @@ class MetricConfig: ...@@ -81,7 +86,7 @@ class MetricConfig:
return is_higher_better(self.name) return is_higher_better(self.name)
return self.higher_is_better return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any: def compute_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments.""" """Calculates the metric using the provided function and arguments."""
if self.fn is None: if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.") raise ValueError(f"Metric function for {self.name} is not defined.")
...@@ -99,7 +104,7 @@ class RepeatConfig: ...@@ -99,7 +104,7 @@ class RepeatConfig:
"""Encapsulates information about a single repeat.""" """Encapsulates information about a single repeat."""
repeats: int = 1 repeats: int = 1
metric_fn: Optional[Callable] = None metric_fn: Optional[str, Callable] = "pass@N"
kwargs: Optional[dict] = None kwargs: Optional[dict] = None
...@@ -246,15 +251,15 @@ class TaskConfig(dict): ...@@ -246,15 +251,15 @@ class TaskConfig(dict):
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
generation_kwargs: Optional[dict] = None generation_kwargs: Optional[dict] = None
repeats: int = 1 repeats: int = 1
filter_list: Optional[Union[str, list]] = None filter_list: Optional[list[dict]] = None
should_decontaminate: bool = False should_decontaminate: bool = False
doc_to_decontamination_query: Optional[str] = None doc_to_decontamination_query: Optional[str] = None
gen_prefix: Optional[str] = None gen_prefix: Optional[str] = None
metadata: Optional[dict] = ( metadata: Optional[dict] = (
None # by default, not used in the code. allows for users to pass arbitrary info to tasks None # by default, not used in the code. allows for users to pass arbitrary info to tasks
) )
_metric_list = None _metric_list: list[MetricConfig] = None
_filter_list = None _filter_list: list[FilterConfig] = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.generation_kwargs is not None: if self.generation_kwargs is not None:
...@@ -289,16 +294,13 @@ class TaskConfig(dict): ...@@ -289,16 +294,13 @@ class TaskConfig(dict):
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}" f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
) )
if self.metric_list is not None: if self.metric_list and not all("metric" in cfg for cfg in self.metric_list):
for metric_config in self.metric_list: raise ValueError("each entry in metric_list must include a 'metric' key")
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
def get_metrics(self) -> list["MetricConfig"]: def get_metrics(self) -> list["MetricConfig"]:
metrics = [] metrics = []
if self.metric_list is None: if self.metric_list is None:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type] _metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info( eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}" f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
...@@ -313,11 +315,8 @@ class TaskConfig(dict): ...@@ -313,11 +315,8 @@ class TaskConfig(dict):
for metric_name in _metric_list for metric_name in _metric_list
) )
else: else:
# ---------- 2. How will the samples be evaluated ----------
for metric_config in self.metric_list: for metric_config in self.metric_list:
if "metric" not in metric_config:
raise ValueError(
"'metric' key not provided for an entry in 'metric_list', must be specified!"
)
metric_name = metric_config["metric"] metric_name = metric_config["metric"]
_metric_fn_kwargs = { _metric_fn_kwargs = {
key: metric_config[key] key: metric_config[key]
...@@ -379,34 +378,30 @@ class TaskConfig(dict): ...@@ -379,34 +378,30 @@ class TaskConfig(dict):
) )
return metrics return metrics
def get_filters(self): def get_filters(self) -> list["FilterEnsemble"]:
if self.filter_list is not None: if not self.filter_list:
_filter_list = []
if isinstance(self.filter_list, dict):
for filter_config in self.filter_list:
_filter_list.append(
build_filter_ensemble(
filter_name=filter_config["name"],
components=[
[
{
key: function[key]
for key in function
if key != "function"
}
]
for function in filter_config["filter"]
],
)
)
else:
# TODO: handle repeats in a more general way rather than just discarding
eval_logger.debug( eval_logger.debug(
"No custom filters defined. Using default 'take_first' filter for handling repeats." "No custom filters defined; falling back to 'take_first' for handling repeats."
) )
_filter_list = [build_filter_ensemble("none", [["take_first", None]])] return [build_filter_ensemble("none", [["take_first", None]])]
else:
return _filter_list def _strip_fn(d: dict) -> dict:
return {k: v for k, v in d.items() if k != "function"}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
return [
build_filter_ensemble(
filter_name=cfg["name"],
components=[[_strip_fn(f) for f in cfg["filter"]]],
)
for cfg in configs
]
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
...@@ -415,31 +410,27 @@ class TaskConfig(dict): ...@@ -415,31 +410,27 @@ class TaskConfig(dict):
return setattr(self, item, value) return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
"""dumps the current config as a dictionary object, as a printable format. """Return a printable dict with Nones stripped and callables serialised.
null fields will not be printed.
Used for dumping results alongside full task configuration
:return: dict :return: dict
A printable dictionary version of the TaskConfig object. A printable dictionary version of the TaskConfig object.
# TODO: should any default value in the TaskConfig not be printed?
""" """
cfg_dict = asdict(self)
# remove values that are `None` def _maybe_serialize(val):
for k, v in list(cfg_dict.items()): return (
if v is None: self.serialize_function(val, keep_callable=keep_callable)
cfg_dict.pop(k) if callable(val)
elif k == "metric_list": else val
for metric_dict in v: )
for metric_key, metric_value in metric_dict.items():
if callable(metric_value): cfg = asdict(self)
metric_dict[metric_key] = self.serialize_function( return {
metric_value, keep_callable=keep_callable k: [{mk: _maybe_serialize(mv) for mk, mv in md.items()} for md in v]
) if k == "metric_list"
cfg_dict[k] = v else _maybe_serialize(v)
elif callable(v): for k, v in cfg.items()
cfg_dict[k] = self.serialize_function(v, keep_callable=keep_callable) if v is not None
return cfg_dict }
def serialize_function( def serialize_function(
self, value: Union[Callable, str], keep_callable=False self, value: Union[Callable, str], keep_callable=False
...@@ -627,7 +618,7 @@ class Task(abc.ABC): ...@@ -627,7 +618,7 @@ class Task(abc.ABC):
return doc return doc
@property @property
def instances(self) -> List[Instance]: def instances(self) -> list[Instance]:
"""After calling `task.build_all_requests()`, tasks """After calling `task.build_all_requests()`, tasks
maintain a list of the dataset instances which will be evaluated. maintain a list of the dataset instances which will be evaluated.
""" """
...@@ -639,27 +630,27 @@ class Task(abc.ABC): ...@@ -639,27 +630,27 @@ class Task(abc.ABC):
return rnd.sample(self._training_docs, k) return rnd.sample(self._training_docs, k)
def doc_to_decontamination_query(self, doc): def doc_to_decontamination_query(self, doc: dict):
raise NotImplementedError( raise NotImplementedError(
"Override doc_to_decontamination_query with document specific decontamination query." "Override doc_to_decontamination_query with document specific decontamination query."
) )
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc) -> str: def doc_to_text(self, doc: dict) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def doc_to_target(self, doc) -> Union[str, int]: def doc_to_target(self, doc: dict) -> Union[str, int]:
pass pass
# not an abstractmethod because not every language-only task has to implement this # not an abstractmethod because not every language-only task has to implement this
def doc_to_image(self, doc): def doc_to_image(self, doc: dict):
raise NotImplementedError raise NotImplementedError
def doc_to_audio(self, doc): def doc_to_audio(self, doc: dict):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc) -> str: def doc_to_prefix(self, doc: dict) -> str:
return "" return ""
def build_all_requests( def build_all_requests(
...@@ -776,7 +767,7 @@ class Task(abc.ABC): ...@@ -776,7 +767,7 @@ class Task(abc.ABC):
save_to_cache(file_name=cache_key, obj=instances) save_to_cache(file_name=cache_key, obj=instances)
@abc.abstractmethod @abc.abstractmethod
def construct_requests(self, doc, ctx, **kwargs): def construct_requests(self, doc: dict, ctx: Union[list[dict], str], **kwargs):
"""Uses RequestFactory to construct Requests and returns an iterable of """Uses RequestFactory to construct Requests and returns an iterable of
Requests which will be sent to the LM. Requests which will be sent to the LM.
...@@ -797,7 +788,7 @@ class Task(abc.ABC): ...@@ -797,7 +788,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc, results): def process_results(self, doc: dict, results: list):
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
...@@ -1450,7 +1441,7 @@ class ConfigurableTask(Task): ...@@ -1450,7 +1441,7 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc, doc_to_text=None): def doc_to_text(self, doc: dict, doc_to_text: Optional[int, str, Callable] = None):
if self.prompt is not None: if self.prompt is not None:
doc_to_text = self.prompt doc_to_text = self.prompt
elif doc_to_text is not None: elif doc_to_text is not None:
...@@ -1486,7 +1477,7 @@ class ConfigurableTask(Task): ...@@ -1486,7 +1477,7 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
def doc_to_target(self, doc: Mapping, doc_to_target=None) -> Union[int, str, list]: def doc_to_target(self, doc: dict, doc_to_target=None) -> Union[int, str, list]:
if self.prompt is not None: if self.prompt is not None:
doc_to_target = self.prompt doc_to_target = self.prompt
elif doc_to_target is not None: elif doc_to_target is not None:
...@@ -1532,7 +1523,9 @@ class ConfigurableTask(Task): ...@@ -1532,7 +1523,9 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_choice(self, doc: Any, doc_to_choice=None) -> List[str]: def doc_to_choice(
self, doc: dict, doc_to_choice: Union[str, list, dict] = None
) -> List[str]:
if self.prompt is not None: if self.prompt is not None:
doc_to_choice = self.prompt doc_to_choice = self.prompt
elif doc_to_choice is not None: elif doc_to_choice is not None:
...@@ -1558,7 +1551,7 @@ class ConfigurableTask(Task): ...@@ -1558,7 +1551,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]: def doc_to_image(self, doc: dict, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
elif self.config.doc_to_image is not None: elif self.config.doc_to_image is not None:
...@@ -1604,7 +1597,7 @@ class ConfigurableTask(Task): ...@@ -1604,7 +1597,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc) -> Optional[str]: def doc_to_prefix(self, doc: dict) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
...@@ -1713,7 +1706,7 @@ class ConfigurableTask(Task): ...@@ -1713,7 +1706,7 @@ class ConfigurableTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc, results): def process_results(self, doc: dict, results: list) -> dict:
if callable(self.config.process_results): if callable(self.config.process_results):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
......
...@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter): ...@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination" name = "track_decontamination"
def __init__(self, path) -> None: def __init__(self, path, **kwargs) -> None:
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id) should further cache result on a given (task_name, doc_id)
""" """
super().__init__(**kwargs)
self._decontam_results = None self._decontam_results = None
def apply(self, resps, docs) -> None: def apply(self, resps, docs) -> None:
......
...@@ -20,11 +20,13 @@ class RegexFilter(Filter): ...@@ -20,11 +20,13 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)", regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0, group_select: int = 0,
fallback: str = "[invalid]", fallback: str = "[invalid]",
**kwargs,
) -> None: ) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
""" """
super().__init__(**kwargs)
self.regex_pattern = regex_pattern self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern) self.regex = re.compile(regex_pattern)
self.group_select = group_select self.group_select = group_select
...@@ -66,11 +68,13 @@ class POSFilter(Filter): ...@@ -66,11 +68,13 @@ class POSFilter(Filter):
regex_pattern: str = r"\['(.*?)'\]", regex_pattern: str = r"\['(.*?)'\]",
group_select=0, group_select=0,
fallback=None, fallback=None,
**kwargs,
) -> None: ) -> None:
""" """
pass a string `regex` to run `re.compile(r"regex")` on. pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located. `fallback` defines the output returned if no matches for the regex are located.
""" """
super().__init__(**kwargs)
if fallback is None: if fallback is None:
fallback = ["invalid"] fallback = ["invalid"]
self.regex_pattern = regex_pattern self.regex_pattern = regex_pattern
......
...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter): ...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(**kwargs) super().__init__(**kwargs)
def apply(self, resps, docs): def apply(self, resps, docs):
......
...@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter ...@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@register_filter("lowercase") @register_filter("lowercase")
class LowercaseFilter(Filter): class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
return [resp.lower() for resp in inst] return [resp.lower() for resp in inst]
...@@ -18,9 +15,6 @@ class LowercaseFilter(Filter): ...@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@register_filter("uppercase") @register_filter("uppercase")
class UppercaseFilter(Filter): class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps, docs):
def filter_set(inst): def filter_set(inst):
return [resp.upper() for resp in inst] return [resp.upper() for resp in inst]
...@@ -31,6 +25,7 @@ class UppercaseFilter(Filter): ...@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@register_filter("map") @register_filter("map")
class MapFilter(Filter): class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None: def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
super().__init__()
""" """
Initializes the MapFilter with a given mapping dictionary and default value. Initializes the MapFilter with a given mapping dictionary and default value.
...@@ -60,9 +55,6 @@ class MapFilter(Filter): ...@@ -60,9 +55,6 @@ class MapFilter(Filter):
@register_filter("format_span") @register_filter("format_span")
class SPANFilter(Filter): class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs): def apply(self, resps, docs):
def format_ner_text(text): def format_ner_text(text):
label_dict = { label_dict = {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment