Commit 930b4253 authored by Baber's avatar Baber
Browse files

Merge branch 'smolrefact' into lazy_reg

# Conflicts:
#	lm_eval/__init__.py
#	lm_eval/api/metrics.py
#	lm_eval/api/registry.py
#	lm_eval/api/task.py
#	lm_eval/filters/__init__.py
#	pyproject.toml
parents d547b663 73202a2e
import json
import logging
import textwrap
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
import yaml
from lm_eval.utils import simple_parse_args_string
if TYPE_CHECKING:
from lm_eval.tasks import TaskManager
eval_logger = logging.getLogger(__name__)
DICT_KEYS = [
"wandb_args",
"wandb_config_args",
"hf_hub_log_args",
"metadata",
"model_args",
"gen_kwargs",
]
@dataclass
class EvaluatorConfig:
"""Configuration for language model evaluation runs.
This dataclass contains all parameters for configuring model evaluations via
`simple_evaluate()` or the CLI. It supports initialization from:
- CLI arguments (via `from_cli()`)
- YAML configuration files (via `from_config()`)
- Direct instantiation with keyword arguments
The configuration handles argument parsing, validation, and preprocessing
to ensure properly structured and validated.
Example:
# From CLI arguments
config = EvaluatorConfig.from_cli(args)
# From YAML file
config = EvaluatorConfig.from_config("eval_config.yaml")
# Direct instantiation
config = EvaluatorConfig(
model="hf",
model_args={"pretrained": "gpt2"},
tasks=["hellaswag", "arc_easy"],
num_fewshot=5
)
See individual field documentation for detailed parameter descriptions.
"""
# Core evaluation parameters
config: Optional[str] = field(
default=None, metadata={"help": "Path to YAML config file"}
)
model: str = field(default="hf", metadata={"help": "Name of model e.g. 'hf'"})
model_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for model initialization"}
)
tasks: Union[str, list[str]] = field(
default_factory=list,
metadata={"help": "Comma-separated list of task names to evaluate"},
)
# Few-shot and batching
num_fewshot: Optional[int] = field(
default=None, metadata={"help": "Number of examples in few-shot context"}
)
batch_size: int = field(default=1, metadata={"help": "Batch size for evaluation"})
max_batch_size: Optional[int] = field(
default=None, metadata={"help": "Maximum batch size for auto batching"}
)
# Device
device: Optional[str] = field(
default="cuda:0", metadata={"help": "Device to use (e.g. cuda, cuda:0, cpu)"}
)
# Data sampling and limiting
limit: Optional[float] = field(
default=None, metadata={"help": "Limit number of examples per task"}
)
samples: Union[str, dict, None] = field(
default=None,
metadata={"help": "dict, JSON string or path to JSON file with doc indices"},
)
# Caching
use_cache: Optional[str] = field(
default=None,
metadata={"help": "Path to sqlite db file for caching model outputs"},
)
cache_requests: dict = field(
default_factory=dict,
metadata={"help": "Cache dataset requests: true/refresh/delete"},
)
# Output and logging flags
check_integrity: bool = field(
default=False, metadata={"help": "Run test suite for tasks"}
)
write_out: bool = field(
default=False, metadata={"help": "Print prompts for first few documents"}
)
log_samples: bool = field(
default=False, metadata={"help": "Save model outputs and inputs"}
)
output_path: Optional[str] = field(
default=None, metadata={"help": "Dir path where result metrics will be saved"}
)
predict_only: bool = field(
default=False,
metadata={
"help": "Only save model outputs, don't evaluate metrics. Use with log_samples."
},
)
# Chat and instruction handling
system_instruction: Optional[str] = field(
default=None, metadata={"help": "Custom System instruction to add"}
)
apply_chat_template: Union[bool, str] = field(
default=False,
metadata={
"help": "Apply chat template to prompt. Either True, or a string identifying the tokenizer template."
},
)
fewshot_as_multiturn: bool = field(
default=False,
metadata={
"help": "Use fewshot as multi-turn conversation. Requires apply_chat_template=True."
},
)
# Configuration display
show_config: bool = field(
default=False, metadata={"help": "Show full config at end of evaluation"}
)
# External tasks and generation
include_path: Optional[str] = field(
default=None, metadata={"help": "Additional dir path for external tasks"}
)
gen_kwargs: Optional[dict] = field(
default=None, metadata={"help": "Arguments for model generation"}
)
# Logging and verbosity
verbosity: Optional[str] = field(
default=None, metadata={"help": "Logging verbosity level"}
)
# External integrations
wandb_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.init"}
)
wandb_config_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for wandb.config.update"}
)
hf_hub_log_args: dict = field(
default_factory=dict, metadata={"help": "Arguments for HF Hub logging"}
)
# Reproducibility
seed: list = field(
default_factory=lambda: [0, 1234, 1234, 1234],
metadata={"help": "Seeds for random, numpy, torch, fewshot (random)"},
)
# Security
trust_remote_code: bool = field(
default=False, metadata={"help": "Trust remote code for HF datasets"}
)
confirm_run_unsafe_code: bool = field(
default=False,
metadata={
"help": "Confirm understanding of unsafe code risks (for code tasks that executes arbitrary Python)"
},
)
# Internal metadata
metadata: dict = field(
default_factory=dict,
metadata={"help": "Additional metadata for tasks that require it"},
)
@classmethod
def from_cli(cls, namespace: Namespace) -> "EvaluatorConfig":
"""
Build an EvaluationConfig by merging with simple precedence:
CLI args > YAML config > built-in defaults
"""
# Start with built-in defaults
config = asdict(cls())
# Load and merge YAML config if provided
if used_config := hasattr(namespace, "config") and namespace.config:
config.update(cls.load_yaml_config(namespace.config))
# Override with CLI args (only truthy values, exclude non-config args)
excluded_args = {"command", "func"} # argparse internal args
cli_args = {
k: v for k, v in vars(namespace).items() if v and k not in excluded_args
}
config.update(cli_args)
# Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config)
# Create instance and validate
instance = cls(**config)
if used_config:
print(textwrap.dedent(f"""{instance}"""))
instance.configure()
return instance
@classmethod
def from_config(cls, config_path: Union[str, Path]) -> "EvaluatorConfig":
"""
Build an EvaluationConfig from a YAML config file.
Merges with built-in defaults and validates.
"""
# Load YAML config
yaml_config = cls.load_yaml_config(config_path)
# Parse string arguments that should be dictionaries
yaml_config = cls._parse_dict_args(yaml_config)
instance = cls(**yaml_config)
instance.configure()
return instance
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
config[key] = simple_parse_args_string(config[key])
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
"""Load and validate YAML config file."""
config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path
)
if not config_file.is_file():
raise FileNotFoundError(f"Config file not found: {config_path}")
try:
yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}")
except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}")
if not isinstance(yaml_data, dict):
raise ValueError(
f"YAML root must be a mapping, got {type(yaml_data).__name__}"
)
return yaml_data
def configure(self) -> None:
"""Validate configuration and preprocess fields after creation."""
self._validate_arguments()
self._process_arguments()
self._set_trust_remote_code()
def _validate_arguments(self) -> None:
"""Validate configuration arguments and cross-field constraints."""
if self.limit:
eval_logger.warning(
"--limit SHOULD ONLY BE USED FOR TESTING. "
"REAL METRICS SHOULD NOT BE COMPUTED USING LIMIT."
)
# predict_only implies log_samples
if self.predict_only:
self.log_samples = True
# log_samples or predict_only requires output_path
if (self.log_samples or self.predict_only) and not self.output_path:
raise ValueError(
"Specify --output_path if providing --log_samples or --predict_only"
)
# fewshot_as_multiturn requires apply_chat_template
if self.fewshot_as_multiturn and self.apply_chat_template is False:
raise ValueError(
"When `fewshot_as_multiturn` is selected, `apply_chat_template` must be set."
)
# samples and limit are mutually exclusive
if self.samples and self.limit is not None:
raise ValueError("If --samples is not None, then --limit must be None.")
# tasks is required
if self.tasks is None:
raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed."""
if self.samples:
if isinstance(self.samples, dict):
self.samples = self.samples
elif isinstance(self.samples, str):
try:
self.samples = json.loads(self.samples)
except json.JSONDecodeError:
if (samples_path := Path(self.samples)).is_file():
self.samples = json.loads(samples_path.read_text())
# Set up metadata by merging model_args and metadata.
if self.model_args is None:
self.model_args = {}
if self.metadata is None:
self.metadata = {}
self.metadata = self.model_args | self.metadata
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager
# if metadata manually passed use that:
self.metadata = metadata if metadata else self.metadata
# Create task manager with metadata
task_manager = TaskManager(
include_path=self.include_path,
metadata=self.metadata if self.metadata else {},
)
task_names = task_manager.match_tasks(self.tasks)
# Check for any individual task files in the list
for task in [task for task in self.tasks if task not in task_names]:
task_path = Path(task)
if task_path.is_file():
config = utils.load_yaml_config(str(task_path))
task_names.append(config)
# Check for missing tasks
task_missing = [
task for task in self.tasks if task not in task_names and "*" not in task
]
if task_missing:
missing = ", ".join(task_missing)
raise ValueError(f"Tasks not found: {missing}")
# Update tasks with resolved names
self.tasks = task_names
return task_manager
def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
# script--`datasets` gets imported by lm_eval internally before these lines can update the env.
import datasets
datasets.config.HF_DATASETS_TRUST_REMOTE_CODE = True
# Add to model_args for the actual model initialization
if self.model_args is None:
self.model_args = {}
self.model_args["trust_remote_code"] = True
from __future__ import annotations
from collections.abc import Callable, Mapping
from dataclasses import dataclass, field
from functools import cached_property
from typing import Any
@dataclass
class MetricConfig:
"""Encapsulates information about a single metric."""
name: str
fn: Callable
kwargs: Mapping[str, Any] = field(default_factory=dict)
aggregation_fn: Callable | None = None
higher_is_better: bool = True
hf_evaluate: bool = False
is_elementwise: bool = True
@cached_property
def metric_name(self) -> str:
return self.name
@cached_property
def aggregation(self) -> Callable[..., Any] | None:
from lm_eval.api.registry import get_aggregation
if self.aggregation_fn is None:
return get_aggregation(self.name)
return self.aggregation_fn
@cached_property
def _higher_is_better(self) -> bool | None:
from lm_eval.api.registry import is_higher_better
if self.higher_is_better is None:
return is_higher_better(self.name)
return self.higher_is_better
def compute(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**(self.kwargs or {}), **kwargs})
def compute_aggregation(self, *args, **kwargs) -> Any:
"""Computes the aggregation of the metric values."""
if self.aggregation_fn is None:
raise ValueError(f"Aggregation function for {self.name} is not defined.")
return self.aggregation_fn(*args, **kwargs)
from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Any, Callable, Union
import datasets
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
from lm_eval.config.metric import MetricConfig
from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task
from lm_eval.config.template import TemplateConfig
eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass
class RepeatConfig:
"""Encapsulates information about a single repeat."""
repeats: int = 1
metric_fn: str | Callable = "pass@N"
kwargs: dict | None = field(default_factory=dict)
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter pipeline."""
name: str
ensemble: FilterEnsemble
metric_list: list[MetricConfig]
@dataclass
class FewshotConfig:
# hack: this returns task.config.num_fewshot
# to keep in sync as it is runtime-modified
num_fewshot: Callable[[], int]
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], DataSet] | DataSet | None = None
process_docs: Callable[[DataSet], DataSet] | None = None
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
def __post_init__(self) -> None:
if self.samples is not None and not (
isinstance(self.samples, list) or callable(self.samples)
):
raise TypeError(
"samples must be either list[dict] or callable returning list[dict]"
)
if self.split is not None and self.samples is not None:
eval_logger.warning(
"Both split and samples are configured; split will take precedence"
)
@property
def has_source(self) -> bool:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> DataSet | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
return None
if self.process_docs is not None:
return self.process_docs(raw_docs)
return raw_docs
@property
def get_sampler(self) -> Callable[..., Any] | None:
from lm_eval.api import samplers
if isinstance(self.sampler, str):
return samplers.get_sampler(self.sampler)
elif callable(self.sampler):
return self.sampler
def init_sampler(
self, docs: list[dict], task: Task, rnd=None, fewshot_indices=None
) -> ContextSampler:
"""Initialize the sampler with the given documents and task."""
if rnd is None:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
return self.get_sampler(
docs,
task,
rnd=rnd,
fewshot_indices=fewshot_indices
if fewshot_indices
else self.fewshot_indices,
)
@dataclass
class TaskConfig:
# task naming/registry
task: str | None = None
task_alias: str | None = None
tag: str | list | None = None
# HF dataset options.
# which dataset to use,
# and what splits for what purpose
custom_dataset: Callable[..., DataSet] | None = None
dataset_path: str | None = None
dataset_name: str | None = None
dataset_kwargs: dict | None = field(default_factory=dict)
training_split: str | None = None
validation_split: str | None = None
test_split: str | None = None
fewshot_split: str | None = None
# formatting / prompting options.
# see docs/advanced_task_guide.md for more info
process_docs: Callable[[DataSet], DataSet] | None = None
doc_to_text: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_target: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_image: Callable[[dict[str, Any]], Any] | str | None = None
doc_to_audio: Callable[[dict[str, Any]], Any] | str | None = None
unsafe_code: bool = False
doc_to_choice: Callable[[dict[str, Any]], Any] | str | dict | list | None = None
process_results: (
Callable[[dict[str, Any], list[Any]], dict[str, Any]] | str | None
) = None
use_prompt: str | None = None
description: str = ""
target_delimiter: str = " "
fewshot_delimiter: str = "\n\n"
fewshot_config: dict[str, Any] | None = None
# runtime configuration options
num_fewshot: int | None = None
generation_kwargs: dict[str, Any] | None = None
# scoring options
metric_list: list | None = None
output_type: OutputType = "generate_until"
repeats: int = 1
filter_list: list[dict] | None = None
should_decontaminate: bool = False
doc_to_decontamination_query: str | None = None
gen_prefix: str | None = None
multiple_input: bool = False
metadata: dict = field(
default_factory=dict
) # by default, not used in the code. allows for users to pass arbitrary info to tasks
_metric_list: list[MetricConfig] = field(default_factory=list)
_filter_list: list[FilterConfig] = field(default_factory=list)
# ds_cfg: DatasetConfig = field(init=False)
fewshot_cfg: FewshotConfig = field(init=False)
_fn: dict[str, Callable] = field(default_factory=dict)
def __post_init__(self) -> None:
### ---setup generation kwargs--- ###
if self.generation_kwargs is not None:
if self.output_type != "generate_until":
eval_logger.warning(
f"[{self.task}] passed `generation_kwargs`, but not using `output_type: generate_until`!"
)
if "temperature" in self.generation_kwargs:
self.generation_kwargs["temperature"] = float(
self.generation_kwargs["temperature"]
)
if "until" not in self.generation_kwargs:
eval_logger.warning(
f"{self.task}: No `until` specified in `generation_kwargs`! Defaulting to the fewshot_delimiter={repr(self.fewshot_delimiter)}"
)
self.generation_kwargs["until"] = [self.fewshot_delimiter]
else:
if self.output_type == "generate_until":
# ensure that we greedily generate in absence of explicit arguments otherwise
self.generation_kwargs = {
"until": (
None
if self.fewshot_delimiter is None
else [self.fewshot_delimiter]
),
"do_sample": False,
"temperature": 0,
}
eval_logger.warning(
f"{self.task}: No `generation_kwargs` specified in task config, defaulting to {self.generation_kwargs}"
)
# ---setup fewshot config--- #
_fewshot_cfg = self.fewshot_config if self.fewshot_config is not None else {}
self.fewshot_cfg = FewshotConfig(
num_fewshot=lambda: self.num_fewshot or _fewshot_cfg.get("num_fewshot", 0),
split=self.fewshot_split,
sampler=_fewshot_cfg.get("sampler", "default"),
samples=_fewshot_cfg.get("samples", None),
process_docs=_fewshot_cfg.get("process_docs", None),
fewshot_indices=_fewshot_cfg.get("fewshot_indices", None),
)
def _get_metric(self, metric_list: list[dict] | None = None) -> list[MetricConfig]:
from lm_eval.api.registry import (
AGGREGATION_REGISTRY,
DEFAULT_METRIC_REGISTRY,
get_aggregation,
get_metric,
get_metric_aggregation,
is_higher_better,
)
# if metric_list defined inside a filter, use that; otherwise use the task's metric_list
metric_list = metric_list or self.metric_list
metrics = []
if not metric_list:
# ---------- 1. If no metrics defined, use defaults for output type ----------
_metric_list = DEFAULT_METRIC_REGISTRY[self.output_type]
eval_logger.info(
f"No metrics defined in config, using default metrics for {self.output_type}={_metric_list}"
)
metrics.extend(
MetricConfig(
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name) or True,
)
for metric_name in _metric_list
)
else:
# ---------- 2. Process user-defined metrics from config ----------
for metric_config in metric_list:
metric_name = metric_config["metric"]
_metric_fn_kwargs = {
key: metric_config[key]
for key in metric_config
if key
not in ["metric", "aggregation", "higher_is_better", "hf_evaluate"]
}
_hf_evaluate_metric: bool = metric_config.get("hf_evaluate", False)
_metric_fn = None
_aggregation = None
if self.process_results is not None:
# User will compute metrics inside `process_results()`
_metric_name = None
_metric_fn_kwargs = {}
elif callable(metric_name):
# User passed a function object
_metric_name = metric_name.__name__
_metric_fn = metric_name.__call__
else:
# Normal: look up by name
_metric_name = metric_name
_metric_fn = get_metric(metric_name, _hf_evaluate_metric)
# ---------- 3. Decide how to aggregate examples ----------
if "aggregation" in metric_config:
if isinstance(_agg_name := metric_config["aggregation"], str):
_aggregation = get_aggregation(_agg_name)
elif callable(_agg_name): # noqa: E721
_aggregation = metric_config["aggregation"]
else:
INV_AGG_REGISTRY = {v: k for k, v in AGGREGATION_REGISTRY.items()}
_aggregation = get_metric_aggregation(metric_name)
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but aggregation is not. "
f"using default "
f"aggregation={INV_AGG_REGISTRY[_aggregation]}"
)
# ---------- 4. Determine “higher-is-better” semantics ----------
if "higher_is_better" in metric_config:
_higher_is_better = metric_config["higher_is_better"]
else:
eval_logger.warning(
f"[Task: {self.task}] metric {metric_name} is defined, but higher_is_better is not. "
f"using default "
f"higher_is_better={is_higher_better(metric_name)}"
)
_higher_is_better = is_higher_better(metric_name)
metrics.append(
MetricConfig(
name=_metric_name,
fn=_metric_fn,
kwargs=_metric_fn_kwargs,
aggregation_fn=_aggregation,
higher_is_better=_higher_is_better,
hf_evaluate=_hf_evaluate_metric,
)
)
for m in metrics:
if m not in self._metric_list:
self._metric_list.append(m)
return metrics
@property
def get_filters(self) -> list[FilterConfig]:
from lm_eval.filters import build_filter_ensemble
if not self.filter_list:
eval_logger.debug(
"No custom filters defined; falling back to 'take_first' for handling repeats."
)
return [
FilterConfig(
name="none",
ensemble=build_filter_ensemble("none", [("take_first", None)]),
metric_list=self._get_metric(metric_list=None),
)
]
else:
def _strip_fn(d: dict) -> tuple[str, dict]:
return d["function"], {
k: v for k, v in d.items() if k not in ["function", "metric_list"]
}
configs = (
self.filter_list.values()
if isinstance(self.filter_list, dict)
else self.filter_list
)
x = [
FilterConfig(
name=cfg["name"],
ensemble=build_filter_ensemble(
filter_name=cfg["name"],
components=[_strip_fn(f) for f in cfg["filter"]],
),
metric_list=self._get_metric(metric_list=cfg.get("metric_list")),
)
for cfg in configs
]
return x
@classmethod
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
fn = {k: v for k, v in data.items() if callable(v)}
return cls(**data, _fn=fn)
@classmethod
def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig:
"""Create a TaskConfig instance from a template.
Args:
template: TemplateConfig instance (MCQTemplateConfig or ClozeTemplateConfig)
**kwargs: Additional arguments to override template defaults
Returns:
TaskConfig instance configured from the template
"""
from lm_eval.config.template import (
ClozeTemplateConfig,
MCQTemplateConfig,
)
# Extract base configuration from template
config_dict = {
"task": template.task,
"doc_to_text": template.doc_to_text,
"doc_to_choice": template.doc_to_choice,
"doc_to_target": template.doc_to_target,
"description": template.description,
"target_delimiter": template.target_delimiter,
"fewshot_delimiter": template.fewshot_delimiter,
"metric_list": template.metric_list,
}
# Add common template attributes if they exist
if hasattr(template, "answer_suffix"):
config_dict["target_delimiter"] = (
template.answer_suffix + template.target_delimiter
)
# Handle template-specific configurations
if isinstance(template, MCQTemplateConfig):
# For MCQ templates, set up multiple choice specific config
config_dict["output_type"] = "multiple_choice"
# MCQ templates typically use accuracy metrics
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}]
elif isinstance(template, ClozeTemplateConfig):
# For Cloze templates, set up generation config
config_dict["output_type"] = "generate_until"
# Cloze templates typically use accuracy and normalized accuracy
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}, {"metric": "acc_norm"}]
else:
# Generic template - try to infer output type
if hasattr(template, "template"):
if template.template == "mcq":
config_dict["output_type"] = "multiple_choice"
elif template.template == "cloze":
config_dict["output_type"] = "generate_until"
# Override with any user-provided kwargs
config_dict.update(kwargs)
# Create and return TaskConfig instance
return cls(**config_dict)
def to_dict(self, keep_callable: bool = False) -> dict:
def _ser(x):
if isinstance(x, dict):
return {k: _ser(v) for k, v in x.items()}
if isinstance(x, (list, tuple, set)):
return type(x)(_ser(i) for i in x)
return maybe_serialize(x, keep_callable)
return {k: _ser(v) for k, v in asdict(self).items() if v is not None}
This diff is collapsed.
from __future__ import annotations
from functools import wraps
from inspect import getsource
from typing import Any, Callable, TypeVar
T = TypeVar("T")
def serialize_callable(
value: Callable[..., T] | str, keep_callable=False
) -> Callable[..., T] | str:
"""Serializes a given function or string.
If 'keep_callable' is True, the original callable is returned.
Otherwise, attempts to return the source code of the callable using 'getsource'.
If serialization fails, returns the string representation.
"""
if keep_callable:
return value
else:
try:
return getsource(value)
except (TypeError, OSError):
return str(value)
def maybe_serialize(
val: Callable[..., T] | Any, keep_callable=False
) -> Callable[..., T] | Any:
"""Conditionally serializes a value if it is callable."""
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
def create_mc_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a multiple-choice question format from a list of choices."""
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
def create_cloze_choices(choices: list[str], choice_delimiter: str = "\n") -> str:
"""Creates a cloze-style question format from a list of choices."""
def doc_to_closure(fn: Callable[..., T]) -> Callable[..., T]:
"""Closure that allows the function to be called with 'self'."""
@wraps(fn)
def closure(self: Any, *args, **kwargs):
return fn(*args, **kwargs)
return closure
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter): ...@@ -27,7 +27,6 @@ class TakeFirstFilter(Filter):
class TakeKFilter(Filter): class TakeKFilter(Filter):
def __init__(self, **kwargs) -> None: def __init__(self, **kwargs) -> None:
self.k = kwargs.pop("k") self.k = kwargs.pop("k")
super().__init__(**kwargs) super().__init__(**kwargs)
def apply(self, resps, docs): def apply(self, resps, docs):
......
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment