Commit abd17276 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into tasklist

# Conflicts:
#	lm_eval/__main__.py
#	lm_eval/api/group.py
#	lm_eval/api/task.py
#	lm_eval/evaluator_utils.py
#	lm_eval/tasks/__init__.py
#	lm_eval/utils.py
#	pyproject.toml
parents 00afd536 70314843
from .evaluate_config import EvaluatorConfig
__all__ = [
"EvaluatorConfig",
]
import json
import logging
import textwrap
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import yaml
from lm_eval.utils import simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.tasks import TaskManager
eval_logger = logging.getLogger(__name__)
DICT_KEYS = [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
"gen_kwargs",
]
@dataclass
class EvaluatorConfig:
"""Configuration for language model evaluation runs.
This dataclass contains all parameters for configuring model evaluations via
`simple_evaluate()` or the CLI. It supports initialization from:
- CLI arguments (via `from_cli()`)
- YAML configuration files (via `from_config()`)
- Direct instantiation with keyword arguments
The configuration handles argument parsing, validation, and preprocessing
to ensure properly structured and validated.
Example:
# From CLI arguments
config = EvaluatorConfig.from_cli(args)
# From YAML file
config = EvaluatorConfig.from_config("eval_config.yaml")
# Direct instantiation
config = EvaluatorConfig(
model="hf",
model_args={"pretrained": "gpt2"},
tasks=["hellaswag", "arc_easy"],
num_fewshot=5
)
See individual field documentation for detailed parameter descriptions.
"""
# Core evaluation parameters
config: Optional[str] = field(
default=None, metadata={"help": "Path to YAML config file"}
)
model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
model_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for model initialization"}
)
tasks: Union[str, list[str]] = field(
default_factory=list,
metadata={"help": "Comma-separated list of task names to evaluate"},
)
# Few-shot and batching
num_fewshot: Optional[int] = field(
default=None, metadata={"help": "Number of examples in few-shot context"}
)
batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
max_batch_size: Optional[int] = field(
default=None, metadata={"help": "Maximum batch size for auto batching"}
)
# Device
device: Optional[str] = field(
default="cuda:0", metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
)
# Data sampling and limiting
limit: Optional[float] = field(
default=None, metadata={"help": "Limit number of examples per task"}
)
samples: Union[str, dict, None] = field(
default=None,
metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
)
# Caching
use_cache: Optional[str] = field(
default=None,
metadata={"help": "Path to sqlite db file for caching model outputs"},
)
cache_requests: dict = field(
default_factory=dict,
metadata={"help": "Cache dataset requests: true/refresh/delete"},
)
# Output and logging flags
check_integrity: bool = field(
default=False, metadata={"help": "Run test suite for tasks"}
)
write_out: bool = field(
default=False, metadata={"help": "Print prompts for first few documents"}
)
log_samples: bool = field(
default=False, metadata={"help": "Save model outputs and inputs"}
)
output_path: Optional[str] = field(
default=None, metadata={"help": "Dir path where result metrics will be saved"}
)
predict_only: bool = field(
default=False,
metadata={
"help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
},
)
# Chat and instruction handling
system_instruction: Optional[str] = field(
default=None, metadata={"help": "Custom System instruction to add"}
)
apply_chat_template: Union[bool, str] = field(
default=False,
metadata={
"help": "Apply chat template to prompt. Either True, or a string identifying the tokenizer template."
},
)
fewshot_as_multiturn: bool = field(
default=False,
metadata={
"help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
},
)
# Configuration display
show_config: bool = field(
default=False, metadata={"help": "Show full config at end of evaluation"}
)
# External tasks and generation
include_path: Optional[str] = field(
default=None, metadata={"help": "Additional dir path for external tasks"}
)
gen_kwargs: Optional[dict] = field(
default=None, metadata={"help": "Arguments for model generation"}
)
# Logging and verbosity
verbosity: Optional[str] = field(
default=None, metadata={"help": "Logging verbosity level"}
)
# External integrations
wandb_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.init"}
)
wandb_config_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
)
hf_hub_log_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
)
# Reproducibility
seed: list = field(
default_factory=lambda: [0, 1234, 1234, 1234],
metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
)
# Security
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code for HF datasets"}
)
confirm_run_unsafe_code: bool = field(
default=False,
metadata={
"help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
},
)
# Internal metadata
metadata: dict = field(
default_factory=dict,
metadata={"help": "Additional metadata for tasks that require it"},
)
@classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
"""
Build an EvaluationConfig by merging with simple precedence:
CLI args > YAML config > built-in defaults
"""
# Start with built-in defaults
config = asdict(cls())
# Load and merge YAML config if provided
if used_config := hasattr(namespace, "config") and namespace.config:
config.update(cls.load_yaml_config(namespace.config))
# Override with CLI args (only truthy values, exclude non-config args)
excluded_args = {"command", "func"} # argparse internal args
cli_args = {
k: v for k, v in vars(namespace).items() if v and k not in excluded_args
}
config.update(cli_args)
# Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config)
# Create instance and validate
instance = cls(**config)
if used_config:
print(textwrap.dedent(f"""{instance}"""))
instance.configure()
return instance
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
"""
Build an EvaluationConfig from a YAML config file.
Merges with built-in defaults and validates.
"""
# Load YAML config
yaml_config = cls.load_yaml_config(config_path)
# Parse string arguments that should be dictionaries
yaml_config = cls._parse_dict_args(yaml_config)
instance = cls(**yaml_config)
instance.configure()
return instance
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load and validate YAML config file."""
config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path
)
if not config_file.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}")
except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}")
if not isinstance(yaml_data, dict):
raise ValueError(
f"YAML root must be a mapping, got {type(yaml_data).__name__}"
)
return yaml_data
def configure(self) -> None:
"""Validate configuration and preprocess fields after creation."""
self._validate_arguments()
self._process_arguments()
self._set_trust_remote_code()
def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints."""
if self.limit:
eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
# predict_only implies log_samples
if self.predict_only:
self.log_samples = True
# log_samples or predict_only requires output_path
if (self.log_samples or self.predict_only) and not self.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
# fewshot_as_multiturn requires apply_chat_template
if self.fewshot_as_multiturn and self.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
)
# samples and limit are mutually exclusive
if self.samples and self.limit is not None:
raise ValueError("If --samples is not None, then --limit must be None.")
# tasks is required
if self.tasks is None:
raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed."""
if self.samples:
if isinstance(self.samples, dict):
self.samples = self.samples
elif isinstance(self.samples, str):
try:
self.samples = json.loads(self.samples)
except json.JSONDecodeError:
if (samples_path := Path(self.samples)).is_file():
self.samples = json.loads(samples_path.read_text())
# Set up metadata by merging model_args and metadata.
if self.model_args is None:
self.model_args = {}
if self.metadata is None:
self.metadata = {}
self.metadata = self.model_args | self.metadata
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager
# if metadata manually passed use that:
self.metadata = metadata if metadata else self.metadata
# Create task manager with metadata
task_manager = TaskManager(
include_path=self.include_path,
metadata=self.metadata if self.metadata else {},
)
task_names = task_manager.match_tasks(self.tasks)
# Check for any individual task files in the list
for task in [task for task in self.tasks if task not in task_names]:
task_path = Path(task)
if task_path.is_file():
config = utils.load_yaml_config(str(task_path))
task_names.append(config)
# Check for missing tasks
task_missing = [
task for task in self.tasks if task not in task_names and "*" not in task
]
if task_missing:
missing = ", ".join(task_missing)
raise ValueError(f"Tasks not found: {missing}")
# Update tasks with resolved names
self.tasks = task_names
return task_manager
def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
# Add to model_args for the actual model initialization
if self.model_args is None:
self.model_args = {}
self.model_args["trust_remote_code"] = True
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task
from lm_eval.config.template import TemplateConfig
eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass
class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: str | Callable = "pass@N"
kwargs: dict | None = field(default_factory=dict)
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter pipeline."""
name: str
ensemble: FilterEnsemble
metric_list: list[MetricConfig]
@dataclass
class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
if self.samples is not None and not (
isinstance(self.samples, list) or callable(self.samples)
):
raise TypeError(
"samples must be either list[dict] or callable returning list[dict]"
)
if self.split is not None and self.samples is not None:
eval_logger.warning(
"Both split and samples are configured; split will take precedence"
)
@property
def has_source(self) -> bool:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
return None
if self.process_docs is not None:
return self.process_docs(raw_docs)
return raw_docs
@property
def get_sampler(self) -> Callable[..., Any] | None:
from lm_eval.api import samplers
if isinstance(self.sampler, str):
return samplers.get_sampler(self.sampler)
elif callable(self.sampler):
return self.sampler
def init_sampler(
self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> ContextSampler:
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
return self.get_sampler(
docs,
task,
rnd=rnd,
fewshot_indices=fewshot_indices
if fewshot_indices
else self.fewshot_indices,
)
@dataclass
class TaskConfig:
# task naming/registry
task: str | None = None
task_alias: str | None = None
tag: str | list | None = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = None
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False
doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict[str, Any] | None = None
# runtime configuration options
num_fewshot: int | None = None
generation_kwargs: dict[str, Any] | None = None
# scoring options
metric_list: list | None = None
output_type: OutputType = "generate_until"
repeats: int = 1
filter_list: list[dict] | None = None
should_decontaminate: bool = False
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
multiple_input: bool = False
metadata: dict = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = field(default_factory=list)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False)
_fn: dict[str, Callable] = field(default_factory=dict)
def __post_init__(self) -> None:
### ---setup generation kwargs--- ###
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg.get("num_fewshot", 0),
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
process_docs=_fewshot_cfg.get("process_docs", None),
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
# if metric_list defined inside a filter, use that; otherwise use the task's metric_list
metric_list = metric_list or self.metric_list
metrics = []
if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name) or True,
)
for metric_name in _metric_list
)
else:
# ---------- 2. Process user-defined metrics from config ----------
for metric_config in metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = metric_name
_metric_fn = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
for m in metrics:
if m not in self._metric_list:
self._metric_list.append(m)
return metrics
@property
def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self._get_metric(metric_list=None),
)
]
else:
def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
x = [
FilterConfig(
name=cfg["name"],
ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self._get_metric(metric_list=cfg.get("metric_list")),
)
for cfg in configs
]
return x
@classmethod
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
fn = {k: v for k, v in data.items() if callable(v)}
return cls(**data, _fn=fn)
@classmethod
def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig:
"""Create a TaskConfig instance from a template.
Args:
template: TemplateConfig instance (MCQTemplateConfig or ClozeTemplateConfig)
**kwargs: Additional arguments to override template defaults
Returns:
TaskConfig instance configured from the template
"""
from lm_eval.config.template import (
ClozeTemplateConfig,
MCQTemplateConfig,
)
# Extract base configuration from template
config_dict = {
"task": template.task,
"doc_to_text": template.doc_to_text,
"doc_to_choice": template.doc_to_choice,
"doc_to_target": template.doc_to_target,
"description": template.description,
"target_delimiter": template.target_delimiter,
"fewshot_delimiter": template.fewshot_delimiter,
"metric_list": template.metric_list,
}
# Add common template attributes if they exist
if hasattr(template, "answer_suffix"):
config_dict["target_delimiter"] = (
template.answer_suffix + template.target_delimiter
)
# Handle template-specific configurations
if isinstance(template, MCQTemplateConfig):
# For MCQ templates, set up multiple choice specific config
config_dict["output_type"] = "multiple_choice"
# MCQ templates typically use accuracy metrics
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}]
elif isinstance(template, ClozeTemplateConfig):
# For Cloze templates, set up generation config
config_dict["output_type"] = "generate_until"
# Cloze templates typically use accuracy and normalized accuracy
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}, {"metric": "acc_norm"}]
else:
# Generic template - try to infer output type
if hasattr(template, "template"):
if template.template == "mcq":
config_dict["output_type"] = "multiple_choice"
elif template.template == "cloze":
config_dict["output_type"] = "generate_until"
# Override with any user-provided kwargs
config_dict.update(kwargs)
# Create and return TaskConfig instance
return cls(**config_dict)
def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x):
if isinstance(x, dict):
return {k: _ser(v) for k, v in x.items()}
if isinstance(x, (list, tuple, set)):
return type(x)(_ser(i) for i in x)
return maybe_serialize(x, keep_callable)
return {k: _ser(v) for k, v in asdict(self).items() if v is not None}
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable
from lm_eval.config.utils import create_mc_choices
if TYPE_CHECKING:
from lm_eval.config.metric import MetricConfig
@dataclass
class TemplateConfig(ABC):
"""Encapsulates information about a template."""
#
template: str
task: str
doc_to_text: str | Callable[[dict], str] | list[str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
description: str
context_prefix: str
prefix_delimiter: str
context_delimiter: str
answer_suffix: str
target_delimiter: str
choice_format: str | None
choice_delimiter: str | None
fewshot_delimiter: str
metric_list: list[str] | list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
@abstractmethod
def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text."""
raise NotImplementedError
def _doc_to_choice(self, doc: dict) -> str:
"""Convert a document to choices."""
raise NotImplementedError
def _doc_to_target(self, doc: dict) -> int | str:
"""Convert a document to target."""
raise NotImplementedError
@dataclass
class MCQTemplateConfig:
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
A. <doc_to_choice(doc)[0]>
B. <doc_to_choice(doc)[1]>
C. <doc_to_choice(doc)[2]>
D. <doc_to_choice(doc)[3]>
Answer: 'doc_to_choice(doc)` for each choice.
"""
doc_to_text: str | Callable[[dict], str]
doc_to_choice: list[str]
doc_to_target: int | Callable[[dict], int]
template = "mcq"
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = "\n"
choice_format: str | None = "letters"
choice_delimiter: str = "\n"
fewshot_delimiter: str = "\n\n"
metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text."""
doc_to_text: str = (
self.doc_to_text
if isinstance(self.doc_to_text, str)
else self.doc_to_text(doc)
)
return (
self.context_prefix
+ self.prefix_delimiter
+ doc_to_text
+ self.context_delimiter
+ create_mc_choices(
self.doc_to_choice, choice_delimiter=self.choice_delimiter
)
+ self.answer_suffix
)
def _doc_to_choice(self, doc: dict) -> str:
if callable(self.doc_to_choice):
doc_to_choice = self.doc_to_choice(doc)
elif isinstance(self.doc_to_choice, str):
doc_to_choice = doc[self.doc_to_choice]
else:
doc_to_choice = self.doc_to_choice
return create_mc_choices(doc_to_choice, choice_delimiter=self.choice_delimiter)
def _doc_to_target(self, doc: dict) -> int:
"""Convert a document to target."""
if callable(self.doc_to_target):
return self.doc_to_target(doc)
elif isinstance(self.doc_to_target, str):
return doc[self.doc_to_target]
else:
return self.doc_to_target
@dataclass
class ClozeTemplateConfig(TemplateConfig):
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
Answer:` <doc_to_target(doc)>`
"""
doc_to_text: str | Callable[[dict], str]
doc_to_choice: list[str]
doc_to_target: int | Callable[[dict], int]
template: str = "cloze"
description: str = ""
context_prefix: str = "Question:"
prefix_delimiter: str = " "
context_delimiter: str = "\n"
answer_suffix: str = "Answer:"
target_delimiter: str = " "
choice_format: str | None = None
choice_delimiter: str = ""
fewshot_delimiter: str = "\n\n"
metric_list: list[MetricConfig] | None = field(
default_factory=lambda: ["acc", "acc_norm"]
)
def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text."""
doc_to_text: str = (
self.doc_to_text
if isinstance(self.doc_to_text, str)
else self.doc_to_text(doc)
)
return (
self.context_prefix
+ self.prefix_delimiter
+ doc_to_text
+ self.context_delimiter
+ self.answer_suffix
)
def _doc_to_choice(self, doc: dict) -> str:
if callable(self.doc_to_choice):
doc_to_choice = self.doc_to_choice(doc)
elif isinstance(self.doc_to_choice, str):
doc_to_choice = doc[self.doc_to_choice]
else:
doc_to_choice = self.doc_to_choice
return create_mc_choices(doc_to_choice, choice_delimiter=self.choice_delimiter)
def _doc_to_target(self, doc: dict) -> int:
"""Convert a document to target."""
if callable(self.doc_to_target):
return self.doc_to_target(doc)
elif isinstance(self.doc_to_target, str):
return doc[self.doc_to_target]
else:
return self.doc_to_target
from __future__ import annotations
from functools import wraps
from inspect import getsource
from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable(
value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., T] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
def create_mc_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a multiple-choice question format from a list of choices."""
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
def create_cloze_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a cloze-style question format from a list of choices."""
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
# /// script
# requires-python = ">=3.8"
# dependencies = [
# "jsonlines",
# "mmap",
# "tqdm",
# "zstandard",
# ]
# ///
# ruff: noqa
import datetime
import io
import json
......@@ -111,7 +122,7 @@ class TextReader:
current_file_position = 0
line_counter = 0
with (
open(self.file_path, "r", encoding="utf-8") as fh,
open(self.file_path, encoding="utf-8") as fh,
tqdm.tqdm(
total=os.path.getsize(self.file_path),
dynamic_ncols=True,
......@@ -133,7 +144,7 @@ class TextReader:
def read_and_tell(self):
current_file_position = 0
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
......@@ -143,14 +154,14 @@ class TextReader:
yield line[:-1], raw_bytes_read
def read(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
with mmap.mmap(fh.fileno(), length=0, access=mmap.ACCESS_READ) as mmap_obj:
for line in iter(mmap_obj.readline, b""):
line = line.decode("utf-8")
yield line[:-1]
def read_slow(self):
with open(self.file_path, "r", encoding="utf8") as fh:
with open(self.file_path, encoding="utf8") as fh:
while True:
line = fh.readline()
if line == -1 or line == "":
......
......@@ -5,8 +5,9 @@ import traceback
from typing import Iterator, List, Sequence, Tuple, TypeVar
# This is a cpp module. Compile janitor_util.cpp with:
# c++ -O3 -Wall -shared -std=c++11 -fPIC $(python3 -m pybind11 --includes) janitor_util.cpp -o janitor_util$(python3-config --extension-suffix) -undefined dynamic_lookup
# This is a cpp module.
# See scripts/clean_training_data/README.md for instructions to compile janitor_util.cpp
try:
import janitor_util
......
from __future__ import annotations
import itertools
import json
import logging
......@@ -5,7 +7,7 @@ import os
import random
import time
from collections import defaultdict
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, Any, List, Optional, Union
import numpy as np
import torch
......@@ -29,11 +31,11 @@ from lm_eval.loggers import EvaluationTracker
from lm_eval.loggers.utils import add_env_info, add_tokenizer_info, get_git_commit_hash
from lm_eval.tasks import TaskManager, get_task_dict
from lm_eval.utils import (
get_logger,
handle_non_serializable,
hash_dict_images,
hash_string,
positional_deprecated,
setup_logging,
simple_parse_args_string,
wrap_text,
)
......@@ -49,7 +51,7 @@ eval_logger = logging.getLogger(__name__)
@positional_deprecated
def simple_evaluate(
model,
model_args: Optional[Union[str, dict]] = None,
model_args: Optional[Union[str, dict[str, Any]]] = None,
tasks: Optional[List[Union[str, dict, object]]] = None,
num_fewshot: Optional[int] = None,
batch_size: Optional[Union[int, str]] = None,
......@@ -147,7 +149,7 @@ def simple_evaluate(
Dictionary of results
"""
if verbosity is not None:
setup_logging(verbosity=verbosity)
get_logger(verbosity)
start_date = time.time()
if limit is not None and samples is not None:
......@@ -287,7 +289,7 @@ def simple_evaluate(
# helper function to recursively apply config overrides to leaf subtasks, skipping their constituent groups.
# (setting of num_fewshot ; bypassing metric calculation ; setting fewshot seed)
def _adjust_config(task_dict):
def _adjust_config(task_dict: dict[str, "Task"]) -> dict[str, "Task"]:
adjusted_task_dict = {}
for task_name, task_obj in task_dict.items():
if isinstance(task_obj, dict):
......@@ -370,8 +372,6 @@ def simple_evaluate(
verbosity=verbosity,
confirm_run_unsafe_code=confirm_run_unsafe_code,
)
if verbosity is not None:
setup_logging(verbosity=verbosity)
if lm.rank == 0:
if isinstance(model, str):
......@@ -420,7 +420,7 @@ def simple_evaluate(
def evaluate(
lm: "LM",
task_dict,
limit: Optional[int] = None,
limit: int | float | None = None,
samples: Optional[dict] = None,
cache_requests: bool = False,
rewrite_requests_cache: bool = False,
......@@ -475,7 +475,9 @@ def evaluate(
"Either 'limit' or 'samples' must be None, but both are not None."
)
if samples is not None:
eval_logger.info(f"Evaluating examples for tasks {list(samples.keys())}")
eval_logger.info(
f"Evaluating examples for tasks {[x for x in list(samples.keys()) if x in task_dict.keys()]}"
)
if apply_chat_template:
eval_logger.warning(
"Chat template formatting change affects loglikelihood and multiple-choice tasks. See docs/chat-template-readme.md for details."
......@@ -775,13 +777,3 @@ def evaluate(
else:
return None
def request_caching_arg_to_dict(cache_requests: str) -> dict:
request_caching_args = {
"cache_requests": cache_requests in {"true", "refresh"},
"rewrite_requests_cache": cache_requests == "refresh",
"delete_requests_cache": cache_requests == "delete",
}
return request_caching_args
......@@ -11,6 +11,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr,
stderr_for_metric,
)
from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated
......@@ -56,7 +57,7 @@ class TaskOutput:
group_alias=None,
is_group=None,
):
self.task = task
self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config
self.task_name = task_name
self.group_name = group_name
......
from __future__ import annotations
from functools import partial
from typing import List
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.registry import get_filter
from lm_eval.api.registry import filter_registry, get_filter
from . import custom, extraction, selection, transformation
def build_filter_ensemble(
filter_name: str, components: List[List[str]]
filter_name: str,
components: list[tuple[str, dict[str, str | int | float] | None]],
) -> FilterEnsemble:
"""
Create a filtering pipeline.
"""
filters = []
for function, kwargs in components:
if kwargs is None:
kwargs = {}
# create a filter given its name in the registry
f = partial(get_filter(function), **kwargs)
# add the filter as a pipeline step
filters.append(f)
# create filters given its name in the registry, and add each as a pipeline step
return FilterEnsemble(
name=filter_name,
filters=[
partial(get_filter(func), **(kwargs or {})) for func, kwargs in components
],
)
return FilterEnsemble(name=filter_name, filters=filters)
__all__ = [
"custom",
"extraction",
"selection",
"transformation",
"build_filter_ensemble",
]
......@@ -10,12 +10,13 @@ class DecontaminationFilter(Filter):
name = "track_decontamination"
def __init__(self, path) -> None:
def __init__(self, path, **kwargs) -> None:
"""
TODO: make sure only ever run one time on the train set (should this be cached as a class var? keyed by value for "path").
should further cache result on a given (task_name, doc_id)
"""
super().__init__(**kwargs)
self._decontam_results = None
def apply(self, resps, docs) -> None:
......
import re
import sys
import unicodedata
from collections.abc import Iterable
from lm_eval.api.filter import Filter
from lm_eval.api.registry import register_filter
......@@ -20,17 +21,21 @@ class RegexFilter(Filter):
regex_pattern: str = r"#### (\-?[0-9\.\,]+)",
group_select: int = 0,
fallback: str = "[invalid]",
**kwargs,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
super().__init__(**kwargs)
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......@@ -57,57 +62,13 @@ class RegexFilter(Filter):
return filtered_resps
@register_filter("regex_pos")
class POSFilter(Filter):
""" """
def __init__(
self,
regex_pattern: str = r"\['(.*?)'\]",
group_select=0,
fallback=None,
) -> None:
"""
pass a string `regex` to run `re.compile(r"regex")` on.
`fallback` defines the output returned if no matches for the regex are located.
"""
if fallback is None:
fallback = ["invalid"]
self.regex_pattern = regex_pattern
self.regex = re.compile(regex_pattern)
self.group_select = group_select
self.fallback = fallback
def apply(self, resps, docs):
def extract_tagged_tokens(text):
# Extract tagged tokens list from text input using regex
tokens = re.findall(r"\('([^']*)', '([^']*)'\)", text)
return [(token, pos) for token, pos in tokens]
def extract_pos_tags(result):
pos_tags = []
if isinstance(result, str):
result = extract_tagged_tokens(result)
pos_tags.extend(pos for _, pos in result)
return pos_tags if pos_tags else self.fallback
def filter_set(inst):
filtered = []
for resp in inst:
match = extract_pos_tags(resp)
filtered.append(match)
return filtered
filtered_resps = map(lambda x: filter_set(x), resps)
return filtered_resps
@register_filter("remove_whitespace")
class WhitespaceFilter(Filter):
"""Filters out leading whitespace from responses."""
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
def filter_set(inst):
filtered_resp = []
for resp in inst:
......@@ -152,7 +113,9 @@ class MultiChoiceRegexFilter(RegexFilter):
self.ignore_punctuation = ignore_punctuation
self.regexes_to_ignore = regexes_to_ignore
def apply(self, resps: list[list[str]], docs: list[dict]) -> list[list[str]]:
def apply(
self, resps: Iterable[list[str]], docs: Iterable[dict]
) -> Iterable[list[str]]:
# here, we assume we have a list, in which each element is
# a list of model responses for some particular input/target pair.
# so we process each of these (same input/target response sets)
......
......@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k")
super().__init__(**kwargs)
def apply(self, resps, docs):
......
......@@ -6,9 +6,6 @@ from lm_eval.api.registry import register_filter
@register_filter("lowercase")
class LowercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.lower() for resp in inst]
......@@ -18,9 +15,6 @@ class LowercaseFilter(Filter):
@register_filter("uppercase")
class UppercaseFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def filter_set(inst):
return [resp.upper() for resp in inst]
......@@ -31,6 +25,7 @@ class UppercaseFilter(Filter):
@register_filter("map")
class MapFilter(Filter):
def __init__(self, mapping_dict: dict = None, default_value=None) -> None:
super().__init__()
"""
Initializes the MapFilter with a given mapping dictionary and default value.
......@@ -60,9 +55,6 @@ class MapFilter(Filter):
@register_filter("format_span")
class SPANFilter(Filter):
def __init__(self) -> None:
pass
def apply(self, resps, docs):
def format_ner_text(text):
label_dict = {
......
from . import (
anthropic_llms,
api_models,
dummy,
gguf,
hf_audiolm,
hf_steered,
hf_vlms,
huggingface,
ibm_watsonx_ai,
mamba_lm,
nemo_lm,
neuron_optimum,
openai_completions,
optimum_ipex,
optimum_lm,
sglang_causallms,
sglang_generate_API,
textsynth,
vllm_causallms,
vllm_vlms,
)
# TODO: implement __all__
# Models are now lazily loaded via the registry system
# No need to import them all at once
# Define model mappings for lazy registration
MODEL_MAPPING = {
"anthropic-completions": "lm_eval.models.anthropic_llms:AnthropicLM",
"anthropic-chat": "lm_eval.models.anthropic_llms:AnthropicChatLM",
"anthropic-chat-completions": "lm_eval.models.anthropic_llms:AnthropicCompletionsLM",
"local-completions": "lm_eval.models.openai_completions:LocalCompletionsAPI",
"local-chat-completions": "lm_eval.models.openai_completions:LocalChatCompletion",
"openai-completions": "lm_eval.models.openai_completions:OpenAICompletionsAPI",
"openai-chat-completions": "lm_eval.models.openai_completions:OpenAIChatCompletion",
"dummy": "lm_eval.models.dummy:DummyLM",
"gguf": "lm_eval.models.gguf:GGUFLM",
"ggml": "lm_eval.models.gguf:GGUFLM",
"hf-audiolm-qwen": "lm_eval.models.hf_audiolm:HFAudioLM",
"steered": "lm_eval.models.hf_steered:SteeredHF",
"hf-multimodal": "lm_eval.models.hf_vlms:HFMultimodalLM",
"hf-auto": "lm_eval.models.huggingface:HFLM",
"hf": "lm_eval.models.huggingface:HFLM",
"huggingface": "lm_eval.models.huggingface:HFLM",
"watsonx_llm": "lm_eval.models.ibm_watsonx_ai:IBMWatsonxAI",
"mamba_ssm": "lm_eval.models.mamba_lm:MambaLMWrapper",
"nemo_lm": "lm_eval.models.nemo_lm:NeMoLM",
"neuronx": "lm_eval.models.neuron_optimum:NeuronModelForCausalLM",
"ipex": "lm_eval.models.optimum_ipex:IPEXForCausalLM",
"openvino": "lm_eval.models.optimum_lm:OptimumLM",
"sglang": "lm_eval.models.sglang_causallms:SGLANG",
"sglang-generate": "lm_eval.models.sglang_generate_API:SGAPI",
"textsynth": "lm_eval.models.textsynth:TextSynthLM",
"vllm": "lm_eval.models.vllm_causallms:VLLM",
"vllm-vlm": "lm_eval.models.vllm_vlms:VLLM_VLM",
}
# Register all models lazily
def _register_all_models():
"""Register all known models lazily in the registry."""
from lm_eval.api.registry import model_registry
for name, path in MODEL_MAPPING.items():
# Only register if not already present (avoids conflicts when modules are imported)
if name not in model_registry:
# Register the lazy placeholder using lazy parameter
model_registry.register(name, lazy=path)
# Call registration on module import
_register_all_models()
__all__ = ["MODEL_MAPPING"]
try:
......
from __future__ import annotations
import abc
import asyncio
import copy
......@@ -8,16 +10,9 @@ from functools import cached_property
from typing import (
TYPE_CHECKING,
Any,
Awaitable,
Callable,
Dict,
Iterable,
List,
Literal,
NamedTuple,
Optional,
Tuple,
Union,
)
......@@ -36,18 +31,21 @@ from importlib.util import find_spec
from io import BytesIO
from lm_eval import utils
from lm_eval.api.instance import Instance
from lm_eval.api.model import TemplateLM
from lm_eval.models.utils import Collator, chunks, configure_pad_token
if TYPE_CHECKING:
from collections.abc import Awaitable, Iterable
from PIL import Image
from lm_eval.api.instance import Instance
eval_logger = logging.getLogger(__name__)
LogLikelihoodInputs = Tuple[Tuple[str, str], List[int], List[int]]
LogLikelihoodInputs = tuple[tuple[str, str], list[int], list[int]]
# utility class to keep track of json encoded chats
......@@ -58,9 +56,7 @@ class JsonChatStr(NamedTuple):
return self.prompt.encode(encoding)
def create_image_prompt(
imgs: list["Image.Image"], chat: dict, fmt: str = "PNG"
) -> dict:
def create_image_prompt(imgs: list[Image.Image], chat: dict, fmt: str = "PNG") -> dict:
"""
Parameters
......@@ -109,33 +105,32 @@ class TemplateAPI(TemplateLM):
model: str = None,
pretrained: str = None, # `model` takes precedence over `pretrained` when passed.
base_url: str = None,
tokenizer: Optional[str] = None,
tokenizer: str | None = None,
# Loglikelihood tasks require a tokenizer to calculate context lengths,
# however the requests can be sent as a string if the API doesn't support token inputs.
# use tokenized_requests=False
tokenizer_backend: Optional[
Literal["tiktoken", "huggingface", "None", "none"]
] = "huggingface",
tokenizer_backend: Literal["tiktoken", "huggingface", "None", "none"]
| None = "huggingface",
truncate: bool = False,
# number of concurrent requests. More useful if not batching
num_concurrent: int = 1,
max_retries: int = 3,
max_gen_toks: int = 256,
batch_size: Union[str, int] = 1,
batch_size: str | int = 1,
seed: int = 1234,
max_length: Optional[int] = 2048,
max_length: int | None = 2048,
add_bos_token: bool = False,
custom_prefix_token_id: int = None,
# send the requests as tokens or strings
tokenized_requests: bool = True,
trust_remote_code: bool = False,
revision: Optional[str] = "main",
revision: str | None = "main",
use_fast_tokenizer: bool = True,
verify_certificate: bool = True,
eos_string: str = None,
# timeout in seconds
timeout: int = 300,
header: Optional[Dict[str, str]] = None,
header: dict[str, str] | None = None,
max_images: int = 1,
**kwargs,
) -> None:
......@@ -232,12 +227,12 @@ class TemplateAPI(TemplateLM):
@abc.abstractmethod
def _create_payload(
self,
messages: Union[List[List[int]], List[dict], List[str], str],
messages: list[list[int]] | list[dict] | list[str] | str,
*,
generate: bool = True,
gen_kwargs: Optional[dict] = None,
gen_kwargs: dict | None = None,
seed: int = 1234,
eos: str = None,
eos: str | None = None,
**kwargs,
) -> dict:
"""This method is responsible for creating the json payload that will be sent to the API."""
......@@ -245,9 +240,9 @@ class TemplateAPI(TemplateLM):
def create_message(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
messages: list[list[int]] | list[str] | list[JsonChatStr],
generate=False,
) -> Union[List[List[int]], List[dict], List[str], str]:
) -> list[list[int]] | list[dict] | list[str] | str:
"""Helper method to transform the prompt into the expected API input format. messages consist of batched requests"""
if isinstance(messages[0], JsonChatStr):
# for chat completions we need to decode the json string to list[dict,...]
......@@ -276,17 +271,17 @@ class TemplateAPI(TemplateLM):
@staticmethod
@abc.abstractmethod
def parse_logprobs(
outputs: Union[Any, List[Any]],
tokens: List[List[int]] = None,
ctxlen: List[int] = None,
outputs: Any | list[Any],
tokens: list[list[int]] | None = None,
ctxlen: list[int] | None = None,
**kwargs,
) -> List[Tuple[float, bool]]:
) -> list[tuple[float, bool]]:
"""Method used to parse the logprobs from the (batched) API response. This method should return a list of tuples"""
raise NotImplementedError
@staticmethod
@abc.abstractmethod
def parse_generations(outputs: Union[Any, List[Any]], **kwargs) -> List[str]:
def parse_generations(outputs: Any | list[Any], **kwargs) -> list[str]:
"""Method used to parse the generations from the (batched) API response. This method should return a list of str"""
raise NotImplementedError
......@@ -303,14 +298,15 @@ class TemplateAPI(TemplateLM):
@property
def tokenizer_name(self) -> str:
"""Must be defined for LM subclasses which implement Chat Templating.
Should return the name of the tokenizer or chat template used.
Used only to properly fingerprint caches when requests are being cached with `--cache_requests`, otherwise not used.
"""
return ""
def apply_chat_template(
self, chat_history: List[Dict[str, str]], add_generation_prompt: bool = True
) -> Union[str, JsonChatStr]:
self, chat_history: list[dict[str, str]], add_generation_prompt: bool = True
) -> str | JsonChatStr:
"""Applies a chat template to a list of chat history between user and model."""
if self.tokenizer_backend == "huggingface" and self.tokenized_requests:
return self.tokenizer.apply_chat_template(
......@@ -319,33 +315,32 @@ class TemplateAPI(TemplateLM):
add_generation_prompt=add_generation_prompt,
continue_final_message=not add_generation_prompt,
)
else:
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
# bit of a hack. We'll load back before sending to the API
return JsonChatStr(
json.dumps(
[{**item, "type": "text"} for item in chat_history],
ensure_ascii=False,
)
)
@cached_property
def eot_token_id(self) -> Optional[int]:
def eot_token_id(self) -> int | None:
if self.tokenizer is None:
return None
else:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token_id
elif self.tokenizer_backend == "tiktoken":
if self.tokenizer_backend == "tiktoken":
return self.tokenizer.eot_token
@cached_property
def eos_string(self) -> Optional[str]:
def eos_string(self) -> str | None:
if self._eos_string:
return self._eos_string
elif self.tokenizer is not None:
if self.tokenizer is not None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.eos_token
elif self.tokenizer_backend == "tiktoken":
if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode([self.tokenizer.eot_token])
else:
eval_logger.warning(
......@@ -354,7 +349,7 @@ class TemplateAPI(TemplateLM):
return None
@cached_property
def prefix_token_id(self) -> Optional[int]:
def prefix_token_id(self) -> int | None:
if self.tokenizer is None:
return None
else:
......@@ -364,24 +359,24 @@ class TemplateAPI(TemplateLM):
if self.tokenizer.bos_token_id is not None:
return self.tokenizer.bos_token_id
return self.tokenizer.eos_token_id
else:
return self.tokenizer.eot_token
return self.tokenizer.eot_token
def tok_encode(
self,
string: str,
left_truncate_len: int = None,
left_truncate_len: int | None = None,
add_special_tokens: bool = False,
truncation: bool = False,
**kwargs,
) -> Union[List[List[int]], List[int], List[str]]:
) -> list[list[int]] | list[int] | list[str]:
if self.tokenizer_backend is None:
return [string]
elif self.tokenizer_backend == "huggingface":
if self.tokenizer_backend == "huggingface":
# by default for CausalLM - false or self.add_bos_token is set
if not add_special_tokens:
add_special_tokens = False or self.add_bos_token
encoding: Union[List[List[int]], List[int]] = self.tokenizer(
encoding: list[list[int]] | list[int] = self.tokenizer(
string,
add_special_tokens=add_special_tokens,
truncation=truncation,
......@@ -404,20 +399,20 @@ class TemplateAPI(TemplateLM):
encoding = self.tokenizer.encode_batch(string)
return encoding
def decode_batch(self, tokens: List[List[int]]) -> List[str]:
def decode_batch(self, tokens: list[list[int]]) -> list[str] | None:
if self.tokenizer_backend == "huggingface":
return self.tokenizer.batch_decode(tokens)
elif self.tokenizer_backend == "tiktoken":
if self.tokenizer_backend == "tiktoken":
return self.tokenizer.decode_batch(tokens)
def model_call(
self,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
messages: list[list[int]] | list[str] | list[JsonChatStr],
*,
generate: bool = True,
gen_kwargs: Optional[Dict] = None,
gen_kwargs: dict | None = None,
**kwargs,
) -> Optional[dict]:
) -> dict | None:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
try:
......@@ -441,7 +436,7 @@ class TemplateAPI(TemplateLM):
response.raise_for_status()
return response.json()
except RetryError:
eval_logger.error(
eval_logger.exception(
"API request failed after multiple retries. Please check the API status."
)
return None
......@@ -450,14 +445,14 @@ class TemplateAPI(TemplateLM):
self,
session: ClientSession,
sem: asyncio.Semaphore,
messages: Union[List[List[int]], List[str], List[JsonChatStr]],
messages: list[list[int]] | list[str] | list[JsonChatStr],
*,
generate: bool = True,
cache_keys: list = None,
ctxlens: Optional[List[int]] = None,
gen_kwargs: Optional[Dict] = None,
cache_keys: list | None = None,
ctxlens: list[int] | None = None,
gen_kwargs: dict | None = None,
**kwargs,
) -> Union[List[str], List[Tuple[float, bool]], None]:
) -> list[str] | list[tuple[float, bool]] | None:
# !!! Copy: shared dict for each request, need new object !!!
gen_kwargs = copy.deepcopy(gen_kwargs)
payload = self._create_payload(
......@@ -508,8 +503,8 @@ class TemplateAPI(TemplateLM):
sem.release()
def batch_loglikelihood_requests(
self, chunks: Iterable[List[LogLikelihoodInputs]]
) -> Tuple[List[List[int]], List[int], List[Tuple[str, str]]]:
self, chunks: Iterable[list[LogLikelihoodInputs]]
) -> tuple[list[list[int]], list[int], list[tuple[str, str]]]:
inputs = []
ctxlens = []
cache_keys = []
......@@ -536,9 +531,9 @@ class TemplateAPI(TemplateLM):
cache_keys: list,
*,
generate: bool = True,
ctxlens: List[int] = None,
ctxlens: list[int] | None = None,
**kwargs,
) -> Union[List[List[str]], List[List[Tuple[float, bool]]]]:
) -> list[list[str]] | list[list[tuple[float, bool]]]:
ctxlens = ctxlens if ctxlens else [None] * len(requests)
conn = TCPConnector(limit=self._concurrent, ssl=self.verify_certificate)
sem = asyncio.Semaphore(self._concurrent)
......@@ -575,14 +570,14 @@ class TemplateAPI(TemplateLM):
return await tqdm_asyncio.gather(*tasks, desc="Requesting API")
def _loglikelihood_tokens(self, requests, **kwargs) -> List[Tuple[float, bool]]:
def _loglikelihood_tokens(self, requests, **kwargs) -> list[tuple[float, bool]]:
assert self.tokenizer is not None, (
"Tokenizer is required for loglikelihood tasks to compute context lengths."
)
res = []
def _collate(req: LogLikelihoodInputs):
"""Defines the key for the sorted method"""
"""Defines the key for the sorted method."""
# the negative sign on len(toks) sorts descending - this has a few advantages:
# - time estimates will always be over not underestimates, which is more useful for planning
# - to know the size of a batch when going through the list, you know the first one is always the batch
......@@ -639,8 +634,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res)
def generate_until(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[str]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[str]:
res = []
def _collate_gen(_requests):
......@@ -773,8 +768,8 @@ class TemplateAPI(TemplateLM):
return re_ord.get_original(res)
def loglikelihood_rolling(
self, requests: List[Instance], disable_tqdm: bool = False
) -> List[float]:
self, requests: list[Instance], disable_tqdm: bool = False
) -> list[float]:
loglikelihoods = []
for (string,) in tqdm([req.args for req in requests], disable=disable_tqdm):
......
from collections.abc import Generator
from contextlib import contextmanager
from functools import partial
from pathlib import Path
from typing import Any, Callable, Generator, Optional, Union
from typing import Any, Callable, Optional, Union
import torch
from peft.peft_model import PeftModel
......@@ -71,13 +72,6 @@ class SteeredModel(HFLM):
"""
HFLM with a steered forward pass.
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,layer_20/width_16k/canonical,increase dogs,
To load steering vectors directly, provide the path to a pytorch (.pt) file with content in the following format:
{
......@@ -86,9 +80,17 @@ class SteeredModel(HFLM):
"steering_coefficient": <float>,
"action": <Literal["add", "clamp"]>,
"bias": <torch.Tensor | None>,
"head_index": <int | None>,
},
...
}
To derive steering vectors from a sparse model loadable with sparsify or sae_lens,
provide the path to a CSV file with the following columns (example rows are provided below):
loader,action,sparse_model,hookpoint,feature_index,steering_coefficient,head_index,sae_id,description,
sparsify,add,EleutherAI/sae-pythia-70m-32k,layers.3,30,10.0,,,,
sae_lens,add,gemma-scope-2b-pt-res-canonical,layers.20,12082,240.0,,layer_20/width_16k/canonical,increase dogs,
"""
super().__init__(pretrained=pretrained, device=device, **kwargs)
......@@ -105,27 +107,31 @@ class SteeredModel(HFLM):
hook_to_steer = {}
for hookpoint, steer_info in steer_config.items():
action = steer_info["action"]
steering_coefficient = steer_info["steering_coefficient"]
steering_vector = (
steer_info["steering_vector"].to(self.device).to(self.model.dtype)
)
bias = (
steer_info["bias"].to(self.device).to(self.model.dtype)
if steer_info["bias"] is not None
else None
)
steering_coefficient = float(steer_info.get("steering_coefficient", 1.0))
head_index = steer_info.get("head_index", None)
bias = steer_info.get("bias", None)
if bias is not None:
bias = bias.to(self.device).to(self.model.dtype)
if action == "add":
# Steers the model by adding some multiple of a steering vector to all sequence positions.
hook_to_steer[hookpoint] = (
lambda acts: acts + steering_coefficient * steering_vector
# Steer the model by adding a multiple of a steering vector to all sequence positions.
assert bias is None, "Bias is not supported for the `add` action."
hook_to_steer[hookpoint] = partial(
self.add,
vector=steering_vector * steering_coefficient,
head_index=head_index,
)
elif action == "clamp":
# Steer the model by clamping the activations to a value in the direction of the steering vector.
hook_to_steer[hookpoint] = partial(
self.clamp,
steering_vector=steering_vector,
direction=steering_vector / torch.norm(steering_vector),
value=steering_coefficient,
bias=bias,
head_index=head_index,
)
else:
raise ValueError(f"Unknown hook type: {action}")
......@@ -195,34 +201,62 @@ class SteeredModel(HFLM):
return steer_data
@classmethod
def add(
cls,
acts: Tensor,
vector: Tensor,
head_index: Optional[int],
):
"""Adds the given vector to the activations.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
vector (Tensor): A vector to add of shape [features]
head_index (int | None): Optional attention head index to add to
"""
if head_index is not None:
acts[:, :, head_index, :] = acts[:, :, head_index, :] + vector
else:
acts = acts + vector
return acts
@classmethod
def clamp(
cls,
acts: Tensor,
steering_vector: Tensor,
direction: Tensor,
value: float,
head_index: Optional[int],
bias: Optional[Tensor] = None,
):
"""Clamps a direction of the activations to be the steering vector * the value.
"""Clamps the activations to a given value in a specified direction. The direction
must be a unit vector.
Args:
acts (Tensor): The activations tensor to edit of shape [batch, pos, features]
steering_vector (Tensor): A direction to clamp of shape [features]
acts (Tensor): The activations tensor to edit of shape [batch, pos, ..., features]
direction (Tensor): A direction to clamp of shape [features]
value (float): Value to clamp the direction to
head_index (int | None): Optional attention head index to clamp
bias (Tensor | None): Optional bias to add to the activations
Returns:
Tensor: The modified activations with the specified direction clamped
"""
if bias is not None:
acts = acts - bias
direction = steering_vector / torch.norm(steering_vector)
proj_magnitude = torch.sum(acts * direction, dim=-1, keepdim=True)
orthogonal_component = acts - proj_magnitude * direction
if head_index is not None:
x = acts[:, :, head_index, :]
proj = (x * direction).sum(dim=-1, keepdim=True)
assert proj == acts @ direction
clamped = orthogonal_component + direction * value
clamped = acts.clone()
clamped[:, :, head_index, :] = x + direction * (value - proj)
else:
proj = torch.sum(acts * direction, dim=-1, keepdim=True)
clamped = acts + direction * (value - proj)
if bias is not None:
return clamped + bias
......
......@@ -124,14 +124,22 @@ class HFLM(TemplateLM):
assert isinstance(pretrained, str)
assert isinstance(batch_size, (int, str))
gpus = torch.cuda.device_count()
accelerator_kwargs = InitProcessGroupKwargs(timeout=timedelta(weeks=52))
accelerator = Accelerator(kwargs_handlers=[accelerator_kwargs])
if accelerator.num_processes > 1:
self.accelerator = accelerator
if "npu" in accelerator.device.type:
# Detect device count based on accelerator device type
device_type = accelerator.device.type
if "cuda" in device_type:
gpus = torch.cuda.device_count()
elif "npu" in device_type:
gpus = torch.npu.device_count()
elif "xpu" in device_type:
gpus = torch.xpu.device_count()
else:
# Fallback to CUDA count for compatibility
gpus = torch.cuda.device_count()
# using one process with no model parallelism
if not (parallelize or accelerator.num_processes > 1):
......@@ -141,6 +149,7 @@ class HFLM(TemplateLM):
+ [f"cuda:{i}" for i in range(gpus)]
+ ["mps", "mps:0"]
+ [f"npu:{i}" for i in range(gpus)]
+ [f"xpu:{i}" for i in range(gpus)]
)
if device and device in device_list:
self._device = torch.device(device)
......@@ -673,17 +682,25 @@ class HFLM(TemplateLM):
)
if peft:
from peft import PeftModel
from peft import __version__ as PEFT_VERSION
from peft import PeftModel, __version__ as PEFT_VERSION
if model_kwargs.get("load_in_4bit") and vparse(PEFT_VERSION) < vparse(
"0.4.0"
):
raise AssertionError("load_in_4bit requires peft >= 0.4.0")
if self._model.config.vocab_size != len(self.tokenizer):
# Compatible with Gemma3 (multimodal) and old models
if hasattr(self._model.config, "text_config") and hasattr(
self._model.config.text_config, "vocab_size"
):
vocab_size = self._model.config.text_config.vocab_size
else:
vocab_size = self._model.config.vocab_size
if vocab_size != len(self.tokenizer):
# resize model for LoRAs with added tokens
eval_logger.info(
f"Model config indicates vocab_size='{self._model.config.vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
f"Model config indicates vocab_size='{vocab_size}', but found tokenizer with vocab size '{len(self.tokenizer)}'. Resizing model embedding layer..."
)
self._model.resize_token_embeddings(len(self.tokenizer))
self._model = PeftModel.from_pretrained(
......
......@@ -3,7 +3,7 @@ import json
import logging
import os
import warnings
from functools import lru_cache
from functools import cache
from typing import Any, Dict, List, NamedTuple, Optional, Tuple, Type, cast
from tqdm import tqdm
......@@ -69,7 +69,7 @@ def _verify_credentials(creds: dict) -> None:
raise ValueError(error_msg)
@lru_cache(maxsize=None)
@cache
def get_watsonx_credentials() -> Dict[str, str]:
"""
Retrieves Watsonx API credentials from environmental variables.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment