test_readout.py 3.77 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
import dgl
import torch
import torch.nn.functional as F

from dgl import DGLGraph
from dgllife.model.readout import *

def test_graph1():
    """Graph with node features"""
    g = DGLGraph([(0, 1), (0, 2), (1, 2)])
    return g, torch.arange(g.number_of_nodes()).float().reshape(-1, 1)

def test_graph2():
    "Batched graph with node features"
    g1 = DGLGraph([(0, 1), (0, 2), (1, 2)])
    g2 = DGLGraph([(0, 1), (1, 2), (1, 3), (1, 4)])
    bg = dgl.batch([g1, g2])
    return bg, torch.arange(bg.number_of_nodes()).float().reshape(-1, 1)

def test_weighted_sum_and_max():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g, node_feats = test_graph1()
    g, node_feats = g.to(device), node_feats.to(device)
    bg, batch_node_feats = test_graph2()
    bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
    model = WeightedSumAndMax(in_feats=1).to(device)
    assert model(g, node_feats).shape == torch.Size([1, 2])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 2])

def test_attentive_fp_readout():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g, node_feats = test_graph1()
    g, node_feats = g.to(device), node_feats.to(device)
    bg, batch_node_feats = test_graph2()
    bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)
    model = AttentiveFPReadout(feat_size=1,
                               num_timesteps=1).to(device)
    assert model(g, node_feats).shape == torch.Size([1, 1])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 1])

def test_mlp_readout():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g, node_feats = test_graph1()
    g, node_feats = g.to(device), node_feats.to(device)
    bg, batch_node_feats = test_graph2()
    bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)

    model = MLPNodeReadout(node_feats=1,
                            hidden_feats=2,
                            graph_feats=3,
                            activation=F.relu,
                            mode='sum').to(device)
    assert model(g, node_feats).shape == torch.Size([1, 3])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 3])

    model = MLPNodeReadout(node_feats=1,
                            hidden_feats=2,
                            graph_feats=3,
                            mode='max').to(device)
    assert model(g, node_feats).shape == torch.Size([1, 3])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 3])

    model = MLPNodeReadout(node_feats=1,
                            hidden_feats=2,
                            graph_feats=3,
                            mode='mean').to(device)
    assert model(g, node_feats).shape == torch.Size([1, 3])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 3])

82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
def test_weave_readout():
    if torch.cuda.is_available():
        device = torch.device('cuda:0')
    else:
        device = torch.device('cpu')

    g, node_feats = test_graph1()
    g, node_feats = g.to(device), node_feats.to(device)
    bg, batch_node_feats = test_graph2()
    bg, batch_node_feats = bg.to(device), batch_node_feats.to(device)

    model = WeaveGather(node_in_feats=1).to(device)
    assert model(g, node_feats).shape == torch.Size([1, 1])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 1])

    model = WeaveGather(node_in_feats=1, gaussian_expand=False).to(device)
    assert model(g, node_feats).shape == torch.Size([1, 1])
    assert model(bg, batch_node_feats).shape == torch.Size([2, 1])

101
102
103
104
if __name__ == '__main__':
    test_weighted_sum_and_max()
    test_attentive_fp_readout()
    test_mlp_readout()
105
    test_weave_readout()