"vscode:/vscode.git/clone" did not exist on "1a8706c8b94918915bcaa44ddbc9e29a0cfea3b2"
Commit bbf79d44 authored by Baber's avatar Baber
Browse files

update type hints

parent 7f7872c1
...@@ -24,36 +24,36 @@ def bypass_agg(arr): ...@@ -24,36 +24,36 @@ def bypass_agg(arr):
@register_aggregation("nanmean") @register_aggregation("nanmean")
def nanmean(arr): def nanmean(arr: list[float]) -> float:
if len(arr) == 0 or all(np.isnan(arr)): if len(arr) == 0 or all(np.isnan(arr)):
return np.nan return np.nan
return np.nanmean(arr) return np.nanmean(arr)
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr: list[float]) -> float:
return sum(arr) / len(arr) return sum(arr) / len(arr)
@register_aggregation("median") @register_aggregation("median")
def median(arr): def median(arr: list[float]) -> float:
return arr[len(arr) // 2] return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark. # Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns. # We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity") @register_aggregation("perplexity")
def perplexity(items): def perplexity(items: list[float]) -> float:
return math.exp(-mean(items)) return math.exp(-mean(items))
@register_aggregation("weighted_perplexity") @register_aggregation("weighted_perplexity")
def weighted_perplexity(items): def weighted_perplexity(items: list[tuple[float, float]]) -> float:
return math.exp(-weighted_mean(items)) return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte") @register_aggregation("bits_per_byte")
def bits_per_byte(items): def bits_per_byte(items: list[tuple[float, float]]) -> float:
return -weighted_mean(items) / math.log(2) return -weighted_mean(items) / math.log(2)
...@@ -416,7 +416,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -416,7 +416,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
def weighted_mean(items): def weighted_mean(items: List[tuple[float, float]]) -> float:
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
......
...@@ -63,11 +63,10 @@ class MetricConfig: ...@@ -63,11 +63,10 @@ class MetricConfig:
aggregation_fn: Optional[Callable] = None aggregation_fn: Optional[Callable] = None
higher_is_better: bool = True higher_is_better: bool = True
hf_evaluate: bool = False hf_evaluate: bool = False
sample_metric: bool = True
is_elementwise: bool = True is_elementwise: bool = True
@cached_property @cached_property
def metric_names(self) -> str: def metric_name(self) -> str:
return self.name return self.name
@cached_property @cached_property
...@@ -82,6 +81,12 @@ class MetricConfig: ...@@ -82,6 +81,12 @@ class MetricConfig:
return is_higher_better(self.name) return is_higher_better(self.name)
return self.higher_is_better return self.higher_is_better
def calculate_metric(self, *args, **kwargs) -> Any:
"""Calculates the metric using the provided function and arguments."""
if self.fn is None:
raise ValueError(f"Metric function for {self.name} is not defined.")
return self.fn(*args, **{**self.kwargs, **kwargs})
@dataclass @dataclass
class RepeatConfig: class RepeatConfig:
...@@ -108,6 +113,16 @@ class FewshotConfig: ...@@ -108,6 +113,16 @@ class FewshotConfig:
process_docs: Optional[Callable] = None process_docs: Optional[Callable] = None
@dataclass
class DatasetConfig:
"""Encapsulates information about a dataset."""
dataset_path: Optional[str] = None
dataset_name: Optional[str] = None
dataset_kwargs: Optional[dict] = None
custom_dataset: Optional[Callable] = None
@dataclass @dataclass
class TaskConfig(dict): class TaskConfig(dict):
# task naming/registry # task naming/registry
...@@ -132,8 +147,8 @@ class TaskConfig(dict): ...@@ -132,8 +147,8 @@ class TaskConfig(dict):
process_docs: Optional[Callable] = None process_docs: Optional[Callable] = None
doc_to_text: Optional[Union[Callable, str]] = None doc_to_text: Optional[Union[Callable, str]] = None
doc_to_target: Optional[Union[Callable, str]] = None doc_to_target: Optional[Union[Callable, str]] = None
doc_to_image: Union[Callable, str] = None doc_to_image: Union[Callable, str, None] = None
doc_to_audio: Union[Callable, str] = None doc_to_audio: Union[Callable, str, None] = None
unsafe_code: bool = False unsafe_code: bool = False
doc_to_choice: Optional[Union[Callable, str, dict, list]] = None doc_to_choice: Optional[Union[Callable, str, dict, list]] = None
process_results: Optional[Union[Callable, str]] = None process_results: Optional[Union[Callable, str]] = None
...@@ -466,17 +481,17 @@ class Task(abc.ABC): ...@@ -466,17 +481,17 @@ class Task(abc.ABC):
return self._config return self._config
@abc.abstractmethod @abc.abstractmethod
def has_training_docs(self): def has_training_docs(self) -> bool:
"""Whether the task has a training set""" """Whether the task has a training set"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def has_validation_docs(self): def has_validation_docs(self) -> bool:
"""Whether the task has a validation set""" """Whether the task has a validation set"""
pass pass
@abc.abstractmethod @abc.abstractmethod
def has_test_docs(self): def has_test_docs(self) -> bool:
"""Whether the task has a test set""" """Whether the task has a test set"""
pass pass
...@@ -536,7 +551,7 @@ class Task(abc.ABC): ...@@ -536,7 +551,7 @@ class Task(abc.ABC):
""" """
return self._instances return self._instances
def fewshot_examples(self, k, rnd): def fewshot_examples(self, k, rnd) -> Iterable[dict]:
if self._training_docs is None: if self._training_docs is None:
self._training_docs = list(self.training_docs()) self._training_docs = list(self.training_docs())
...@@ -548,11 +563,11 @@ class Task(abc.ABC): ...@@ -548,11 +563,11 @@ class Task(abc.ABC):
) )
@abc.abstractmethod @abc.abstractmethod
def doc_to_text(self, doc): def doc_to_text(self, doc) -> str:
pass pass
@abc.abstractmethod @abc.abstractmethod
def doc_to_target(self, doc): def doc_to_target(self, doc) -> Union[str, int]:
pass pass
# not an abstractmethod because not every language-only task has to implement this # not an abstractmethod because not every language-only task has to implement this
...@@ -562,7 +577,7 @@ class Task(abc.ABC): ...@@ -562,7 +577,7 @@ class Task(abc.ABC):
def doc_to_audio(self, doc): def doc_to_audio(self, doc):
raise NotImplementedError raise NotImplementedError
def doc_to_prefix(self, doc): def doc_to_prefix(self, doc) -> str:
return "" return ""
def build_all_requests( def build_all_requests(
...@@ -734,12 +749,12 @@ class Task(abc.ABC): ...@@ -734,12 +749,12 @@ class Task(abc.ABC):
return getattr(self._config, key, None) return getattr(self._config, key, None)
@classmethod @classmethod
def count_bytes(cls, doc): def count_bytes(cls, doc) -> int:
"""Used for byte-level perplexity metrics in rolling loglikelihood""" """Used for byte-level perplexity metrics in rolling loglikelihood"""
return len(doc.encode("utf-8")) return len(doc.encode("utf-8"))
@classmethod @classmethod
def count_words(cls, doc): def count_words(cls, doc) -> int:
"""Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!""" """Downstream loglikelihood_rolling perplexity tasks with custom word boundaries should override this!"""
return len(re.split(r"\s+", doc)) return len(re.split(r"\s+", doc))
...@@ -853,7 +868,7 @@ class Task(abc.ABC): ...@@ -853,7 +868,7 @@ class Task(abc.ABC):
self.sampler.rnd = self.fewshot_rnd self.sampler.rnd = self.fewshot_rnd
@property @property
def eval_docs(self) -> Union[datasets.Dataset, List[dict]]: def eval_docs(self) -> Union[datasets.Dataset, Iterable[dict]]:
if self.has_test_docs(): if self.has_test_docs():
return self.test_docs() return self.test_docs()
elif self.has_validation_docs(): elif self.has_validation_docs():
...@@ -952,7 +967,7 @@ class ConfigurableTask(Task): ...@@ -952,7 +967,7 @@ class ConfigurableTask(Task):
if self.config.dataset_name is not None: if self.config.dataset_name is not None:
self.DATASET_NAME = self.config.dataset_name self.DATASET_NAME = self.config.dataset_name
self.metric_list: list[MetricConfig] = self._config.get_metrics() self.metric_list: list[MetricConfig] = self.config.get_metrics()
self.download(self.config.dataset_kwargs) self.download(self.config.dataset_kwargs)
self._training_docs = None self._training_docs = None
...@@ -1092,7 +1107,7 @@ class ConfigurableTask(Task): ...@@ -1092,7 +1107,7 @@ class ConfigurableTask(Task):
else: else:
return False return False
def training_docs(self) -> datasets.Dataset: def training_docs(self) -> Optional[datasets.Dataset]:
if self.has_training_docs(): if self.has_training_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -1100,7 +1115,7 @@ class ConfigurableTask(Task): ...@@ -1100,7 +1115,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.training_split] return self.dataset[self.config.training_split]
def validation_docs(self) -> datasets.Dataset: def validation_docs(self) -> Optional[datasets.Dataset]:
if self.has_validation_docs(): if self.has_validation_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs( return self.config.process_docs(
...@@ -1108,7 +1123,7 @@ class ConfigurableTask(Task): ...@@ -1108,7 +1123,7 @@ class ConfigurableTask(Task):
) )
return self.dataset[self.config.validation_split] return self.dataset[self.config.validation_split]
def test_docs(self) -> datasets.Dataset: def test_docs(self) -> Optional[datasets.Dataset]:
if self.has_test_docs(): if self.has_test_docs():
if self.config.process_docs is not None: if self.config.process_docs is not None:
return self.config.process_docs(self.dataset[self.config.test_split]) return self.config.process_docs(self.dataset[self.config.test_split])
...@@ -1174,7 +1189,7 @@ class ConfigurableTask(Task): ...@@ -1174,7 +1189,7 @@ class ConfigurableTask(Task):
fewshot_as_multiturn: bool = False, fewshot_as_multiturn: bool = False,
chat_template: Optional[Callable] = None, chat_template: Optional[Callable] = None,
gen_prefix: Optional[str] = None, gen_prefix: Optional[str] = None,
) -> Union[str, List[str]]: ) -> Union[str, List[str], None]:
"""Returns a fewshot context string that is made up of a prepended description """Returns a fewshot context string that is made up of a prepended description
(if provided), the `num_fewshot` number of examples, and an appended prompt example. (if provided), the `num_fewshot` number of examples, and an appended prompt example.
...@@ -1461,7 +1476,7 @@ class ConfigurableTask(Task): ...@@ -1461,7 +1476,7 @@ class ConfigurableTask(Task):
else: else:
raise TypeError raise TypeError
def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list]: def doc_to_image(self, doc: Any, doc_to_image=None) -> Union[int, str, list, None]:
if doc_to_image is not None: if doc_to_image is not None:
doc_to_image = doc_to_image doc_to_image = doc_to_image
elif self.config.doc_to_image is not None: elif self.config.doc_to_image is not None:
...@@ -1484,7 +1499,7 @@ class ConfigurableTask(Task): ...@@ -1484,7 +1499,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list]: def doc_to_audio(self, doc: Any, doc_to_audio=None) -> Union[int, str, list, None]:
if doc_to_audio is not None: if doc_to_audio is not None:
doc_to_audio = doc_to_audio doc_to_audio = doc_to_audio
elif self.config.doc_to_audio is not None: elif self.config.doc_to_audio is not None:
...@@ -1507,7 +1522,7 @@ class ConfigurableTask(Task): ...@@ -1507,7 +1522,7 @@ class ConfigurableTask(Task):
else: else:
return None return None
def doc_to_prefix(self, doc): def doc_to_prefix(self, doc) -> Optional[str]:
if (gen_prefix := self.config.gen_prefix) is not None: if (gen_prefix := self.config.gen_prefix) is not None:
if gen_prefix in self.features: if gen_prefix in self.features:
return doc[gen_prefix] return doc[gen_prefix]
...@@ -1554,7 +1569,7 @@ class ConfigurableTask(Task): ...@@ -1554,7 +1569,7 @@ class ConfigurableTask(Task):
arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices] arguments = [(ctx, f"{target_delimiter}{cont}") for cont in choices]
# TODO: we should raise a warning telling users this will at most ~2x runtime. # TODO: we should raise a warning telling users this will at most ~2x runtime.
if "acc_mutual_info" in [m.metric_names for m in self.metric_list]: if "acc_mutual_info" in [m.metric_name for m in self.metric_list]:
# if we are calculating multiple choice accuracy # if we are calculating multiple choice accuracy
# using mutual information instead of raw loglikelihood as metric, need unconditional lls. # using mutual information instead of raw loglikelihood as metric, need unconditional lls.
...@@ -1621,7 +1636,7 @@ class ConfigurableTask(Task): ...@@ -1621,7 +1636,7 @@ class ConfigurableTask(Task):
return self.config.process_results(doc, results) return self.config.process_results(doc, results)
result_dict = {} result_dict = {}
use_metric = list(m.metric_names for m in self.metric_list) use_metric = list(m.metric_name for m in self.metric_list)
if self.OUTPUT_TYPE == "loglikelihood": if self.OUTPUT_TYPE == "loglikelihood":
results = results[0] results = results[0]
ll, is_greedy = results ll, is_greedy = results
...@@ -1819,7 +1834,7 @@ class ConfigurableTask(Task): ...@@ -1819,7 +1834,7 @@ class ConfigurableTask(Task):
return getattr(self._config, key, None) return getattr(self._config, key, None)
@property @property
def task_name(self) -> Any: def task_name(self) -> Optional[str]:
return getattr(self.config, "task", None) return getattr(self.config, "task", None)
def __repr__(self): def __repr__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment