"next_docs/vscode:/vscode.git/clone" did not exist on "fd2f3c58da3f61a80bdd7548721dc8125608940b"
Commit bbf79d44 authored by Baber's avatar Baber
Browse files

update type hints

parent 7f7872c1
......@@ -24,36 +24,36 @@ def bypass_agg(arr):
@register_aggregation("nanmean")
def nanmean(arr):
def nanmean(arr: list[float]) -> float:
if len(arr) == 0 or all(np.isnan(arr)):
return np.nan
return np.nanmean(arr)
@register_aggregation("mean")
def mean(arr):
def mean(arr: list[float]) -> float:
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
def median(arr: list[float]) -> float:
return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
def perplexity(items):
def perplexity(items: list[float]) -> float:
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
def weighted_perplexity(items: list[tuple[float, float]]) -> float:
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
def bits_per_byte(items: list[tuple[float, float]]) -> float:
return -weighted_mean(items) / math.log(2)
......@@ -416,7 +416,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def weighted_mean(items):
def weighted_mean(items: List[tuple[float, float]]) -> float:
a, b = zip(*items)
return sum(a) / sum(b)
......
......@@ -63,11 +63,10 @@ class MetricConfig:
aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True
hf_evaluate: bool = False
sample_metric: bool = True
is_elementwise: bool = True
@cached_property
def metric_names(self) -> str:
def metric_name(self) -> str:
return self.name
@cached_property
......@@ -82,6 +81,12 @@ class MetricConfig:
return is_higher_better(self.name)
return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
@dataclass
class RepeatConfig:
......@@ -108,6 +113,16 @@ class FewshotConfig:
process_docs: Optional[Callable] = None
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
custom_dataset: Optional[Callable] = None
@dataclass
class TaskConfig(dict):
# task naming/registry
......@@ -132,8 +147,8 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None
doc_to_audio: Union[Callable, str] = None
doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str, None] = None
unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None
......@@ -466,17 +481,17 @@ class Task(abc.ABC):
return self._config
@abc.abstractmethod
def has_training_docs(self):
def has_training_docs(self) -> bool:
"""Whether the task has a training set"""
pass
@abc.abstractmethod
def has_validation_docs(self):
def has_validation_docs(self) -> bool:
"""Whether the task has a validation set"""
pass
@abc.abstractmethod
def has_test_docs(self):
def has_test_docs(self) -> bool:
"""Whether the task has a test set"""
pass
......@@ -536,7 +551,7 @@ class Task(abc.ABC):
"""
return self._instances
def fewshot_examples(self, k, rnd):
def fewshot_examples(self, k, rnd) -> Iterable[dict]:
if self._training_docs is None:
self._training_docs = list(self.training_docs())
......@@ -548,11 +563,11 @@ class Task(abc.ABC):
)
@abc.abstractmethod
def doc_to_text(self, doc):
def doc_to_text(self, doc) -> str:
pass
@abc.abstractmethod
def doc_to_target(self, doc):
def doc_to_target(self, doc) -> Union[str, int]:
pass
# not an abstractmethod because not every language-only task has to implement this
......@@ -562,7 +577,7 @@ class Task(abc.ABC):
def doc_to_audio(self, doc):
raise NotImplementedError
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc) -> str:
return ""
def build_all_requests(
......@@ -734,12 +749,12 @@ class Task(abc.ABC):
return getattr(self._config, key, None)
@classmethod
def count_bytes(cls, doc):
def count_bytes(cls, doc) -> int:
"""Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8"))
@classmethod
def count_words(cls, doc):
def count_words(cls, doc) -> int:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc))
......@@ -853,7 +868,7 @@ class Task(abc.ABC):
self.sampler.rnd = self.fewshot_rnd
@property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]:
def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]:
if self.has_test_docs():
return self.test_docs()
elif self.has_validation_docs():
......@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self._config.get_metrics()
self.metric_list: list[MetricConfig] = self.config.get_metrics()
self.download(self.config.dataset_kwargs)
self._training_docs = None
......@@ -1092,7 +1107,7 @@ class ConfigurableTask(Task):
else:
return False
def training_docs(self) -> datasets.Dataset:
def training_docs(self) -> Optional[datasets.Dataset]:
if self.has_training_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -1100,7 +1115,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset:
def validation_docs(self) -> Optional[datasets.Dataset]:
if self.has_validation_docs():
if self.config.process_docs is not None:
return self.config.process_docs(
......@@ -1108,7 +1123,7 @@ class ConfigurableTask(Task):
)
return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset:
def test_docs(self) -> Optional[datasets.Dataset]:
if self.has_test_docs():
if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split])
......@@ -1174,7 +1189,7 @@ class ConfigurableTask(Task):
fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]:
) -> Union[str, List[str], None]:
"""Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example.
......@@ -1461,7 +1476,7 @@ class ConfigurableTask(Task):
else:
raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]:
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None:
doc_to_image = doc_to_image
elif self.config.doc_to_image is not None:
......@@ -1484,7 +1499,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]:
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]:
if doc_to_audio is not None:
doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None:
......@@ -1507,7 +1522,7 @@ class ConfigurableTask(Task):
else:
return None
def doc_to_prefix(self, doc):
def doc_to_prefix(self, doc) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features:
return doc[gen_prefix]
......@@ -1554,7 +1569,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_names for m in self.metric_list]:
if "acc_mutual_info" in [m.metric_name for m in self.metric_list]:
# if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls.
......@@ -1621,7 +1636,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results)
result_dict = {}
use_metric = list(m.metric_names for m in self.metric_list)
use_metric = list(m.metric_name for m in self.metric_list)
if self.OUTPUT_TYPE == "loglikelihood":
results = results[0]
ll, is_greedy = results
......@@ -1819,7 +1834,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None)
@property
def task_name(self) -> Any:
def task_name(self) -> Optional[str]:
return getattr(self.config, "task", None)
def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment