Commit 87445e95 authored by Baber's avatar Baber
Browse files

types

parent 5fdeb436
...@@ -7,12 +7,7 @@ import random ...@@ -7,12 +7,7 @@ import random
import re import re
from collections.abc import Callable, Iterable, Iterator, Mapping from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy from copy import deepcopy
from typing import ( from typing import TYPE_CHECKING, Any, Literal, overload
TYPE_CHECKING,
Any,
Literal,
overload,
)
import datasets import datasets
import numpy as np import numpy as np
...@@ -25,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity ...@@ -25,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble from lm_eval.filters import build_filter_ensemble
...@@ -134,6 +129,7 @@ class Task(abc.ABC): ...@@ -134,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD` - `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset. Fresh download and fresh dataset.
""" """
assert self.DATASET_PATH is not None, "DATASET_PATH must be set in Task class"
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.DATASET_PATH, path=self.DATASET_PATH,
name=self.DATASET_NAME, name=self.DATASET_NAME,
...@@ -147,43 +143,40 @@ class Task(abc.ABC): ...@@ -147,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class.""" """Returns the TaskConfig associated with this class."""
return self._config return self._config
@abc.abstractmethod
def has_training_docs(self) -> bool: def has_training_docs(self) -> bool:
"""Whether the task has a training set""" """Whether the task has a training set"""
pass raise NotImplementedError
@abc.abstractmethod
def has_validation_docs(self) -> bool: def has_validation_docs(self) -> bool:
"""Whether the task has a validation set""" """Whether the task has a validation set"""
pass raise NotImplementedError
@abc.abstractmethod
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
"""Whether the task has a test set""" """Whether the task has a test set"""
pass raise NotImplementedError
def training_docs(self) -> Iterable: def training_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def validation_docs(self) -> Iterable: def validation_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def test_docs(self) -> Iterable: def test_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
""" """
return [] return []
def fewshot_docs(self) -> Iterable: def fewshot_docs(self) -> DataSet | None:
""" """
:return: Iterable[obj] :return: Iterable[obj]
A iterable of any object, that doc_to_text can handle A iterable of any object, that doc_to_text can handle
...@@ -587,7 +580,7 @@ class ConfigurableTask(Task): ...@@ -587,7 +580,7 @@ class ConfigurableTask(Task):
data_dir=None, data_dir=None,
cache_dir=None, cache_dir=None,
download_mode=None, download_mode=None,
config: dict | None = None, config: Mapping[str, Any] | None = None,
) -> None: ) -> None:
# Get pre-configured attributes # Get pre-configured attributes
self._config = self.CONFIG self._config = self.CONFIG
...@@ -722,6 +715,9 @@ class ConfigurableTask(Task): ...@@ -722,6 +715,9 @@ class ConfigurableTask(Task):
) )
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata)) self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else: else:
assert self.config.dataset_path is not None, (
"dataset_path must be set in TaskConfig"
)
self.dataset = datasets.load_dataset( self.dataset = datasets.load_dataset(
path=self.config.dataset_path, path=self.config.dataset_path,
name=self.config.dataset_name, name=self.config.dataset_name,
...@@ -737,7 +733,7 @@ class ConfigurableTask(Task): ...@@ -737,7 +733,7 @@ class ConfigurableTask(Task):
def has_test_docs(self) -> bool: def has_test_docs(self) -> bool:
return self.config.test_split is not None return self.config.test_split is not None
def training_docs(self) -> datasets.Dataset | None: def training_docs(self) -> DataSet | None:
if self.has_training_docs(): if self.has_training_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -745,7 +741,7 @@ class ConfigurableTask(Task): ...@@ -745,7 +741,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset | None: def validation_docs(self) -> DataSet | None:
if self.has_validation_docs(): if self.has_validation_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -753,7 +749,7 @@ class ConfigurableTask(Task): ...@@ -753,7 +749,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset | None: def test_docs(self) -> DataSet | None:
if self.has_test_docs(): if self.has_test_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
......
...@@ -21,6 +21,7 @@ if TYPE_CHECKING: ...@@ -21,6 +21,7 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__) eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]] DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass @dataclass
...@@ -34,7 +35,7 @@ class RepeatConfig: ...@@ -34,7 +35,7 @@ class RepeatConfig:
@dataclass @dataclass
class FilterConfig: class FilterConfig:
"""Encapsulates information about a single filter.""" """Encapsulates information about a single filter pipeline."""
name: str name: str
ensemble: FilterEnsemble ensemble: FilterEnsemble
...@@ -71,16 +72,17 @@ class FewshotConfig: ...@@ -71,16 +72,17 @@ class FewshotConfig:
"""Check if any fewshot source is configured.""" """Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None return self.split is not None or self.samples is not None
def _get_raw_docs( def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
self, dataset
) -> list[dict] | Callable[[], Iterable[dict[str, Any]]] | None:
"""Get raw documents from configured source.""" """Get raw documents from configured source."""
if self.split is not None: if self.split is not None:
return dataset[self.split] return dataset[self.split]
if self.samples is not None: if self.samples is not None:
if isinstance(self.samples, list) or callable(self.samples): if isinstance(self.samples, list):
return self.samples return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else: else:
raise TypeError( raise TypeError(
"samples must be either a list of dicts or a callable returning a list" "samples must be either a list of dicts or a callable returning a list"
...@@ -158,7 +160,7 @@ class TaskConfig: ...@@ -158,7 +160,7 @@ class TaskConfig:
fewshot_delimiter: str = "\n\n" fewshot_delimiter: str = "\n\n"
fewshot_config: dict[str, Any] | None = None fewshot_config: dict[str, Any] | None = None
# runtime configuration options # runtime configuration options
num_fewshot: int | None = 0 num_fewshot: int | None = None
generation_kwargs: dict[str, Any] | None = None generation_kwargs: dict[str, Any] | None = None
# scoring options # scoring options
metric_list: list | None = None metric_list: list | None = None
...@@ -359,7 +361,7 @@ class TaskConfig: ...@@ -359,7 +361,7 @@ class TaskConfig:
return x return x
@classmethod @classmethod
def from_yaml(cls, data: dict) -> TaskConfig: def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary.""" """Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data) return cls(**data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment