Commit 4ad6cd9f authored by Baber's avatar Baber
Browse files

remove deps; types

parent 689e0c91
...@@ -33,7 +33,7 @@ repos: ...@@ -33,7 +33,7 @@ repos:
hooks: hooks:
# Run the linter. # Run the linter.
- id: ruff-check - id: ruff-check
args: [ --fix ] args: [ --fix]
# Run the formatter. # Run the formatter.
- id: ruff-format - id: ruff-format
- repo: https://github.com/codespell-project/codespell - repo: https://github.com/codespell-project/codespell
......
from __future__ import annotations
import abc import abc
import hashlib import hashlib
import json import json
import logging import logging
import os import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm from tqdm import tqdm
...@@ -31,7 +34,7 @@ class LM(abc.ABC): ...@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default. # set rank and world size to a single process, by default.
self._rank = 0 self._rank = 0
self._world_size = 1 self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None) self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod @abc.abstractmethod
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]: def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
...@@ -101,7 +104,7 @@ class LM(abc.ABC): ...@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length # TODO: Add an optional max length
@abc.abstractmethod @abc.abstractmethod
def generate_until(self, requests: list["Instance"]) -> list[str]: def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence """Generate greedily until a stopping sequence
:param requests: list[Instance] :param requests: list[Instance]
...@@ -137,7 +140,7 @@ class LM(abc.ABC): ...@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod @classmethod
def create_from_arg_string( def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T: ) -> T:
""" """
Creates an instance of the LM class using the given argument string and additional config. Creates an instance of the LM class using the given argument string and additional config.
...@@ -156,7 +159,7 @@ class LM(abc.ABC): ...@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod @classmethod
def create_from_arg_obj( def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T: ) -> T:
""" """
Creates an instance of the LM class using the given arg_obj Creates an instance of the LM class using the given arg_obj
...@@ -201,7 +204,7 @@ class LM(abc.ABC): ...@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property." "To use this model with chat templates, please implement the 'tokenizer_name' property."
) )
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided. """Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format. This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default. For models that do not support chat templates, this method returns None by default.
...@@ -209,7 +212,7 @@ class LM(abc.ABC): ...@@ -209,7 +212,7 @@ class LM(abc.ABC):
return "" return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None: def set_cache_hook(self, cache_hook: CacheHook) -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM.""" """Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook self.cache_hook = cache_hook
...@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str: ...@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook: class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None: def __init__(self, cachinglm: CachingLM | None) -> None:
"""CacheHook is used to cache responses from the LM.""" """CacheHook is used to cache responses from the LM."""
if cachinglm is None: if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None self.dbdict: SqliteDict | None = None
return return
self.dbdict = cachinglm.dbdict self.dbdict = cachinglm.dbdict
...@@ -238,7 +241,7 @@ class CacheHook: ...@@ -238,7 +241,7 @@ class CacheHook:
class CachingLM: class CachingLM:
def __init__(self, lm: "LM", cache_db: str) -> None: def __init__(self, lm: LM, cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not. """LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM :param lm: LM
...@@ -263,7 +266,7 @@ class CachingLM: ...@@ -263,7 +266,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM") eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]: def _fn(requests: list[Instance]) -> list[Instance]:
res = [] res = []
remaining_reqs = [] remaining_reqs = []
warned = False warned = False
...@@ -295,11 +298,8 @@ class CachingLM: ...@@ -295,11 +298,8 @@ class CachingLM:
eval_logger.info( eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}" f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
) )
if remaining_reqs:
# actually run the LM on the requests that do not have cached results rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else []
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
# stick the new ones back into the list and also cache any of the new ones # stick the new ones back into the list and also cache any of the new ones
resptr = 0 resptr = 0
...@@ -318,7 +318,7 @@ class CachingLM: ...@@ -318,7 +318,7 @@ class CachingLM:
return _fn return _fn
def get_cache_hook(self) -> "CacheHook": def get_cache_hook(self) -> CacheHook:
return CacheHook(self) return CacheHook(self)
...@@ -399,7 +399,7 @@ class TemplateLM(LM): ...@@ -399,7 +399,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc return context_enc, continuation_enc
def loglikelihood( def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]: ) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context. """Compute log-likelihood of generating a continuation from a context.
...@@ -432,7 +432,7 @@ class TemplateLM(LM): ...@@ -432,7 +432,7 @@ class TemplateLM(LM):
@abc.abstractmethod @abc.abstractmethod
def generate_until( def generate_until(
self, requests: list["Instance"], disable_tqdm: bool = False self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]: ) -> list[str]:
"""Generate until a stopping sequence. """Generate until a stopping sequence.
...@@ -453,7 +453,7 @@ class TemplateLM(LM): ...@@ -453,7 +453,7 @@ class TemplateLM(LM):
""" """
pass pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]: def chat_template(self, chat_template: bool | str = False) -> str | None:
""" """
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str) Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model. Set and get the appropriate chat template for the model.
......
...@@ -7,11 +7,7 @@ import random ...@@ -7,11 +7,7 @@ import random
import re import re
from collections.abc import Callable from collections.abc import Callable
from copy import deepcopy from copy import deepcopy
from typing import ( from typing import TYPE_CHECKING, Any, Literal, overload
TYPE_CHECKING,
Any,
Literal,
)
import datasets import datasets
import numpy as np import numpy as np
...@@ -24,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity ...@@ -24,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
...@@ -133,6 +129,7 @@ class Task(abc.ABC): ...@@ -133,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD` - `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
assert self.DATASET_PATH is not None, "DATASET_PATH must be set in Task class"
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -146,43 +143,40 @@ class Task(abc.ABC): ...@@ -146,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class.""" """Returns the TaskConfig associated with this class."""
return self._config return self._config
@abc.abstractmethod
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
"""Whether the task has a training set""" """Whether the task has a training set"""
pass raise NotImplementedError
@abc.abstractmethod
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
"""Whether the task has a validation set""" """Whether the task has a validation set"""
pass raise NotImplementedError
@abc.abstractmethod
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
"""Whether the task has a test set""" """Whether the task has a test set"""
pass raise NotImplementedError
def training_docs(self) -> Iterable: def training_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def validation_docs(self) -> Iterable: def validation_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def test_docs(self) -> Iterable: def test_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def fewshot_docs(self) -> Iterable: def fewshot_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
...@@ -192,7 +186,7 @@ class Task(abc.ABC): ...@@ -192,7 +186,7 @@ class Task(abc.ABC):
elif self.has_validation_docs(): elif self.has_validation_docs():
return self.validation_docs() return self.validation_docs()
else: else:
if self.config.get("num_fewshot", 0) > 0: if self.config.num_fewshot and self.config.num_fewshot > 0:
eval_logger.warning( eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False" f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended." ", using test_docs as fewshot_docs but this is not recommended."
...@@ -331,7 +325,7 @@ class Task(abc.ABC): ...@@ -331,7 +325,7 @@ class Task(abc.ABC):
inst = self.construct_requests( inst = self.construct_requests(
doc=doc, doc=doc,
ctx=fewshot_ctx, ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats), metadata=(self.config.task, doc_id, self.config.repeats),
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
chat_template=chat_template, chat_template=chat_template,
) )
...@@ -586,7 +580,7 @@ class ConfigurableTask(Task): ...@@ -586,7 +580,7 @@ class ConfigurableTask(Task):
data_dir=None, data_dir=None,
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config: dict | None = None, config: Mapping[str, Any] | None = None,
) -> None: ) -> None:
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -727,6 +721,9 @@ class ConfigurableTask(Task): ...@@ -727,6 +721,9 @@ class ConfigurableTask(Task):
) )
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata)) self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else: else:
assert self.config.dataset_path is not None, (
"dataset_path must be set in TaskConfig"
)
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.config.dataset_path, path=self.config.dataset_path,
name=self.config.dataset_name, name=self.config.dataset_name,
...@@ -742,7 +739,7 @@ class ConfigurableTask(Task): ...@@ -742,7 +739,7 @@ class ConfigurableTask(Task):
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
return self.config.test_split is not None return self.config.test_split is not None
def training_docs(self) -> datasets.Dataset | None: def training_docs(self) -> DataSet | None:
if self.has_training_docs(): if self.has_training_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -750,7 +747,7 @@ class ConfigurableTask(Task): ...@@ -750,7 +747,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset | None: def validation_docs(self) -> DataSet | None:
if self.has_validation_docs(): if self.has_validation_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -758,7 +755,7 @@ class ConfigurableTask(Task): ...@@ -758,7 +755,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset | None: def test_docs(self) -> DataSet | None:
if self.has_test_docs(): if self.has_test_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
...@@ -996,9 +993,21 @@ class ConfigurableTask(Task): ...@@ -996,9 +993,21 @@ class ConfigurableTask(Task):
""" """
return doc return doc
@overload
def doc_to_text(self, doc: dict, doc_to_text: None = None) -> str | int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: int) -> int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: str) -> str: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: Callable[..., str]) -> str: ...
def doc_to_text( def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str: ) -> str | int:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text doc_to_text = doc_to_text or self.config.doc_to_text
...@@ -1031,6 +1040,25 @@ class ConfigurableTask(Task): ...@@ -1031,6 +1040,25 @@ class ConfigurableTask(Task):
print(type(doc_to_text)) print(type(doc_to_text))
raise TypeError raise TypeError
@overload
def doc_to_target(
self, doc: dict, doc_to_target: None = None
) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: int) -> int: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: str) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: list) -> list[int]: ...
@overload
def doc_to_target(
self, doc: dict, doc_to_target: Callable[..., int | str | list[int]]
) -> int | str | list[int]: ...
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]: def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_target = self.prompt # doc_to_target = self.prompt
...@@ -1077,6 +1105,23 @@ class ConfigurableTask(Task): ...@@ -1077,6 +1105,23 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: None = None) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: str) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: list) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: dict) -> list[str]: ...
@overload
def doc_to_choice(
self, doc: dict, doc_to_choice: Callable[..., list[str]]
) -> list[str]: ...
def doc_to_choice( def doc_to_choice(
self, self,
doc: dict, doc: dict,
...@@ -1108,6 +1153,18 @@ class ConfigurableTask(Task): ...@@ -1108,6 +1153,18 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
@overload
def doc_to_image(self, doc: dict, doc_to_image: None = None) -> None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: list) -> list: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: str) -> int | str | None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: Callable[..., Any]) -> Any: ...
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None: def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
...@@ -1131,6 +1188,18 @@ class ConfigurableTask(Task): ...@@ -1131,6 +1188,18 @@ class ConfigurableTask(Task):
else: else:
return None return None
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: None = None) -> None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: list) -> list: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: str) -> int | str | None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: Callable[..., Any]) -> Any: ...
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None: def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None: if doc_to_audio is not None:
doc_to_audio = doc_to_audio doc_to_audio = doc_to_audio
...@@ -1375,15 +1444,15 @@ class ConfigurableTask(Task): ...@@ -1375,15 +1444,15 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
for metric in self._metric_fn_list: for metric in self.config._metric_list:
try: try:
result_score = self._metric_fn_list[metric]( result_score = metric.fn(
references=[gold] if not isinstance(gold, list) else gold, references=[gold] if not isinstance(gold, list) else gold,
predictions=[result], predictions=[result],
**self._metric_fn_kwargs[metric], **metric.kwargs,
) )
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result]) result_score = metric.fn([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict. # TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function # This allows for multiple metrics to be returned from the same function
......
from __future__ import annotations from __future__ import annotations
from collections.abc import Callable, Mapping from collections.abc import Callable, Mapping
from dataclasses import dataclass from dataclasses import dataclass, field
from functools import cached_property from functools import cached_property
from typing import Any from typing import Any
...@@ -11,8 +11,8 @@ class MetricConfig: ...@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric.""" """Encapsulates information about a single metric."""
name: str name: str
fn: Callable | None = None fn: Callable
kwargs: Mapping[str, Any] | None = None kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None aggregation_fn: Callable | None = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
......
...@@ -3,7 +3,9 @@ from __future__ import annotations ...@@ -3,7 +3,9 @@ from __future__ import annotations
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -18,6 +20,9 @@ if TYPE_CHECKING: ...@@ -18,6 +20,9 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass @dataclass
class RepeatConfig: class RepeatConfig:
...@@ -30,7 +35,7 @@ class RepeatConfig: ...@@ -30,7 +35,7 @@ class RepeatConfig:
@dataclass @dataclass
class FilterConfig: class FilterConfig:
"""Encapsulates information about a single filter.""" """Encapsulates information about a single filter pipeline."""
name: str name: str
ensemble: FilterEnsemble ensemble: FilterEnsemble
...@@ -44,10 +49,8 @@ class FewshotConfig: ...@@ -44,10 +49,8 @@ class FewshotConfig:
num_fewshot: Callable[[], int] num_fewshot: Callable[[], int]
split: str | None = None split: str | None = None
sampler: str | Callable = "default" sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = ( process_docs: Callable[[DataSet], DataSet] | None = None
None
)
fewshot_indices: list[int] | None = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
...@@ -69,22 +72,23 @@ class FewshotConfig: ...@@ -69,22 +72,23 @@ class FewshotConfig:
"""Check if any fewshot source is configured.""" """Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None return self.split is not None or self.samples is not None
def _get_raw_docs( def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
self, dataset
) -> list[dict] | Callable[[], Iterable[dict]] | None:
"""Get raw documents from configured source.""" """Get raw documents from configured source."""
if self.split is not None: if self.split is not None:
return dataset[self.split] return dataset[self.split]
if self.samples is not None: if self.samples is not None:
if isinstance(self.samples, list) or callable(self.samples): if isinstance(self.samples, list):
return self.samples return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else: else:
raise TypeError( raise TypeError(
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None: def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -130,34 +134,34 @@ class TaskConfig: ...@@ -130,34 +134,34 @@ class TaskConfig:
# HF dataset options. # HF dataset options.
# which dataset to use, # which dataset to use,
# and what splits for what purpose # and what splits for what purpose
custom_dataset: Callable | None = None custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None dataset_path: str | None = None
dataset_name: str | None = None dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict) dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None training_split: str | None = None
validation_split: str | None = None validation_split: str | None = None
test_split: str | None = None test_split: str | None = None
fewshot_split: str | None = ( fewshot_split: str | None = None
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
# formatting / prompting options. # formatting / prompting options.
# see docs/advanced_task_guide.md for more info # see docs/advanced_task_guide.md for more info
process_docs: Callable | None = None process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable | str | None = None doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable | str | None = None doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable | str | None = None doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable | str | None = None doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False unsafe_code: bool = False
doc_to_choice: Callable | str | dict | list | None = None doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: Callable | str | None = None process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None use_prompt: str | None = None
description: str = "" description: str = ""
target_delimiter: str = " " target_delimiter: str = " "
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict | None = None fewshot_config: dict[str, Any] | None = None
# runtime configuration options # runtime configuration options
num_fewshot: int | None = 0 num_fewshot: int | None = None
generation_kwargs: dict | None = None generation_kwargs: dict[str, Any] | None = None
# scoring options # scoring options
metric_list: list | None = None metric_list: list | None = None
output_type: OutputType = "generate_until" output_type: OutputType = "generate_until"
...@@ -357,7 +361,7 @@ class TaskConfig: ...@@ -357,7 +361,7 @@ class TaskConfig:
return x return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> TaskConfig: def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary.""" """Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data) return cls(**data)
...@@ -425,12 +429,6 @@ class TaskConfig: ...@@ -425,12 +429,6 @@ class TaskConfig:
# Create and return TaskConfig instance # Create and return TaskConfig instance
return cls(**config_dict) return cls(**config_dict)
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict: def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x): def _ser(x):
if isinstance(x, dict): if isinstance(x, dict):
......
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import datetime import datetime
import io import io
import json import json
...@@ -111,7 +122,7 @@ class TextReader: ...@@ -111,7 +122,7 @@ class TextReader:
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with ( with (
open(self.file_path, "r", encoding="utf-8") as fh, open(self.file_path, encoding="utf-8") as fh,
tqdm.tqdm( tqdm.tqdm(
total=os.path.getsize(self.file_path), total=os.path.getsize(self.file_path),
dynamic_ncols=True, dynamic_ncols=True,
...@@ -133,7 +144,7 @@ class TextReader: ...@@ -133,7 +144,7 @@ class TextReader:
def read_and_tell(self): def read_and_tell(self):
current_file_position = 0 current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh: with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
...@@ -143,14 +154,14 @@ class TextReader: ...@@ -143,14 +154,14 @@ class TextReader:
yield line[:-1], raw_bytes_read yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, "r", encoding="utf8") as fh: with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
yield line[:-1] yield line[:-1]
def read_slow(self): def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh: with open(self.file_path, encoding="utf8") as fh:
while True: while True:
line = fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
......
import collections import collections
import fnmatch import fnmatch
import functools
import hashlib import hashlib
import importlib.util import importlib.util
import inspect import inspect
...@@ -8,10 +7,12 @@ import json ...@@ -8,10 +7,12 @@ import json
import logging import logging
import os import os
import re import re
from collections.abc import Generator
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice from itertools import islice
from pathlib import Path from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple from typing import Any, Callable, Optional
import numpy as np import numpy as np
import yaml import yaml
...@@ -108,7 +109,7 @@ def escaped_split(text, sep_char, maxsplit=-1): ...@@ -108,7 +109,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return text return text
maxsplit = max(0, maxsplit) maxsplit = max(0, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit) return re.split(r"(?<!\\)" + sep_char, text, maxsplit=maxsplit)
def handle_arg_string(arg): def handle_arg_string(arg):
...@@ -125,7 +126,7 @@ def handle_arg_string(arg): ...@@ -125,7 +126,7 @@ def handle_arg_string(arg):
def handle_non_serializable(o): def handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32): if isinstance(o, np.integer):
return int(o) return int(o)
elif isinstance(o, set): elif isinstance(o, set):
return list(o) return list(o)
...@@ -235,21 +236,21 @@ def sanitize_task_name(task_name: str) -> str: ...@@ -235,21 +236,21 @@ def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name) return re.sub(r"\W", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str: def get_latest_filename(filenames: list[str]) -> str:
""" """
Given a list of filenames, returns the filename with the latest datetime. Given a list of filenames, returns the filename with the latest datetime.
""" """
return max(filenames, key=lambda f: get_file_datetime(f)) return max(filenames, key=lambda f: get_file_datetime(f))
def get_results_filenames(filenames: List[str]) -> List[str]: def get_results_filenames(filenames: list[str]) -> list[str]:
""" """
Extracts filenames that correspond to aggregated results. Extracts filenames that correspond to aggregated results.
""" """
return [f for f in filenames if "/results_" in f and ".json" in f] return [f for f in filenames if "/results_" in f and ".json" in f]
def get_sample_results_filenames(filenames: List[str]) -> List[str]: def get_sample_results_filenames(filenames: list[str]) -> list[str]:
""" """
Extracts filenames that correspond to sample results. Extracts filenames that correspond to sample results.
""" """
...@@ -257,8 +258,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]: ...@@ -257,8 +258,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_rolling_token_windows( def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int token_list: list[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]: ) -> Generator[tuple[list[int], list[int]], None, None]:
""" """
- context_len allows for a rolling window context, allowing each prediction window to potentially - context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context condition on some context
...@@ -300,8 +301,8 @@ def get_rolling_token_windows( ...@@ -300,8 +301,8 @@ def get_rolling_token_windows(
def make_disjoint_window( def make_disjoint_window(
pair: Tuple[List[int], List[int]], pair: tuple[list[int], list[int]],
) -> Tuple[List[int], List[int]]: ) -> tuple[list[int], list[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation""" """Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair a, b = pair
return a[: len(a) - (len(b) - 1)], b return a[: len(a) - (len(b) - 1)], b
...@@ -320,7 +321,7 @@ class EnhancedJSONEncoder(json.JSONEncoder): ...@@ -320,7 +321,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class Reorderer: class Reorderer:
def __init__(self, arr: List[Any], fn: Callable) -> None: def __init__(self, arr: list[Any], fn: Callable) -> None:
"""Reorder an array according to some function """Reorder an array according to some function
Args: Args:
...@@ -423,11 +424,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False) ...@@ -423,11 +424,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix # TODO: fix
hib = "↑" hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v v = f"{v:.4f}" if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic: if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f] se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se se = " N/A" if se == "N/A" else f"{se:.4f}"
values.append([k, version, f, n, m, hib, v, "±", se]) values.append([k, version, f, n, m, hib, v, "±", se])
else: else:
values.append([k, version, f, n, m, hib, v, "", ""]) values.append([k, version, f, n, m, hib, v, "", ""])
...@@ -448,7 +449,8 @@ def positional_deprecated(fn): ...@@ -448,7 +449,8 @@ def positional_deprecated(fn):
wrapped function, `fn`. wrapped function, `fn`.
""" """
@functools.wraps(fn) wraps(fn)
def _wrapper(*args, **kwargs): def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0: if len(args) != 1 if inspect.ismethod(fn) else 0:
print( print(
...@@ -494,7 +496,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full ...@@ -494,7 +496,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if yaml_path is None: if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.") raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later # Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path)) constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader # Add the import_function constructor to the YAML loader
...@@ -543,13 +545,18 @@ def regex_replace(string, pattern, repl, count: int = 0): ...@@ -543,13 +545,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env = Environment( env = Environment(
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
) )
env.filters["regex_replace"] = regex_replace env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str):
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str: def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template) rtemplate = _compile(template)
return rtemplate.render(**doc) return rtemplate.render(**doc)
......
...@@ -11,34 +11,28 @@ authors = [ ...@@ -11,34 +11,28 @@ authors = [
description = "A framework for evaluating language models" description = "A framework for evaluating language models"
readme = "README.md" readme = "README.md"
classifiers = [ classifiers = [
"Development Status :: 3 - Alpha", "Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3", "Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License", "License :: OSI Approved :: MIT License",
"Operating System :: OS Independent", "Operating System :: OS Independent"
] ]
requires-python = ">=3.9" requires-python = ">=3.9"
license = { "text" = "MIT" } license = { "text" = "MIT" }
dependencies = [ dependencies = [
"accelerate>=0.26.0", "accelerate>=0.26.0",
"evaluate", "datasets>=2.16.0,<4.0",
"datasets>=2.16.0,<4.0", "evaluate>=0.4.0",
"evaluate>=0.4.0", "peft>=0.2.0",
"jsonlines", "pytablewriter",
"numexpr", "rouge-score>=0.0.4",
"peft>=0.2.0", "sacrebleu>=1.5.0",
"pybind11>=2.6.2", "scikit-learn>=0.24.1",
"pytablewriter", "sqlitedict",
"rouge-score>=0.0.4", "torch>=1.8",
"sacrebleu>=1.5.0", "transformers>=4.1",
"scikit-learn>=0.24.1", "dill",
"sqlitedict", "word2number",
"torch>=1.8", "more_itertools"
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
"dill",
"word2number",
"more_itertools",
] ]
[tool.setuptools.packages.find] [tool.setuptools.packages.find]
...@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"] ...@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"] ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
ipex = ["optimum"] ipex = ["optimum"]
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"] japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
longbench=["jieba", "fuzzywuzzy", "rouge"] longbench = ["jieba", "fuzzywuzzy", "rouge"]
libra=["pymorphy2"] libra=["pymorphy2"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"] mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"] math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
...@@ -87,17 +81,30 @@ vllm = ["vllm>=0.4.2"] ...@@ -87,17 +81,30 @@ vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"] wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"] zeno = ["pandas", "zeno-client"]
tasks = [ tasks = [
"lm_eval[acpbench]", "lm_eval[acpbench]",
"lm_eval[discrim_eval]", "lm_eval[discrim_eval]",
"lm_eval[ifeval]", "lm_eval[ifeval]",
"lm_eval[japanese_leaderboard]", "lm_eval[japanese_leaderboard]",
"lm_eval[longbench]", "lm_eval[longbench]",
"lm_eval[libra]", "lm_eval[libra]",
"lm_eval[mamba]", "lm_eval[mamba]",
"lm_eval[math]", "lm_eval[math]",
"lm_eval[multilingual]", "lm_eval[multilingual]",
"lm_eval[ruler]", "lm_eval[ruler]"
] ]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"]
vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"]
[project.scripts]
lm-eval = "lm_eval.__main__:cli_evaluate"
lm_eval = "lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown] [tool.pymarkdown]
plugins.md013.enabled = false # line-length plugins.md013.enabled = false # line-length
...@@ -107,21 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote ...@@ -107,21 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values = true # ol-prefix plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls plugins.md034.enabled = false # no-bare-urls
[tool.ruff] [tool.ruff]
target-version = "py39" target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"] lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117"] lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401", "F402", "F403"]
[tool.ruff.lint.isort] [tool.ruff.lint.isort]
combine-as-imports = true combine-as-imports = true
lines-after-imports = 2
known-first-party = ["lm_eval"] known-first-party = ["lm_eval"]
lines-after-imports = 2
[tool.ruff.lint.extend-per-file-ignores] # required to include yaml files in pip installation
"__init__.py" = ["F401","F402","F403"] [tool.setuptools.package-data]
lm_eval = ["**/*.yaml", "tasks/**/*"]
[dependency-groups] [tool.setuptools.packages.find]
dev = [ include = ["lm_eval*"]
"api","dev","sentencepiece"
]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment