Commit c6a91582 authored by lintangsutawika's avatar lintangsutawika
Browse files

update

parent a808c661
import math
from collections.abc import Iterable
import abc
import numpy as np
import sacrebleu
import sklearn.metrics
......@@ -17,33 +18,21 @@ eval_logger = logging.getLogger("lm-eval")
class BaseMetric:
def __init__(
self,
aggregation=None,
) -> None:
self.aggregation = aggregation
def __call__(self, *items):
@abc.abstractmethod
def update(self, *items):
pass
sample_wise_score = self.sample_wise_compute(*items)
@abc.abstractmethod
def compute(self, *items):
pass
if self.aggregation is not None:
return self.aggregation(sample_wise_score)
else:
return self.set_wise_compute(sample_wise_score)
def sample_wise_compute(self, *items):
return items
def set_wise_compute(self, *items):
return items
# Register Aggregations First
@register_aggregation("mean")
def mean(arr):
return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr):
return arr[len(arr) // 2]
......@@ -54,10 +43,10 @@ def median(arr):
output_type="loglikelihood",
)
class PerplexityMetric(BaseMetric):
def sample_wise_compute(self, ll, is_greedy):
def update(self, ll, is_greedy):
return ll
def set_wise_compute(self, items):
def compute(self, items):
return math.exp(-mean(items))
......@@ -65,12 +54,13 @@ class PerplexityMetric(BaseMetric):
metric="acc",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
class LoglikelihoodAccMetric(BaseMetric):
def __call__(self, ll, is_greedy):
def update(self, ll, is_greedy):
return int(is_greedy)
def compute(self, items):
return math.exp(-mean(items))
@register_aggregation("f1")
def f1_score(items):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment