t5_utils.py 1.5 KB
Newer Older
lintangsutawika's avatar
lintangsutawika committed
1
2
3
4
5
import collections

import numpy as np
import sklearn.metrics

lintangsutawika's avatar
lintangsutawika committed
6

lintangsutawika's avatar
lintangsutawika committed
7
8
9
10
def f1(predictions, references):  # This is a passthrough function

    _prediction = predictions[0]
    _reference = references[0].split("_")[-1]
lintangsutawika's avatar
lintangsutawika committed
11
    string_label = ["False", "True"]
lintangsutawika's avatar
lintangsutawika committed
12
    reference = string_label.index(_reference)
lintangsutawika's avatar
lintangsutawika committed
13
14
15
16
17
    prediction = (
        string_label.index(_prediction)
        if _prediction in string_label
        else not bool(reference)
    )
lintangsutawika's avatar
lintangsutawika committed
18
19
20

    return (prediction, reference)

lintangsutawika's avatar
lintangsutawika committed
21

lintangsutawika's avatar
lintangsutawika committed
22
23
24
25
26
27
28
29
30
31
32
33
def agg_f1(items):

    predictions, references = zip(*items)
    references, predictions = np.asarray(references), np.asarray(predictions)

    return sklearn.metrics.f1_score(references, predictions)


def em(predictions, references):  # This is a passthrough function

    _prediction = predictions[0]
    _group, _reference = references[0].split("_")
lintangsutawika's avatar
lintangsutawika committed
34
    string_label = ["False", "True"]
lintangsutawika's avatar
lintangsutawika committed
35
    reference = string_label.index(_reference)
lintangsutawika's avatar
lintangsutawika committed
36
37
38
39
40
    prediction = (
        string_label.index(_prediction)
        if _prediction in string_label
        else not bool(reference)
    )
lintangsutawika's avatar
lintangsutawika committed
41

lintangsutawika's avatar
lintangsutawika committed
42
    return (_group, prediction, reference)
lintangsutawika's avatar
lintangsutawika committed
43
44
45
46
47
48
49
50
51
52
53
54
55


def agg_em(items):
    grouped_values = collections.defaultdict(lambda: ([], []))
    for group, prediction, reference in items:
        grouped_values[group][0].append(reference)
        grouped_values[group][1].append(prediction)

    group_scores = []
    for group, (targets, predictions) in grouped_values.items():
        score = float(np.array_equal(targets, predictions))
        group_scores.append(score)

lintangsutawika's avatar
lintangsutawika committed
56
    return np.mean(group_scores)