Commit abd17276 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into tasklist

# Conflicts:
#	lm_eval/__main__.py
#	lm_eval/api/group.py
#	lm_eval/api/task.py
#	lm_eval/evaluator_utils.py
#	lm_eval/tasks/__init__.py
#	lm_eval/utils.py
#	pyproject.toml
parents 00afd536 70314843
from .evaluate_config import EvaluatorConfig
__all__ = [
"EvaluatorConfig",
]
This diff is collapsed.
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(*args, **kwargs)
This diff is collapsed.
This diff is collapsed.
from __future__ import annotations
from functools import wraps
from inspect import getsource
from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable(
value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., T] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
def create_mc_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a multiple-choice question format from a list of choices."""
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
def create_cloze_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a cloze-style question format from a list of choices."""
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import datetime
import io
import json
......@@ -111,7 +122,7 @@ class TextReader:
current_file_position = 0
line_counter = 0
with (
open(self.file_path, "r", encoding="utf-8") as fh,
open(self.file_path, encoding="utf-8") as fh,
tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
......@@ -133,7 +144,7 @@ class TextReader:
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -143,14 +154,14 @@ class TextReader:
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
......
......@@ -5,8 +5,9 @@ import traceback
from typing import Iterator, List, Sequence, Tuple, TypeVar
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
# This is a cpp module.
# See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
try:
import janitor_util
......
This diff is collapsed.
......@@ -11,6 +11,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated
......@@ -56,7 +57,7 @@ class TaskOutput:
group_alias=None,
is_group=None,
):
self.task = task
self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
......
This diff is collapsed.
......@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination"
def __init__(self, path) -> None:
def __init__(self, path, **kwargs) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
super().__init__(**kwargs)
self._decontam_results = None
def apply(self, resps, docs) -> None:
......
This diff is collapsed.
......@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(**kwargs)
def apply(self, resps, docs):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
......@@ -3,7 +3,7 @@ import json
import logging
import os
import warnings
from functools import lru_cache
from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg)
@lru_cache(maxsize=None)
@cache
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment