Unverified Commit 967eb4fa authored by Lintang Sutawika's avatar Lintang Sutawika Committed by GitHub
Browse files

[Refactor] Continuous Metrics (#969)

* add brier_score

* process brier_score

* brier score is working for N-sized class

* fxied brier score

* add TED to BigBench and Brier score to MMLU

* format

* Update metrics.py

* Update task.py

* Update generate_until_template_yaml

* Delete lm_eval/tasks/bigbench/aux_metric.py

* Update generate_until_template_yaml

* Update _default_template_yaml
parent c9bbec6e
...@@ -109,6 +109,25 @@ def ter(items): ...@@ -109,6 +109,25 @@ def ter(items):
return sacrebleu.corpus_ter(preds, refs).score return sacrebleu.corpus_ter(preds, refs).score
@register_aggregation("brier_score")
def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
gold = list(gold)
gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric(
metric="brier_score",
higher_is_better=False,
output_type=["multiple_choice"],
aggregation="brier_score",
)
def brier_score_fn(items): # This is a passthrough function
return items
@register_metric( @register_metric(
metric="acc", metric="acc",
higher_is_better=True, higher_is_better=True,
......
...@@ -1095,12 +1095,21 @@ class ConfigurableTask(Task): ...@@ -1095,12 +1095,21 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0 exact_match = int(is_greedy[gold]) if gold != -100 else 0
prob_norm = utils.softmax(lls)
# TODO use keyword arguments to the metric?
# gold, pred, norm stuff, the original lls,
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
......
...@@ -15,6 +15,7 @@ from typing import Iterator, List, Literal, Union, Any, Callable ...@@ -15,6 +15,7 @@ from typing import Iterator, List, Literal, Union, Any, Callable
import gc import gc
import torch import torch
import transformers import transformers
import numpy as np
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
...@@ -167,6 +168,12 @@ def pattern_match(patterns, source_list): ...@@ -167,6 +168,12 @@ def pattern_match(patterns, source_list):
return sorted(list(task_names)) return sorted(list(task_names))
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment