Commit 3ba4e897 authored by Baber's avatar Baber
Browse files

update type hints

parent 9b192374
......@@ -21,36 +21,36 @@ def bypass_agg(arr):
@register_aggregation("nanmean")
def nanmean(arr):
def nanmean(arr: list[float]) -> float:
if len(arr) == 0 or all(np.isnan(arr)):
return np.nan
return np.nanmean(arr)
@register_aggregation("mean")
def mean(arr):
def mean(arr: list[float]) -> float:
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
def median(arr: list[float]) -> float:
return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
def perplexity(items):
def perplexity(items: list[float]) -> float:
return math.exp(-mean(items))
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
def weighted_perplexity(items: list[tuple[float, float]]) -> float:
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
def bits_per_byte(items: list[tuple[float, float]]) -> float:
return -weighted_mean(items) / math.log(2)
......@@ -413,7 +413,7 @@ def metric_max_over_ground_truths(metric_fn, prediction, ground_truths):
return max(scores_for_ground_truths)
def weighted_mean(items):
def weighted_mean(items: List[tuple[float, float]]) -> float:
a, b = zip(*items)
return sum(a) / sum(b)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment