utils.py 1.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
from collections import Counter
from string import punctuation

import numpy as np


def normalize(text):
    exclude = set(punctuation)
    return "".join(ch for ch in text if ch not in exclude).lower().strip()


def f1(prediction, completion):
    gold_toks = completion.split()
    pred_toks = prediction.split()
    common = Counter(gold_toks) & Counter(pred_toks)
    num_same = sum(common.values())
    if len(gold_toks) == 0 or len(pred_toks) == 0:
        return int(gold_toks == pred_toks)
    if num_same == 0:
        return 0
    precision = 1.0 * num_same / len(pred_toks)
    recall = 1.0 * num_same / len(gold_toks)
    f1 = (2 * precision * recall) / (precision + recall)
    return f1


def process_results(doc, results):
    prediction = normalize(results[0])
    completions = [normalize(completion) for completion in doc["accepted_completions"]]
    exact_match = np.nanmax(
        [int(prediction == completion) for completion in completions]
    )
    fscore = np.nanmax(
        [f1(prediction=prediction, completion=completion) for completion in completions]
    )
    return {"em": exact_match, "fscore": fscore}


def filter_dataset_nb(dataset):
    return dataset.filter(lambda example: example["language"] == "nob")


def filter_dataset_nn(dataset):
    return dataset.filter(lambda example: example["language"] == "nno")