Commit c6a91582 authored by lintangsutawika's avatar lintangsutawika
Browse files

update

parent a808c661
import math import math
from collections.abc import Iterable from collections.abc import Iterable
import abc
import numpy as np import numpy as np
import sacrebleu import sacrebleu
import sklearn.metrics import sklearn.metrics
...@@ -17,33 +18,21 @@ eval_logger = logging.getLogger("lm-eval") ...@@ -17,33 +18,21 @@ eval_logger = logging.getLogger("lm-eval")
class BaseMetric: class BaseMetric:
def __init__( def __init__(
self, self,
aggregation=None,
) -> None: ) -> None:
self.aggregation = aggregation
def __call__(self, *items): @abc.abstractmethod
def update(self, *items):
pass
sample_wise_score = self.sample_wise_compute(*items) @abc.abstractmethod
def compute(self, *items):
pass
if self.aggregation is not None:
return self.aggregation(sample_wise_score)
else:
return self.set_wise_compute(sample_wise_score)
def sample_wise_compute(self, *items):
return items
def set_wise_compute(self, *items):
return items
# Register Aggregations First
@register_aggregation("mean")
def mean(arr): def mean(arr):
return sum(arr) / len(arr) return sum(arr) / len(arr)
@register_aggregation("median")
def median(arr): def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
...@@ -54,10 +43,10 @@ def median(arr): ...@@ -54,10 +43,10 @@ def median(arr):
output_type="loglikelihood", output_type="loglikelihood",
) )
class PerplexityMetric(BaseMetric): class PerplexityMetric(BaseMetric):
def sample_wise_compute(self, ll, is_greedy): def update(self, ll, is_greedy):
return ll return ll
def set_wise_compute(self, items): def compute(self, items):
return math.exp(-mean(items)) return math.exp(-mean(items))
...@@ -65,12 +54,13 @@ class PerplexityMetric(BaseMetric): ...@@ -65,12 +54,13 @@ class PerplexityMetric(BaseMetric):
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
output_type="loglikelihood", output_type="loglikelihood",
aggregation="mean",
) )
class LoglikelihoodAccMetric(BaseMetric): class LoglikelihoodAccMetric(BaseMetric):
def __call__(self, ll, is_greedy): def update(self, ll, is_greedy):
return int(is_greedy) return int(is_greedy)
def compute(self, items):
return math.exp(-mean(items))
@register_aggregation("f1") @register_aggregation("f1")
def f1_score(items): def f1_score(items):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment