Commit 87445e95 authored by Baber's avatar Baber
Browse files

types

parent 5fdeb436
......@@ -7,12 +7,7 @@ import random
import re
from collections.abc import Callable, Iterable, Iterator, Mapping
from copy import deepcopy
from typing import (
TYPE_CHECKING,
Any,
Literal,
overload,
)
from typing import TYPE_CHECKING, Any, Literal, overload
import datasets
import numpy as np
......@@ -25,7 +20,7 @@ from lm_eval.api.metrics import bits_per_byte, mean, weighted_perplexity
from lm_eval.api.utils import check_gold_index_error
from lm_eval.caching.cache import load_from_cache, save_to_cache
from lm_eval.config.metric import MetricConfig
from lm_eval.config.task import TaskConfig
from lm_eval.config.task import DataSet, TaskConfig
from lm_eval.filters import build_filter_ensemble
......@@ -134,6 +129,7 @@ class Task(abc.ABC):
- `datasets.DownloadMode.FORCE_REDOWNLOAD`
Fresh download and fresh dataset.
"""
assert self.DATASET_PATH is not None, "DATASET_PATH must be set in Task class"
self.dataset = datasets.load_dataset(
path=self.DATASET_PATH,
name=self.DATASET_NAME,
......@@ -147,43 +143,40 @@ class Task(abc.ABC):
"""Returns the TaskConfig associated with this class."""
return self._config
@abc.abstractmethod
def has_training_docs(self) -> bool:
"""Whether the task has a training set"""
pass
raise NotImplementedError
@abc.abstractmethod
def has_validation_docs(self) -> bool:
"""Whether the task has a validation set"""
pass
raise NotImplementedError
@abc.abstractmethod
def has_test_docs(self) -> bool:
"""Whether the task has a test set"""
pass
raise NotImplementedError
def training_docs(self) -> Iterable:
def training_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def validation_docs(self) -> Iterable:
def validation_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def test_docs(self) -> Iterable:
def test_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
"""
return []
def fewshot_docs(self) -> Iterable:
def fewshot_docs(self) -> DataSet | None:
"""
:return: Iterable[obj]
A iterable of any object, that doc_to_text can handle
......@@ -587,7 +580,7 @@ class ConfigurableTask(Task):
data_dir=None,
cache_dir=None,
download_mode=None,
config: dict | None = None,
config: Mapping[str, Any] | None = None,
) -> None:
# Get pre-configured attributes
self._config = self.CONFIG
......@@ -722,6 +715,9 @@ class ConfigurableTask(Task):
)
self.dataset = df(**(self.config.dataset_kwargs | self.config.metadata))
else:
assert self.config.dataset_path is not None, (
"dataset_path must be set in TaskConfig"
)
self.dataset = datasets.load_dataset(
path=self.config.dataset_path,
name=self.config.dataset_name,
......@@ -737,7 +733,7 @@ class ConfigurableTask(Task):
def has_test_docs(self) -> bool:
return self.config.test_split is not None
def training_docs(self) -> datasets.Dataset | None:
def training_docs(self) -> DataSet | None:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -745,7 +741,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset | None:
def validation_docs(self) -> DataSet | None:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -753,7 +749,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset | None:
def test_docs(self) -> DataSet | None:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
......
......@@ -21,6 +21,7 @@ if TYPE_CHECKING:
eval_logger = logging.getLogger(__name__)
DataSet = Union[datasets.Dataset, Iterable[dict[str, Any]]]
DSplits = dict[str, DataSet]
@dataclass
......@@ -34,7 +35,7 @@ class RepeatConfig:
@dataclass
class FilterConfig:
"""Encapsulates information about a single filter."""
"""Encapsulates information about a single filter pipeline."""
name: str
ensemble: FilterEnsemble
......@@ -71,16 +72,17 @@ class FewshotConfig:
"""Check if any fewshot source is configured."""
return self.split is not None or self.samples is not None
def _get_raw_docs(
self, dataset
) -> list[dict] | Callable[[], Iterable[dict[str, Any]]] | None:
def _get_raw_docs(self, dataset: DSplits) -> DataSet | None:
"""Get raw documents from configured source."""
if self.split is not None:
return dataset[self.split]
if self.samples is not None:
if isinstance(self.samples, list) or callable(self.samples):
if isinstance(self.samples, list):
return self.samples
elif callable(self.samples):
# If samples is a callable, it should return a list of dicts
return self.samples()
else:
raise TypeError(
"samples must be either a list of dicts or a callable returning a list"
......@@ -158,7 +160,7 @@ class TaskConfig:
fewshot_delimiter: str = "\n\n"
fewshot_config: dict[str, Any] | None = None
# runtime configuration options
num_fewshot: int | None = 0
num_fewshot: int | None = None
generation_kwargs: dict[str, Any] | None = None
# scoring options
metric_list: list | None = None
......@@ -359,7 +361,7 @@ class TaskConfig:
return x
@classmethod
def from_yaml(cls, data: dict) -> TaskConfig:
def from_yaml(cls, data: dict[str, Any]) -> TaskConfig:
"""Create a TaskConfig instance from a YAML-like dictionary."""
return cls(**data)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment