Commit 9b192374 authored by Baber's avatar Baber
Browse files

update type hints

parent cb8dfe63
......@@ -63,11 +63,10 @@ class MetricConfig:
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
sample_metric: bool = True
is_elementwise: bool = True
@cached_property
def metric_names(self) -> str:
def metric_name(self) -> str:
return self.name
@cached_property
......@@ -82,6 +81,12 @@ class MetricConfig:
return is_higher_better(self.name)
return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
@dataclass
class RepeatConfig:
......@@ -108,6 +113,16 @@ class FewshotConfig:
process_docs: Optional[Callable] = None
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
custom_dataset: Optional[Callable] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -132,8 +147,8 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_audio: Union[Callable, str] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
......@@ -466,17 +481,17 @@ class Task(abc.ABC):
return self._config
@abc.abstractmethod
def has_training_docs(self):
def has_training_docs(self) -> bool:
"""Whether the task has a training set"""
pass
@abc.abstractmethod
def has_validation_docs(self):
def has_validation_docs(self) -> bool:
"""Whether the task has a validation set"""
pass
@abc.abstractmethod
def has_test_docs(self):
def has_test_docs(self) -> bool:
"""Whether the task has a test set"""
pass
......@@ -536,7 +551,7 @@ class Task(abc.ABC):
"""
return self._instances
def fewshot_examples(self, k, rnd):
def fewshot_examples(self, k, rnd) -> Iterable[dict]:
if self._training_docs is None:
self._training_docs = list(self.training_docs())
......@@ -548,11 +563,11 @@ class Task(abc.ABC):
)
@abc.abstractmethod
def doc_to_text(self, doc):
def doc_to_text(self, doc) -> str:
pass
@abc.abstractmethod
def doc_to_target(self, doc):
def doc_to_target(self, doc) -> Union[str, int]:
pass
# not an abstractmethod because not every language-only task has to implement this
......@@ -562,7 +577,7 @@ class Task(abc.ABC):
def doc_to_audio(self, doc):
raise NotImplementedError
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc) -> str:
return ""
def build_all_requests(
......@@ -734,12 +749,12 @@ class Task(abc.ABC):
return getattr(self._config, key, None)
@classmethod
def count_bytes(cls, doc):
def count_bytes(cls, doc) -> int:
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
def count_words(cls, doc) -> int:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
......@@ -853,7 +868,7 @@ class Task(abc.ABC):
self.sampler.rnd = self.fewshot_rnd
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
......@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self._config.get_metrics()
self.metric_list: list[MetricConfig] = self.config.get_metrics()
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -1088,7 +1103,7 @@ class ConfigurableTask(Task):
else:
return False
def training_docs(self) -> datasets.Dataset:
def training_docs(self) -> Optional[datasets.Dataset]:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -1096,7 +1111,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset:
def validation_docs(self) -> Optional[datasets.Dataset]:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -1104,7 +1119,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset:
def test_docs(self) -> Optional[datasets.Dataset]:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
......@@ -1170,7 +1185,7 @@ class ConfigurableTask(Task):
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]:
) -> Union[str, List[str], None]:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1457,7 +1472,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1480,7 +1495,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None:
......@@ -1503,7 +1518,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
......@@ -1550,7 +1565,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_names for m in self.metric_list]:
if "acc_mutual_info" in [m.metric_name for m in self.metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1617,7 +1632,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(m.metric_names for m in self.metric_list)
use_metric = list(m.metric_name for m in self.metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
......@@ -1815,7 +1830,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None)
@property
def task_name(self) -> Any:
def task_name(self) -> Optional[str]:
return getattr(self.config, "task", None)
def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment