Unverified Commit 003e5852 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Fewshot refactor (#3227)



* overhaul `ContextSampler`

* refactor masakhapos

* move multi_target to `exact_match`

* remove doc_to_choice from `boolq-seq2seq`

* remove doc_to_choice in generation process_results

* Remove unused `doc_to_choice` and fix superglue whitespaces

* require multiple_inputs and multiple_targets to be explicitly set in taskconfig

* fix copa; better logging in task init

* fix doc_to_target to return int rather than str (deprecated)

* fix processing regression; recursively parse lists fron template

* remove redundant jinja parsing logic

* remove promptsource

* for multiple_inputs use `doc_to_text: list[str]``

* Refactor `ContextSampler` `fewshot_context`

* fix multiple_input context

* fix `target_delimiter` with `gen_prefix`

* `doc_to_text` is list for multiple_inputs

* Refactor `count_bytes` and `count_words` methods to `@staticmethod`

* make has_*(train/test/validation) to properties

* remove `multi_target` `generate_until`

* `fix doc_to_target/multiple_targets handling add tests

* rename `multi_target` to `multiple_targets`

* evalaute list when multiple targets

* allow doc_to_target to return list

* Remove gen_prefix space and add warning (#3239)

* Remove gen_prefix space and add warning

* fix null gen_prefix bug again

* use git tests

---------
Co-authored-by: default avatarBoaz Ben-Dov <bendboaz@gmail.com>
parent 79a22a11
...@@ -8,8 +8,6 @@ on: ...@@ -8,8 +8,6 @@ on:
branches: branches:
- 'main' - 'main'
pull_request: pull_request:
branches:
- 'main'
workflow_dispatch: workflow_dispatch:
# Jobs run concurrently and steps run sequentially within a job. # Jobs run concurrently and steps run sequentially within a job.
# jobs: linter and cpu_tests. Add more jobs/steps as required. # jobs: linter and cpu_tests. Add more jobs/steps as required.
......
...@@ -53,6 +53,6 @@ class FilterEnsemble: ...@@ -53,6 +53,6 @@ class FilterEnsemble:
resps = f().apply(resps, docs) resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has a key ` self.name `: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps): for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple from typing import Any, Literal, Optional
OutputType = Literal[ OutputType = Literal[
...@@ -10,10 +10,10 @@ OutputType = Literal[ ...@@ -10,10 +10,10 @@ OutputType = Literal[
@dataclass @dataclass
class Instance: class Instance:
request_type: OutputType request_type: OutputType
doc: dict doc: dict[str, Any]
arguments: tuple arguments: tuple
idx: int idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( metadata: tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None), default_factory=lambda: (None, None, None),
metadata=dict( metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats." description="Metadata tuple containing task name, document ID, and number of repeats."
......
...@@ -213,7 +213,7 @@ def exact_match_hf_evaluate( ...@@ -213,7 +213,7 @@ def exact_match_hf_evaluate(
ignore_case: bool = False, ignore_case: bool = False,
ignore_punctuation: bool = False, ignore_punctuation: bool = False,
ignore_numbers: bool = False, ignore_numbers: bool = False,
multi_target: bool = False, multiple_targets: bool = False,
): ):
""" """
Compute exact match scores between predictions and references. Compute exact match scores between predictions and references.
...@@ -245,8 +245,8 @@ def exact_match_hf_evaluate( ...@@ -245,8 +245,8 @@ def exact_match_hf_evaluate(
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True. - "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
""" """
predictions, references = list(predictions), list(references) predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, ( assert len(predictions) == len(references) if not multiple_targets else True, (
"predictions and references must have the same length unless `multi_target` is True" "predictions and references must have the same length unless `multiple_targets` is True"
) )
if regexes_to_ignore is not None: if regexes_to_ignore is not None:
...@@ -275,7 +275,7 @@ def exact_match_hf_evaluate( ...@@ -275,7 +275,7 @@ def exact_match_hf_evaluate(
return { return {
"exact_match": np.mean(score_list) "exact_match": np.mean(score_list)
if not multi_target if not multiple_targets
else float(np.any(score_list)) else float(np.any(score_list))
} }
......
...@@ -220,8 +220,8 @@ class Registry(Generic[T]): ...@@ -220,8 +220,8 @@ class Registry(Generic[T]):
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel") >>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises: Raises:
ValueError: If alias already registered with different target ValueError: If alias is already registered with a different target
TypeError: If object doesn't inherit from base_cls (when specified) TypeError: If an object doesn't inherit from base_cls (when specified)
""" """
def _store(alias: str, target: T | Placeholder) -> None: def _store(alias: str, target: T | Placeholder) -> None:
...@@ -229,21 +229,27 @@ class Registry(Generic[T]): ...@@ -229,21 +229,27 @@ class Registry(Generic[T]):
# collision handling ------------------------------------------ # collision handling ------------------------------------------
if current is not None and current != target: if current is not None and current != target:
# allow placeholder → real object upgrade # allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type): # mod, _, cls = current.partition(":")
# mod, _, cls = current.partition(":") if (
if current == f"{target.__module__}:{target.__name__}": isinstance(current, str)
self._objs[alias] = target and isinstance(target, type)
return and current == f"{target.__module__}:{target.__name__}"
):
self._objs[alias] = target
return
raise ValueError( raise ValueError(
f"{self._name!r} alias '{alias}' already registered (" f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})" f"existing={current}, new={target})"
) )
# type check for concrete classes ---------------------------------------------- # type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type): if (
if not issubclass(target, self._base_cls): # type: ignore[arg-type] self._base_cls is not None
raise TypeError( and isinstance(target, type)
f"{target} must inherit from {self._base_cls} to be a {self._name}" and not issubclass(target, self._base_cls)
) ):
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type] def decorator(obj: T) -> T: # type: ignore[valid-type]
...@@ -409,9 +415,7 @@ class MetricSpec: ...@@ -409,9 +415,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast( model_registry = cast(Registry[type[LM]], Registry("model", base_cls=LM))
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry( metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
...@@ -457,7 +461,7 @@ def register_metric(**kw): ...@@ -457,7 +461,7 @@ def register_metric(**kw):
then registers it in the metric registry. then registers it in the metric registry.
Args: Args:
**kw: Keyword arguments including: **kw: Keyword arguments including
- metric: Name to register the metric under (required) - metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry - aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True) - higher_is_better: Whether higher scores are better (default: True)
...@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False):
The metric's compute function The metric's compute function
Raises: Raises:
KeyError: If metric not found in registry or HF evaluate KeyError: If a metric is not found in registry or HF evaluate
""" """
try: try:
spec = metric_registry.get(name) spec = metric_registry.get(name)
...@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False):
return hf.load(name).compute # type: ignore[attr-defined] return hf.load(name).compute # type: ignore[attr-defined]
except Exception: except Exception:
raise KeyError(f"Metric '{name}' not found anywhere") raise KeyError(f"Metric '{name}' not found anywhere") from None
register_metric_aggregation = metric_agg_registry.register register_metric_aggregation = metric_agg_registry.register
......
from __future__ import annotations from __future__ import annotations
import logging import logging
import warnings from random import Random
from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING
from functools import partial
from typing import TYPE_CHECKING, Any
import datasets
if TYPE_CHECKING: if TYPE_CHECKING:
from random import Random from collections.abc import Iterable, Sequence
from typing import Any, TypeVar
from lm_eval.api.task import ConfigurableTask, Task _T = TypeVar("_T")
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger(__name__)
class ContextSampler: class ContextSampler:
def __init__( def __init__(
self, self,
docs: list[dict], docs: Sequence[dict[str, Any]] | None = None,
task: Task | ConfigurableTask, *,
fewshot_indices: Iterable | None = None, rnd: int | None = None,
rnd: Random | None = None, fewshot_indices: list[int] | None = None,
**kwargs,
) -> None: ) -> None:
self.rnd = rnd self.rnd = Random(rnd)
if not self.rnd: self.docs = docs or []
raise ValueError( self.fewshot_indices = fewshot_indices
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
self.task = task
self.config = task._config
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
)
else:
self.doc_to_text = self.task.doc_to_text
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
)
else:
self.doc_to_target = self.task.doc_to_target
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
)
else:
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
if not isinstance(self.docs, datasets.Dataset):
raise ValueError(
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs if self.fewshot_indices and self.docs:
fewshotex = self.sample(n_samples) self.docs = [self.docs[i] for i in self.fewshot_indices]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot def sample(
# TODO: should we just stop people from using fewshot from same split as evaluating? self, n: int, doc: dict[str, Any] | None = None, **kwargs
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] ) -> Sequence[dict]:
labeled_examples = ""
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content
else:
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
# TODO: add logger warn once here.
warnings.warn(
"Both target_delimiter and target start with a space. This may cause issues.",
Warning,
stacklevel=2,
)
labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += (
str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
if self.config.doc_to_choice is None or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target])
)
labeled_examples += self.fewshot_delimiter
return labeled_examples
def get_chat_context(
self,
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
gen_prefix: str | None = None,
):
# TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
if fewshot_as_multiturn:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
chat_history.append(
{
"role": "user",
"content": doc_content
if self.config.doc_to_choice is None
or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content],
}
)
chat_history.append(
{
"role": "assistant",
"content": prefix + str(doc_target[0])
if isinstance(doc_target, list)
else prefix + doc_target
if self.config.doc_to_choice is None
or isinstance(doc_target, str)
else prefix + str(self.doc_to_choice(doc)[doc_target]),
}
)
else:
# get fewshot context as one user turn
chat_history.append(
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, gen_prefix=gen_prefix
),
}
)
return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]:
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Sample n documents from the pool.
Args:
n: Number of documents to sample
doc: Optional document to exclude from sampling
Returns:
List of sampled documents
""" """
assert self.rnd is not None, ( if n <= 0:
"Error: `rnd` must be set to a random.Random instance before sampling." return []
return (
self.rnd.sample(self.docs, n)
if not doc
else self.remove_doc(doc, self.rnd.sample(self.docs, n + 1))
) )
return self.rnd.sample(self.docs, n)
def set_rnd(self, rnd: int) -> None:
self.rnd = Random(rnd)
@staticmethod
def remove_doc(doc: _T, _iter: Iterable[_T]) -> list[_T]:
return [x for x in _iter if x != doc]
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict[str, Any]]: def sample(self, n: int, doc=None, **kwargs):
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...@@ -214,7 +72,7 @@ class FirstNSampler(ContextSampler): ...@@ -214,7 +72,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler): class BalancedSampler(ContextSampler):
def sample(self, n: int): def sample(self, n: int, doc=None, **kwargs):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
...@@ -224,7 +82,7 @@ class BalancedSampler(ContextSampler): ...@@ -224,7 +82,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler): class ManualSampler(ContextSampler):
def sample(self, n: int): def sample(self, n: int, doc=None, **kwargs):
""" """ """ """
raise NotImplementedError raise NotImplementedError
......
...@@ -7,8 +7,15 @@ import random ...@@ -7,8 +7,15 @@ import random
import re import re
from collections.abc import Callable, Iterable, Iterator, Mapping from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from functools import cached_property from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Literal, overload from functools import cached_property, partial
from typing import (
TYPE_CHECKING,
Any,
Literal,
cast,
overload,
)
import datasets import datasets
import numpy as np import numpy as np
...@@ -17,12 +24,12 @@ from typing_extensions import deprecated ...@@ -17,12 +24,12 @@ from typing_extensions import deprecated
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity from lm_eval.api.samplers import ContextSampler
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import DataSet, TaskConfig from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
from lm_eval.utils import validate_index
ALL_OUTPUT_TYPES = [ ALL_OUTPUT_TYPES = [
...@@ -39,6 +46,16 @@ if TYPE_CHECKING: ...@@ -39,6 +46,16 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
@dataclass
class Message:
role: str # "system" | "user" | "assistant"
content: str
def format_turn(content: str, role: str):
return {"role": role, "content": content}
class Task(abc.ABC): class Task(abc.ABC):
"""A task represents an entire benchmark including its dataset, problems, """A task represents an entire benchmark including its dataset, problems,
answers, and evaluation methods. See BoolQ for a simple example implementation answers, and evaluation methods. See BoolQ for a simple example implementation
...@@ -99,6 +116,8 @@ class Task(abc.ABC): ...@@ -99,6 +116,8 @@ class Task(abc.ABC):
self.fewshot_rnd: random.Random | None = ( self.fewshot_rnd: random.Random | None = (
None # purposely induce errors in case of improper usage None # purposely induce errors in case of improper usage
) )
self.sampler = ContextSampler(list(self.fewshot_docs))
self.multiple_input = False
def download( def download(
self, self,
...@@ -241,8 +260,10 @@ class Task(abc.ABC): ...@@ -241,8 +260,10 @@ class Task(abc.ABC):
def doc_to_audio(self, doc: dict): def doc_to_audio(self, doc: dict):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc: dict) -> str: @staticmethod
return "" def resolve_field(doc: dict[str, str], field: str | None = None):
if field is not None:
return doc[field] if field in doc else utils.apply_template(field, doc)
def build_all_requests( def build_all_requests(
self, self,
...@@ -322,7 +343,7 @@ class Task(abc.ABC): ...@@ -322,7 +343,7 @@ class Task(abc.ABC):
apply_chat_template=apply_chat_template, apply_chat_template=apply_chat_template,
fewshot_as_multiturn=fewshot_as_multiturn, fewshot_as_multiturn=fewshot_as_multiturn,
chat_template=chat_template, chat_template=chat_template,
gen_prefix=self.doc_to_prefix(doc), gen_prefix=self.resolve_field(doc, self.config.gen_prefix),
) )
# TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute # TODO: we should override self.config.repeats if doing greedy gen so users don't waste time+compute
...@@ -411,13 +432,13 @@ class Task(abc.ABC): ...@@ -411,13 +432,13 @@ class Task(abc.ABC):
def get_config(self, key: str) -> Any: def get_config(self, key: str) -> Any:
return getattr(self._config, key, None) return getattr(self._config, key, None)
@classmethod @staticmethod
def count_bytes(cls, doc: str) -> int: def count_bytes(doc: str) -> int:
"""Used for byte-level perplexity metrics in rolling loglikelihood""" """Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8")) return len(doc.encode("utf-8"))
@classmethod @staticmethod
def count_words(cls, doc: str) -> int: def count_words(doc: str) -> int:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!""" """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
...@@ -525,9 +546,8 @@ class Task(abc.ABC): ...@@ -525,9 +546,8 @@ class Task(abc.ABC):
self._config.process_results = lambda *args: {"bypass": 0} self._config.process_results = lambda *args: {"bypass": 0}
def set_fewshot_seed(self, seed: int | None = None) -> None: def set_fewshot_seed(self, seed: int | None = None) -> None:
self.fewshot_rnd = random.Random(seed)
if hasattr(self, "sampler"): if hasattr(self, "sampler"):
self.sampler.rnd = self.fewshot_rnd self.sampler.set_rnd(seed)
@property @property
def eval_docs(self) -> datasets.Dataset | Iterable[dict]: def eval_docs(self) -> datasets.Dataset | Iterable[dict]:
...@@ -587,6 +607,7 @@ class ConfigurableTask(Task): ...@@ -587,6 +607,7 @@ class ConfigurableTask(Task):
) -> None: ) -> None:
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
self.fewshot_rnd = 1234
# Use new configurations if there was no preconfiguration # Use new configurations if there was no preconfiguration
if self.config is None: if self.config is None:
...@@ -611,6 +632,12 @@ class ConfigurableTask(Task): ...@@ -611,6 +632,12 @@ class ConfigurableTask(Task):
) )
self.OUTPUT_TYPE = self.config.output_type self.OUTPUT_TYPE = self.config.output_type
self.multiple_targets = self.config.multiple_targets
self.multiple_inputs = self.config.multiple_inputs
assert not (self.multiple_targets and self.multiple_inputs), (
"Cannot have both multiple_targets and multiple_inputs"
)
if self.config.doc_to_image is not None: if self.config.doc_to_image is not None:
# mark the task as requiring multimodality. # mark the task as requiring multimodality.
self.MULTIMODAL = True self.MULTIMODAL = True
...@@ -650,7 +677,7 @@ class ConfigurableTask(Task): ...@@ -650,7 +677,7 @@ class ConfigurableTask(Task):
): ):
self.fewshot_rnd = random.Random() self.fewshot_rnd = random.Random()
self.sampler = self.config.fewshot_cfg.init_sampler( self.sampler = self.config.fewshot_cfg.init_sampler(
list(self.fewshot_docs()), self, rnd=self.fewshot_rnd list(self.fewshot_docs()), rnd=self.fewshot_rnd
) )
self.task_docs = self.eval_docs self.task_docs = self.eval_docs
...@@ -667,9 +694,7 @@ class ConfigurableTask(Task): ...@@ -667,9 +694,7 @@ class ConfigurableTask(Task):
self.runtime_checks(self.task_docs[0]) self.runtime_checks(self.task_docs[0])
def download( def download(self, dataset_kwargs: dict[str, Any] | None = None, **kwargs) -> None:
self, dataset_kwargs:dict[str, Any] | None = None, **kwargs
) -> None:
from packaging.version import parse as vparse from packaging.version import parse as vparse
self.config.dataset_kwargs, self.config.metadata = ( self.config.dataset_kwargs, self.config.metadata = (
...@@ -748,176 +773,6 @@ class ConfigurableTask(Task): ...@@ -748,176 +773,6 @@ class ConfigurableTask(Task):
return super().fewshot_docs() return super().fewshot_docs()
@staticmethod
def append_target_question(
labeled_examples: list[dict[str, str]],
question: str,
fewshot_as_multiturn: bool = False,
gen_prefix: str | None = None,
) -> None:
"""Adds a target question to the labeled examples list.
If fewshot_as_multiturn is True, or labeled_examples is empty, or the last entry is a system turn, appends the question as a new user entry.
Otherwise, it is appended to the last user entry, ensuring that the conversation alternates between the user and the assistant.
"""
if not fewshot_as_multiturn:
# if no messages or last message is system, append as new user entry
if len(labeled_examples) == 0 or labeled_examples[-1]["role"] == "system":
labeled_examples.append({"role": "user", "content": question})
# if last message is user, append to it to avoid two user messages in a row
else:
labeled_examples[-1]["content"] += question
else:
# if fewshot_as_multiturn is True, append as next user entry (last is always assistant)
labeled_examples.append({"role": "user", "content": question})
if gen_prefix:
labeled_examples.append({"role": "assistant", "content": gen_prefix})
@utils.positional_deprecated
def fewshot_context(
self,
doc: dict,
num_fewshot: int,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Callable | None = None,
gen_prefix: str | None = None,
) -> str | list[str] | None:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
:param doc: str
The document as returned from training_docs, validation_docs, or test_docs.
:param num_fewshot: int
The number of fewshot examples to provide in the returned context string.
:param system_instruction: str
System instruction to be applied to the prompt.
:param apply_chat_template: bool
Whether to apply the chat template to the fewshot context.
:param fewshot_as_multiturn: bool
Whether to provide the fewshot examples as a multiturn conversation or a single user turn.
:param chat_template:
callable (from lm.apply_chat_template) that takes in a list[Dict] chat transcript and renders it into a string.
:param gen_prefix:
String to append after the <|assistant|> token.
:returns: str
The fewshot context.
"""
labeled_examples = [] if apply_chat_template else ""
# get task description
if description := self.config.description:
description = utils.apply_template(self.config.description, doc)
# create system prompt based on the provided system instruction and description
if system_instruction is not None and description:
system_prompt = (
f"{system_instruction}{self.sampler.fewshot_delimiter}{description}"
)
elif system_instruction is not None:
system_prompt = system_instruction
elif description:
system_prompt = description
else:
system_prompt = ""
# add system prompt if specified
if system_prompt:
if apply_chat_template:
labeled_examples.append({"role": "system", "content": system_prompt})
else:
labeled_examples = system_prompt
# if few-shot - append examples after the system prompt
if num_fewshot > 0:
if apply_chat_template:
labeled_examples.extend(
self.sampler.get_chat_context(
doc,
num_fewshot,
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
)
else:
labeled_examples += self.sampler.get_context(
doc, num_fewshot, gen_prefix=gen_prefix
)
example = self.doc_to_text(doc)
if apply_chat_template:
if self.multiple_input:
# TODO: append prefill?
if not labeled_examples:
return ""
return chat_template(labeled_examples)
if isinstance(example, str):
self.append_target_question(
labeled_examples,
example,
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
# for loglikelihood create a list of questions with appended choices
elif isinstance(example, list):
labeled_examples_list = []
# copy chat history for each example and append the answer
for ex in example:
chat = deepcopy(labeled_examples)
self.append_target_question(
chat,
ex,
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
# TODO: append prefill?
labeled_examples_list.append(
chat_template(
chat,
add_generation_prompt=not gen_prefix,
)
)
return labeled_examples_list
# if example is an integer, append the choice or convert to string
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
self.append_target_question(
labeled_examples,
choices[example],
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
else:
self.append_target_question(
labeled_examples,
str(example),
fewshot_as_multiturn,
gen_prefix=gen_prefix,
)
# return lm.apply_chat_template(labeled_examples)
return chat_template(
labeled_examples,
add_generation_prompt=not gen_prefix,
)
else:
prefix = (
self.config.target_delimiter + gen_prefix
if gen_prefix is not None
else ""
)
if self.multiple_input:
return labeled_examples
if isinstance(example, str):
return labeled_examples + example + prefix
elif isinstance(example, list):
return [labeled_examples + ex + prefix for ex in example]
elif isinstance(example, int):
if self.config.doc_to_choice is not None:
choices = self.doc_to_choice(doc)
return labeled_examples + choices[example] + prefix
else:
return labeled_examples + str(example) + prefix
def apply_filters(self) -> list[Instance] | None: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters") and self._instances: if hasattr(self, "_filters") and self._instances:
...@@ -927,7 +782,7 @@ class ConfigurableTask(Task): ...@@ -927,7 +782,7 @@ class ConfigurableTask(Task):
eval_logger.warning( eval_logger.warning(
"No filter defined or instances found. Passing through instances" "No filter defined or instances found. Passing through instances"
) )
return self._instances return self._instances
def should_decontaminate(self): def should_decontaminate(self):
return self.config.should_decontaminate return self.config.should_decontaminate
...@@ -1178,19 +1033,169 @@ class ConfigurableTask(Task): ...@@ -1178,19 +1033,169 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc: dict) -> str | None: def _doc_to_qa_pair(
if (gen_prefix := self.config.gen_prefix) is not None: self,
if gen_prefix in doc: doc: dict[str, Any],
return doc[gen_prefix] gen_prefix: str | None,
*,
q: str | None = None,
a: str | None = None,
include_answer: bool = True,
) -> list[Message]:
"""Return `[user, assistant?]` for a single doc."""
q = q or self.doc_to_text(doc)
a = a or self.doc_to_target(doc)
# Handle multiple-choice indirection
if isinstance(q, list) and self.config.doc_to_choice:
q = q[cast(int, self.doc_to_target(doc))]
if isinstance(a, int) and self.config.doc_to_choice:
a = (
self.doc_to_choice(doc)[a]
if not self.multiple_inputs
else self.doc_to_choice(doc)[0]
)
assert isinstance(q, str), "Context is not a string!"
msgs = [Message("user", q)]
if include_answer:
if gen_prefix and not gen_prefix[-1].isspace():
prefix = gen_prefix + " "
elif gen_prefix:
prefix = gen_prefix
else: else:
return utils.apply_template(gen_prefix, doc) prefix = ""
return None answer_txt = prefix + (a if not isinstance(a, list) else a[0])
msgs.append(Message("assistant", answer_txt))
else:
msgs.append(Message("assistant", gen_prefix)) if gen_prefix else None
return msgs
@staticmethod
def _render_chat_template(
messages: list[Message],
chat_template: Callable[[list[dict[str, str]]], str],
*,
tgt_delim: str = " ",
few_delim: str = "\n\n",
multiturn=True,
) -> str:
if multiturn:
return chat_template([m.__dict__ for m in messages])
else:
has_prefix = messages[-1].role == "assistant"
if not has_prefix:
context = [
format_turn(
ConfigurableTask._message_to_text(
messages, tgt_delim=tgt_delim, few_delim=few_delim
),
role="user",
)
]
else:
context = [
format_turn(
ConfigurableTask._message_to_text(
messages[:-1], tgt_delim=tgt_delim, few_delim=few_delim
),
role="user",
)
]
context += [format_turn(**messages[-1].__dict__)]
return chat_template(context)
def fewshot_context(
self,
doc: dict[str, str],
num_fewshot: int,
system_instruction: str | None = None,
apply_chat_template: bool = False,
fewshot_as_multiturn: bool = False,
chat_template: Callable[..., str] | None = None,
gen_prefix: str | None = None,
) -> str | list[str]:
messages = []
tgt_delim, few_delim = (
self.config.target_delimiter,
self.config.fewshot_delimiter,
)
chat_template = (
partial(chat_template, add_generation_prompt=not gen_prefix)
if chat_template
else None
)
description = self.resolve_field(doc, self.config.description) or ""
system_prompt = few_delim.join(filter(None, [system_instruction, description]))
if system_prompt:
messages.append(Message("system", system_prompt))
for fs_doc in self.sampler.sample(
n=num_fewshot,
doc=doc if self.config.fewshot_split == self.config.test_split else None,
):
messages += self._doc_to_qa_pair(fs_doc, gen_prefix)
if self.multiple_inputs:
# if multiple inputs, then doc_to_text: list[str]
messages = [
messages
+ self._doc_to_qa_pair(
doc,
gen_prefix,
q=q,
include_answer=False,
)
for q in cast(list[str], self.doc_to_text(doc))
]
else:
# otherwise, doc_to_text: str for all other cases
messages += self._doc_to_qa_pair(doc, gen_prefix, include_answer=False)
messages = [messages]
if apply_chat_template and chat_template:
res = [
self._render_chat_template(
m,
chat_template,
tgt_delim=tgt_delim,
few_delim=few_delim,
multiturn=fewshot_as_multiturn,
)
for m in messages
]
else:
res = [
self._message_to_text(m, tgt_delim=tgt_delim, few_delim=few_delim)
for m in messages
]
return res[0] if not self.multiple_inputs else res
@staticmethod
def _message_to_text(
messages: list[Message],
*,
tgt_delim=" ",
few_delim="\n\n",
) -> str:
buff = []
for i, m in enumerate(messages):
if m.role == "system" or m.role == "user":
buff.append(m.content)
elif m.role == "assistant":
buff.append(tgt_delim + m.content)
if i != len(messages) - 1:
# then this is not assis prefill
buff.append(few_delim)
return "".join(buff)
def construct_requests( def construct_requests(
self, doc: dict, ctx: str, **kwargs self, doc: dict[str, str], ctx: str | list[str], **kwargs
) -> list[Instance] | Instance: ) -> list[Instance] | Instance:
apply_chat_template = kwargs.pop("apply_chat_template", False) apply_chat_template = kwargs.pop("apply_chat_template", False)
chat_template: Callable | None = kwargs.pop("chat_template", None) chat_template: Callable | None = kwargs.pop("chat_template", None) # noqa: F841
aux_arguments = None aux_arguments = None
...@@ -1200,31 +1205,18 @@ class ConfigurableTask(Task): ...@@ -1200,31 +1205,18 @@ class ConfigurableTask(Task):
arguments = (self.doc_to_target(doc),) arguments = (self.doc_to_target(doc),)
elif self.OUTPUT_TYPE == "multiple_choice": elif self.OUTPUT_TYPE == "multiple_choice":
choices = self.doc_to_choice(doc) choices = self.doc_to_choice(doc)
target_delimiter = self.config.target_delimiter target_delimiter = (
if apply_chat_template: ""
target_delimiter = "" if (apply_chat_template and not self.config.gen_prefix)
if self.multiple_input: else self.config.target_delimiter
# If there are multiple inputs, choices are placed in the ctx )
# apply chat_template to choices if apply_chat_template if self.multiple_inputs:
cont = self.doc_to_target(doc) # If there are multiple inputs, assume only one choice
arguments = [(_ctx, f"{target_delimiter}{choices[0]}") for _ctx in ctx]
arguments = [
(
ctx
+ (
chat_template([{"role": "user", "content": choice}])
if apply_chat_template
else choice
),
f"{target_delimiter}{cont}",
)
for choice in choices
]
else: else:
# Otherwise they are placed in the continuation # Otherwise they are placed in the continuation
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_name for m in self.config._metric_list]: if "acc_mutual_info" in [m.metric_name for m in self.config._metric_list]:
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -1232,7 +1224,6 @@ class ConfigurableTask(Task): ...@@ -1232,7 +1224,6 @@ class ConfigurableTask(Task):
# here mutual info refers to calculating # here mutual info refers to calculating
# log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice)) # log(P(choice|ctx) / P(choice)) = log(P(choice|ctx)) - log(P(choice))
# in other words normalizing by subtracting the unconditional logprob of each choice. # in other words normalizing by subtracting the unconditional logprob of each choice.
# TODO: should these be strided? will have to modify the processing in process_results if so
aux_arguments = [ aux_arguments = [
("", f"{target_delimiter}{choice}") for choice in choices ("", f"{target_delimiter}{choice}") for choice in choices
] ]
...@@ -1327,7 +1318,11 @@ class ConfigurableTask(Task): ...@@ -1327,7 +1318,11 @@ class ConfigurableTask(Task):
lls, is_greedy = zip(*results) lls, is_greedy = zip(*results)
# retrieve choices in list[str] form, to compute choice lengths, etc. # retrieve choices in list[str] form, to compute choice lengths, etc.
choices = self.doc_to_choice(doc) choices = (
self.doc_to_choice(doc)
if not self.multiple_inputs
else cast(list[str], self.doc_to_text(doc))
)
completion_len = np.array([float(len(i)) for i in choices]) completion_len = np.array([float(len(i)) for i in choices])
if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric: if 2 * len(choices) == len(lls) and "acc_mutual_info" in use_metric:
...@@ -1345,20 +1340,26 @@ class ConfigurableTask(Task): ...@@ -1345,20 +1340,26 @@ class ConfigurableTask(Task):
pred = np.argmax(lls) pred = np.argmax(lls)
pred_norm = np.argmax(lls / completion_len) pred_norm = np.argmax(lls / completion_len)
if self.multiple_input: gold = backup = self.doc_to_target(doc)
gold = self.doc_to_text(doc)
if isinstance(gold, list):
gold = [validate_index(g, len(choices)) for g in gold]
gold_index_error = -100 in gold
else: else:
gold = self.doc_to_target(doc) if isinstance(gold, int):
gold = validate_index(gold, len(choices))
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
gold, gold_index_error = check_gold_index_error(choices, gold) gold_index_error = gold == -100
if gold_index_error: if gold_index_error:
eval_logger.warning( eval_logger.warning(
f"Label index was not in within range of available choices," f"Label [{backup}] index was not in within range of available choices {choices},"
f"Sample:\n\n{doc}\n\n" f"Sample:\n\n{doc}\n\n"
) )
if self.multiple_target: if self.multiple_targets:
acc = 1.0 if pred in gold else 0.0 acc = 1.0 if pred in gold else 0.0
acc_norm = 1.0 if pred_norm in gold else 0.0 acc_norm = 1.0 if pred_norm in gold else 0.0
exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold)) exact_match = int(any(is_greedy[i] if i != -100 else 0 for i in gold))
...@@ -1408,7 +1409,6 @@ class ConfigurableTask(Task): ...@@ -1408,7 +1409,6 @@ class ConfigurableTask(Task):
except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics except TypeError: # needed for now in order to use a different interface between our own metrics and HF Evaluate metrics
result_score = metric.fn([gold, result]) result_score = metric.fn([gold, result])
if isinstance(result_score, dict): if isinstance(result_score, dict):
# TODO: this handles the case where HF evaluate returns a dict.
# This allows for multiple metrics to be returned from the same function # This allows for multiple metrics to be returned from the same function
for k, v in result_score.items(): for k, v in result_score.items():
result_dict[k] = v result_dict[k] = v
......
...@@ -4,7 +4,7 @@ import textwrap ...@@ -4,7 +4,7 @@ import textwrap
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import yaml import yaml
...@@ -214,7 +214,7 @@ class EvaluatorConfig: ...@@ -214,7 +214,7 @@ class EvaluatorConfig:
# Parse string arguments that should be dictionaries # Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config) config = cls._parse_dict_args(config)
# Create instance and validate # Create an instance and validate
instance = cls(**config) instance = cls(**config)
if used_config: if used_config:
print(textwrap.dedent(f"""{instance}""")) print(textwrap.dedent(f"""{instance}"""))
...@@ -238,7 +238,7 @@ class EvaluatorConfig: ...@@ -238,7 +238,7 @@ class EvaluatorConfig:
return instance return instance
@staticmethod @staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]: def _parse_dict_args(config: dict[str, Any]) -> dict[str, Any]:
"""Parse string arguments that should be dictionaries.""" """Parse string arguments that should be dictionaries."""
for key in config: for key in config:
if key in DICT_KEYS and isinstance(config[key], str): if key in DICT_KEYS and isinstance(config[key], str):
...@@ -246,7 +246,7 @@ class EvaluatorConfig: ...@@ -246,7 +246,7 @@ class EvaluatorConfig:
return config return config
@staticmethod @staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: def load_yaml_config(config_path: Union[str, Path]) -> dict[str, Any]:
"""Load and validate YAML config file.""" """Load and validate YAML config file."""
config_file = ( config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path Path(config_path) if not isinstance(config_path, Path) else config_path
...@@ -257,9 +257,9 @@ class EvaluatorConfig: ...@@ -257,9 +257,9 @@ class EvaluatorConfig:
try: try:
yaml_data = yaml.safe_load(config_file.read_text()) yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e: except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}") raise ValueError(f"Invalid YAML in {config_path}: {e}") from e
except (OSError, UnicodeDecodeError) as e: except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}") raise ValueError(f"Could not read config file {config_path}: {e}") from e
if not isinstance(yaml_data, dict): if not isinstance(yaml_data, dict):
raise ValueError( raise ValueError(
...@@ -307,7 +307,7 @@ class EvaluatorConfig: ...@@ -307,7 +307,7 @@ class EvaluatorConfig:
raise ValueError("Need to specify task to evaluate.") raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None: def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed.""" """Process samples argument - load from a file if needed."""
if self.samples: if self.samples:
if isinstance(self.samples, dict): if isinstance(self.samples, dict):
self.samples = self.samples self.samples = self.samples
...@@ -328,7 +328,6 @@ class EvaluatorConfig: ...@@ -328,7 +328,6 @@ class EvaluatorConfig:
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager": def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names.""" """Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
# if metadata manually passed use that: # if metadata manually passed use that:
...@@ -365,7 +364,7 @@ class EvaluatorConfig: ...@@ -365,7 +364,7 @@ class EvaluatorConfig:
return task_manager return task_manager
def _set_trust_remote_code(self) -> None: def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply the trust_remote_code setting if enabled."""
if self.trust_remote_code: if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our # because it's already been determined based on the prior env var before launching our
......
...@@ -21,7 +21,7 @@ def serialize_callable( ...@@ -21,7 +21,7 @@ def serialize_callable(
return value return value
else: else:
try: try:
return getsource(value) return getsource(value) # type: ignore
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
......
...@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any: ...@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any:
module_parts = relative_path.replace(".py", "").replace("/", ".") module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}" module_name = f"lm_eval.tasks.{module_parts}"
else: else:
# Fallback to full path if pattern not found # Fallback to a full path if a pattern not found
module_name = str(module_path.with_suffix("")) module_name = str(module_path.with_suffix(""))
else: else:
# External module - use full path without extension # External module - use a full path without extension
module_name = str(module_path.with_suffix("")) module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module # Check if we need to reload the module
...@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any: ...@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any:
raise ImportError(f"Cannot load module from {module_path}") from None raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
# Store mtime for future checks # Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns module.__mtime__ = module_path.stat().st_mtime_ns # type: ignore
spec.loader.exec_module(module) # type: ignore[arg-type] spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module sys.modules[module_name] = module
return module return module
......
...@@ -32,9 +32,9 @@ class TaskFactory: ...@@ -32,9 +32,9 @@ class TaskFactory:
registry: Mapping[str, Entry], registry: Mapping[str, Entry],
): ):
""" """
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object • entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks) • entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion) • entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
""" """
if entry.kind is Kind.TAG: if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry) return self._build_tag(entry, overrides, registry)
...@@ -121,4 +121,4 @@ class TaskFactory: ...@@ -121,4 +121,4 @@ class TaskFactory:
def _ctor_accepts_config(cls) -> bool: def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None) init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters return bool(init and "config" in inspect.signature(init).parameters)
...@@ -61,8 +61,8 @@ class TaskManager: ...@@ -61,8 +61,8 @@ class TaskManager:
def load_spec(self, spec: str | dict[str, Any]): def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be: """Spec can be:
• str task / group / tag name (registered) • str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5} • dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
""" """
if isinstance(spec, str): if isinstance(spec, str):
entry = self._entry(spec) entry = self._entry(spec)
......
...@@ -8,8 +8,8 @@ Homepage: [google-research-datasets/natural-questions@master/nq_open](https://gi ...@@ -8,8 +8,8 @@ Homepage: [google-research-datasets/natural-questions@master/nq_open](https://gi
Paper: [aclanthology.org/P19-1612](https://aclanthology.org/P19-1612/) Paper: [aclanthology.org/P19-1612](https://aclanthology.org/P19-1612/)
Derived from the Natural Questions dataset, introduced in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf . Derived from the Natural Questions dataset, introduced
in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf .
### Citation ### Citation
...@@ -26,4 +26,5 @@ journal = {Transactions of the Association of Computational Linguistics}} ...@@ -26,4 +26,5 @@ journal = {Transactions of the Association of Computational Linguistics}}
* `nq_open` * `nq_open`
### Changelog ### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
* 2025-07-21: Added `multiple_targets` to `exact_match`. Scores should not change.
...@@ -5,7 +5,7 @@ training_split: train ...@@ -5,7 +5,7 @@ training_split: train
validation_split: validation validation_split: validation
description: "Answer these questions:\n\n" description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:" doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}" doc_to_target: answer
fewshot_delimiter: "\n" fewshot_delimiter: "\n"
generation_kwargs: generation_kwargs:
until: until:
...@@ -27,7 +27,7 @@ metric_list: ...@@ -27,7 +27,7 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
regexes_to_ignore: regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )" - "\\b(?:The |the |An |A |The |a |an )"
multi_target: true multiple_targets: true
metadata: metadata:
version: 4.0 version: 4.0
...@@ -79,3 +79,6 @@ If other tasks on this dataset are already supported: ...@@ -79,3 +79,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted? * [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
- 2025-07-22: `record` and `multirc`: set target_delimiter to "" and trim doc_to_text respectively.
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: boolq task: boolq
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: boolq dataset_name: boolq
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
tag: tag:
- super-glue-lm-eval-v1-seq2seq - super-glue-lm-eval-v1-seq2seq
task: "boolq-seq2seq" task: "boolq-seq2seq"
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: boolq dataset_name: boolq
output_type: generate_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
doc_to_target: label doc_to_target: "{{ [' no', ' yes'][label|int] }}"
doc_to_choice: [' no', ' yes']
target_delimiter: "" target_delimiter: ""
generation_kwargs: generation_kwargs:
until: until:
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-boolq-t5-prompt task: super_glue-boolq-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: boolq dataset_name: boolq
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "boolq passage: {{passage}} question: {{question}}" doc_to_text: "boolq passage: {{passage}} question: {{question}}"
doc_to_target: label doc_to_target: "{{['False', 'True'][label|int]}}"
doc_to_choice: ['False', 'True']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: cb task: cb
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: cb dataset_name: cb
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:" doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: label doc_to_target: label
doc_to_choice: ['True', 'False', 'Neither'] doc_to_choice: ["True", "False", "Neither"]
metric_list: metric_list:
- metric: acc - metric: acc
- metric: f1 - metric: f1
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-cb-t5-prompt task: super_glue-cb-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: cb dataset_name: cb
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}" doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label doc_to_target: "{{ ['entailment', 'contradiction', 'neutral'][label|int] }}"
doc_to_choice: ['entailment', 'contradiction', 'neutral']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment