Unverified Commit 003e5852 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Fewshot refactor (#3227)



* overhaul `ContextSampler`

* refactor masakhapos

* move multi_target to `exact_match`

* remove doc_to_choice from `boolq-seq2seq`

* remove doc_to_choice in generation process_results

* Remove unused `doc_to_choice` and fix superglue whitespaces

* require multiple_inputs and multiple_targets to be explicitly set in taskconfig

* fix copa; better logging in task init

* fix doc_to_target to return int rather than str (deprecated)

* fix processing regression; recursively parse lists fron template

* remove redundant jinja parsing logic

* remove promptsource

* for multiple_inputs use `doc_to_text: list[str]``

* Refactor `ContextSampler` `fewshot_context`

* fix multiple_input context

* fix `target_delimiter` with `gen_prefix`

* `doc_to_text` is list for multiple_inputs

* Refactor `count_bytes` and `count_words` methods to `@staticmethod`

* make has_*(train/test/validation) to properties

* remove `multi_target` `generate_until`

* `fix doc_to_target/multiple_targets handling add tests

* rename `multi_target` to `multiple_targets`

* evalaute list when multiple targets

* allow doc_to_target to return list

* Remove gen_prefix space and add warning (#3239)

* Remove gen_prefix space and add warning

* fix null gen_prefix bug again

* use git tests

---------
Co-authored-by: default avatarBoaz Ben-Dov <bendboaz@gmail.com>
parent 79a22a11
......@@ -8,8 +8,6 @@ on:
branches:
- 'main'
pull_request:
branches:
- 'main'
workflow_dispatch:
# Jobs run concurrently and steps run sequentially within a job.
# jobs: linter and cpu_tests. Add more jobs/steps as required.
......
......@@ -53,6 +53,6 @@ class FilterEnsemble:
resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
# has a key ` self.name `: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple
from typing import Any, Literal, Optional
OutputType = Literal[
......@@ -10,10 +10,10 @@ OutputType = Literal[
@dataclass
class Instance:
request_type: OutputType
doc: dict
doc: dict[str, Any]
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
metadata: tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
......
......@@ -213,7 +213,7 @@ def exact_match_hf_evaluate(
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
multi_target: bool = False,
multiple_targets: bool = False,
):
"""
Compute exact match scores between predictions and references.
......@@ -245,8 +245,8 @@ def exact_match_hf_evaluate(
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
assert len(predictions) == len(references) if not multiple_targets else True, (
"predictions and references must have the same length unless `multiple_targets` is True"
)
if regexes_to_ignore is not None:
......@@ -275,7 +275,7 @@ def exact_match_hf_evaluate(
return {
"exact_match": np.mean(score_list)
if not multi_target
if not multiple_targets
else float(np.any(score_list))
}
......
......@@ -220,8 +220,8 @@ class Registry(Generic[T]):
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
ValueError: If alias is already registered with a different target
TypeError: If an object doesn't inherit from base_cls (when specified)
"""
def _store(alias: str, target: T | Placeholder) -> None:
......@@ -229,21 +229,27 @@ class Registry(Generic[T]):
# collision handling ------------------------------------------
if current is not None and current != target:
# allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type):
# mod, _, cls = current.partition(":")
if current == f"{target.__module__}:{target.__name__}":
self._objs[alias] = target
return
# mod, _, cls = current.partition(":")
if (
isinstance(current, str)
and isinstance(target, type)
and current == f"{target.__module__}:{target.__name__}"
):
self._objs[alias] = target
return
raise ValueError(
f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})"
)
# type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
if (
self._base_cls is not None
and isinstance(target, type)
and not issubclass(target, self._base_cls)
):
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type]
......@@ -409,9 +415,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast(
Registry[type[LM]], Registry("model", base_cls=LM)
)
model_registry = cast(Registry[type[LM]], Registry("model", base_cls=LM))
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
......@@ -457,7 +461,7 @@ def register_metric(**kw):
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
**kw: Keyword arguments including
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
......@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False):
The metric's compute function
Raises:
KeyError: If metric not found in registry or HF evaluate
KeyError: If a metric is not found in registry or HF evaluate
"""
try:
spec = metric_registry.get(name)
......@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False):
return hf.load(name).compute # type: ignore[attr-defined]
except Exception:
raise KeyError(f"Metric '{name}' not found anywhere")
raise KeyError(f"Metric '{name}' not found anywhere") from None
register_metric_aggregation = metric_agg_registry.register
......
from __future__ import annotations
import logging
import warnings
from collections.abc import Iterable, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any
import datasets
from random import Random
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from random import Random
from collections.abc import Iterable, Sequence
from typing import Any, TypeVar
from lm_eval.api.task import ConfigurableTask, Task
_T = TypeVar("_T")
eval_logger = logging.getLogger("lm-eval")
eval_logger = logging.getLogger(__name__)
class ContextSampler:
def __init__(
self,
docs: list[dict],
task: Task | ConfigurableTask,
fewshot_indices: Iterable | None = None,
rnd: Random | None = None,
docs: Sequence[dict[str, Any]] | None = None,
*,
rnd: int | None = None,
fewshot_indices: list[int] | None = None,
**kwargs,
) -> None:
self.rnd = rnd
if not self.rnd:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
self.task = task
self.config = task._config
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
)
else:
self.doc_to_text = self.task.doc_to_text
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
)
else:
self.doc_to_target = self.task.doc_to_target
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
)
else:
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
if not isinstance(self.docs, datasets.Dataset):
raise ValueError(
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
self.rnd = Random(rnd)
self.docs = docs or []
self.fewshot_indices = fewshot_indices
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = ""
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content
else:
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
# TODO: add logger warn once here.
warnings.warn(
"Both target_delimiter and target start with a space. This may cause issues.",
Warning,
stacklevel=2,
)
labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += (
str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
if self.config.doc_to_choice is None or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target])
)
labeled_examples += self.fewshot_delimiter
return labeled_examples
def get_chat_context(
self,
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
gen_prefix: str | None = None,
):
# TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
if fewshot_as_multiturn:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
chat_history.append(
{
"role": "user",
"content": doc_content
if self.config.doc_to_choice is None
or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content],
}
)
chat_history.append(
{
"role": "assistant",
"content": prefix + str(doc_target[0])
if isinstance(doc_target, list)
else prefix + doc_target
if self.config.doc_to_choice is None
or isinstance(doc_target, str)
else prefix + str(self.doc_to_choice(doc)[doc_target]),
}
)
else:
# get fewshot context as one user turn
chat_history.append(
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, gen_prefix=gen_prefix
),
}
)
return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]:
if self.fewshot_indices and self.docs:
self.docs = [self.docs[i] for i in self.fewshot_indices]
def sample(
self, n: int, doc: dict[str, Any] | None = None, **kwargs
) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
Sample n documents from the pool.
Args:
n: Number of documents to sample
doc: Optional document to exclude from sampling
Returns:
List of sampled documents
"""
assert self.rnd is not None, (
"Error: `rnd` must be set to a random.Random instance before sampling."
if n <= 0:
return []
return (
self.rnd.sample(self.docs, n)
if not doc
else self.remove_doc(doc, self.rnd.sample(self.docs, n + 1))
)
return self.rnd.sample(self.docs, n)
def set_rnd(self, rnd: int) -> None:
self.rnd = Random(rnd)
@staticmethod
def remove_doc(doc: _T, _iter: Iterable[_T]) -> list[_T]:
return [x for x in _iter if x != doc]
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict[str, Any]]:
def sample(self, n: int, doc=None, **kwargs):
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......@@ -214,7 +72,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler):
def sample(self, n: int):
def sample(self, n: int, doc=None, **kwargs):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
......@@ -224,7 +82,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler):
def sample(self, n: int):
def sample(self, n: int, doc=None, **kwargs):
""" """
raise NotImplementedError
......
......@@ -7,8 +7,15 @@ import random
import re
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from functools import cached_property
from typing import TYPE_CHECKING, Any, Literal, overload
from dataclasses import dataclass
from functools import cached_property, partial
from typing import (
TYPE_CHECKING,
Any,
Literal,
cast,
overload,
)
import datasets
import numpy as np
......@@ -17,12 +24,12 @@ from typing_extensions import deprecated
from lm_eval import utils
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.api.samplers import ContextSampler
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble
from lm_eval.utils import validate_index
ALL_OUTPUT_TYPES = [
......@@ -39,6 +46,16 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
@dataclass
class Message:
role: str # "system" | "user" | "assistant"
content: str
def format_turn(content: str, role: str):
return {"role": role, "content": content}
class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation
......@@ -99,6 +116,8 @@ class Task(abc.ABC):
self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage
)
self.sampler = ContextSampler(list(self.fewshot_docs))
self.multiple_input = False
def download(
self,
......@@ -241,8 +260,10 @@ class Task(abc.ABC):
def doc_to_audio(self, doc: dict):
raise NotImplementedError
def doc_to_prefix(self, doc: dict) -> str:
return ""
@staticmethod
def resolve_field(doc: dict[str, str], field: str | None = None):
if field is not None:
return doc[field] if field in doc else utils.apply_template(field, doc)
def build_all_requests(
self,
......@@ -322,7 +343,7 @@ class Task(abc.ABC):
apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template,
gen_prefix=self.doc_to_prefix(doc),
gen_prefix=self.resolve_field(doc, self.config.gen_prefix),
)
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
......@@ -411,13 +432,13 @@ class Task(abc.ABC):
def get_config(self, key: str) -> Any:
return getattr(self._config, key, None)
@classmethod
def count_bytes(cls, doc: str) -> int:
@staticmethod
def count_bytes(doc: str) -> int:
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc: str) -> int:
@staticmethod
def count_words(doc: str) -> int:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
......@@ -525,9 +546,8 @@ class Task(abc.ABC):
self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd
self.sampler.set_rnd(seed)
@property
def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
......@@ -587,6 +607,7 @@ class ConfigurableTask(Task):
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
self.fewshot_rnd = 1234
# Use new configurations if there was no preconfiguration
if self.config is None:
......@@ -611,6 +632,12 @@ class ConfigurableTask(Task):
)
self.OUTPUT_TYPE = self.config.output_type
self.multiple_targets = self.config.multiple_targets
self.multiple_inputs = self.config.multiple_inputs
assert not (self.multiple_targets and self.multiple_inputs), (
"Cannot have both multiple_targets and multiple_inputs"
)
if self.config.doc_to_image is not None:
# mark the task as requiring multimodality.
self.MULTIMODAL = True
......@@ -650,7 +677,7 @@ class ConfigurableTask(Task):
):
self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd
list(self.fewshot_docs()), rnd=self.fewshot_rnd
)
self.task_docs = self.eval_docs
......@@ -667,9 +694,7 @@ class ConfigurableTask(Task):
self.runtime_checks(self.task_docs[0])
def download(
self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
) -> None:
def download(self, dataset_kwargs: dict[str, Any] | None = None, **kwargs) -> None:
from packaging.version import parse as vparse
self.config.dataset_kwargs, self.config.metadata = (
......@@ -748,176 +773,6 @@ class ConfigurableTask(Task):
return super().fewshot_docs()
@staticmethod
def append_target_question(
labeled_examples: list[dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
gen_prefix: str | None = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
"""
if not fewshot_as_multiturn:
# if no messages or last message is system, append as new user entry
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
labeled_examples.append({"role": "user", "content": question})
# if last message is user, append to it to avoid two user messages in a row
else:
labeled_examples[-1]["content"] += question
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
if gen_prefix:
labeled_examples.append({"role": "assistant", "content": gen_prefix})
@utils.positional_deprecated
def fewshot_context(
self,
doc: dict,
num_fewshot: int,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Callable | None = None,
gen_prefix: str | None = None,
) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param system_instruction: str
System instruction to be applied to the prompt.
:param apply_chat_template: bool
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template:
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:param gen_prefix:
String to append after the <|assistant|> token.
:returns: str
The fewshot context.
"""
labeled_examples = [] if apply_chat_template else ""
# get task description
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)
# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
system_prompt = (
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
)
elif system_instruction is not None:
system_prompt = system_instruction
elif description:
system_prompt = description
else:
system_prompt = ""
# add system prompt if specified
if system_prompt:
if apply_chat_template:
labeled_examples.append({"role": "system", "content": system_prompt})
else:
labeled_examples = system_prompt
# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc,
num_fewshot,
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
)
else:
labeled_examples += self.sampler.get_context(
doc, num_fewshot, gen_prefix=gen_prefix
)
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
# TODO: append prefill?
if not labeled_examples:
return ""
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples,
example,
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(
chat,
ex,
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=not gen_prefix,
)
)
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(
labeled_examples,
choices[example],
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
else:
self.append_target_question(
labeled_examples,
str(example),
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=not gen_prefix,
)
else:
prefix = (
self.config.target_delimiter + gen_prefix
if gen_prefix is not None
else ""
)
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example + prefix
elif isinstance(example, list):
return [labeled_examples + ex + prefix for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] + prefix
else:
return labeled_examples + str(example) + prefix
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters") and self._instances:
......@@ -927,7 +782,7 @@ class ConfigurableTask(Task):
eval_logger.warning(
"No filter defined or instances found. Passing through instances"
)
return self._instances
return self._instances
def should_decontaminate(self):
return self.config.should_decontaminate
......@@ -1178,19 +1033,169 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc: dict) -> str | None:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in doc:
return doc[gen_prefix]
def _doc_to_qa_pair(
self,
doc: dict[str, Any],
gen_prefix: str | None,
*,
q: str | None = None,
a: str | None = None,
include_answer: bool = True,
) -> list[Message]:
"""Return `[user, assistant?]` for a single doc."""
q = q or self.doc_to_text(doc)
a = a or self.doc_to_target(doc)
# Handle multiple-choice indirection
if isinstance(q, list) and self.config.doc_to_choice:
q = q[cast(int, self.doc_to_target(doc))]
if isinstance(a, int) and self.config.doc_to_choice:
a = (
self.doc_to_choice(doc)[a]
if not self.multiple_inputs
else self.doc_to_choice(doc)[0]
)
assert isinstance(q, str), "Context is not a string!"
msgs = [Message("user", q)]
if include_answer:
if gen_prefix and not gen_prefix[-1].isspace():
prefix = gen_prefix + " "
elif gen_prefix:
prefix = gen_prefix
else:
return utils.apply_template(gen_prefix, doc)
return None
prefix = ""
answer_txt = prefix + (a if not isinstance(a, list) else a[0])
msgs.append(Message("assistant", answer_txt))
else:
msgs.append(Message("assistant", gen_prefix)) if gen_prefix else None
return msgs
@staticmethod
def _render_chat_template(
messages: list[Message],
chat_template: Callable[[list[dict[str, str]]], str],
*,
tgt_delim: str = " ",
few_delim: str = "\n\n",
multiturn=True,
) -> str:
if multiturn:
return chat_template([m.__dict__ for m in messages])
else:
has_prefix = messages[-1].role == "assistant"
if not has_prefix:
context = [
format_turn(
ConfigurableTask._message_to_text(
messages, tgt_delim=tgt_delim, few_delim=few_delim
),
role="user",
)
]
else:
context = [
format_turn(
ConfigurableTask._message_to_text(
messages[:-1], tgt_delim=tgt_delim, few_delim=few_delim
),
role="user",
)
]
context += [format_turn(**messages[-1].__dict__)]
return chat_template(context)
def fewshot_context(
self,
doc: dict[str, str],
num_fewshot: int,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Callable[..., str] | None = None,
gen_prefix: str | None = None,
) -> str | list[str]:
messages = []
tgt_delim, few_delim = (
self.config.target_delimiter,
self.config.fewshot_delimiter,
)
chat_template = (
partial(chat_template, add_generation_prompt=not gen_prefix)
if chat_template
else None
)
description = self.resolve_field(doc, self.config.description) or ""
system_prompt = few_delim.join(filter(None, [system_instruction, description]))
if system_prompt:
messages.append(Message("system", system_prompt))
for fs_doc in self.sampler.sample(
n=num_fewshot,
doc=doc if self.config.fewshot_split == self.config.test_split else None,
):
messages += self._doc_to_qa_pair(fs_doc, gen_prefix)
if self.multiple_inputs:
# if multiple inputs, then doc_to_text: list[str]
messages = [
messages
+ self._doc_to_qa_pair(
doc,
gen_prefix,
q=q,
include_answer=False,
)
for q in cast(list[str], self.doc_to_text(doc))
]
else:
# otherwise, doc_to_text: str for all other cases
messages += self._doc_to_qa_pair(doc, gen_prefix, include_answer=False)
messages = [messages]
if apply_chat_template and chat_template:
res = [
self._render_chat_template(
m,
chat_template,
tgt_delim=tgt_delim,
few_delim=few_delim,
multiturn=fewshot_as_multiturn,
)
for m in messages
]
else:
res = [
self._message_to_text(m, tgt_delim=tgt_delim, few_delim=few_delim)
for m in messages
]
return res[0] if not self.multiple_inputs else res
@staticmethod
def _message_to_text(
messages: list[Message],
*,
tgt_delim=" ",
few_delim="\n\n",
) -> str:
buff = []
for i, m in enumerate(messages):
if m.role == "system" or m.role == "user":
buff.append(m.content)
elif m.role == "assistant":
buff.append(tgt_delim + m.content)
if i != len(messages) - 1:
# then this is not assis prefill
buff.append(few_delim)
return "".join(buff)
def construct_requests(
self, doc: dict, ctx: str, **kwargs
self, doc: dict[str, str], ctx: str | list[str], **kwargs
) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None)
chat_template: Callable | None = kwargs.pop("chat_template", None) # noqa: F841
aux_arguments = None
......@@ -1200,31 +1205,18 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter
if apply_chat_template:
target_delimiter = ""
if self.multiple_input:
# If there are multiple inputs, choices are placed in the ctx
# apply chat_template to choices if apply_chat_template
cont = self.doc_to_target(doc)
arguments = [
(
ctx
+ (
chat_template([{"role": "user", "content": choice}])
if apply_chat_template
else choice
),
f"{target_delimiter}{cont}",
)
for choice in choices
]
target_delimiter = (
""
if (apply_chat_template and not self.config.gen_prefix)
else self.config.target_delimiter
)
if self.multiple_inputs:
# If there are multiple inputs, assume only one choice
arguments = [(_ctx, f"{target_delimiter}{choices[0]}") for _ctx in ctx]
else:
# Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_name for m in self.config._metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1232,7 +1224,6 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice.
# TODO: should these be strided? will have to modify the processing in process_results if so
aux_arguments = [
("", f"{target_delimiter}{choice}") for choice in choices
]
......@@ -1327,7 +1318,11 @@ class ConfigurableTask(Task):
lls, is_greedy = zip(*results)
# retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc)
choices = (
self.doc_to_choice(doc)
if not self.multiple_inputs
else cast(list[str], self.doc_to_text(doc))
)
completion_len = np.array([float(len(i)) for i in choices])
if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric:
......@@ -1345,20 +1340,26 @@ class ConfigurableTask(Task):
pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len)
if self.multiple_input:
gold = self.doc_to_text(doc)
gold = backup = self.doc_to_target(doc)
if isinstance(gold, list):
gold = [validate_index(g, len(choices)) for g in gold]
gold_index_error = -100 in gold
else:
gold = self.doc_to_target(doc)
if isinstance(gold, int):
gold = validate_index(gold, len(choices))
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
gold, gold_index_error = check_gold_index_error(choices, gold)
gold_index_error = gold == -100
if gold_index_error:
eval_logger.warning(
f"Label index was not in within range of available choices,"
f"Label [{backup}] index was not in within range of available choices {choices},"
f"Sample:\n\n{doc}\n\n"
)
if self.multiple_target:
if self.multiple_targets:
acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
......@@ -1408,7 +1409,6 @@ class ConfigurableTask(Task):
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = metric.fn([gold, result])
if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function
for k, v in result_score.items():
result_dict[k] = v
......
......@@ -4,7 +4,7 @@ import textwrap
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import yaml
......@@ -214,7 +214,7 @@ class EvaluatorConfig:
# Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config)
# Create instance and validate
# Create an instance and validate
instance = cls(**config)
if used_config:
print(textwrap.dedent(f"""{instance}"""))
......@@ -238,7 +238,7 @@ class EvaluatorConfig:
return instance
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
def _parse_dict_args(config: dict[str, Any]) -> dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
......@@ -246,7 +246,7 @@ class EvaluatorConfig:
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
def load_yaml_config(config_path: Union[str, Path]) -> dict[str, Any]:
"""Load and validate YAML config file."""
config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path
......@@ -257,9 +257,9 @@ class EvaluatorConfig:
try:
yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}")
raise ValueError(f"Invalid YAML in {config_path}: {e}") from e
except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}")
raise ValueError(f"Could not read config file {config_path}: {e}") from e
if not isinstance(yaml_data, dict):
raise ValueError(
......@@ -307,7 +307,7 @@ class EvaluatorConfig:
raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed."""
"""Process samples argument - load from a file if needed."""
if self.samples:
if isinstance(self.samples, dict):
self.samples = self.samples
......@@ -328,7 +328,6 @@ class EvaluatorConfig:
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager
# if metadata manually passed use that:
......@@ -365,7 +364,7 @@ class EvaluatorConfig:
return task_manager
def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
"""Apply the trust_remote_code setting if enabled."""
if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
......
......@@ -21,7 +21,7 @@ def serialize_callable(
return value
else:
try:
return getsource(value)
return getsource(value) # type: ignore
except (TypeError, OSError):
return str(value)
......
......@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any:
module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}"
else:
# Fallback to full path if pattern not found
# Fallback to a full path if a pattern not found
module_name = str(module_path.with_suffix(""))
else:
# External module - use full path without extension
# External module - use a full path without extension
module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module
......@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
# Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns
module.__mtime__ = module_path.stat().st_mtime_ns # type: ignore
spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module
return module
......
......@@ -32,9 +32,9 @@ class TaskFactory:
registry: Mapping[str, Entry],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry)
......@@ -121,4 +121,4 @@ class TaskFactory:
def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters
return bool(init and "config" in inspect.signature(init).parameters)
......@@ -61,8 +61,8 @@ class TaskManager:
def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if isinstance(spec, str):
entry = self._entry(spec)
......
......@@ -8,8 +8,8 @@ Homepage: [google-research-datasets/natural-questions@master/nq_open](https://gi
Paper: [aclanthology.org/P19-1612](https://aclanthology.org/P19-1612/)
Derived from the Natural Questions dataset, introduced in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf .
Derived from the Natural Questions dataset, introduced
in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf .
### Citation
......@@ -26,4 +26,5 @@ journal = {Transactions of the Association of Computational Linguistics}}
* `nq_open`
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
* 2025-07-21: Added `multiple_targets` to `exact_match`. Scores should not change.
......@@ -5,7 +5,7 @@ training_split: train
validation_split: validation
description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}"
doc_to_target: answer
fewshot_delimiter: "\n"
generation_kwargs:
until:
......@@ -27,7 +27,7 @@ metric_list:
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )"
multi_target: true
- "\\b(?:The |the |An |A |The |a |an )"
multiple_targets: true
metadata:
version: 4.0
......@@ -79,3 +79,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
- 2025-07-22: `record` and `multirc`: set target_delimiter to "" and trim doc_to_text respectively.
tag:
- super-glue-lm-eval-v1
task: boolq
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: boolq
output_type: multiple_choice
training_split: train
......
tag:
- super-glue-lm-eval-v1-seq2seq
task: "boolq-seq2seq"
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: boolq
output_type: generate_until
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
doc_to_target: label
doc_to_choice: [' no', ' yes']
doc_to_target: "{{ [' no', ' yes'][label|int] }}"
target_delimiter: ""
generation_kwargs:
until:
......
tag:
- super-glue-t5-prompt
task: super_glue-boolq-t5-prompt
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: boolq
training_split: train
validation_split: validation
output_type: generate_until
doc_to_text: "boolq passage: {{passage}} question: {{question}}"
doc_to_target: label
doc_to_choice: ['False', 'True']
doc_to_target: "{{['False', 'True'][label|int]}}"
generation_kwargs:
until:
- "</s>"
......
tag:
- super-glue-lm-eval-v1
task: cb
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: cb
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: label
doc_to_choice: ['True', 'False', 'Neither']
doc_to_choice: ["True", "False", "Neither"]
metric_list:
- metric: acc
- metric: f1
......
tag:
- super-glue-t5-prompt
task: super_glue-cb-t5-prompt
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: cb
training_split: train
validation_split: validation
output_type: generate_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label
doc_to_choice: ['entailment', 'contradiction', 'neutral']
doc_to_target: "{{ ['entailment', 'contradiction', 'neutral'][label|int] }}"
generation_kwargs:
until:
- "</s>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment