Commit 0a39d055 authored by lintangsutawika's avatar lintangsutawika
Browse files

format

parent 2b7d8c2d
......@@ -50,7 +50,7 @@ dataset_kwargs: null # any extra keyword arguments that should be passed to the
```
dataset_path: json
dataset_name: null
dataset_kwargs:
dataset_kwargs:
data_files: /path/to/my/json
```
-------------------------------
......
......@@ -111,9 +111,9 @@ def ter(items):
def brier_score(items): # This is a passthrough function
gold, predictions = list(zip(*items))
gold = list(gold)
gold_one_hot = np.eye(np.max(gold)+1)[gold]
gold_one_hot = np.eye(np.max(gold) + 1)[gold]
predictions = list(zip(*items))[1]
return np.mean(np.sum((predictions - gold_one_hot)**2, axis=1))
return np.mean(np.sum((predictions - gold_one_hot) ** 2, axis=1))
@register_metric(
......
......@@ -1066,14 +1066,18 @@ class ConfigurableTask(Task):
prob_norm = utils.softmax(lls)
# TODO use keyword arguments to the metric?
# gold, pred, norm stuff, the original lls,
# gold, pred, norm stuff, the original lls,
result_dict = {
**({"acc": acc} if "acc" in use_metric else {}),
**({"f1": (gold, pred)} if "f1" in use_metric else {}),
**({"mcc": (gold, pred)} if "mcc" in use_metric else {}),
**({"acc_norm": acc_norm} if "acc_norm" in use_metric else {}),
**({"exact_match": exact_match} if "exact_match" in use_metric else {}),
**({"brier_score": (gold, prob_norm)} if "brier_score" in use_metric else {}),
**(
{"brier_score": (gold, prob_norm)}
if "brier_score" in use_metric
else {}
),
}
if "acc_mutual_info" in use_metric:
......
......@@ -2,9 +2,13 @@ from textdistance import levenshtein
from transformers import AutoTokenizer
# Change this tokenizer to fit with the model you are using.
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b")
tokenizer = AutoTokenizer.from_pretrained("EleutherAI/pythia-2.8b", max_new_tokens=128)
def token_edit_distance(references, predictions, **kwargs):
print(references)
print(predictions)
print("###")
ref_tokens = tokenizer.encode(references[0])
pred_tokens = tokenizer.encode(predictions[0])
return levenshtein.distance(ref_tokens, pred_tokens)
......@@ -16,4 +16,4 @@ metric_list:
ignore_punctuation: true
- metric: !function aux_metric.token_edit_distance # pip install textdistance
aggregation: mean
higher_is_better: false
\ No newline at end of file
higher_is_better: false
......@@ -17,4 +17,4 @@ metric_list:
higher_is_better: true
- metric: brier_score
aggregation: mean
higher_is_better: false
\ No newline at end of file
higher_is_better: false
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment