"tests/implicitron/vscode:/vscode.git/clone" did not exist on "69c6d06ed880ff83419a960aaa20de0e5753f9a6"
Commit d9876b22 authored by Baber's avatar Baber
Browse files

`check_gold_index_error` util; fix `process_results`; rm generate_until multiple-choice

parent d19bd889
from __future__ import annotations
import logging import logging
import warnings import warnings
from collections.abc import Iterable, Sequence
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union from typing import TYPE_CHECKING, Any
import datasets import datasets
...@@ -18,9 +21,9 @@ class ContextSampler: ...@@ -18,9 +21,9 @@ class ContextSampler:
def __init__( def __init__(
self, self,
docs: list[dict], docs: list[dict],
task: Union["Task", "ConfigurableTask"], task: Task | ConfigurableTask,
fewshot_indices: Optional[Iterable] = None, fewshot_indices: Iterable | None = None,
rnd: Optional["Random"] = None, rnd: Random | None = None,
) -> None: ) -> None:
self.rnd = rnd self.rnd = rnd
if not self.rnd: if not self.rnd:
...@@ -75,7 +78,7 @@ class ContextSampler: ...@@ -75,7 +78,7 @@ class ContextSampler:
) )
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None): def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else "" prefix = gen_prefix + " " if gen_prefix else ""
n_samples = ( n_samples = (
...@@ -95,10 +98,13 @@ class ContextSampler: ...@@ -95,10 +98,13 @@ class ContextSampler:
for doc in selected_docs: for doc in selected_docs:
doc_content = self.doc_to_text(doc) doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc) doc_target = self.doc_to_target(doc)
if self.config.doc_to_choice is None or isinstance(doc_content, str): if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content labeled_examples += doc_content
else: else:
labeled_examples += self.doc_to_choice(doc)[doc_content] if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "": if doc_target != "":
if self.target_delimiter.isspace() and str(doc_target)[0].isspace(): if self.target_delimiter.isspace() and str(doc_target)[0].isspace():
...@@ -126,7 +132,7 @@ class ContextSampler: ...@@ -126,7 +132,7 @@ class ContextSampler:
doc: dict, doc: dict,
num_fewshot: int, num_fewshot: int,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None, gen_prefix: str | None = None,
): ):
# TODO: Do we need any other delimiter # TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else "" prefix = gen_prefix + " " if gen_prefix else ""
...@@ -181,16 +187,22 @@ class ContextSampler: ...@@ -181,16 +187,22 @@ class ContextSampler:
return chat_history return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]: def sample(self, n: int) -> Sequence[dict]:
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
""" """
assert self.rnd is not None, (
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return self.rnd.sample(self.docs, n) return self.rnd.sample(self.docs, n)
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict]: def sample(self, n: int) -> Sequence[dict[str, Any]]:
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler): ...@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler): class BalancedSampler(ContextSampler):
def sample(self, n: int) -> None: def sample(self, n: int):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
""" """
pass raise NotImplementedError
class ManualSampler(ContextSampler): class ManualSampler(ContextSampler):
def sample(self, n: int) -> None: def sample(self, n: int):
""" """ """ """
pass raise NotImplementedError
SAMPLER_REGISTRY = { SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
"default": ContextSampler, "default": ContextSampler,
"first_n": FirstNSampler, "first_n": FirstNSampler,
} }
...@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = { ...@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def get_sampler(name: str): def get_sampler(name: str):
try: try:
return SAMPLER_REGISTRY[name] return SAMPLER_REGISTRY[name]
except KeyError: except KeyError as e:
raise ValueError( raise KeyError(
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
) ) from e
...@@ -21,6 +21,7 @@ from typing_extensions import deprecated ...@@ -21,6 +21,7 @@ from typing_extensions import deprecated
from lm_eval import utils from lm_eval import utils
from lm_eval.api.instance import Instance, OutputType from lm_eval.api.instance import Instance, OutputType
from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig from lm_eval.config.task import TaskConfig
...@@ -380,7 +381,7 @@ class Task(abc.ABC): ...@@ -380,7 +381,7 @@ class Task(abc.ABC):
pass pass
@abc.abstractmethod @abc.abstractmethod
def process_results(self, doc: dict, results: list): def process_results(self, doc: dict, results: list) -> dict[str, Any]:
"""Take a single document and the LM results and evaluates, returning a """Take a single document and the LM results and evaluates, returning a
dict where keys are the names of submetrics and values are the values of dict where keys are the names of submetrics and values are the values of
the metric for that one document the metric for that one document
...@@ -390,7 +391,7 @@ class Task(abc.ABC): ...@@ -390,7 +391,7 @@ class Task(abc.ABC):
:param results: :param results:
The results of the requests created in construct_requests. The results of the requests created in construct_requests.
""" """
pass raise NotImplementedError
@deprecated("not used anymore") @deprecated("not used anymore")
def aggregation(self): def aggregation(self):
...@@ -955,11 +956,13 @@ class ConfigurableTask(Task): ...@@ -955,11 +956,13 @@ class ConfigurableTask(Task):
def apply_filters(self) -> list[Instance] | None: def apply_filters(self) -> list[Instance] | None:
"""Iterates over FilterEnsembles and applies them to instances""" """Iterates over FilterEnsembles and applies them to instances"""
if hasattr(self, "_filters"): if hasattr(self, "_filters") and self._instances:
for f in self._filters: for f in self._filters:
f.ensemble.apply(self._instances) f.ensemble.apply(self._instances)
else: else:
eval_logger.warning("No filter defined, passing through instances") eval_logger.warning(
"No filter defined or instances found. Passing through instances"
)
return self._instances return self._instances
def should_decontaminate(self): def should_decontaminate(self):
...@@ -993,13 +996,12 @@ class ConfigurableTask(Task): ...@@ -993,13 +996,12 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None): def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
if doc_to_text is not None: doc_to_text = doc_to_text or self.config.doc_to_text
doc_to_text = doc_to_text
else:
doc_to_text = self.config.doc_to_text
if isinstance(doc_to_text, int): if isinstance(doc_to_text, int):
return doc_to_text return doc_to_text
...@@ -1261,7 +1263,7 @@ class ConfigurableTask(Task): ...@@ -1261,7 +1263,7 @@ class ConfigurableTask(Task):
**kwargs, **kwargs,
) )
def process_results(self, doc: dict, results: list) -> dict: def process_results(self, doc: dict, results: list) -> dict[str, Any]:
if callable(self.config.process_results): if callable(self.config.process_results):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
...@@ -1275,9 +1277,12 @@ class ConfigurableTask(Task): ...@@ -1275,9 +1277,12 @@ class ConfigurableTask(Task):
**({"acc": int(is_greedy)} if "acc" in use_metric else {}), **({"acc": int(is_greedy)} if "acc" in use_metric else {}),
} }
elif self.OUTPUT_TYPE == "loglikelihood_rolling": elif self.OUTPUT_TYPE == "loglikelihood_rolling":
(loglikelihood,) = results (loglikelihood, *_) = results
_words = self.count_words(self.doc_to_target(doc)) assert isinstance(_target := self.doc_to_target(doc), str), (
_bytes = self.count_bytes(self.doc_to_target(doc)) "Require target to be a string for loglikelihood_rolling"
)
_words = self.count_words(_target)
_bytes = self.count_bytes(_target)
return { return {
**( **(
{"word_perplexity": (loglikelihood, _words)} {"word_perplexity": (loglikelihood, _words)}
...@@ -1322,19 +1327,7 @@ class ConfigurableTask(Task): ...@@ -1322,19 +1327,7 @@ class ConfigurableTask(Task):
else: else:
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
gold_index_error = False gold, gold_index_error = check_gold_index_error(choices, gold)
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
if gold_index_error: if gold_index_error:
eval_logger.warning( eval_logger.warning(
...@@ -1382,11 +1375,6 @@ class ConfigurableTask(Task): ...@@ -1382,11 +1375,6 @@ class ConfigurableTask(Task):
elif self.OUTPUT_TYPE == "generate_until": elif self.OUTPUT_TYPE == "generate_until":
gold = self.doc_to_target(doc) gold = self.doc_to_target(doc)
result = results[0] result = results[0]
if self.config.doc_to_choice is not None:
# If you set doc_to_choice,
# it assumes that doc_to_target returns a number.
choices = self.doc_to_choice(doc)
gold = choices[gold]
for metric in self._metric_fn_list: for metric in self._metric_fn_list:
try: try:
result_score = self._metric_fn_list[metric]( result_score = self._metric_fn_list[metric](
......
from __future__ import annotations
def check_gold_index_error(
choices: list[int] | list[str], gold: list[int] | int | str
) -> tuple[int | list[int], bool]:
gold_index_error = False
if isinstance(gold, list):
gold = [i if i < len(choices) else -100 for i in gold]
if -100 in gold:
gold_index_error = True
return gold, gold_index_error
else:
if isinstance(gold, int):
gold = gold if gold < len(choices) else -100
elif isinstance(gold, str):
gold = choices.index(gold) if gold in choices else -100
if gold == -100:
gold_index_error = True
return gold, gold_index_error
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Any, Callable
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -45,7 +45,9 @@ class FewshotConfig: ...@@ -45,7 +45,9 @@ class FewshotConfig:
split: str | None = None split: str | None = None
sampler: str | Callable = "default" sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict]], Iterable[dict]] | None = None process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None
)
fewshot_indices: list[int] | None = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
...@@ -82,7 +84,7 @@ class FewshotConfig: ...@@ -82,7 +84,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Iterable[dict] | None: def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -93,7 +95,7 @@ class FewshotConfig: ...@@ -93,7 +95,7 @@ class FewshotConfig:
return raw_docs return raw_docs
@property @property
def get_sampler(self): def get_sampler(self) -> Callable[..., Any] | None:
from lm_eval.api import samplers from lm_eval.api import samplers
if isinstance(self.sampler, str): if isinstance(self.sampler, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment