Commit d19bd889 authored by Baber's avatar Baber
Browse files

improve metric aggregation default and higher-better checks; add `TaskConfig.from_template`

parent 69d14fb3
......@@ -167,20 +167,24 @@ def get_aggregation(name: str) -> Callable[..., Any] | None:
eval_logger.warning(f"{name} not a registered aggregation metric!")
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable]] | None:
def get_metric_aggregation(name: str) -> Callable[[], dict[str, Callable[..., Any]]]:
try:
return METRIC_AGGREGATION_REGISTRY[name]
except KeyError:
eval_logger.warning(f"{name} metric is not assigned a default aggregation!")
eval_logger.warning(
f"{name} metric is not assigned a default aggregation!. Using default aggregation mean"
)
return AGGREGATION_REGISTRY["mean"]
def is_higher_better(metric_name: str) -> bool | None:
def is_higher_better(metric_name: str) -> bool:
try:
return HIGHER_IS_BETTER_REGISTRY[metric_name]
except KeyError:
eval_logger.warning(
f"higher_is_better not specified for metric '{metric_name}'!"
f"higher_is_better not specified for metric '{metric_name}'!. Will default to True."
)
return True
def register_filter(name: str):
......
......@@ -14,6 +14,7 @@ from lm_eval.config.utils import maybe_serialize
if TYPE_CHECKING:
from lm_eval.api.samplers import ContextSampler
from lm_eval.api.task import Task
from lm_eval.config.template import TemplateConfig
eval_logger = logging.getLogger(__name__)
......@@ -119,7 +120,7 @@ class FewshotConfig:
@dataclass
class TaskConfig(dict):
class TaskConfig:
# task naming/registry
task: str | None = None
task_alias: str | None = None
......@@ -240,7 +241,7 @@ class TaskConfig(dict):
name=metric_name,
fn=get_metric(metric_name),
aggregation_fn=get_metric_aggregation(metric_name),
higher_is_better=is_higher_better(metric_name),
higher_is_better=is_higher_better(metric_name) or True,
)
for metric_name in _metric_list
)
......@@ -357,6 +358,70 @@ class TaskConfig(dict):
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
@classmethod
def from_template(cls, template: TemplateConfig, **kwargs) -> TaskConfig:
"""Create a TaskConfig instance from a template.
Args:
template: TemplateConfig instance (MCQTemplateConfig or ClozeTemplateConfig)
**kwargs: Additional arguments to override template defaults
Returns:
TaskConfig instance configured from the template
"""
from lm_eval.config.template import (
ClozeTemplateConfig,
MCQTemplateConfig,
)
# Extract base configuration from template
config_dict = {
"task": template.task,
"doc_to_text": template.doc_to_text,
"doc_to_choice": template.doc_to_choice,
"doc_to_target": template.doc_to_target,
"description": template.description,
"target_delimiter": template.target_delimiter,
"fewshot_delimiter": template.fewshot_delimiter,
"metric_list": template.metric_list,
}
# Add common template attributes if they exist
if hasattr(template, "answer_suffix"):
config_dict["target_delimiter"] = (
template.answer_suffix + template.target_delimiter
)
# Handle template-specific configurations
if isinstance(template, MCQTemplateConfig):
# For MCQ templates, set up multiple choice specific config
config_dict["output_type"] = "multiple_choice"
# MCQ templates typically use accuracy metrics
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}]
elif isinstance(template, ClozeTemplateConfig):
# For Cloze templates, set up generation config
config_dict["output_type"] = "generate_until"
# Cloze templates typically use accuracy and normalized accuracy
if template.metric_list is None:
config_dict["metric_list"] = [{"metric": "acc"}, {"metric": "acc_norm"}]
else:
# Generic template - try to infer output type
if hasattr(template, "template"):
if template.template == "mcq":
config_dict["output_type"] = "multiple_choice"
elif template.template == "cloze":
config_dict["output_type"] = "generate_until"
# Override with any user-provided kwargs
config_dict.update(kwargs)
# Create and return TaskConfig instance
return cls(**config_dict)
def __getitem__(self, item):
return getattr(self, item)
......
from __future__ import annotations
from abc import ABC, abstractmethod
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Callable
from lm_eval.config.utils import create_mc_choices
if TYPE_CHECKING:
from lm_eval.config.metric import MetricConfig
@dataclass
class TemplateConfig:
class TemplateConfig(ABC):
"""Encapsulates information about a template."""
#
template: str
task: str
doc_to_text: str | Callable[[dict], str]
doc_to_choice: str | list | Callable[[dict], list]
doc_to_target: int | Callable[[dict], int]
......@@ -29,9 +34,22 @@ class TemplateConfig:
default_factory=lambda: ["acc", "acc_norm"]
)
@abstractmethod
def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text."""
raise NotImplementedError
def _doc_to_choice(self, doc: dict) -> str:
"""Convert a document to choices."""
raise NotImplementedError
def _doc_to_target(self, doc: dict) -> int | str:
"""Convert a document to target."""
raise NotImplementedError
@dataclass
class MCQTemplateConfig:
class MCQTemplateConfig(TemplateConfig):
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
......@@ -56,9 +74,36 @@ class MCQTemplateConfig:
fewshot_delimiter: str = "\n\n"
metric_list: list[MetricConfig] | None = field(default_factory=lambda: ["acc"])
def _doc_to_text(self, doc: dict) -> str:
"""Convert a document to text."""
doc_to_text = (
self.doc_to_text
if isinstance(self.doc_to_text, str)
else self.doc_to_text(doc)
)
return self.context_prefix + doc_to_text
def _doc_to_choice(self, doc: dict) -> str:
if callable(self.doc_to_choice):
doc_to_choice = self.doc_to_choice(doc)
elif isinstance(self.doc_to_choice, str):
doc_to_choice = doc[self.doc_to_choice]
else:
doc_to_choice = self.doc_to_choice
return create_mc_choices(doc_to_choice, choice_delimiter=self.choice_delimiter)
def _doc_to_target(self, doc: dict) -> int:
"""Convert a document to target."""
if callable(self.doc_to_target):
return self.doc_to_target(doc)
elif isinstance(self.doc_to_target, str):
return doc[self.doc_to_target]
else:
return self.doc_to_target
@dataclass
class ClozeTemplateConfig:
class ClozeTemplateConfig(TemplateConfig):
"""Encapsulates information about a template.
Would return a sample with the following format:
Question: <doc_to_text(doc)>
......
......@@ -28,3 +28,16 @@ def maybe_serialize(val: Callable | Any, keep_callable=False) -> Callable | Any:
return (
serialize_callable(val, keep_callable=keep_callable) if callable(val) else val
)
def create_mc_choices(choices: list[str], choice_delimiter: str | None = "\n") -> str:
"""Creates a multiple-choice question format from a list of choices."""
if len(choices) < 2:
raise ValueError(
"At least two choices are required for a multiple-choice question."
)
if choice_delimiter is None:
choice_delimiter = "\n"
formatted_choices = [f"{chr(65 + i)}. {choice}" for i, choice in enumerate(choices)]
return choice_delimiter.join(formatted_choices)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment