Commit 4ad6cd9f authored by Baber's avatar Baber
Browse files

remove deps; types

parent 689e0c91
......@@ -33,7 +33,7 @@ repos:
hooks:
# Run the linter.
- id: ruff-check
args: [ --fix ]
args: [ --fix]
# Run the formatter.
- id: ruff-format
- repo: https://github.com/codespell-project/codespell
......
from __future__ import annotations
import abc
import hashlib
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Iterable, Optional, Type, TypeVar, Union
from collections.abc import Iterable
from typing import TYPE_CHECKING, Any, TypeVar
from tqdm import tqdm
......@@ -31,7 +34,7 @@ class LM(abc.ABC):
# set rank and world size to a single process, by default.
self._rank = 0
self._world_size = 1
self.cache_hook: "CacheHook" = CacheHook(None)
self.cache_hook: CacheHook = CacheHook(None)
@abc.abstractmethod
def loglikelihood(self, requests: list[Instance]) -> list[tuple[float, bool]]:
......@@ -101,7 +104,7 @@ class LM(abc.ABC):
# TODO: Add an optional max length
@abc.abstractmethod
def generate_until(self, requests: list["Instance"]) -> list[str]:
def generate_until(self, requests: list[Instance]) -> list[str]:
"""Generate greedily until a stopping sequence
:param requests: list[Instance]
......@@ -137,7 +140,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_string(
cls: Type[T], arg_string: str, additional_config: Optional[dict] = None
cls: type[T], arg_string: str, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given argument string and additional config.
......@@ -156,7 +159,7 @@ class LM(abc.ABC):
@classmethod
def create_from_arg_obj(
cls: Type[T], arg_dict: dict, additional_config: Optional[dict] = None
cls: type[T], arg_dict: dict, additional_config: dict | None = None
) -> T:
"""
Creates an instance of the LM class using the given arg_obj
......@@ -201,7 +204,7 @@ class LM(abc.ABC):
"To use this model with chat templates, please implement the 'tokenizer_name' property."
)
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""Returns the chat template structure for user/assistant messages if a template is provided.
This method is intended to be overridden in a subclass to define a specific chat template format.
For models that do not support chat templates, this method returns None by default.
......@@ -209,7 +212,7 @@ class LM(abc.ABC):
return ""
def set_cache_hook(self, cache_hook: "CacheHook") -> None:
def set_cache_hook(self, cache_hook: CacheHook) -> None:
"""Sets the cache hook for the LM, which is used to cache responses from the LM."""
self.cache_hook = cache_hook
......@@ -221,10 +224,10 @@ def hash_args(attr: str, args: Iterable[Any]) -> str:
class CacheHook:
def __init__(self, cachinglm: Optional["CachingLM"]) -> None:
def __init__(self, cachinglm: CachingLM | None) -> None:
"""CacheHook is used to cache responses from the LM."""
if cachinglm is None:
self.dbdict: Optional["SqliteDict"] = None
self.dbdict: SqliteDict | None = None
return
self.dbdict = cachinglm.dbdict
......@@ -238,7 +241,7 @@ class CacheHook:
class CachingLM:
def __init__(self, lm: "LM", cache_db: str) -> None:
def __init__(self, lm: LM, cache_db: str) -> None:
"""LM wrapper that returns cached results if they exist, and uses the underlying LM if not.
:param lm: LM
......@@ -263,7 +266,7 @@ class CachingLM:
eval_logger.debug(f"Passing through attribute '{attr}' to underlying LM")
return lm_attr
def _fn(requests: list["Instance"]) -> list["Instance"]:
def _fn(requests: list[Instance]) -> list[Instance]:
res = []
remaining_reqs = []
warned = False
......@@ -295,11 +298,8 @@ class CachingLM:
eval_logger.info(
f"Cached requests: {len(requests) - len(remaining_reqs)}, Requests remaining: {len(remaining_reqs)}"
)
if remaining_reqs:
# actually run the LM on the requests that do not have cached results
rem_res = getattr(self.lm, attr)(remaining_reqs)
else:
rem_res = []
rem_res = getattr(self.lm, attr)(remaining_reqs) if remaining_reqs else []
# stick the new ones back into the list and also cache any of the new ones
resptr = 0
......@@ -318,7 +318,7 @@ class CachingLM:
return _fn
def get_cache_hook(self) -> "CacheHook":
def get_cache_hook(self) -> CacheHook:
return CacheHook(self)
......@@ -399,7 +399,7 @@ class TemplateLM(LM):
return context_enc, continuation_enc
def loglikelihood(
self, requests: list["Instance"], disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[tuple[float, bool]]:
"""Compute log-likelihood of generating a continuation from a context.
......@@ -432,7 +432,7 @@ class TemplateLM(LM):
@abc.abstractmethod
def generate_until(
self, requests: list["Instance"], disable_tqdm: bool = False
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
"""Generate until a stopping sequence.
......@@ -453,7 +453,7 @@ class TemplateLM(LM):
"""
pass
def chat_template(self, chat_template: Union[bool, str] = False) -> Optional[str]:
def chat_template(self, chat_template: bool | str = False) -> str | None:
"""
Assumes tokenizer has a chat_template attribute (self.tokenizer.chat_template: dict | str)
Set and get the appropriate chat template for the model.
......
......@@ -7,11 +7,7 @@ import random
import re
from collections.abc import Callable
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Literal,
)
from typing import TYPE_CHECKING, Any, Literal, overload
import datasets
import numpy as np
......@@ -24,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig
from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble
......@@ -133,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
assert self.DATASET_PATH is not None, "DATASET_PATH must be set in Task class"
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -146,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
return self._config
@abc.abstractmethod
def has_training_docs(self) -> bool:
"""Whether the task has a training set"""
pass
raise NotImplementedError
@abc.abstractmethod
def has_validation_docs(self) -> bool:
"""Whether the task has a validation set"""
pass
raise NotImplementedError
@abc.abstractmethod
def has_test_docs(self) -> bool:
"""Whether the task has a test set"""
pass
raise NotImplementedError
def training_docs(self) -> Iterable:
def training_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self) -> Iterable:
def validation_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self) -> Iterable:
def test_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_docs(self) -> Iterable:
def fewshot_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
......@@ -192,7 +186,7 @@ class Task(abc.ABC):
elif self.has_validation_docs():
return self.validation_docs()
else:
if self.config.get("num_fewshot", 0) > 0:
if self.config.num_fewshot and self.config.num_fewshot > 0:
eval_logger.warning(
f"[Task: {self.config.task}] has_training_docs and has_validation_docs are False"
", using test_docs as fewshot_docs but this is not recommended."
......@@ -331,7 +325,7 @@ class Task(abc.ABC):
inst = self.construct_requests(
doc=doc,
ctx=fewshot_ctx,
metadata=(self.config["task"], doc_id, self.config.repeats),
metadata=(self.config.task, doc_id, self.config.repeats),
apply_chat_template=apply_chat_template,
chat_template=chat_template,
)
......@@ -586,7 +580,7 @@ class ConfigurableTask(Task):
data_dir=None,
cache_dir=None,
download_mode=None,
config: dict | None = None,
config: Mapping[str, Any] | None = None,
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -727,6 +721,9 @@ class ConfigurableTask(Task):
)
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else:
assert self.config.dataset_path is not None, (
"dataset_path must be set in TaskConfig"
)
self.dataset = datasets.load_dataset(
path=self.config.dataset_path,
name=self.config.dataset_name,
......@@ -742,7 +739,7 @@ class ConfigurableTask(Task):
def has_test_docs(self) -> bool:
return self.config.test_split is not None
def training_docs(self) -> datasets.Dataset | None:
def training_docs(self) -> DataSet | None:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -750,7 +747,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset | None:
def validation_docs(self) -> DataSet | None:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -758,7 +755,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset | None:
def test_docs(self) -> DataSet | None:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
......@@ -996,9 +993,21 @@ class ConfigurableTask(Task):
"""
return doc
@overload
def doc_to_text(self, doc: dict, doc_to_text: None = None) -> str | int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: int) -> int: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: str) -> str: ...
@overload
def doc_to_text(self, doc: dict, doc_to_text: Callable[..., str]) -> str: ...
def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str:
) -> str | int:
# if self.prompt is not None:
# doc_to_text = self.prompt
doc_to_text = doc_to_text or self.config.doc_to_text
......@@ -1031,6 +1040,25 @@ class ConfigurableTask(Task):
print(type(doc_to_text))
raise TypeError
@overload
def doc_to_target(
self, doc: dict, doc_to_target: None = None
) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: int) -> int: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: str) -> int | str | list[int]: ...
@overload
def doc_to_target(self, doc: dict, doc_to_target: list) -> list[int]: ...
@overload
def doc_to_target(
self, doc: dict, doc_to_target: Callable[..., int | str | list[int]]
) -> int | str | list[int]: ...
def doc_to_target(self, doc: dict, doc_to_target=None) -> int | str | list[int]:
# if self.prompt is not None:
# doc_to_target = self.prompt
......@@ -1077,6 +1105,23 @@ class ConfigurableTask(Task):
else:
raise TypeError
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: None = None) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: str) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: list) -> list[str]: ...
@overload
def doc_to_choice(self, doc: dict, doc_to_choice: dict) -> list[str]: ...
@overload
def doc_to_choice(
self, doc: dict, doc_to_choice: Callable[..., list[str]]
) -> list[str]: ...
def doc_to_choice(
self,
doc: dict,
......@@ -1108,6 +1153,18 @@ class ConfigurableTask(Task):
else:
raise TypeError
@overload
def doc_to_image(self, doc: dict, doc_to_image: None = None) -> None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: list) -> list: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: str) -> int | str | None: ...
@overload
def doc_to_image(self, doc: dict, doc_to_image: Callable[..., Any]) -> Any: ...
def doc_to_image(self, doc: dict, doc_to_image=None) -> int | str | list | None:
if doc_to_image is not None:
doc_to_image = doc_to_image
......@@ -1131,6 +1188,18 @@ class ConfigurableTask(Task):
else:
return None
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: None = None) -> None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: list) -> list: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: str) -> int | str | None: ...
@overload
def doc_to_audio(self, doc: Any, doc_to_audio: Callable[..., Any]) -> Any: ...
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> int | str | list | None:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
......@@ -1375,15 +1444,15 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
for metric in self._metric_fn_list:
for metric in self.config._metric_list:
try:
result_score = self._metric_fn_list[metric](
result_score = metric.fn(
references=[gold] if not isinstance(gold, list) else gold,
predictions=[result],
**self._metric_fn_kwargs[metric],
**metric.kwargs,
)
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = self._metric_fn_list[metric]([gold, result])
result_score = metric.fn([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
......
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
......@@ -11,8 +11,8 @@ class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable | None = None
kwargs: Mapping[str, Any] | None = None
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
......
......@@ -3,7 +3,9 @@ from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable
from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
......@@ -18,6 +20,9 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass
class RepeatConfig:
......@@ -30,7 +35,7 @@ class RepeatConfig:
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter."""
"""Encapsulates information about a single filter pipeline."""
name: str
ensemble: FilterEnsemble
......@@ -44,10 +49,8 @@ class FewshotConfig:
num_fewshot: Callable[[], int]
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None
)
samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
......@@ -69,22 +72,23 @@ class FewshotConfig:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(
self, dataset
) -> list[dict] | Callable[[], Iterable[dict]] | None:
def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list) or callable(self.samples):
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None:
def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
......@@ -130,34 +134,34 @@ class TaskConfig:
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Callable | None = None
custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = (
None # TODO: assert that this not None if num_fewshot > 0. (?) assert if this is same split as one evaluating (?)
)
fewshot_split: str | None = None
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable | None = None
doc_to_text: Callable | str | None = None
doc_to_target: Callable | str | None = None
doc_to_image: Callable | str | None = None
doc_to_audio: Callable | str | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False
doc_to_choice: Callable | str | dict | list | None = None
process_results: Callable | str | None = None
doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict | None = None
fewshot_config: dict[str, Any] | None = None
# runtime configuration options
num_fewshot: int | None = 0
generation_kwargs: dict | None = None
num_fewshot: int | None = None
generation_kwargs: dict[str, Any] | None = None
# scoring options
metric_list: list | None = None
output_type: OutputType = "generate_until"
......@@ -357,7 +361,7 @@ class TaskConfig:
return x
@classmethod
def from_yaml(cls, data: dict) -> TaskConfig:
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
......@@ -425,12 +429,6 @@ class TaskConfig:
# Create and return TaskConfig instance
return cls(**config_dict)
def __getitem__(self, item):
return getattr(self, item)
def __setitem__(self, item, value):
return setattr(self, item, value)
def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x):
if isinstance(x, dict):
......
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import datetime
import io
import json
......@@ -111,7 +122,7 @@ class TextReader:
current_file_position = 0
line_counter = 0
with (
open(self.file_path, "r", encoding="utf-8") as fh,
open(self.file_path, encoding="utf-8") as fh,
tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
......@@ -133,7 +144,7 @@ class TextReader:
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -143,14 +154,14 @@ class TextReader:
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
......
import collections
import fnmatch
import functools
import hashlib
import importlib.util
import inspect
......@@ -8,10 +7,12 @@ import json
import logging
import os
import re
from collections.abc import Generator
from dataclasses import asdict, is_dataclass
from functools import lru_cache, partial, wraps
from itertools import islice
from pathlib import Path
from typing import Any, Callable, Generator, List, Optional, Tuple
from typing import Any, Callable, Optional
import numpy as np
import yaml
......@@ -108,7 +109,7 @@ def escaped_split(text, sep_char, maxsplit=-1):
return text
maxsplit = max(0, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit)
return re.split(r"(?<!\\)" + sep_char, text, maxsplit=maxsplit)
def handle_arg_string(arg):
......@@ -125,7 +126,7 @@ def handle_arg_string(arg):
def handle_non_serializable(o):
if isinstance(o, np.int64) or isinstance(o, np.int32):
if isinstance(o, np.integer):
return int(o)
elif isinstance(o, set):
return list(o)
......@@ -235,21 +236,21 @@ def sanitize_task_name(task_name: str) -> str:
return re.sub(r"\W", "_", task_name)
def get_latest_filename(filenames: List[str]) -> str:
def get_latest_filename(filenames: list[str]) -> str:
"""
Given a list of filenames, returns the filename with the latest datetime.
"""
return max(filenames, key=lambda f: get_file_datetime(f))
def get_results_filenames(filenames: List[str]) -> List[str]:
def get_results_filenames(filenames: list[str]) -> list[str]:
"""
Extracts filenames that correspond to aggregated results.
"""
return [f for f in filenames if "/results_" in f and ".json" in f]
def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_sample_results_filenames(filenames: list[str]) -> list[str]:
"""
Extracts filenames that correspond to sample results.
"""
......@@ -257,8 +258,8 @@ def get_sample_results_filenames(filenames: List[str]) -> List[str]:
def get_rolling_token_windows(
token_list: List[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[Tuple[List[int], List[int]], None, None]:
token_list: list[int], prefix_token: int, max_seq_len: int, context_len: int
) -> Generator[tuple[list[int], list[int]], None, None]:
"""
- context_len allows for a rolling window context, allowing each prediction window to potentially
condition on some context
......@@ -300,8 +301,8 @@ def get_rolling_token_windows(
def make_disjoint_window(
pair: Tuple[List[int], List[int]],
) -> Tuple[List[int], List[int]]:
pair: tuple[list[int], list[int]],
) -> tuple[list[int], list[int]]:
"""Takes output from get_rolling_token_windows and makes the context not overlap with the continuation"""
a, b = pair
return a[: len(a) - (len(b) - 1)], b
......@@ -320,7 +321,7 @@ class EnhancedJSONEncoder(json.JSONEncoder):
class Reorderer:
def __init__(self, arr: List[Any], fn: Callable) -> None:
def __init__(self, arr: list[Any], fn: Callable) -> None:
"""Reorder an array according to some function
Args:
......@@ -423,11 +424,11 @@ def make_table(result_dict, column: str = "results", sort_results: bool = False)
# TODO: fix
hib = "↑"
v = "%.4f" % v if isinstance(v, float) else v
v = f"{v:.4f}" if isinstance(v, float) else v
if m + "_stderr" + "," + f in dic:
se = dic[m + "_stderr" + "," + f]
se = " N/A" if se == "N/A" else "%.4f" % se
se = " N/A" if se == "N/A" else f"{se:.4f}"
values.append([k, version, f, n, m, hib, v, "±", se])
else:
values.append([k, version, f, n, m, hib, v, "", ""])
......@@ -448,7 +449,8 @@ def positional_deprecated(fn):
wrapped function, `fn`.
"""
@functools.wraps(fn)
wraps(fn)
def _wrapper(*args, **kwargs):
if len(args) != 1 if inspect.ismethod(fn) else 0:
print(
......@@ -494,7 +496,7 @@ def load_yaml_config(yaml_path=None, yaml_config=None, yaml_dir=None, mode="full
if yaml_path is None:
raise ValueError("yaml_path must be provided if mode is 'full'.")
# Attach yaml_path to the import function so that it can be used later
constructor_fn = functools.partial(import_function, yaml_path=Path(yaml_path))
constructor_fn = partial(import_function, yaml_path=Path(yaml_path))
loader = yaml.CLoader if yaml.__with_libyaml__ else yaml.FullLoader
# Add the import_function constructor to the YAML loader
......@@ -543,13 +545,18 @@ def regex_replace(string, pattern, repl, count: int = 0):
env = Environment(
loader=BaseLoader, undefined=StrictUndefined, keep_trailing_newline=True
loader=BaseLoader(), undefined=StrictUndefined, keep_trailing_newline=True
)
env.filters["regex_replace"] = regex_replace
@lru_cache(maxsize=128)
def _compile(raw: str):
return env.from_string(raw)
def apply_template(template: str, doc: dict) -> str:
rtemplate = env.from_string(template)
rtemplate = _compile(template)
return rtemplate.render(**doc)
......
......@@ -11,34 +11,28 @@ authors = [
description = "A framework for evaluating language models"
readme = "README.md"
classifiers = [
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent",
"Development Status :: 3 - Alpha",
"Programming Language :: Python :: 3",
"License :: OSI Approved :: MIT License",
"Operating System :: OS Independent"
]
requires-python = ">=3.9"
license = { "text" = "MIT" }
dependencies = [
"accelerate>=0.26.0",
"evaluate",
"datasets>=2.16.0,<4.0",
"evaluate>=0.4.0",
"jsonlines",
"numexpr",
"peft>=0.2.0",
"pybind11>=2.6.2",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu>=1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
"tqdm-multiprocess",
"transformers>=4.1",
"zstandard",
"dill",
"word2number",
"more_itertools",
"accelerate>=0.26.0",
"datasets>=2.16.0,<4.0",
"evaluate>=0.4.0",
"peft>=0.2.0",
"pytablewriter",
"rouge-score>=0.0.4",
"sacrebleu>=1.5.0",
"scikit-learn>=0.24.1",
"sqlitedict",
"torch>=1.8",
"transformers>=4.1",
"dill",
"word2number",
"more_itertools"
]
[tool.setuptools.packages.find]
......@@ -68,7 +62,7 @@ ibm_watsonx_ai = ["ibm_watsonx_ai>=1.1.22", "python-dotenv"]
ifeval = ["langdetect", "immutabledict", "nltk>=3.9.1"]
ipex = ["optimum"]
japanese_leaderboard = ["emoji==2.14.0", "neologdn==0.5.3", "fugashi[unidic-lite]", "rouge_score>=0.1.2"]
longbench=["jieba", "fuzzywuzzy", "rouge"]
longbench = ["jieba", "fuzzywuzzy", "rouge"]
libra=["pymorphy2"]
mamba = ["mamba_ssm", "causal-conv1d==1.0.2", "torch"]
math = ["sympy>=1.12", "antlr4-python3-runtime==4.11", "math_verify[antlr4_11_0]"]
......@@ -87,17 +81,30 @@ vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"]
tasks = [
"lm_eval[acpbench]",
"lm_eval[discrim_eval]",
"lm_eval[acpbench]",
"lm_eval[discrim_eval]",
"lm_eval[ifeval]",
"lm_eval[japanese_leaderboard]",
"lm_eval[longbench]",
"lm_eval[japanese_leaderboard]",
"lm_eval[longbench]",
"lm_eval[libra]",
"lm_eval[mamba]",
"lm_eval[math]",
"lm_eval[multilingual]",
"lm_eval[ruler]",
"lm_eval[math]",
"lm_eval[multilingual]",
"lm_eval[ruler]"
]
testing = ["pytest", "pytest-cov", "pytest-xdist"]
unitxt = ["unitxt==1.22.0"]
vllm = ["vllm>=0.4.2"]
wandb = ["wandb>=0.16.3", "pandas", "numpy"]
zeno = ["pandas", "zeno-client"]
[project.scripts]
lm-eval = "lm_eval.__main__:cli_evaluate"
lm_eval = "lm_eval.__main__:cli_evaluate"
[project.urls]
Homepage = "https://github.com/EleutherAI/lm-evaluation-harness"
Repository = "https://github.com/EleutherAI/lm-evaluation-harness"
[tool.pymarkdown]
plugins.md013.enabled = false # line-length
......@@ -107,21 +114,23 @@ plugins.md028.enabled = false # no-blanks-blockquote
plugins.md029.allow_extended_start_values = true # ol-prefix
plugins.md034.enabled = false # no-bare-urls
[tool.ruff]
target-version = "py39"
lint.extend-select = ["I", "UP", "E", "C419", "F", "B", "SIM"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117"]
lint.fixable = ["I001", "F401", "UP"]
lint.ignore = ["E402", "E731", "E501", "E111", "E114", "E117", "E741"]
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401", "F402", "F403"]
[tool.ruff.lint.isort]
combine-as-imports = true
lines-after-imports = 2
known-first-party = ["lm_eval"]
lines-after-imports = 2
[tool.ruff.lint.extend-per-file-ignores]
"__init__.py" = ["F401","F402","F403"]
# required to include yaml files in pip installation
[tool.setuptools.package-data]
lm_eval = ["**/*.yaml", "tasks/**/*"]
[dependency-groups]
dev = [
"api","dev","sentencepiece"
]
[tool.setuptools.packages.find]
include = ["lm_eval*"]
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment