Commit e7cd7d68 authored by lintangsutawika's avatar lintangsutawika
Browse files

sample metrics that have both sample-wise and set-wise operations

parent 08fcf1fe
...@@ -13,6 +13,30 @@ import logging ...@@ -13,6 +13,30 @@ import logging
eval_logger = logging.getLogger("lm-eval") eval_logger = logging.getLogger("lm-eval")
class BaseMetric:
def __init__(
self,
aggregation=None,
) -> None:
self.aggregation = aggregation
def __call__(self, *items):
sample_wise_score = self.sample_wise_compute(*items)
if self.aggregation is not None:
return self.aggregation(sample_wise_score)
else:
return self.set_wise_compute(sample_wise_score)
def sample_wise_compute(self, *items):
return items
def set_wise_compute(self, *items):
return items
# Register Aggregations First # Register Aggregations First
@register_aggregation("mean") @register_aggregation("mean")
def mean(arr): def mean(arr):
...@@ -24,21 +48,28 @@ def median(arr): ...@@ -24,21 +48,28 @@ def median(arr):
return arr[len(arr) // 2] return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark. @register_metric(
# We use them as aggregation metrics, paired with no-op passthrough metric fns. metric="perplexity",
@register_aggregation("perplexity") higher_is_better=False,
def perplexity(items): output_type="loglikelihood",
return math.exp(-mean(items)) )
class PerplexityMetric(BaseMetric):
def sample_wise_compute(self, ll, is_greedy):
return ll
@register_aggregation("weighted_perplexity") def set_wise_compute(self, items):
def weighted_perplexity(items): return math.exp(-mean(items))
return math.exp(-weighted_mean(items))
@register_aggregation("bits_per_byte") @register_metric(
def bits_per_byte(items): metric="acc",
return -weighted_mean(items) / math.log(2) higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
class LoglikelihoodAccMetric(BaseMetric):
def __call__(self, ll, is_greedy):
return int(is_greedy)
@register_aggregation("f1") @register_aggregation("f1")
...@@ -109,87 +140,86 @@ def ter(items): ...@@ -109,87 +140,86 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
@register_metric( # @register_metric(
metric="acc", # metric="acc",
higher_is_better=True, # higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"], # output_type=["loglikelihood", "multiple_choice"],
aggregation="mean", # aggregation="mean",
) # )
def acc_fn(items): # This is a passthrough function # def acc_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="acc_norm", # metric="acc_norm",
higher_is_better=True, # higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"], # output_type=["loglikelihood", "multiple_choice"],
aggregation="mean", # aggregation="mean",
) # )
def acc_norm_fn(items): # This is a passthrough function # def acc_norm_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="acc_mutual_info", # metric="acc_mutual_info",
higher_is_better=True, # higher_is_better=True,
output_type="multiple_choice", # output_type="multiple_choice",
aggregation="mean", # aggregation="mean",
) # )
def acc_mutual_info_fn(items): # This is a passthrough function # def acc_mutual_info_fn(items): # This is a passthrough function
return items # return items
exact_match = evaluate.load("exact_match") exact_match = evaluate.load("exact_match")
@register_metric( # @register_metric(
metric="exact_match", # metric="exact_match",
higher_is_better=True, # higher_is_better=True,
output_type="generate_until", # output_type="generate_until",
aggregation="mean", # aggregation="mean",
) # )
def exact_match_fn(**kwargs): # def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs) # return exact_match.compute(**kwargs)
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items
@register_metric( @register_metric(
metric="word_perplexity", metric="word_perplexity",
higher_is_better=False, higher_is_better=False,
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
) )
def word_perplexity_fn(items): # This is a passthrough function class BytePerplexityMetric(BaseMetric):
return items def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _words
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
@register_metric( @register_metric(
metric="byte_perplexity", metric="byte_perplexity",
higher_is_better=False, higher_is_better=False,
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
) )
def byte_perplexity_fn(items): # This is a passthrough function class BytePerplexityMetric(BaseMetric):
return items def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _bytes
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
@register_metric( @register_metric(
metric="bits_per_byte", metric="bits_per_byte",
higher_is_better=False, higher_is_better=False,
output_type="loglikelihood_rolling", output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
) )
def bits_per_byte_fn(items): # This is a passthrough function class BitsPerByteMetric(BaseMetric):
return items def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _bytes
def set_wise_compute(self, items):
return -weighted_mean(items) / math.log(2)
def pop_stddev(arr): def pop_stddev(arr):
...@@ -206,79 +236,79 @@ def mean_stderr(arr): ...@@ -206,79 +236,79 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr)) return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric( # @register_metric(
metric="mcc", # metric="mcc",
higher_is_better=True, # higher_is_better=True,
output_type="multiple_choice", # output_type="multiple_choice",
aggregation="matthews_corrcoef", # aggregation="matthews_corrcoef",
) # )
def mcc_fn(items): # This is a passthrough function # def mcc_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="f1", # metric="f1",
higher_is_better=True, # higher_is_better=True,
output_type="multiple_choice", # output_type="multiple_choice",
aggregation="f1", # aggregation="f1",
) # )
def f1_fn(items): # This is a passthrough function # def f1_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="bleu", # metric="bleu",
higher_is_better=True, # higher_is_better=True,
output_type="generate_until", # output_type="generate_until",
aggregation="bleu", # aggregation="bleu",
) # )
def bleu_fn(items): # This is a passthrough function # def bleu_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="chrf", # metric="chrf",
higher_is_better=True, # higher_is_better=True,
output_type="generate_until", # output_type="generate_until",
aggregation="chrf", # aggregation="chrf",
) # )
def chrf_fn(items): # This is a passthrough function # def chrf_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="ter", # metric="ter",
higher_is_better=True, # higher_is_better=True,
output_type="generate_until", # output_type="generate_until",
aggregation="ter", # aggregation="ter",
) # )
def ter_fn(items): # This is a passthrough function # def ter_fn(items): # This is a passthrough function
return items # return items
@register_metric( # @register_metric(
metric="acc_all", # metric="acc_all",
higher_is_better=True, # higher_is_better=True,
output_type="loglikelihood", # output_type="loglikelihood",
aggregation="mean", # aggregation="mean",
) # )
def acc_all(items): # def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question # # Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {} # question_scoring_dict = {}
preds = list(zip(*items))[0] # preds = list(zip(*items))[0]
docs = list(zip(*items))[1] # docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds): # for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"] # paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"] # question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict: # if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = [] # question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1 # gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred) # question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()]) # acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc # return acc
def acc_all_stderr(items): def acc_all_stderr(items):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment