"megatron/legacy/model/bert_model.py" did not exist on "52a5f2f272e5ef242eb271227da712f7dfc55da3"
Unverified Commit 003e5852 authored by Baber Abbasi's avatar Baber Abbasi Committed by GitHub
Browse files

Fewshot refactor (#3227)



* overhaul `ContextSampler`

* refactor masakhapos

* move multi_target to `exact_match`

* remove doc_to_choice from `boolq-seq2seq`

* remove doc_to_choice in generation process_results

* Remove unused `doc_to_choice` and fix superglue whitespaces

* require multiple_inputs and multiple_targets to be explicitly set in taskconfig

* fix copa; better logging in task init

* fix doc_to_target to return int rather than str (deprecated)

* fix processing regression; recursively parse lists fron template

* remove redundant jinja parsing logic

* remove promptsource

* for multiple_inputs use `doc_to_text: list[str]``

* Refactor `ContextSampler` `fewshot_context`

* fix multiple_input context

* fix `target_delimiter` with `gen_prefix`

* `doc_to_text` is list for multiple_inputs

* Refactor `count_bytes` and `count_words` methods to `@staticmethod`

* make has_*(train/test/validation) to properties

* remove `multi_target` `generate_until`

* `fix doc_to_target/multiple_targets handling add tests

* rename `multi_target` to `multiple_targets`

* evalaute list when multiple targets

* allow doc_to_target to return list

* Remove gen_prefix space and add warning (#3239)

* Remove gen_prefix space and add warning

* fix null gen_prefix bug again

* use git tests

---------
Co-authored-by: default avatarBoaz Ben-Dov <bendboaz@gmail.com>
parent 79a22a11
......@@ -8,8 +8,6 @@ on:
branches:
- 'main'
pull_request:
branches:
- 'main'
workflow_dispatch:
# Jobs run concurrently and steps run sequentially within a job.
# jobs: linter and cpu_tests. Add more jobs/steps as required.
......
......@@ -53,6 +53,6 @@ class FilterEnsemble:
resps = f().apply(resps, docs)
# add the end results after filtering to filtered_requests of their respective source instances.
# has key `self.name`: each FilterEnsemble applied in a given run should use a different name.
# has a key ` self.name `: each FilterEnsemble applied in a given run should use a different name.
for inst, resp in zip(instances, resps):
inst.filtered_resps[self.name] = resp
from dataclasses import dataclass, field
from typing import Literal, Optional, Tuple
from typing import Any, Literal, Optional
OutputType = Literal[
......@@ -10,10 +10,10 @@ OutputType = Literal[
@dataclass
class Instance:
request_type: OutputType
doc: dict
doc: dict[str, Any]
arguments: tuple
idx: int
metadata: Tuple[Optional[str], Optional[int], Optional[int]] = field(
metadata: tuple[Optional[str], Optional[int], Optional[int]] = field(
default_factory=lambda: (None, None, None),
metadata=dict(
description="Metadata tuple containing task name, document ID, and number of repeats."
......
......@@ -213,7 +213,7 @@ def exact_match_hf_evaluate(
ignore_case: bool = False,
ignore_punctuation: bool = False,
ignore_numbers: bool = False,
multi_target: bool = False,
multiple_targets: bool = False,
):
"""
Compute exact match scores between predictions and references.
......@@ -245,8 +245,8 @@ def exact_match_hf_evaluate(
- "exact_match" (float): The mean exact match score or 1.0/0.0 if `multi_target` is True.
"""
predictions, references = list(predictions), list(references)
assert len(predictions) == len(references) if not multi_target else True, (
"predictions and references must have the same length unless `multi_target` is True"
assert len(predictions) == len(references) if not multiple_targets else True, (
"predictions and references must have the same length unless `multiple_targets` is True"
)
if regexes_to_ignore is not None:
......@@ -275,7 +275,7 @@ def exact_match_hf_evaluate(
return {
"exact_match": np.mean(score_list)
if not multi_target
if not multiple_targets
else float(np.any(score_list))
}
......
......@@ -220,8 +220,8 @@ class Registry(Generic[T]):
>>> model_registry.register("lazy-name", lazy="mymodule:MyModel")
Raises:
ValueError: If alias already registered with different target
TypeError: If object doesn't inherit from base_cls (when specified)
ValueError: If alias is already registered with a different target
TypeError: If an object doesn't inherit from base_cls (when specified)
"""
def _store(alias: str, target: T | Placeholder) -> None:
......@@ -229,21 +229,27 @@ class Registry(Generic[T]):
# collision handling ------------------------------------------
if current is not None and current != target:
# allow placeholder → real object upgrade
if isinstance(current, str) and isinstance(target, type):
# mod, _, cls = current.partition(":")
if current == f"{target.__module__}:{target.__name__}":
self._objs[alias] = target
return
# mod, _, cls = current.partition(":")
if (
isinstance(current, str)
and isinstance(target, type)
and current == f"{target.__module__}:{target.__name__}"
):
self._objs[alias] = target
return
raise ValueError(
f"{self._name!r} alias '{alias}' already registered ("
f"existing={current}, new={target})"
)
# type check for concrete classes ----------------------------------------------
if self._base_cls is not None and isinstance(target, type):
if not issubclass(target, self._base_cls): # type: ignore[arg-type]
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
if (
self._base_cls is not None
and isinstance(target, type)
and not issubclass(target, self._base_cls)
):
raise TypeError(
f"{target} must inherit from {self._base_cls} to be a {self._name}"
)
self._objs[alias] = target
def decorator(obj: T) -> T: # type: ignore[valid-type]
......@@ -409,9 +415,7 @@ class MetricSpec:
from lm_eval.api.model import LM # noqa: E402
model_registry: Registry[type[LM]] = cast(
Registry[type[LM]], Registry("model", base_cls=LM)
)
model_registry = cast(Registry[type[LM]], Registry("model", base_cls=LM))
task_registry: Registry[Callable[..., Any]] = Registry("task")
metric_registry: Registry[MetricSpec] = Registry("metric")
metric_agg_registry: Registry[Callable[[Iterable[Any]], float]] = Registry(
......@@ -457,7 +461,7 @@ def register_metric(**kw):
then registers it in the metric registry.
Args:
**kw: Keyword arguments including:
**kw: Keyword arguments including
- metric: Name to register the metric under (required)
- aggregation: Name of aggregation function in metric_agg_registry
- higher_is_better: Whether higher scores are better (default: True)
......@@ -512,7 +516,7 @@ def get_metric(name, hf_evaluate_metric=False):
The metric's compute function
Raises:
KeyError: If metric not found in registry or HF evaluate
KeyError: If a metric is not found in registry or HF evaluate
"""
try:
spec = metric_registry.get(name)
......@@ -529,7 +533,7 @@ def get_metric(name, hf_evaluate_metric=False):
return hf.load(name).compute # type: ignore[attr-defined]
except Exception:
raise KeyError(f"Metric '{name}' not found anywhere")
raise KeyError(f"Metric '{name}' not found anywhere") from None
register_metric_aggregation = metric_agg_registry.register
......
from __future__ import annotations
import logging
import warnings
from collections.abc import Iterable, Sequence
from functools import partial
from typing import TYPE_CHECKING, Any
import datasets
from random import Random
from typing import TYPE_CHECKING
if TYPE_CHECKING:
from random import Random
from collections.abc import Iterable, Sequence
from typing import Any, TypeVar
from lm_eval.api.task import ConfigurableTask, Task
_T = TypeVar("_T")
eval_logger = logging.getLogger("lm-eval")
eval_logger = logging.getLogger(__name__)
class ContextSampler:
def __init__(
self,
docs: list[dict],
task: Task | ConfigurableTask,
fewshot_indices: Iterable | None = None,
rnd: Random | None = None,
docs: Sequence[dict[str, Any]] | None = None,
*,
rnd: int | None = None,
fewshot_indices: list[int] | None = None,
**kwargs,
) -> None:
self.rnd = rnd
if not self.rnd:
raise ValueError(
"A `random.Random` generator argument must be provided to `rnd` of FewShotSampler!"
)
self.task = task
self.config = task._config
self.target_delimiter = self.config.target_delimiter
self.fewshot_delimiter = self.config.fewshot_delimiter
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_text", None) is not None
):
self.doc_to_text = partial(
self.task.doc_to_text,
doc_to_text=self.config.fewshot_config.get("doc_to_text", None),
)
else:
self.doc_to_text = self.task.doc_to_text
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_target", None) is not None
):
self.doc_to_target = partial(
self.task.doc_to_target,
doc_to_target=self.config.fewshot_config.get("doc_to_target", None),
)
else:
self.doc_to_target = self.task.doc_to_target
if (
self.config.fewshot_config is not None
and self.config.fewshot_config.get("doc_to_choice", None) is not None
):
self.doc_to_choice = partial(
self.task.doc_to_choice,
doc_to_choice=self.config.fewshot_config.get("doc_to_choice", None),
)
else:
self.doc_to_choice = self.task.doc_to_choice
self.docs = docs # HF dataset split, provided by task._fewshot_docs()
if fewshot_indices: # subset few-shot docs from
if not isinstance(self.docs, datasets.Dataset):
raise ValueError(
"Got `fewshot_indices` but fewshot_docs are not a HF dataset. Don't use both `fewshot_indices` and a user-defined few-shot sample list simultaneously"
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
self.rnd = Random(rnd)
self.docs = docs or []
self.fewshot_indices = fewshot_indices
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
labeled_examples = ""
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content
else:
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
# TODO: add logger warn once here.
warnings.warn(
"Both target_delimiter and target start with a space. This may cause issues.",
Warning,
stacklevel=2,
)
labeled_examples += self.target_delimiter
labeled_examples += prefix
labeled_examples += (
str(doc_target[0])
if isinstance(doc_target, list)
else doc_target
if self.config.doc_to_choice is None or isinstance(doc_target, str)
else str(self.doc_to_choice(doc)[doc_target])
)
labeled_examples += self.fewshot_delimiter
return labeled_examples
def get_chat_context(
self,
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
gen_prefix: str | None = None,
):
# TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else ""
chat_history = []
# draw an extra fewshot sample if using same split as evaluating on
n_samples = (
num_fewshot + 1
if self.config.fewshot_split == self.config.test_split
else num_fewshot
)
# draw `n_samples` docs from fewshot_docs
fewshotex = self.sample(n_samples)
# get rid of the doc that's the one we're evaluating, if it's in the fewshot
# TODO: should we just stop people from using fewshot from same split as evaluating?
selected_docs = [x for x in fewshotex if x != doc][:num_fewshot]
if fewshot_as_multiturn:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
chat_history.append(
{
"role": "user",
"content": doc_content
if self.config.doc_to_choice is None
or isinstance(doc_content, str)
else self.doc_to_choice(doc)[doc_content],
}
)
chat_history.append(
{
"role": "assistant",
"content": prefix + str(doc_target[0])
if isinstance(doc_target, list)
else prefix + doc_target
if self.config.doc_to_choice is None
or isinstance(doc_target, str)
else prefix + str(self.doc_to_choice(doc)[doc_target]),
}
)
else:
# get fewshot context as one user turn
chat_history.append(
{
"role": "user",
"content": self.get_context(
doc, num_fewshot, gen_prefix=gen_prefix
),
}
)
return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]:
if self.fewshot_indices and self.docs:
self.docs = [self.docs[i] for i in self.fewshot_indices]
def sample(
self, n: int, doc: dict[str, Any] | None = None, **kwargs
) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
Sample n documents from the pool.
Args:
n: Number of documents to sample
doc: Optional document to exclude from sampling
Returns:
List of sampled documents
"""
assert self.rnd is not None, (
"Error: `rnd` must be set to a random.Random instance before sampling."
if n <= 0:
return []
return (
self.rnd.sample(self.docs, n)
if not doc
else self.remove_doc(doc, self.rnd.sample(self.docs, n + 1))
)
return self.rnd.sample(self.docs, n)
def set_rnd(self, rnd: int) -> None:
self.rnd = Random(rnd)
@staticmethod
def remove_doc(doc: _T, _iter: Iterable[_T]) -> list[_T]:
return [x for x in _iter if x != doc]
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict[str, Any]]:
def sample(self, n: int, doc=None, **kwargs):
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......@@ -214,7 +72,7 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler):
def sample(self, n: int):
def sample(self, n: int, doc=None, **kwargs):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
......@@ -224,7 +82,7 @@ class BalancedSampler(ContextSampler):
class ManualSampler(ContextSampler):
def sample(self, n: int):
def sample(self, n: int, doc=None, **kwargs):
""" """
raise NotImplementedError
......
This diff is collapsed.
......@@ -4,7 +4,7 @@ import textwrap
from argparse import Namespace
from dataclasses import asdict, dataclass, field
from pathlib import Path
from typing import TYPE_CHECKING, Any, Dict, Optional, Union
from typing import TYPE_CHECKING, Any, Optional, Union
import yaml
......@@ -214,7 +214,7 @@ class EvaluatorConfig:
# Parse string arguments that should be dictionaries
config = cls._parse_dict_args(config)
# Create instance and validate
# Create an instance and validate
instance = cls(**config)
if used_config:
print(textwrap.dedent(f"""{instance}"""))
......@@ -238,7 +238,7 @@ class EvaluatorConfig:
return instance
@staticmethod
def _parse_dict_args(config: Dict[str, Any]) -> Dict[str, Any]:
def _parse_dict_args(config: dict[str, Any]) -> dict[str, Any]:
"""Parse string arguments that should be dictionaries."""
for key in config:
if key in DICT_KEYS and isinstance(config[key], str):
......@@ -246,7 +246,7 @@ class EvaluatorConfig:
return config
@staticmethod
def load_yaml_config(config_path: Union[str, Path]) -> Dict[str, Any]:
def load_yaml_config(config_path: Union[str, Path]) -> dict[str, Any]:
"""Load and validate YAML config file."""
config_file = (
Path(config_path) if not isinstance(config_path, Path) else config_path
......@@ -257,9 +257,9 @@ class EvaluatorConfig:
try:
yaml_data = yaml.safe_load(config_file.read_text())
except yaml.YAMLError as e:
raise ValueError(f"Invalid YAML in {config_path}: {e}")
raise ValueError(f"Invalid YAML in {config_path}: {e}") from e
except (OSError, UnicodeDecodeError) as e:
raise ValueError(f"Could not read config file {config_path}: {e}")
raise ValueError(f"Could not read config file {config_path}: {e}") from e
if not isinstance(yaml_data, dict):
raise ValueError(
......@@ -307,7 +307,7 @@ class EvaluatorConfig:
raise ValueError("Need to specify task to evaluate.")
def _process_arguments(self) -> None:
"""Process samples argument - load from file if needed."""
"""Process samples argument - load from a file if needed."""
if self.samples:
if isinstance(self.samples, dict):
self.samples = self.samples
......@@ -328,7 +328,6 @@ class EvaluatorConfig:
def process_tasks(self, metadata: Optional[dict] = None) -> "TaskManager":
"""Process and validate tasks, return resolved task names."""
from lm_eval import utils
from lm_eval.tasks import TaskManager
# if metadata manually passed use that:
......@@ -365,7 +364,7 @@ class EvaluatorConfig:
return task_manager
def _set_trust_remote_code(self) -> None:
"""Apply trust_remote_code setting if enabled."""
"""Apply the trust_remote_code setting if enabled."""
if self.trust_remote_code:
# HACK: import datasets and override its HF_DATASETS_TRUST_REMOTE_CODE value internally,
# because it's already been determined based on the prior env var before launching our
......
......@@ -21,7 +21,7 @@ def serialize_callable(
return value
else:
try:
return getsource(value)
return getsource(value) # type: ignore
except (TypeError, OSError):
return str(value)
......
......@@ -60,10 +60,10 @@ def _load_module_with_cache(module_path: Path) -> Any:
module_parts = relative_path.replace(".py", "").replace("/", ".")
module_name = f"lm_eval.tasks.{module_parts}"
else:
# Fallback to full path if pattern not found
# Fallback to a full path if a pattern not found
module_name = str(module_path.with_suffix(""))
else:
# External module - use full path without extension
# External module - use a full path without extension
module_name = str(module_path.with_suffix(""))
# Check if we need to reload the module
......@@ -84,7 +84,7 @@ def _load_module_with_cache(module_path: Path) -> Any:
raise ImportError(f"Cannot load module from {module_path}") from None
module = importlib.util.module_from_spec(spec)
# Store mtime for future checks
module.__mtime__ = module_path.stat().st_mtime_ns
module.__mtime__ = module_path.stat().st_mtime_ns # type: ignore
spec.loader.exec_module(module) # type: ignore[arg-type]
sys.modules[module_name] = module
return module
......
......@@ -32,9 +32,9 @@ class TaskFactory:
registry: Mapping[str, Entry],
):
"""
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
• entry.kind == TASK / PY_TASK ➜ returns instantiated task object
• entry.kind == GROUP ➜ returns (GroupConfig, mapping-of-subtasks)
• entry.kind == TAG ➜ returns mapping-of-tasks (tag expansion)
"""
if entry.kind is Kind.TAG:
return self._build_tag(entry, overrides, registry)
......@@ -121,4 +121,4 @@ class TaskFactory:
def _ctor_accepts_config(cls) -> bool:
init = getattr(cls, "__init__", None)
return init and "config" in inspect.signature(init).parameters
return bool(init and "config" in inspect.signature(init).parameters)
......@@ -61,8 +61,8 @@ class TaskManager:
def load_spec(self, spec: str | dict[str, Any]):
"""Spec can be:
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
• str task / group / tag name (registered)
• dict inline overrides {'task': 'hellaswag', 'num_fewshot': 5}
"""
if isinstance(spec, str):
entry = self._entry(spec)
......
......@@ -8,8 +8,8 @@ Homepage: [google-research-datasets/natural-questions@master/nq_open](https://gi
Paper: [aclanthology.org/P19-1612](https://aclanthology.org/P19-1612/)
Derived from the Natural Questions dataset, introduced in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf .
Derived from the Natural Questions dataset, introduced
in https://storage.googleapis.com/gweb-research2023-media/pubtools/pdf/1f7b46b5378d757553d3e92ead36bda2e4254244.pdf .
### Citation
......@@ -26,4 +26,5 @@ journal = {Transactions of the Association of Computational Linguistics}}
* `nq_open`
### Changelog
* 2025-07-21: Added `multi_target` to `exact_match`. Scores should not change.
* 2025-07-21: Added `multiple_targets` to `exact_match`. Scores should not change.
......@@ -5,7 +5,7 @@ training_split: train
validation_split: validation
description: "Answer these questions:\n\n"
doc_to_text: "Q: {{question}}?\nA:"
doc_to_target: "{{answer}}"
doc_to_target: answer
fewshot_delimiter: "\n"
generation_kwargs:
until:
......@@ -27,7 +27,7 @@ metric_list:
ignore_case: true
ignore_punctuation: true
regexes_to_ignore:
- "\\b(?:The |the |An |A |The |a |an )"
multi_target: true
- "\\b(?:The |the |An |A |The |a |an )"
multiple_targets: true
metadata:
version: 4.0
......@@ -79,3 +79,6 @@ If other tasks on this dataset are already supported:
* [ ] Is the "Main" variant of this task clearly denoted?
* [ ] Have you provided a short sentence in a README on what each new variant adds / evaluates?
* [ ] Have you noted which, if any, published evaluation setups are matched by this variant?
### Changelog
- 2025-07-22: `record` and `multirc`: set target_delimiter to "" and trim doc_to_text respectively.
tag:
- super-glue-lm-eval-v1
task: boolq
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: boolq
output_type: multiple_choice
training_split: train
......
tag:
- super-glue-lm-eval-v1-seq2seq
task: "boolq-seq2seq"
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: boolq
output_type: generate_until
training_split: train
validation_split: validation
doc_to_text: "{{passage}}\nQuestion: {{question}}?\nAnswer:"
doc_to_target: label
doc_to_choice: [' no', ' yes']
doc_to_target: "{{ [' no', ' yes'][label|int] }}"
target_delimiter: ""
generation_kwargs:
until:
......
tag:
- super-glue-t5-prompt
task: super_glue-boolq-t5-prompt
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: boolq
training_split: train
validation_split: validation
output_type: generate_until
doc_to_text: "boolq passage: {{passage}} question: {{question}}"
doc_to_target: label
doc_to_choice: ['False', 'True']
doc_to_target: "{{['False', 'True'][label|int]}}"
generation_kwargs:
until:
- "</s>"
......
tag:
- super-glue-lm-eval-v1
task: cb
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: cb
output_type: multiple_choice
training_split: train
validation_split: validation
doc_to_text: "{{premise}}\nQuestion: {{hypothesis}}. True, False, or Neither?\nAnswer:"
doc_to_target: label
doc_to_choice: ['True', 'False', 'Neither']
doc_to_choice: ["True", "False", "Neither"]
metric_list:
- metric: acc
- metric: f1
......
tag:
- super-glue-t5-prompt
task: super_glue-cb-t5-prompt
dataset_path: super_glue
dataset_path: aps/super_glue
dataset_name: cb
training_split: train
validation_split: validation
output_type: generate_until
doc_to_text: "cb hypothesis: {{hypothesis}} premise: {{premise}}"
doc_to_target: label
doc_to_choice: ['entailment', 'contradiction', 'neutral']
doc_to_target: "{{ ['entailment', 'contradiction', 'neutral'][label|int] }}"
generation_kwargs:
until:
- "</s>"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment