"docs/source/vscode:/vscode.git/clone" did not exist on "b52f7756fbcf6669dbe92e97e11415c4084cf881"
Commit e7cd7d68 authored by lintangsutawika's avatar lintangsutawika
Browse files

sample metrics that have both sample-wise and set-wise operations

parent 08fcf1fe
......@@ -13,6 +13,30 @@ import logging
eval_logger = logging.getLogger("lm-eval")
class BaseMetric:
def __init__(
self,
aggregation=None,
) -> None:
self.aggregation = aggregation
def __call__(self, *items):
sample_wise_score = self.sample_wise_compute(*items)
if self.aggregation is not None:
return self.aggregation(sample_wise_score)
else:
return self.set_wise_compute(sample_wise_score)
def sample_wise_compute(self, *items):
return items
def set_wise_compute(self, *items):
return items
# Register Aggregations First
@register_aggregation("mean")
def mean(arr):
......@@ -24,21 +48,28 @@ def median(arr):
return arr[len(arr) // 2]
# Certain metrics must be calculated across all documents in a benchmark.
# We use them as aggregation metrics, paired with no-op passthrough metric fns.
@register_aggregation("perplexity")
def perplexity(items):
return math.exp(-mean(items))
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
)
class PerplexityMetric(BaseMetric):
def sample_wise_compute(self, ll, is_greedy):
return ll
@register_aggregation("weighted_perplexity")
def weighted_perplexity(items):
return math.exp(-weighted_mean(items))
def set_wise_compute(self, items):
return math.exp(-mean(items))
@register_aggregation("bits_per_byte")
def bits_per_byte(items):
return -weighted_mean(items) / math.log(2)
@register_metric(
metric="acc",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
class LoglikelihoodAccMetric(BaseMetric):
def __call__(self, ll, is_greedy):
return int(is_greedy)
@register_aggregation("f1")
......@@ -109,87 +140,86 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score
@register_metric(
metric="acc",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_fn(items): # This is a passthrough function
return items
# @register_metric(
# metric="acc",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_fn(items): # This is a passthrough function
# return items
@register_metric(
metric="acc_norm",
higher_is_better=True,
output_type=["loglikelihood", "multiple_choice"],
aggregation="mean",
)
def acc_norm_fn(items): # This is a passthrough function
return items
# @register_metric(
# metric="acc_norm",
# higher_is_better=True,
# output_type=["loglikelihood", "multiple_choice"],
# aggregation="mean",
# )
# def acc_norm_fn(items): # This is a passthrough function
# return items
@register_metric(
metric="acc_mutual_info",
higher_is_better=True,
output_type="multiple_choice",
aggregation="mean",
)
def acc_mutual_info_fn(items): # This is a passthrough function
return items
# @register_metric(
# metric="acc_mutual_info",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="mean",
# )
# def acc_mutual_info_fn(items): # This is a passthrough function
# return items
exact_match = evaluate.load("exact_match")
@register_metric(
metric="exact_match",
higher_is_better=True,
output_type="generate_until",
aggregation="mean",
)
def exact_match_fn(**kwargs):
return exact_match.compute(**kwargs)
@register_metric(
metric="perplexity",
higher_is_better=False,
output_type="loglikelihood",
aggregation="perplexity",
)
def perplexity_fn(items): # This is a passthrough function
return items
# @register_metric(
# metric="exact_match",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="mean",
# )
# def exact_match_fn(**kwargs):
# return exact_match.compute(**kwargs)
@register_metric(
metric="word_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def word_perplexity_fn(items): # This is a passthrough function
return items
class BytePerplexityMetric(BaseMetric):
def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _words
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
@register_metric(
metric="byte_perplexity",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="weighted_perplexity",
)
def byte_perplexity_fn(items): # This is a passthrough function
return items
class BytePerplexityMetric(BaseMetric):
def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _bytes
def set_wise_compute(self, items):
return math.exp(-weighted_mean(items))
@register_metric(
metric="bits_per_byte",
higher_is_better=False,
output_type="loglikelihood_rolling",
aggregation="bits_per_byte",
)
def bits_per_byte_fn(items): # This is a passthrough function
return items
class BitsPerByteMetric(BaseMetric):
def sample_wise_compute(self, loglikelihood, _words, _bytes):
return loglikelihood, _bytes
def set_wise_compute(self, items):
return -weighted_mean(items) / math.log(2)
def pop_stddev(arr):
......@@ -206,79 +236,79 @@ def mean_stderr(arr):
return sample_stddev(arr) / math.sqrt(len(arr))
@register_metric(
metric="mcc",
higher_is_better=True,
output_type="multiple_choice",
aggregation="matthews_corrcoef",
)
def mcc_fn(items): # This is a passthrough function
return items
@register_metric(
metric="f1",
higher_is_better=True,
output_type="multiple_choice",
aggregation="f1",
)
def f1_fn(items): # This is a passthrough function
return items
@register_metric(
metric="bleu",
higher_is_better=True,
output_type="generate_until",
aggregation="bleu",
)
def bleu_fn(items): # This is a passthrough function
return items
@register_metric(
metric="chrf",
higher_is_better=True,
output_type="generate_until",
aggregation="chrf",
)
def chrf_fn(items): # This is a passthrough function
return items
@register_metric(
metric="ter",
higher_is_better=True,
output_type="generate_until",
aggregation="ter",
)
def ter_fn(items): # This is a passthrough function
return items
@register_metric(
metric="acc_all",
higher_is_better=True,
output_type="loglikelihood",
aggregation="mean",
)
def acc_all(items):
# Only count as correct if all answers are labeled correctly for each question
question_scoring_dict = {}
preds = list(zip(*items))[0]
docs = list(zip(*items))[1]
for doc, pred in zip(docs, preds):
paragraph_id = doc["idx"]["paragraph"]
question_id = doc["idx"]["question"]
if (paragraph_id, question_id) not in question_scoring_dict:
question_scoring_dict[(paragraph_id, question_id)] = []
gold_label = doc["label"] == 1
question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
return acc
# @register_metric(
# metric="mcc",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="matthews_corrcoef",
# )
# def mcc_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="f1",
# higher_is_better=True,
# output_type="multiple_choice",
# aggregation="f1",
# )
# def f1_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="bleu",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="bleu",
# )
# def bleu_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="chrf",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="chrf",
# )
# def chrf_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="ter",
# higher_is_better=True,
# output_type="generate_until",
# aggregation="ter",
# )
# def ter_fn(items): # This is a passthrough function
# return items
# @register_metric(
# metric="acc_all",
# higher_is_better=True,
# output_type="loglikelihood",
# aggregation="mean",
# )
# def acc_all(items):
# # Only count as correct if all answers are labeled correctly for each question
# question_scoring_dict = {}
# preds = list(zip(*items))[0]
# docs = list(zip(*items))[1]
# for doc, pred in zip(docs, preds):
# paragraph_id = doc["idx"]["paragraph"]
# question_id = doc["idx"]["question"]
# if (paragraph_id, question_id) not in question_scoring_dict:
# question_scoring_dict[(paragraph_id, question_id)] = []
# gold_label = doc["label"] == 1
# question_scoring_dict[(paragraph_id, question_id)].append(gold_label == pred)
# acc = np.mean([int(all(x)) for x in question_scoring_dict.values()])
# return acc
def acc_all_stderr(items):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment