Commit 930b4253 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into lazy_reg

# Conflicts:
#	lm_eval/__init__.py
#	lm_eval/api/metrics.py
#	lm_eval/api/registry.py
#	lm_eval/api/task.py
#	lm_eval/filters/__init__.py
#	pyproject.toml
parents d547b663 73202a2e
import json
import logging
import textwrap
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import yaml
from lm_eval.utils import simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.tasks import TaskManager
eval_logger = logging.getLogger(__name__)
DICT_KEYS = [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
"gen_kwargs",
]
@dataclass
class EvaluatorConfig:
"""Configuration for language model evaluation runs.
This dataclass contains all parameters for configuring model evaluations via
`simple_evaluate()` or the CLI. It supports initialization from:
- CLI arguments (via `from_cli()`)
- YAML configuration files (via `from_config()`)
- Direct instantiation with keyword arguments
The configuration handles argument parsing, validation, and preprocessing
to ensure properly structured and validated.
Example:
# From CLI arguments
config = EvaluatorConfig.from_cli(args)
# From YAML file
config = EvaluatorConfig.from_config("eval_config.yaml")
# Direct instantiation
config = EvaluatorConfig(
model="hf",
model_args={"pretrained": "gpt2"},
tasks=["hellaswag", "arc_easy"],
num_fewshot=5
)
See individual field documentation for detailed parameter descriptions.
"""
# Core evaluation parameters
config: Optional[str] = field(
default=None, metadata={"help": "Path to YAML config file"}
)
model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
model_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for model initialization"}
)
tasks: Union[str, list[str]] = field(
default_factory=list,
metadata={"help": "Comma-separated list of task names to evaluate"},
)
# Few-shot and batching
num_fewshot: Optional[int] = field(
default=None, metadata={"help": "Number of examples in few-shot context"}
)
batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
max_batch_size: Optional[int] = field(
default=None, metadata={"help": "Maximum batch size for auto batching"}
)
# Device
device: Optional[str] = field(
default="cuda:0", metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
)
# Data sampling and limiting
limit: Optional[float] = field(
default=None, metadata={"help": "Limit number of examples per task"}
)
samples: Union[str, dict, None] = field(
default=None,
metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
)
# Caching
use_cache: Optional[str] = field(
default=None,
metadata={"help": "Path to sqlite db file for caching model outputs"},
)
cache_requests: dict = field(
default_factory=dict,
metadata={"help": "Cache dataset requests: true/refresh/delete"},
)
# Output and logging flags
check_integrity: bool = field(
default=False, metadata={"help": "Run test suite for tasks"}
)
write_out: bool = field(
default=False, metadata={"help": "Print prompts for first few documents"}
)
log_samples: bool = field(
default=False, metadata={"help": "Save model outputs and inputs"}
)
output_path: Optional[str] = field(
default=None, metadata={"help": "Dir path where result metrics will be saved"}
)
predict_only: bool = field(
default=False,
metadata={
"help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
},
)
# Chat and instruction handling
system_instruction: Optional[str] = field(
default=None, metadata={"help": "Custom System instruction to add"}
)
apply_chat_template: Union[bool, str] = field(
default=False,
metadata={
"help": "Apply chat template to prompt. Either True, or a string identifying the tokenizer template."
},
)
fewshot_as_multiturn: bool = field(
default=False,
metadata={
"help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
},
)
# Configuration display
show_config: bool = field(
default=False, metadata={"help": "Show full config at end of evaluation"}
)
# External tasks and generation
include_path: Optional[str] = field(
default=None, metadata={"help": "Additional dir path for external tasks"}
)
gen_kwargs: Optional[dict] = field(
default=None, metadata={"help": "Arguments for model generation"}
)
# Logging and verbosity
verbosity: Optional[str] = field(
default=None, metadata={"help": "Logging verbosity level"}
)
# External integrations
wandb_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.init"}
)
wandb_config_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
)
hf_hub_log_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
)
# Reproducibility
seed: list = field(
default_factory=lambda: [0, 1234, 1234, 1234],
metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
)
# Security
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code for HF datasets"}
)
confirm_run_unsafe_code: bool = field(
default=False,
metadata={
"help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
},
)
# Internal metadata
metadata: dict = field(
default_factory=dict,
metadata={"help": "Additional metadata for tasks that require it"},
)
@classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
"""
Build an EvaluationConfig by merging with simple precedence:
CLI args > YAML config > built-in defaults
"""
# Start with built-in defaults
config = asdict(cls())
# Load and merge YAML config if provided
if used_config := hasattr(namespace, "config") and namespace.config:
config.update(cls.load_yaml_config(namespace.config))
# Override with CLI args (only truthy values, exclude non-config args)
excluded_args = {"command", "func"} # argparse internal args
cli_args = {
k: v for k, v in vars(namespace).items() if v and k not in excluded_args
}
config.update(cli_args)
# Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config)
# Create instance and validate
instance = cls(**config)
if used_config:
print(textwrap.dedent(f"""{instance}"""))
instance.configure()
return instance
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
"""
Build an EvaluationConfig from a YAML config file.
Merges with built-in defaults and validates.
"""
# Load YAML config
yaml_config = cls.load_yaml_config(config_path)
# Parse string arguments that should be dictionaries
yaml_config = cls._parse_dict_args(yaml_config)
instance = cls(**yaml_config)
instance.configure()
return instance
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load and validate YAML config file."""
config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path
)
if not config_file.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}")
except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}")
if not isinstance(yaml_data, dict):
raise ValueError(
f"YAML root must be a mapping, got {type(yaml_data).__name__}"
)
return yaml_data
def configure(self) -> None:
"""Validate configuration and preprocess fields after creation."""
self._validate_arguments()
self._process_arguments()
self._set_trust_remote_code()
def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints."""
if self.limit:
eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
# predict_only implies log_samples
if self.predict_only:
self.log_samples = True
# log_samples or predict_only requires output_path
if (self.log_samples or self.predict_only) and not self.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
# fewshot_as_multiturn requires apply_chat_template
if self.fewshot_as_multiturn and self.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
)
# samples and limit are mutually exclusive
if self.samples and self.limit is not None:
raise ValueError("If --samples is not None, then --limit must be None.")
# tasks is required
if self.tasks is None:
raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed."""
if self.samples:
if isinstance(self.samples, dict):
self.samples = self.samples
elif isinstance(self.samples, str):
try:
self.samples = json.loads(self.samples)
except json.JSONDecodeError:
if (samples_path := Path(self.samples)).is_file():
self.samples = json.loads(samples_path.read_text())
# Set up metadata by merging model_args and metadata.
if self.model_args is None:
self.model_args = {}
if self.metadata is None:
self.metadata = {}
self.metadata = self.model_args | self.metadata
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager
# if metadata manually passed use that:
self.metadata = metadata if metadata else self.metadata
# Create task manager with metadata
task_manager = TaskManager(
include_path=self.include_path,
metadata=self.metadata if self.metadata else {},
)
task_names = task_manager.match_tasks(self.tasks)
# Check for any individual task files in the list
for task in [task for task in self.tasks if task not in task_names]:
task_path = Path(task)
if task_path.is_file():
config = utils.load_yaml_config(str(task_path))
task_names.append(config)
# Check for missing tasks
task_missing = [
task for task in self.tasks if task not in task_names and "*" not in task
]
if task_missing:
missing = ", ".join(task_missing)
raise ValueError(f"Tasks not found: {missing}")
# Update tasks with resolved names
self.tasks = task_names
return task_manager
def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
# Add to model_args for the actual model initialization
if self.model_args is None:
self.model_args = {}
self.model_args["trust_remote_code"] = True
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task
from lm_eval.config.template import TemplateConfig
eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass
class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: str | Callable = "pass@N"
kwargs: dict | None = field(default_factory=dict)
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter pipeline."""
name: str
ensemble: FilterEnsemble
metric_list: list[MetricConfig]
@dataclass
class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
if self.samples is not None and not (
isinstance(self.samples, list) or callable(self.samples)
):
raise TypeError(
"samples must be either list[dict] or callable returning list[dict]"
)
if self.split is not None and self.samples is not None:
eval_logger.warning(
"Both split and samples are configured; split will take precedence"
)
@property
def has_source(self) -> bool:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
return None
if self.process_docs is not None:
return self.process_docs(raw_docs)
return raw_docs
@property
def get_sampler(self) -> Callable[..., Any] | None:
from lm_eval.api import samplers
if isinstance(self.sampler, str):
return samplers.get_sampler(self.sampler)
elif callable(self.sampler):
return self.sampler
def init_sampler(
self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> ContextSampler:
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
return self.get_sampler(
docs,
task,
rnd=rnd,
fewshot_indices=fewshot_indices
if fewshot_indices
else self.fewshot_indices,
)
@dataclass
class TaskConfig:
# task naming/registry
task: str | None = None
task_alias: str | None = None
tag: str | list | None = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = None
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False
doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict[str, Any] | None = None
# runtime configuration options
num_fewshot: int | None = None
generation_kwargs: dict[str, Any] | None = None
# scoring options
metric_list: list | None = None
output_type: OutputType = "generate_until"
repeats: int = 1
filter_list: list[dict] | None = None
should_decontaminate: bool = False
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
multiple_input: bool = False
metadata: dict = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = field(default_factory=list)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False)
_fn: dict[str, Callable] = field(default_factory=dict)
def __post_init__(self) -> None:
### ---setup generation kwargs--- ###
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg.get("num_fewshot", 0),
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
process_docs=_fewshot_cfg.get("process_docs", None),
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
# if metric_list defined inside a filter, use that; otherwise use the task's metric_list
metric_list = metric_list or self.metric_list
metrics = []
if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name) or True,
)
for metric_name in _metric_list
)
else:
# ---------- 2. Process user-defined metrics from config ----------
for metric_config in metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = metric_name
_metric_fn = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
for m in metrics:
if m not in self._metric_list:
self._metric_list.append(m)
return metrics
@property
def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self._get_metric(metric_list=None),
)
]
else:
def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
x = [
FilterConfig(
name=cfg["name"],
ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self._get_metric(metric_list=cfg.get("metric_list")),
)
for cfg in configs
]
return x
@classmethod
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
fn = {k: v for k, v in data.items() if callable(v)}
return cls(**data, _fn=fn)
@classmethod
def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig:
"""Create a TaskConfig instance from a template.
Args:
template: TemplateConfig instance (MCQTemplateConfig or ClozeTemplateConfig)
**kwargs: Additional arguments to override template defaults
Returns:
TaskConfig instance configured from the template
"""
from lm_eval.config.template import (
ClozeTemplateConfig,
MCQTemplateConfig,
)
# Extract base configuration from template
config_dict = {
"task": template.task,
"doc_to_text": template.doc_to_text,
"doc_to_choice": template.doc_to_choice,
"doc_to_target": template.doc_to_target,
"description": template.description,
"target_delimiter": template.target_delimiter,
"fewshot_delimiter": template.fewshot_delimiter,
"metric_list": template.metric_list,
}
# Add common template attributes if they exist
if hasattr(template, "answer_suffix"):
config_dict["target_delimiter"] = (
template.answer_suffix + template.target_delimiter
)
# Handle template-specific configurations
if isinstance(template, MCQTemplateConfig):
# For MCQ templates, set up multiple choice specific config
config_dict["output_type"] = "multiple_choice"
# MCQ templates typically use accuracy metrics
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}]
elif isinstance(template, ClozeTemplateConfig):
# For Cloze templates, set up generation config
config_dict["output_type"] = "generate_until"
# Cloze templates typically use accuracy and normalized accuracy
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}, {"metric": "acc_norm"}]
else:
# Generic template - try to infer output type
if hasattr(template, "template"):
if template.template == "mcq":
config_dict["output_type"] = "multiple_choice"
elif template.template == "cloze":
config_dict["output_type"] = "generate_until"
# Override with any user-provided kwargs
config_dict.update(kwargs)
# Create and return TaskConfig instance
return cls(**config_dict)
def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x):
if isinstance(x, dict):
return {k: _ser(v) for k, v in x.items()}
if isinstance(x, (list, tuple, set)):
return type(x)(_ser(i) for i in x)
return maybe_serialize(x, keep_callable)
return {k: _ser(v) for k, v in asdict(self).items() if v is not None}
This diff is collapsed.
from __future__ import annotations
from functools import wraps
from inspect import getsource
from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable(
value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., T] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
def create_mc_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a multiple-choice question format from a list of choices."""
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
def create_cloze_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a cloze-style question format from a list of choices."""
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
This diff is collapsed.
This diff is collapsed.
...@@ -12,7 +12,7 @@ from lm_eval.api.metrics import ( ...@@ -12,7 +12,7 @@ from lm_eval.api.metrics import (
pooled_sample_stderr, pooled_sample_stderr,
stderr_for_metric, stderr_for_metric,
) )
from lm_eval.api.task import Task from lm_eval.api.task import ConfigurableTask, Task
from lm_eval.utils import positional_deprecated from lm_eval.utils import positional_deprecated
...@@ -58,7 +58,7 @@ class TaskOutput: ...@@ -58,7 +58,7 @@ class TaskOutput:
group_alias=None, group_alias=None,
is_group=None, is_group=None,
): ):
self.task = task self.task: Union[Task, ConfigurableTask] = task
self.task_config = task_config self.task_config = task_config
self.task_name = task_name self.task_name = task_name
self.group_name = group_name self.group_name = group_name
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter): ...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(**kwargs) super().__init__(**kwargs)
def apply(self, resps, docs): def apply(self, resps, docs):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos ...@@ -73,3 +73,5 @@ HomePage: https://github.com/masakhane-io/masakhane-pos
abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages." abstract = "In this paper, we present AfricaPOS, the largest part-of-speech (POS) dataset for 20 typologically diverse African languages. We discuss the challenges in annotating POS for these languages using the universal dependencies (UD) guidelines. We conducted extensive POS baseline experiments using both conditional random field and several multilingual pre-trained language models. We applied various cross-lingual transfer models trained with data available in the UD. Evaluating on the AfricaPOS dataset, we show that choosing the best transfer language(s) in both single-source and multi-source setups greatly improves the POS tagging performance of the target languages, in particular when combined with parameter-fine-tuning methods. Crucially, transferring knowledge from a language that matches the language family and morphosyntactic properties seems to be more effective for POS tagging in unseen languages."
} }
``` ```
## Changelog
- 2025-07-21: Refactored. Scores should not be affected.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment