Commit 0a39d055 authored by lintangsutawika's avatar lintangsutawika
Browse files

format

parent 2b7d8c2d
...@@ -111,9 +111,9 @@ def ter(items): ...@@ -111,9 +111,9 @@ def ter(items):
def brier_score(items): # This is a passthrough function def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items)) gold, predictions = list(zip(*items))
gold = list(gold) gold = list(gold)
gold_one_hot = np.eye(np.max(gold)+1)[gold] gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1] predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot)**2, axis=1)) return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric( @register_metric(
......
...@@ -1073,7 +1073,11 @@ class ConfigurableTask(Task): ...@@ -1073,7 +1073,11 @@ class ConfigurableTask(Task):
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}), **({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}), **({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}), **({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**({"brier_score": (gold, prob_norm)} if "brier_score" in use_metric else {}), **(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
} }
if "acc_mutual_info" in use_metric: if "acc_mutual_info" in use_metric:
......
...@@ -2,9 +2,13 @@ from textdistance import levenshtein ...@@ -2,9 +2,13 @@ from textdistance import levenshtein
from transformers import AutoTokenizer from transformers import AutoTokenizer
# Change this tokenizer to fit with the model you are using. # Change this tokenizer to fit with the model you are using.
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b") tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b", max_new_tokens=128)
def token_edit_distance(references, predictions, **kwargs): def token_edit_distance(references, predictions, **kwargs):
print(references)
print(predictions)
print("###")
ref_tokens = tokenizer.encode(references[0]) ref_tokens = tokenizer.encode(references[0])
pred_tokens = tokenizer.encode(predictions[0]) pred_tokens = tokenizer.encode(predictions[0])
return levenshtein.distance(ref_tokens, pred_tokens) return levenshtein.distance(ref_tokens, pred_tokens)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment