"tests/vscode:/vscode.git/clone" did not exist on "1d2291943e2b0f3862f202c1b123dbb15e0c76f3"
test_eval.py 4.41 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import numpy as np
import torch

from dgllife.utils.eval import *

def test_Meter():
    label = torch.tensor([[0., 1.],
                          [0., 1.],
                          [1., 0.]])
    pred = torch.tensor([[0.5, 0.5],
                         [0., 1.],
                         [1., 0.]])
    mask = torch.tensor([[1., 0.],
                         [0., 1.],
                         [1., 1.]])
    label_mean, label_std = label.mean(dim=0), label.std(dim=0)

    # pearson r2
    meter = Meter(label_mean, label_std)
20
21
    meter.update(pred, label)
    true_scores = [0.7500000774286983, 0.7500000516191412]
22
23
24
25
26
27
28
29
    assert meter.pearson_r2() == true_scores
    assert meter.pearson_r2('mean') == np.mean(true_scores)
    assert meter.pearson_r2('sum') == np.sum(true_scores)
    assert meter.compute_metric('r2') == true_scores
    assert meter.compute_metric('r2', 'mean') == np.mean(true_scores)
    assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)

    meter = Meter(label_mean, label_std)
30
    meter.update(pred, label, mask)
31
32
33
34
35
36
37
38
39
40
    true_scores = [1.0, 1.0]
    assert meter.pearson_r2() == true_scores
    assert meter.pearson_r2('mean') == np.mean(true_scores)
    assert meter.pearson_r2('sum') == np.sum(true_scores)
    assert meter.compute_metric('r2') == true_scores
    assert meter.compute_metric('r2', 'mean') == np.mean(true_scores)
    assert meter.compute_metric('r2', 'sum') == np.sum(true_scores)

    # mae
    meter = Meter()
41
    meter.update(pred, label)
42
43
44
45
46
47
48
49
50
    true_scores = [0.1666666716337204, 0.1666666716337204]
    assert meter.mae() == true_scores
    assert meter.mae('mean') == np.mean(true_scores)
    assert meter.mae('sum') == np.sum(true_scores)
    assert meter.compute_metric('mae') == true_scores
    assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
    assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)

    meter = Meter()
51
    meter.update(pred, label, mask)
52
53
54
55
56
57
58
59
    true_scores = [0.25, 0.0]
    assert meter.mae() == true_scores
    assert meter.mae('mean') == np.mean(true_scores)
    assert meter.mae('sum') == np.sum(true_scores)
    assert meter.compute_metric('mae') == true_scores
    assert meter.compute_metric('mae', 'mean') == np.mean(true_scores)
    assert meter.compute_metric('mae', 'sum') == np.sum(true_scores)

60
    # rmsef
61
    meter = Meter(label_mean, label_std)
62
63
    meter.update(pred, label)
    true_scores = [0.41068359261794546, 0.4106836107598449]
64
65
    assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
    assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
66
67

    meter = Meter(label_mean, label_std)
68
69
    meter.update(pred, label, mask)
    true_scores = [0.44433766459035057, 0.5019903799993205]
70
71
    assert torch.allclose(torch.tensor(meter.rmse()), torch.tensor(true_scores))
    assert torch.allclose(torch.tensor(meter.compute_metric('rmse')), torch.tensor(true_scores))
72
73
74

    # roc auc score
    meter = Meter()
75
76
    meter.update(pred, label)
    true_scores = [1.0, 1.0]
77
78
79
80
81
82
83
84
    assert meter.roc_auc_score() == true_scores
    assert meter.roc_auc_score('mean') == np.mean(true_scores)
    assert meter.roc_auc_score('sum') == np.sum(true_scores)
    assert meter.compute_metric('roc_auc_score') == true_scores
    assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
    assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)

    meter = Meter()
85
    meter.update(pred, label, mask)
86
87
88
89
90
91
92
93
    true_scores = [1.0, 1.0]
    assert meter.roc_auc_score() == true_scores
    assert meter.roc_auc_score('mean') == np.mean(true_scores)
    assert meter.roc_auc_score('sum') == np.sum(true_scores)
    assert meter.compute_metric('roc_auc_score') == true_scores
    assert meter.compute_metric('roc_auc_score', 'mean') == np.mean(true_scores)
    assert meter.compute_metric('roc_auc_score', 'sum') == np.sum(true_scores)

94
95
96
97
98
99
100
101
102
103
104
105
106
107
def test_cases_with_undefined_scores():
    label = torch.tensor([[0., 1.],
                          [0., 1.],
                          [1., 1.]])
    pred = torch.tensor([[0.5, 0.5],
                         [0., 1.],
                         [1., 0.]])
    meter = Meter()
    meter.update(pred, label)
    true_scores = [1.0]
    assert meter.roc_auc_score() == true_scores
    assert meter.roc_auc_score('mean') == np.mean(true_scores)
    assert meter.roc_auc_score('sum') == np.sum(true_scores)

108
109
if __name__ == '__main__':
    test_Meter()
110
    test_cases_with_undefined_scores()