Commit 5cc65a79 authored by lintangsutawika's avatar lintangsutawika
Browse files

fxied brier score

parent d49636a3
...@@ -159,18 +159,6 @@ def acc_mutual_info_fn(items): # This is a passthrough function ...@@ -159,18 +159,6 @@ def acc_mutual_info_fn(items): # This is a passthrough function
exact_match = evaluate.load("exact_match") exact_match = evaluate.load("exact_match")
# @register_metric(
# metric="token_edit_distance",
# higher_is_better=False,
# output_type=["generate_until"],
# aggregation="mean",
# )
# def ted_fn(items): # This is a passthrough function
# references, predictions = items
# return distance(references, predictions)
@register_metric( @register_metric(
metric="exact_match", metric="exact_match",
higher_is_better=True, higher_is_better=True,
......
...@@ -1063,8 +1063,10 @@ class ConfigurableTask(Task): ...@@ -1063,8 +1063,10 @@ class ConfigurableTask(Task):
# TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly # TODO: this gets score of 0 on arc_challenge for pythia-70m. need to test that this works properly
exact_match = int(is_greedy[gold]) if gold != -100 else 0 exact_match = int(is_greedy[gold]) if gold != -100 else 0
prob_norm = [float(i)/sum(lls) for i in lls] prob_norm = utils.softmax(lls)
# TODO use keyword arguments to the metric?
# gold, pred, norm stuff, the original lls,
result_dict = { result_dict = {
**({"acc": acc} if "acc" in use_metric else {}), **({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}), **({"f1": (gold, pred)} if "f1" in use_metric else {}),
......
...@@ -15,6 +15,7 @@ from typing import Iterator, List, Literal, Union ...@@ -15,6 +15,7 @@ from typing import Iterator, List, Literal, Union
import gc import gc
import torch import torch
import transformers import transformers
import numpy as np
from jinja2 import BaseLoader, Environment, StrictUndefined from jinja2 import BaseLoader, Environment, StrictUndefined
from itertools import islice from itertools import islice
...@@ -127,6 +128,12 @@ def pattern_match(patterns, source_list): ...@@ -127,6 +128,12 @@ def pattern_match(patterns, source_list):
return sorted(list(task_names)) return sorted(list(task_names))
def softmax(x):
"""Compute softmax values for each sets of scores in x."""
e_x = np.exp(x - np.max(x))
return e_x / e_x.sum()
def general_detokenize(string): def general_detokenize(string):
string = string.replace(" n't", "n't") string = string.replace(" n't", "n't")
string = string.replace(" )", ")") string = string.replace(" )", ")")
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment