utils.py 2.64 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
def calculate_measure(tp, fn, fp):
    # avoid nan
    if tp == 0:
        return 0, 0, 0

    p = tp * 1.0 / (tp + fp)
    r = tp * 1.0 / (tp + fn)
    if (p + r) > 0:
        f1 = 2.0 * (p * r) / (p + r)
    else:
        f1 = 0
    return p, r, f1


class Measure(object):
    def __init__(self, num_classes, target_class):
        """

        Args:
            num_classes: number of classes.
            target_class: target class we focus on, used to print info and do early stopping.
        """
        self.num_classes = num_classes
        self.target_class = target_class
        self.true_positives = {}
        self.false_positives = {}
        self.false_negatives = {}
        self.target_best_f1 = 0.0
        self.target_best_f1_epoch = 0
        self.reset_info()

    def reset_info(self):
        """
            reset info after each epoch.
        """
        self.true_positives = {cur_class: [] for cur_class in range(self.num_classes)}
        self.false_positives = {cur_class: [] for cur_class in range(self.num_classes)}
        self.false_negatives = {cur_class: [] for cur_class in range(self.num_classes)}

    def append_measures(self, predictions, labels):
        predicted_classes = predictions.argmax(dim=1)
        for cl in range(self.num_classes):
            cl_indices = (labels == cl)
            pos = (predicted_classes == cl)
            hits = (predicted_classes[cl_indices] == labels[cl_indices])

            tp = hits.sum()
            fn = hits.size(0) - tp
            fp = pos.sum() - tp

            self.true_positives[cl].append(tp.cpu())
            self.false_negatives[cl].append(fn.cpu())
            self.false_positives[cl].append(fp.cpu())

    def get_each_timestamp_measure(self):
        precisions = []
        recalls = []
        f1s = []
        for i in range(len(self.true_positives[self.target_class])):
            tp = self.true_positives[self.target_class][i]
            fn = self.false_negatives[self.target_class][i]
            fp = self.false_positives[self.target_class][i]

            p, r, f1 = calculate_measure(tp, fn, fp)
            precisions.append(p)
            recalls.append(r)
            f1s.append(f1)
        return precisions, recalls, f1s

    def get_total_measure(self):
        tp = sum(self.true_positives[self.target_class])
        fn = sum(self.false_negatives[self.target_class])
        fp = sum(self.false_positives[self.target_class])

        p, r, f1 = calculate_measure(tp, fn, fp)
        return p, r, f1

    def update_best_f1(self, cur_f1, cur_epoch):
        if cur_f1 > self.target_best_f1:
            self.target_best_f1 = cur_f1
            self.target_best_f1_epoch = cur_epoch