Unverified Commit 003e5852 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Fewshot refactor (#3227)



* overhaul `ContextSampler`

* refactor masakhapos

* move multi_target to `exact_match`

* remove doc_to_choice from `boolq-seq2seq`

* remove doc_to_choice in generation process_results

* Remove unused `doc_to_choice` and fix superglue whitespaces

* require multiple_inputs and multiple_targets to be explicitly set in taskconfig

* fix copa; better logging in task init

* fix doc_to_target to return int rather than str (deprecated)

* fix processing regression; recursively parse lists fron template

* remove redundant jinja parsing logic

* remove promptsource

* for multiple_inputs use `doc_to_text: list[str]``

* Refactor `ContextSampler` `fewshot_context`

* fix multiple_input context

* fix `target_delimiter` with `gen_prefix`

* `doc_to_text` is list for multiple_inputs

* Refactor `count_bytes` and `count_words` methods to `@staticmethod`

* make has_*(train/test/validation) to properties

* remove `multi_target` `generate_until`

* `fix doc_to_target/multiple_targets handling add tests

* rename `multi_target` to `multiple_targets`

* evalaute list when multiple targets

* allow doc_to_target to return list

* Remove gen_prefix space and add warning (#3239)

* Remove gen_prefix space and add warning

* fix null gen_prefix bug again

* use git tests

---------
Co-authored-by: default avatarBoaz Ben-Dov <bendboaz@gmail.com>
parent 79a22a11
...@@ -8,8 +8,6 @@ on: ...@@ -8,8 +8,6 @@ on:
branches: branches:
- 'main' - 'main'
pull_request: pull_request:
branches:
- 'main'
workflow_dispatch: workflow_dispatch:
# Jobs run concurrently and steps run sequentially within a job. # Jobs run concurrently and steps run sequentially within a job.
# jobs: linter and cpu_tests. Add more jobs/steps as required. # jobs: linter and cpu_tests. Add more jobs/steps as required.
......
...@@ -53,6 +53,6 @@ class FilterEnsemble: ...@@ -53,6 +53,6 @@ class FilterEnsemble:
resps = f().apply(resps, docs) resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances. # add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name. # has a key ` self.name `: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps): for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple from typing import Any, Literal, Optional
OutputType = Literal[ OutputType = Literal[
...@@ -10,10 +10,10 @@ OutputType = Literal[ ...@@ -10,10 +10,10 @@ OutputType = Literal[
@dataclass @dataclass
class Instance: class Instance:
request_type: OutputType request_type: OutputType
doc: dict doc: dict[str, Any]
arguments: tuple arguments: tuple
idx: int idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field( metadata: tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None), default_factory=lambda: (None, None, None),
metadata=dict( metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats." description="Metadata tuple containing task name, document ID, and number of repeats."
......
...@@ -213,7 +213,7 @@ def exact_match_hf_evaluate( ...@@ -213,7 +213,7 @@ def exact_match_hf_evaluate(
ignore_case: bool = False, ignore_case: bool = False,
ignore_punctuation: bool = False, ignore_punctuation: bool = False,
ignore_numbers: bool = False, ignore_numbers: bool = False,
multi_target: bool = False, multiple_targets: bool = False,
): ):
""" """
Compute exact match scores between predictions and references. Compute exact match scores between predictions and references.
...@@ -245,8 +245,8 @@ def exact_match_hf_evaluate( ...@@ -245,8 +245,8 @@ def exact_match_hf_evaluate(
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True. - "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
""" """
predictions, references = list(predictions), list(references) predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, ( assert len(predictions) == len(references) if not multiple_targets else True, (
"predictions and references must have the same length unless `multi_target` is True" "predictions and references must have the same length unless `multiple_targets` is True"
) )
if regexes_to_ignore is not None: if regexes_to_ignore is not None:
...@@ -275,7 +275,7 @@ def exact_match_hf_evaluate( ...@@ -275,7 +275,7 @@ def exact_match_hf_evaluate(
return { return {
"exact_match": np.mean(score_list) "exact_match": np.mean(score_list)
if not multi_target if not multiple_targets
else float(np.any(score_list)) else float(np.any(score_list))
} }
......
...@@ -220,8 +220,8 @@ class Registry(Generic[T]): ...@@ -220,8 +220,8 @@ class Registry(Generic[T]):
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel") >>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises: Raises:
ValueError: If alias already registered with different target ValueError: If alias is already registered with a different target
TypeError: If object doesn't inherit from base_cls (when specified) TypeError: If an object doesn't inherit from base_cls (when specified)
""" """
def _store(alias: str, target: T | Placeholder) -> None: def _store(alias: str, target: T | Placeholder) -> None:
...@@ -229,21 +229,27 @@ class Registry(Generic[T]): ...@@ -229,21 +229,27 @@ class Registry(Generic[T]):
# collision handling ------------------------------------------ # collision handling ------------------------------------------
if current is not None and current != target: if current is not None and current != target:
# allow placeholder → real object upgrade # allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type): # mod, _, cls = current.partition(":")
# mod, _, cls = current.partition(":") if (
if current == f"{target.__module__}:{target.__name__}": isinstance(current, str)
self._objs[alias] = target and isinstance(target, type)
return and current == f"{target.__module__}:{target.__name__}"
):
self._objs[alias] = target
return
raise ValueError( raise ValueError(
f"{self._name!r} alias '{alias}' already registered (" f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})" f"existing={current}, new={target})"
) )
# type check for concrete classes ---------------------------------------------- # type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type): if (
if not issubclass(target, self._base_cls): # type: ignore[arg-type] self._base_cls is not None
raise TypeError( and isinstance(target, type)
f"{target} must inherit from {self._base_cls} to be a {self._name}" and not issubclass(target, self._base_cls)
) ):
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type] def decorator(obj: T) -> T: # type: ignore[valid-type]
...@@ -409,9 +415,7 @@ class MetricSpec: ...@@ -409,9 +415,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402 from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast( model_registry = cast(Registry[type[LM]], Registry("model", base_cls=LM))
Registry[type[LM]], Registry("model", base_cls=LM)
)
task_registry: Registry[Callable[..., Any]] = Registry("task") task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric") metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry( metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
...@@ -457,7 +461,7 @@ def register_metric(**kw): ...@@ -457,7 +461,7 @@ def register_metric(**kw):
then registers it in the metric registry. then registers it in the metric registry.
Args: Args:
**kw: Keyword arguments including: **kw: Keyword arguments including
- metric: Name to register the metric under (required) - metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry - aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True) - higher_is_better: Whether higher scores are better (default: True)
...@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False):
The metric's compute function The metric's compute function
Raises: Raises:
KeyError: If metric not found in registry or HF evaluate KeyError: If a metric is not found in registry or HF evaluate
""" """
try: try:
spec = metric_registry.get(name) spec = metric_registry.get(name)
...@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False): ...@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False):
return hf.load(name).compute # type: ignore[attr-defined] return hf.load(name).compute # type: ignore[attr-defined]
except Exception: except Exception:
raise KeyError(f"Metric '{name}' not found anywhere") raise KeyError(f"Metric '{name}' not found anywhere") from None
register_metric_aggregation = metric_agg_registry.register register_metric_aggregation = metric_agg_registry.register
......
from __future__ import annotations from __future__ import annotations
import logging import logging
import warnings from random import Random
from collections.abc import Iterable, Sequence from typing import TYPE_CHECKING
from functools import partial
from typing import TYPE_CHECKING, Any
import datasets
if TYPE_CHECKING: if TYPE_CHECKING:
from random import Random from collections.abc import Iterable, Sequence
from typing import Any, TypeVar
from lm_eval.api.task import ConfigurableTask, Task _T = TypeVar("_T")
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger(__name__)
class ContextSampler: class ContextSampler:
def __init__( def __init__(
self, self,
docs: list[dict], docs: Sequence[dict[str, Any]] | None = None,
task: Task | ConfigurableTask, *,
fewshot_indices: Iterable | None = None, rnd: int | None = None,
rnd: Random | None = None, fewshot_indices: list[int] | None = None,
**kwargs,
) -> None: ) -> None:
self.rnd = rnd self.rnd = Random(rnd)
if not self.rnd: self.docs = docs or []
raise ValueError( self.fewshot_indices = fewshot_indices
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
self.task = task
self.config = task._config
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
)
else:
self.doc_to_text = self.task.doc_to_text
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
)
else:
self.doc_to_target = self.task.doc_to_target
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
)
else:
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
if not isinstance(self.docs, datasets.Dataset):
raise ValueError(
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs if self.fewshot_indices and self.docs:
fewshotex = self.sample(n_samples) self.docs = [self.docs[i] for i in self.fewshot_indices]
# get rid of the doc that's the one we're evaluating, if it's in the fewshot def sample(
# TODO: should we just stop people from using fewshot from same split as evaluating? self, n: int, doc: dict[str, Any] | None = None, **kwargs
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot] ) -> Sequence[dict]:
labeled_examples = ""
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content
else:
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
# TODO: add logger warn once here.
warnings.warn(
"Both target_delimiter and target start with a space. This may cause issues.",
Warning,
stacklevel=2,
)
labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += (
str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
if self.config.doc_to_choice is None or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target])
)
labeled_examples += self.fewshot_delimiter
return labeled_examples
def get_chat_context(
self,
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
gen_prefix: str | None = None,
):
# TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
if fewshot_as_multiturn:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
chat_history.append(
{
"role": "user",
"content": doc_content
if self.config.doc_to_choice is None
or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content],
}
)
chat_history.append(
{
"role": "assistant",
"content": prefix + str(doc_target[0])
if isinstance(doc_target, list)
else prefix + doc_target
if self.config.doc_to_choice is None
or isinstance(doc_target, str)
else prefix + str(self.doc_to_choice(doc)[doc_target]),
}
)
else:
# get fewshot context as one user turn
chat_history.append(
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, gen_prefix=gen_prefix
),
}
)
return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]:
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Sample n documents from the pool.
Args:
n: Number of documents to sample
doc: Optional document to exclude from sampling
Returns:
List of sampled documents
""" """
assert self.rnd is not None, ( if n <= 0:
"Error: `rnd` must be set to a random.Random instance before sampling." return []
return (
self.rnd.sample(self.docs, n)
if not doc
else self.remove_doc(doc, self.rnd.sample(self.docs, n + 1))
) )
return self.rnd.sample(self.docs, n)
def set_rnd(self, rnd: int) -> None:
self.rnd = Random(rnd)
@staticmethod
def remove_doc(doc: _T, _iter: Iterable[_T]) -> list[_T]:
return [x for x in _iter if x != doc]
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict[str, Any]]: def sample(self, n: int, doc=None, **kwargs):
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...@@ -214,7 +72,7 @@ class FirstNSampler(ContextSampler): ...@@ -214,7 +72,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler): class BalancedSampler(ContextSampler):
def sample(self, n: int): def sample(self, n: int, doc=None, **kwargs):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
...@@ -224,7 +82,7 @@ class BalancedSampler(ContextSampler): ...@@ -224,7 +82,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler): class ManualSampler(ContextSampler):
def sample(self, n: int): def sample(self, n: int, doc=None, **kwargs):
""" """ """ """
raise NotImplementedError raise NotImplementedError
......
This diff is collapsed.
...@@ -4,7 +4,7 @@ import textwrap ...@@ -4,7 +4,7 @@ import textwrap
from argparse import Namespace from argparse import Namespace
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union from typing import TYPE_CHECKING, Any, Optional, Union
import yaml import yaml
...@@ -214,7 +214,7 @@ class EvaluatorConfig: ...@@ -214,7 +214,7 @@ class EvaluatorConfig:
# Parse string arguments that should be dictionaries # Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config) config = cls._parse_dict_args(config)
# Create instance and validate # Create an instance and validate
instance = cls(**config) instance = cls(**config)
if used_config: if used_config:
print(textwrap.dedent(f"""{instance}""")) print(textwrap.dedent(f"""{instance}"""))
...@@ -238,7 +238,7 @@ class EvaluatorConfig: ...@@ -238,7 +238,7 @@ class EvaluatorConfig:
return instance return instance
@staticmethod @staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]: def _parse_dict_args(config: dict[str, Any]) -> dict[str, Any]:
"""Parse string arguments that should be dictionaries.""" """Parse string arguments that should be dictionaries."""
for key in config: for key in config:
if key in DICT_KEYS and isinstance(config[key], str): if key in DICT_KEYS and isinstance(config[key], str):
...@@ -246,7 +246,7 @@ class EvaluatorConfig: ...@@ -246,7 +246,7 @@ class EvaluatorConfig:
return config return config
@staticmethod @staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]: def load_yaml_config(config_path: Union[str, Path]) -> dict[str, Any]:
"""Load and validate YAML config file.""" """Load and validate YAML config file."""
config_file = ( config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path Path(config_path) if not isinstance(config_path, Path) else config_path
...@@ -257,9 +257,9 @@ class EvaluatorConfig: ...@@ -257,9 +257,9 @@ class EvaluatorConfig:
try: try:
yaml_data = yaml.safe_load(config_file.read_text()) yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e: except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}") raise ValueError(f"Invalid YAML in {config_path}: {e}") from e
except (OSError, UnicodeDecodeError) as e: except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}") raise ValueError(f"Could not read config file {config_path}: {e}") from e
if not isinstance(yaml_data, dict): if not isinstance(yaml_data, dict):
raise ValueError( raise ValueError(
...@@ -307,7 +307,7 @@ class EvaluatorConfig: ...@@ -307,7 +307,7 @@ class EvaluatorConfig:
raise ValueError("Need to specify task to evaluate.") raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None: def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed.""" """Process samples argument - load from a file if needed."""
if self.samples: if self.samples:
if isinstance(self.samples, dict): if isinstance(self.samples, dict):
self.samples = self.samples self.samples = self.samples
...@@ -328,7 +328,6 @@ class EvaluatorConfig: ...@@ -328,7 +328,6 @@ class EvaluatorConfig:
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager": def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names.""" """Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager from lm_eval.tasks import TaskManager
# if metadata manually passed use that: # if metadata manually passed use that:
...@@ -365,7 +364,7 @@ class EvaluatorConfig: ...@@ -365,7 +364,7 @@ class EvaluatorConfig:
return task_manager return task_manager
def _set_trust_remote_code(self) -> None: def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled.""" """Apply the trust_remote_code setting if enabled."""
if self.trust_remote_code: if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally, # HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our # because it's already been determined based on the prior env var before launching our
......
...@@ -21,7 +21,7 @@ def serialize_callable( ...@@ -21,7 +21,7 @@ def serialize_callable(
return value return value
else: else:
try: try:
return getsource(value) return getsource(value) # type: ignore
except (TypeError, OSError): except (TypeError, OSError):
return str(value) return str(value)
......
...@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any: ...@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any:
module_parts = relative_path.replace(".py", "").replace("/", ".") module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}" module_name = f"lm_eval.tasks.{module_parts}"
else: else:
# Fallback to full path if pattern not found # Fallback to a full path if a pattern not found
module_name = str(module_path.with_suffix("")) module_name = str(module_path.with_suffix(""))
else: else:
# External module - use full path without extension # External module - use a full path without extension
module_name = str(module_path.with_suffix("")) module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module # Check if we need to reload the module
...@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any: ...@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any:
raise ImportError(f"Cannot load module from {module_path}") from None raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec) module = importlib.util.module_from_spec(spec)
# Store mtime for future checks # Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns module.__mtime__ = module_path.stat().st_mtime_ns # type: ignore
spec.loader.exec_module(module) # type: ignore[arg-type] spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module sys.modules[module_name] = module
return module return module
......
...@@ -32,9 +32,9 @@ class TaskFactory: ...@@ -32,9 +32,9 @@ class TaskFactory:
registry: Mapping[str, Entry], registry: Mapping[str, Entry],
): ):
""" """
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object • entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks) • entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion) • entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
""" """
if entry.kind is Kind.TAG: if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry) return self._build_tag(entry, overrides, registry)
...@@ -121,4 +121,4 @@ class TaskFactory: ...@@ -121,4 +121,4 @@ class TaskFactory:
def _ctor_accepts_config(cls) -> bool: def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None) init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters return bool(init and "config" in inspect.signature(init).parameters)
...@@ -61,8 +61,8 @@ class TaskManager: ...@@ -61,8 +61,8 @@ class TaskManager:
def load_spec(self, spec: str | dict[str, Any]): def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be: """Spec can be:
• str task / group / tag name (registered) • str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5} • dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
""" """
if isinstance(spec, str): if isinstance(spec, str):
entry = self._entry(spec) entry = self._entry(spec)
......
...@@ -8,8 +8,8 @@ Homepage: [google-research-datasets/natural-questions@master/nq_open](https://gi ...@@ -8,8 +8,8 @@ Homepage: [google-research-datasets/natural-questions@master/nq_open](https://gi
Paper: [aclanthology.org/P19-1612](https://aclanthology.org/P19-1612/) Paper: [aclanthology.org/P19-1612](https://aclanthology.org/P19-1612/)
Derived from the Natural Questions dataset, introduced in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf . Derived from the Natural Questions dataset, introduced
in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf .
### Citation ### Citation
...@@ -26,4 +26,5 @@ journal = {Transactions of the Association of Computational Linguistics}} ...@@ -26,4 +26,5 @@ journal = {Transactions of the Association of Computational Linguistics}}
* `nq_open` * `nq_open`
### Changelog ### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
* 2025-07-21: Added `multiple_targets` to `exact_match`. Scores should not change.
...@@ -5,7 +5,7 @@ training_split: train ...@@ -5,7 +5,7 @@ training_split: train
validation_split: validation validation_split: validation
description: "Answer these questions:\n\n" description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:" doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}" doc_to_target: answer
fewshot_delimiter: "\n" fewshot_delimiter: "\n"
generation_kwargs: generation_kwargs:
until: until:
...@@ -27,7 +27,7 @@ metric_list: ...@@ -27,7 +27,7 @@ metric_list:
ignore_case: true ignore_case: true
ignore_punctuation: true ignore_punctuation: true
regexes_to_ignore: regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )" - "\\b(?:The |the |An |A |The |a |an )"
multi_target: true multiple_targets: true
metadata: metadata:
version: 4.0 version: 4.0
...@@ -79,3 +79,6 @@ If other tasks on this dataset are already supported: ...@@ -79,3 +79,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted? * [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates? * [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant? * [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
- 2025-07-22: `record` and `multirc`: set target_delimiter to "" and trim doc_to_text respectively.
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: boolq task: boolq
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: boolq dataset_name: boolq
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
......
tag: tag:
- super-glue-lm-eval-v1-seq2seq - super-glue-lm-eval-v1-seq2seq
task: "boolq-seq2seq" task: "boolq-seq2seq"
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: boolq dataset_name: boolq
output_type: generate_until output_type: generate_until
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:" doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
doc_to_target: label doc_to_target: "{{ [' no', ' yes'][label|int] }}"
doc_to_choice: [' no', ' yes']
target_delimiter: "" target_delimiter: ""
generation_kwargs: generation_kwargs:
until: until:
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-boolq-t5-prompt task: super_glue-boolq-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: boolq dataset_name: boolq
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "boolq passage: {{passage}} question: {{question}}" doc_to_text: "boolq passage: {{passage}} question: {{question}}"
doc_to_target: label doc_to_target: "{{['False', 'True'][label|int]}}"
doc_to_choice: ['False', 'True']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
tag: tag:
- super-glue-lm-eval-v1 - super-glue-lm-eval-v1
task: cb task: cb
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: cb dataset_name: cb
output_type: multiple_choice output_type: multiple_choice
training_split: train training_split: train
validation_split: validation validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:" doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: label doc_to_target: label
doc_to_choice: ['True', 'False', 'Neither'] doc_to_choice: ["True", "False", "Neither"]
metric_list: metric_list:
- metric: acc - metric: acc
- metric: f1 - metric: f1
......
tag: tag:
- super-glue-t5-prompt - super-glue-t5-prompt
task: super_glue-cb-t5-prompt task: super_glue-cb-t5-prompt
dataset_path: super_glue dataset_path: aps/super_glue
dataset_name: cb dataset_name: cb
training_split: train training_split: train
validation_split: validation validation_split: validation
output_type: generate_until output_type: generate_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}" doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label doc_to_target: "{{ ['entailment', 'contradiction', 'neutral'][label|int] }}"
doc_to_choice: ['entailment', 'contradiction', 'neutral']
generation_kwargs: generation_kwargs:
until: until:
- "</s>" - "</s>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment