Commit d83f7eb0 authored by Baber's avatar Baber
Browse files

add type hints

parent a617e184
from __future__ import annotations
import logging import logging
import math import math
import os import os
import random import random
import re import re
import string import string
from collections.abc import Iterable from collections.abc import Iterable, Sequence
from typing import Callable, List, Optional, Sequence, TypeVar from typing import Callable, Generic, TypeVar
import numpy as np import numpy as np
...@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float: ...@@ -31,7 +33,7 @@ def nanmean(arr: list[float]) -> float:
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr: list[float]) -> float: def mean(arr: Sequence[float]) -> float:
return sum(arr) / len(arr) return sum(arr) / len(arr)
...@@ -70,7 +72,7 @@ def f1_score(items): ...@@ -70,7 +72,7 @@ def f1_score(items):
@register_aggregation("matthews_corrcoef") @register_aggregation("matthews_corrcoef")
def matthews_corrcoef(items): def matthews_corrcoef(items: Iterable[tuple[int, int] | tuple[str, str]]) -> float:
from sklearn.metrics import matthews_corrcoef from sklearn.metrics import matthews_corrcoef
unzipped_list = list(zip(*items)) unzipped_list = list(zip(*items))
...@@ -80,7 +82,7 @@ def matthews_corrcoef(items): ...@@ -80,7 +82,7 @@ def matthews_corrcoef(items):
@register_aggregation("bleu") @register_aggregation("bleu")
def bleu(items): def bleu(items: Iterable[tuple[str, str]]):
"""The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric """The Bilingual Evaluation Understudy Score, or BLEU for short, is a metric
for evaluating a generated sentence to a reference sentence. It counts matching for evaluating a generated sentence to a reference sentence. It counts matching
n-grams in the candidate translation to n-grams in the reference text, where n-grams in the candidate translation to n-grams in the reference text, where
...@@ -117,7 +119,7 @@ def chrf(items): ...@@ -117,7 +119,7 @@ def chrf(items):
@register_aggregation("ter") @register_aggregation("ter")
def ter(items): def ter(items: Iterable[tuple[str, str]]):
"""Translation Error Rate is an error metric for machine translation that """Translation Error Rate is an error metric for machine translation that
measures the number of edits required to change a system output into one measures the number of edits required to change a system output into one
of the references of the references
...@@ -135,7 +137,9 @@ def ter(items): ...@@ -135,7 +137,9 @@ def ter(items):
@register_aggregation("brier_score") @register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function def brier_score(
items: Iterable[tuple[str, float]],
): # This is a passthrough function
gold, predictions = list(zip(*items)) gold, predictions = list(zip(*items))
bs, num_class = np.array(predictions).shape bs, num_class = np.array(predictions).shape
...@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function ...@@ -266,7 +270,7 @@ def perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="weighted_perplexity", aggregation="weighted_perplexity",
) )
def word_perplexity_fn(items): # This is a passthrough function def word_perplexity_fn(items: T) -> T: # This is a passthrough function
return items return items
...@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function ...@@ -276,7 +280,7 @@ def word_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="weighted_perplexity", aggregation="weighted_perplexity",
) )
def byte_perplexity_fn(items): # This is a passthrough function def byte_perplexity_fn(items: T) -> T: # This is a passthrough function
return items return items
...@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function ...@@ -286,7 +290,7 @@ def byte_perplexity_fn(items): # This is a passthrough function
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="bits_per_byte", aggregation="bits_per_byte",
) )
def bits_per_byte_fn(items): # This is a passthrough function def bits_per_byte_fn(items: T) -> T: # This is a passthrough function
return items return items
...@@ -295,7 +299,7 @@ def pop_stddev(arr): ...@@ -295,7 +299,7 @@ def pop_stddev(arr):
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / len(arr))
def sample_stddev(arr: Sequence[T]) -> float: def sample_stddev(arr: Sequence[float]) -> float:
mu = mean(arr) mu = mean(arr)
return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1)) return math.sqrt(sum([(x - mu) ** 2 for x in arr]) / (len(arr) - 1))
...@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths): ...@@ -416,7 +420,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths) return max(scores_for_ground_truths)
def weighted_mean(items: List[tuple[float, float]]) -> float: def weighted_mean(items: list[tuple[float, float]]) -> float:
a, b = zip(*items) a, b = zip(*items)
return sum(a) / sum(b) return sum(a) / sum(b)
...@@ -427,15 +431,15 @@ def is_non_str_iterable(obj): ...@@ -427,15 +431,15 @@ def is_non_str_iterable(obj):
def _sacreformat(refs, preds): def _sacreformat(refs, preds):
"""Format refs and preds for sacrebleu corpus calculation. It is very particular""" """Format refs and preds for sacrebleu corpus calculation. It is very particular"""
# Sacrebleu expects (List[str], List[List[str]) # Sacrebleu expects (list[str], list[list[str])
# e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...]) # e.g. sacrebleu.corpus_bleu([pred_t], [[ref1_stream], [ref2_stream], ...])
# Note [ref1_stream] is the first reference for each pred. # Note [ref1_stream] is the first reference for each pred.
# So lists are size N and (M, N) for N preds and M possible refs for each pred # So lists are size N and (M, N) for N preds and M possible refs for each pred
# This is a different order of dimensions that I would expect # This is a different order of dimensions that I would expect
# We expect refs to be List[str] or List[List[str]], the outer list corresponding to preds # We expect refs to be list[str] or list[list[str]], the outer list corresponding to preds
# Must become List[List[str]] with the inner list corresponding to preds # Must become list[list[str]] with the inner list corresponding to preds
if not is_non_str_iterable(refs): if not is_non_str_iterable(refs):
refs = list(refs) refs = list(refs)
if not is_non_str_iterable(refs[0]): if not is_non_str_iterable(refs[0]):
...@@ -443,7 +447,7 @@ def _sacreformat(refs, preds): ...@@ -443,7 +447,7 @@ def _sacreformat(refs, preds):
refs = list(zip(*refs)) refs = list(zip(*refs))
# Note the number of refs in each ref list much match the number of preds # Note the number of refs in each ref list much match the number of preds
# We expect preds to be List[str] or List[List[str]]. Must become List[str] # We expect preds to be list[str] or list[list[str]]. Must become list[str]
if not is_non_str_iterable(preds): if not is_non_str_iterable(preds):
preds = list(preds) preds = list(preds)
if is_non_str_iterable(preds[0]): if is_non_str_iterable(preds[0]):
...@@ -456,7 +460,7 @@ def _sacreformat(refs, preds): ...@@ -456,7 +460,7 @@ def _sacreformat(refs, preds):
# stderr stuff # stderr stuff
class _bootstrap_internal: class _bootstrap_internal(Generic[T]):
""" """
Pool worker: `(i, xs)` → `n` bootstrap replicates Pool worker: `(i, xs)` → `n` bootstrap replicates
of `f(xs)`using a RNG seeded with `i`. of `f(xs)`using a RNG seeded with `i`.
...@@ -539,7 +543,7 @@ def bootstrap_stderr( ...@@ -539,7 +543,7 @@ def bootstrap_stderr(
def stderr_for_metric( def stderr_for_metric(
metric: Callable[[Sequence[T]], float], bootstrap_iters: int metric: Callable[[Sequence[T]], float], bootstrap_iters: int
) -> Optional[Callable[[Sequence[T]], float]]: ) -> Callable[[Sequence[T]], float] | None:
""" """
Return a function that estimates the standard error of `metric(xs)`. Return a function that estimates the standard error of `metric(xs)`.
...@@ -569,10 +573,10 @@ def stderr_for_metric( ...@@ -569,10 +573,10 @@ def stderr_for_metric(
stderr = {mean: mean_stderr, acc_all: acc_all_stderr} stderr = {mean: mean_stderr, acc_all: acc_all_stderr}
return stderr.get(metric, None) return stderr.get(metric)
def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): def pooled_sample_stderr(stderrs: list[float], sizes: list[int]):
# Used to aggregate bootstrapped stderrs across subtasks in a group, # Used to aggregate bootstrapped stderrs across subtasks in a group,
# when we are weighting by the size of each subtask. # when we are weighting by the size of each subtask.
# #
...@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]): ...@@ -590,7 +594,7 @@ def pooled_sample_stderr(stderrs: List[float], sizes: List[int]):
return np.sqrt(pooled_sample_var / sum(sizes)) return np.sqrt(pooled_sample_var / sum(sizes))
def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None): def combined_sample_stderr(stderrs: list[float], sizes: list[int], metrics=None):
assert metrics is not None, ( assert metrics is not None, (
"Need to pass a list of each subtask's metric for this stderr aggregation" "Need to pass a list of each subtask's metric for this stderr aggregation"
) )
...@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None) ...@@ -622,7 +626,9 @@ def combined_sample_stderr(stderrs: List[float], sizes: List[int], metrics=None)
return np.sqrt(variance) return np.sqrt(variance)
def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): def aggregate_subtask_metrics(
metrics: list[float], sizes: list[float], weight_by_size: bool = True
):
# A helper function that is used to aggregate # A helper function that is used to aggregate
# subtask scores cross-task. # subtask scores cross-task.
# TODO: does not hold for non-mean aggregations # TODO: does not hold for non-mean aggregations
...@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True): ...@@ -631,4 +637,4 @@ def aggregate_subtask_metrics(metrics, sizes, weight_by_size=True):
assert len(metrics) == len(sizes) assert len(metrics) == len(sizes)
return sum([metric * size for metric, size in zip(metrics, sizes)]) / sum(sizes) return sum(metric * size for metric, size in zip(metrics, sizes)) / sum(sizes)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment