Commit abd17276 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into tasklist

# Conflicts:
#	lm_eval/__main__.py
#	lm_eval/api/group.py
#	lm_eval/api/task.py
#	lm_eval/evaluator_utils.py
#	lm_eval/tasks/__init__.py
#	lm_eval/utils.py
#	pyproject.toml
parents 00afd536 70314843
from .evaluate_config import EvaluatorConfig
__all__ = [
"EvaluatorConfig",
]
This diff is collapsed.
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(*args, **kwargs)
This diff is collapsed.
This diff is collapsed.
from __future__ import annotations
from functools import wraps
from inspect import getsource
from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable(
value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., T] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
def create_mc_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a multiple-choice question format from a list of choices."""
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
def create_cloze_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a cloze-style question format from a list of choices."""
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import datetime import datetime
import io import io
import json import json
...@@ -111,7 +122,7 @@ class TextReader: ...@@ -111,7 +122,7 @@ class TextReader:
current_file_position = 0 current_file_position = 0
line_counter = 0 line_counter = 0
with ( with (
open(self.file_path, "r", encoding="utf-8") as fh, open(self.file_path, encoding="utf-8") as fh,
tqdm.tqdm( tqdm.tqdm(
total=os.path.getsize(self.file_path), total=os.path.getsize(self.file_path),
dynamic_ncols=True, dynamic_ncols=True,
...@@ -133,7 +144,7 @@ class TextReader: ...@@ -133,7 +144,7 @@ class TextReader:
def read_and_tell(self): def read_and_tell(self):
current_file_position = 0 current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh: with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
...@@ -143,14 +154,14 @@ class TextReader: ...@@ -143,14 +154,14 @@ class TextReader:
yield line[:-1], raw_bytes_read yield line[:-1], raw_bytes_read
def read(self): def read(self):
with open(self.file_path, "r", encoding="utf8") as fh: with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj: with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""): for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8") line = line.decode("utf-8")
yield line[:-1] yield line[:-1]
def read_slow(self): def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh: with open(self.file_path, encoding="utf8") as fh:
while True: while True:
line = fh.readline() line = fh.readline()
if line == -1 or line == "": if line == -1 or line == "":
......
...@@ -5,8 +5,9 @@ import traceback ...@@ -5,8 +5,9 @@ import traceback
from typing import Iterator, List, Sequence, Tuple, TypeVar from typing import Iterator, List, Sequence, Tuple, TypeVar
# This is a cpp module. Compile janitor_util.cpp with: # This is a cpp module.
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup # See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
try: try:
import janitor_util import janitor_util
......
This diff is collapsed.
...@@ -11,6 +11,7 @@ from lm_eval.api.metrics import ( ...@@ -11,6 +11,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr, pooled_sample_stderr,
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated
...@@ -56,7 +57,7 @@ class TaskOutput: ...@@ -56,7 +57,7 @@ class TaskOutput:
group_alias=None, group_alias=None,
is_group=None, is_group=None,
): ):
self.task = task self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config self.task_config = task_config
self.task_name = task_name self.task_name = task_name
self.group_name = group_name self.group_name = group_name
......
This diff is collapsed.
...@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter): ...@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination" name = "track_decontamination"
def __init__(self, path) -> None: def __init__(self, path, **kwargs) -> None:
""" """
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path"). TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id) should further cache result on a given (task_name, doc_id)
""" """
super().__init__(**kwargs)
self._decontam_results = None self._decontam_results = None
def apply(self, resps, docs) -> None: def apply(self, resps, docs) -> None:
......
This diff is collapsed.
...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter): ...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(**kwargs) super().__init__(**kwargs)
def apply(self, resps, docs): def apply(self, resps, docs):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -3,7 +3,7 @@ import json ...@@ -3,7 +3,7 @@ import json
import logging import logging
import os import os
import warnings import warnings
from functools import lru_cache from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm from tqdm import tqdm
...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None: ...@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg) raise ValueError(error_msg)
@lru_cache(maxsize=None) @cache
def get_watsonx_credentials() -> Dict[str, str]: def get_watsonx_credentials() -> Dict[str, str]:
""" """
Retrieves Watsonx API credentials from environmental variables. Retrieves Watsonx API credentials from environmental variables.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment