Commit 0a39d055 authored by lintangsutawika's avatar lintangsutawika
Browse files

format

parent 2b7d8c2d
......@@ -111,9 +111,9 @@ def ter(items):
def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
gold = list(gold)
gold_one_hot = np.eye(np.max(gold)+1)[gold]
gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot)**2, axis=1))
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric(
......
......@@ -1073,7 +1073,11 @@ class ConfigurableTask(Task):
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**({"brier_score": (gold, prob_norm)} if "brier_score" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
}
if "acc_mutual_info" in use_metric:
......
......@@ -2,9 +2,13 @@ from textdistance import levenshtein
from transformers import AutoTokenizer
# Change this tokenizer to fit with the model you are using.
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b", max_new_tokens=128)
def token_edit_distance(references, predictions, **kwargs):
print(references)
print(predictions)
print("###")
ref_tokens = tokenizer.encode(references[0])
pred_tokens = tokenizer.encode(predictions[0])
return levenshtein.distance(ref_tokens, pred_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment