Commit d9876b22 authored by Baber's avatar Baber
Browse files

`check_gold_index_error` util; fix `process_results`; rm generate_until multiple-choice

parent d19bd889
from __future__ import annotations
import logging
import warnings
from collections.abc import Iterable, Sequence
from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union
from typing import TYPE_CHECKING, Any
import datasets
......@@ -18,9 +21,9 @@ class ContextSampler:
def __init__(
self,
docs: list[dict],
task: Union["Task", "ConfigurableTask"],
fewshot_indices: Optional[Iterable] = None,
rnd: Optional["Random"] = None,
task: Task | ConfigurableTask,
fewshot_indices: Iterable | None = None,
rnd: Random | None = None,
) -> None:
self.rnd = rnd
if not self.rnd:
......@@ -75,7 +78,7 @@ class ContextSampler:
)
self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None):
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else ""
n_samples = (
......@@ -95,10 +98,13 @@ class ContextSampler:
for doc in selected_docs:
doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc)
if self.config.doc_to_choice is None or isinstance(doc_content, str):
if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content
else:
labeled_examples += self.doc_to_choice(doc)[doc_content]
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
......@@ -126,7 +132,7 @@ class ContextSampler:
doc: dict,
num_fewshot: int,
fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None,
gen_prefix: str | None = None,
):
# TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else ""
......@@ -181,16 +187,22 @@ class ContextSampler:
return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]:
"""
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
"""
assert self.rnd is not None, (
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return self.rnd.sample(self.docs, n)
class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict]:
def sample(self, n: int) -> Sequence[dict[str, Any]]:
"""
Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
......@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int):
"""
TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random?
"""
pass
raise NotImplementedError
class ManualSampler(ContextSampler):
def sample(self, n: int) -> None:
def sample(self, n: int):
""" """
pass
raise NotImplementedError
SAMPLER_REGISTRY = {
SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
"default": ContextSampler,
"first_n": FirstNSampler,
}
......@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def get_sampler(name: str):
try:
return SAMPLER_REGISTRY[name]
except KeyError:
raise ValueError(
except KeyError as e:
raise KeyError(
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
)
) from e
......@@ -21,6 +21,7 @@ from typing_extensions import deprecated
from lm_eval import utils
from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig
......@@ -380,7 +381,7 @@ class Task(abc.ABC):
pass
@abc.abstractmethod
def process_results(self, doc: dict, results: list):
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
"""Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of
the metric for that one document
......@@ -390,7 +391,7 @@ class Task(abc.ABC):
:param results:
The results of the requests created in construct_requests.
"""
pass
raise NotImplementedError
@deprecated("not used anymore")
def aggregation(self):
......@@ -955,11 +956,13 @@ class ConfigurableTask(Task):
def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"):
if hasattr(self, "_filters") and self._instances:
for f in self._filters:
f.ensemble.apply(self._instances)
else:
eval_logger.warning("No filter defined, passing through instances")
eval_logger.warning(
"No filter defined or instances found. Passing through instances"
)
return self._instances
def should_decontaminate(self):
......@@ -993,13 +996,12 @@ class ConfigurableTask(Task):
"""
return doc
def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None):
def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str:
# if self.prompt is not None:
# doc_to_text = self.prompt
if doc_to_text is not None:
doc_to_text = doc_to_text
else:
doc_to_text = self.config.doc_to_text
doc_to_text = doc_to_text or self.config.doc_to_text
if isinstance(doc_to_text, int):
return doc_to_text
......@@ -1261,7 +1263,7 @@ class ConfigurableTask(Task):
**kwargs,
)
def process_results(self, doc: dict, results: list) -> dict:
def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results):
return self.config.process_results(doc, results)
......@@ -1275,9 +1277,12 @@ class ConfigurableTask(Task):
**({"acc": int(is_greedy)} if "acc" in use_metric else {}),
}
elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results
_words = self.count_words(self.doc_to_target(doc))
_bytes = self.count_bytes(self.doc_to_target(doc))
(loglikelihood, *_) = results
assert isinstance(_target := self.doc_to_target(doc), str), (
"Require target to be a string for loglikelihood_rolling"
)
_words = self.count_words(_target)
_bytes = self.count_bytes(_target)
return {
**(
{"word_perplexity": (loglikelihood, _words)}
......@@ -1322,19 +1327,7 @@ class ConfigurableTask(Task):
else:
gold = self.doc_to_target(doc)
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
gold, gold_index_error = check_gold_index_error(choices, gold)
if gold_index_error:
eval_logger.warning(
......@@ -1382,11 +1375,6 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc)
result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
for metric in self._metric_fn_list:
try:
result_score = self._metric_fn_list[metric](
......
from __future__ import annotations
def check_gold_index_error(
choices: list[int] | list[str], gold: list[int] | int | str
) -> tuple[int | list[int], bool]:
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
return gold, gold_index_error
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
return gold, gold_index_error
......@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
from collections.abc import Iterable
from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable
from typing import TYPE_CHECKING, Any, Callable
from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType
......@@ -45,7 +45,9 @@ class FewshotConfig:
split: str | None = None
sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict]], Iterable[dict]] | None = None
process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None
)
fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False)
......@@ -82,7 +84,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list"
)
def get_docs(self, dataset) -> Iterable[dict] | None:
def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None:
"""Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset)
if raw_docs is None:
......@@ -93,7 +95,7 @@ class FewshotConfig:
return raw_docs
@property
def get_sampler(self):
def get_sampler(self) -> Callable[..., Any] | None:
from lm_eval.api import samplers
if isinstance(self.sampler, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment