t5_utils.py 1.5 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
3
4
import collections

import numpy as np

lintangsutawika's avatar
lintangsutawika committed
5

lintangsutawika's avatar
lintangsutawika committed
6
7
8
def f1(predictions, references):  # This is a passthrough function
    _prediction = predictions[0]
    _reference = references[0].split("_")[-1]
lintangsutawika's avatar
lintangsutawika committed
9
    string_label = ["False", "True"]
lintangsutawika's avatar
lintangsutawika committed
10
    reference = string_label.index(_reference)
lintangsutawika's avatar
lintangsutawika committed
11
12
13
14
15
    prediction = (
        string_label.index(_prediction)
        if _prediction in string_label
        else not bool(reference)
    )
lintangsutawika's avatar
lintangsutawika committed
16
17
18

    return (prediction, reference)

lintangsutawika's avatar
lintangsutawika committed
19

lintangsutawika's avatar
lintangsutawika committed
20
def agg_f1(items):
21
22
    from sklearn.metrics import f1_score

lintangsutawika's avatar
lintangsutawika committed
23
24
25
    predictions, references = zip(*items)
    references, predictions = np.asarray(references), np.asarray(predictions)

26
    return f1_score(references, predictions)
lintangsutawika's avatar
lintangsutawika committed
27
28
29
30
31


def em(predictions, references):  # This is a passthrough function
    _prediction = predictions[0]
    _group, _reference = references[0].split("_")
lintangsutawika's avatar
lintangsutawika committed
32
    string_label = ["False", "True"]
lintangsutawika's avatar
lintangsutawika committed
33
    reference = string_label.index(_reference)
lintangsutawika's avatar
lintangsutawika committed
34
35
36
37
38
    prediction = (
        string_label.index(_prediction)
        if _prediction in string_label
        else not bool(reference)
    )
lintangsutawika's avatar
lintangsutawika committed
39

lintangsutawika's avatar
lintangsutawika committed
40
    return (_group, prediction, reference)
lintangsutawika's avatar
lintangsutawika committed
41
42
43
44
45
46
47
48
49
50
51
52
53


def agg_em(items):
    grouped_values = collections.defaultdict(lambda: ([], []))
    for group, prediction, reference in items:
        grouped_values[group][0].append(reference)
        grouped_values[group][1].append(prediction)

    group_scores = []
    for group, (targets, predictions) in grouped_values.items():
        score = float(np.array_equal(targets, predictions))
        group_scores.append(score)

lintangsutawika's avatar
lintangsutawika committed
54
    return np.mean(group_scores)