Commit 00048838 authored by Baber's avatar Baber
Browse files

nit

parent 55be51ea
from __future__ import annotations
import logging import logging
import warnings import warnings
from collections.abc import Iterable, Sequence
from functools import partial from functools import partial
from typing import TYPE_CHECKING, Iterable, Optional, Sequence, Union from typing import TYPE_CHECKING, Any
import datasets import datasets
...@@ -18,9 +21,9 @@ class ContextSampler: ...@@ -18,9 +21,9 @@ class ContextSampler:
def __init__( def __init__(
self, self,
docs: list[dict], docs: list[dict],
task: Union["Task", "ConfigurableTask"], task: Task | ConfigurableTask,
fewshot_indices: Optional[Iterable] = None, fewshot_indices: Iterable | None = None,
rnd: Optional["Random"] = None, rnd: Random | None = None,
) -> None: ) -> None:
self.rnd = rnd self.rnd = rnd
if not self.rnd: if not self.rnd:
...@@ -75,7 +78,7 @@ class ContextSampler: ...@@ -75,7 +78,7 @@ class ContextSampler:
) )
self.docs = self.docs.select(fewshot_indices) self.docs = self.docs.select(fewshot_indices)
def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str = None): def get_context(self, doc: dict, num_fewshot: int, gen_prefix: str | None = None):
# draw an extra fewshot sample if using same split as evaluating on # draw an extra fewshot sample if using same split as evaluating on
prefix = gen_prefix + " " if gen_prefix else "" prefix = gen_prefix + " " if gen_prefix else ""
n_samples = ( n_samples = (
...@@ -95,9 +98,12 @@ class ContextSampler: ...@@ -95,9 +98,12 @@ class ContextSampler:
for doc in selected_docs: for doc in selected_docs:
doc_content = self.doc_to_text(doc) doc_content = self.doc_to_text(doc)
doc_target = self.doc_to_target(doc) doc_target = self.doc_to_target(doc)
if self.config.doc_to_choice is None or isinstance(doc_content, str): if (
self.config.doc_to_choice is None and isinstance(doc_content, str)
) or isinstance(doc_content, str):
labeled_examples += doc_content labeled_examples += doc_content
else: else:
if isinstance(doc_content, int):
labeled_examples += self.doc_to_choice(doc)[doc_content] labeled_examples += self.doc_to_choice(doc)[doc_content]
if doc_target != "": if doc_target != "":
...@@ -126,7 +132,7 @@ class ContextSampler: ...@@ -126,7 +132,7 @@ class ContextSampler:
doc: dict, doc: dict,
num_fewshot: int, num_fewshot: int,
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
gen_prefix: Optional[str] = None, gen_prefix: str | None = None,
): ):
# TODO: Do we need any other delimiter # TODO: Do we need any other delimiter
prefix = gen_prefix + " " if gen_prefix else "" prefix = gen_prefix + " " if gen_prefix else ""
...@@ -181,16 +187,22 @@ class ContextSampler: ...@@ -181,16 +187,22 @@ class ContextSampler:
return chat_history return chat_history
# @classmethod
# def from_fewshot_dfg(cls, cfg: FewshotConfig):
# if not
def sample(self, n: int) -> Sequence[dict]: def sample(self, n: int) -> Sequence[dict]:
""" """
Draw `n` samples from our fewshot docs. This method should be overridden by subclasses. Draw `n` samples from our fewshot docs. This method should be overridden by subclasses.
""" """
assert self.rnd is not None, (
"Error: `rnd` must be set to a random.Random instance before sampling."
)
return self.rnd.sample(self.docs, n) return self.rnd.sample(self.docs, n)
class FirstNSampler(ContextSampler): class FirstNSampler(ContextSampler):
def sample(self, n: int) -> Sequence[dict]: def sample(self, n: int) -> Sequence[dict[str, Any]]:
""" """
Draw the first `n` samples in order from the specified split. Draw the first `n` samples in order from the specified split.
Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU. Used for tasks with "canonical" ordered fewshot examples, such as MMLU and CMMLU.
...@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler): ...@@ -202,22 +214,22 @@ class FirstNSampler(ContextSampler):
class BalancedSampler(ContextSampler): class BalancedSampler(ContextSampler):
def sample(self, n: int) -> None: def sample(self, n: int):
""" """
TODO: this should return approximately class-balanced samples from our fewshot examples. TODO: this should return approximately class-balanced samples from our fewshot examples.
TODO: what order should they be in? maybe random? TODO: what order should they be in? maybe random?
""" """
pass raise NotImplementedError
class ManualSampler(ContextSampler): class ManualSampler(ContextSampler):
def sample(self, n: int) -> None: def sample(self, n: int):
""" """ """ """
pass raise NotImplementedError
SAMPLER_REGISTRY = { SAMPLER_REGISTRY: dict[str, type[ContextSampler]] = {
"default": ContextSampler, "default": ContextSampler,
"first_n": FirstNSampler, "first_n": FirstNSampler,
} }
...@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = { ...@@ -226,7 +238,7 @@ SAMPLER_REGISTRY = {
def get_sampler(name: str): def get_sampler(name: str):
try: try:
return SAMPLER_REGISTRY[name] return SAMPLER_REGISTRY[name]
except KeyError: except KeyError as e:
raise ValueError( raise KeyError(
f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}" f"Attempted to use contextsampler '{name}', but no sampling strategy for this name found! Supported model names: {', '.join(SAMPLER_REGISTRY.keys())}"
) ) from e
...@@ -990,13 +990,12 @@ class ConfigurableTask(Task): ...@@ -990,13 +990,12 @@ class ConfigurableTask(Task):
""" """
return doc return doc
def doc_to_text(self, doc: dict, doc_to_text: int | str | Callable | None = None): def doc_to_text(
self, doc: dict, doc_to_text: int | str | Callable[..., str] | None = None
) -> str:
# if self.prompt is not None: # if self.prompt is not None:
# doc_to_text = self.prompt # doc_to_text = self.prompt
if doc_to_text is not None: doc_to_text = doc_to_text or self.config.doc_to_text
doc_to_text = doc_to_text
else:
doc_to_text = self.config.doc_to_text
if isinstance(doc_to_text, int): if isinstance(doc_to_text, int):
return doc_to_text return doc_to_text
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import logging import logging
from collections.abc import Iterable from collections.abc import Iterable
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING, Callable from typing import TYPE_CHECKING, Any, Callable
from lm_eval.api.filter import FilterEnsemble from lm_eval.api.filter import FilterEnsemble
from lm_eval.api.instance import OutputType from lm_eval.api.instance import OutputType
...@@ -45,7 +45,9 @@ class FewshotConfig: ...@@ -45,7 +45,9 @@ class FewshotConfig:
split: str | None = None split: str | None = None
sampler: str | Callable = "default" sampler: str | Callable = "default"
samples: Callable[[], list[dict]] | list[dict] | None = None samples: Callable[[], list[dict]] | list[dict] | None = None
process_docs: Callable[[list[dict]], Iterable[dict]] | None = None process_docs: Callable[[list[dict[str, Any]]], Iterable[dict[str, Any]]] | None = (
None
)
fewshot_indices: list[int] | None = None fewshot_indices: list[int] | None = None
rnd: int = field(init=False, default=False) rnd: int = field(init=False, default=False)
...@@ -82,7 +84,7 @@ class FewshotConfig: ...@@ -82,7 +84,7 @@ class FewshotConfig:
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
) )
def get_docs(self, dataset) -> Iterable[dict] | None: def get_docs(self, dataset) -> Iterable[dict[str, Any]] | None:
"""Get processed documents from configured source.""" """Get processed documents from configured source."""
raw_docs = self._get_raw_docs(dataset) raw_docs = self._get_raw_docs(dataset)
if raw_docs is None: if raw_docs is None:
...@@ -93,7 +95,7 @@ class FewshotConfig: ...@@ -93,7 +95,7 @@ class FewshotConfig:
return raw_docs return raw_docs
@property @property
def get_sampler(self): def get_sampler(self) -> Callable[..., Any] | None:
from lm_eval.api import samplers from lm_eval.api import samplers
if isinstance(self.sampler, str): if isinstance(self.sampler, str):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment