"vscode:/vscode.git/clone" did not exist on "5f633fcbc223fa901bf940f941cbbad09fffacd7"
gcnii.py 4.68 KB
Newer Older
paoxiaode's avatar
paoxiaode committed
1
2
3
4
5
6
7
"""
[Simple and Deep Graph Convolutional Networks]
(https://arxiv.org/abs/2007.02133)
"""

import math

8
import dgl.sparse as dglsp
9

paoxiaode's avatar
paoxiaode committed
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl.data import CoraGraphDataset
from torch.optim import Adam


class GCNIIConvolution(nn.Module):
    def __init__(self, in_size, out_size):
        super().__init__()
        self.out_size = out_size
        self.weight = nn.Linear(in_size, out_size, bias=False)

    ############################################################################
    # (HIGHLIGHT) Take the advantage of DGL sparse APIs to implement the GCNII
    # forward process.
    ############################################################################
    def forward(self, A_norm, H, H0, lamda, alpha, l):
        beta = math.log(lamda / l + 1)

        # Multiply a sparse matrix by a dense matrix.
        H = A_norm @ H
        H = (1 - alpha) * H + alpha * H0
        H = (1 - beta) * H + beta * self.weight(H)
        return H


class GCNII(nn.Module):
    def __init__(
        self,
        in_size,
        out_size,
        hidden_size,
        n_layers,
        lamda,
        alpha,
        dropout=0.5,
    ):
        super().__init__()
        self.hidden_size = hidden_size
        self.n_layers = n_layers
        self.lamda = lamda
        self.alpha = alpha

        # The GCNII model.
        self.layers = nn.ModuleList()
        self.layers.append(nn.Linear(in_size, hidden_size))
        for _ in range(n_layers):
            self.layers.append(GCNIIConvolution(hidden_size, hidden_size))
        self.layers.append(nn.Linear(hidden_size, out_size))

        self.activation = nn.ReLU()
        self.dropout = dropout

    def forward(self, A_norm, feature):
        H = feature
        H = F.dropout(H, self.dropout, training=self.training)
        H = self.layers[0](H)
        H = self.activation(H)
        H0 = H

        # The GCNII convolution forward.
        for i, conv in enumerate(self.layers[1:-1]):
            H = F.dropout(H, self.dropout, training=self.training)
            H = conv(A_norm, H, H0, self.lamda, self.alpha, i + 1)
            H = self.activation(H)

        H = F.dropout(H, self.dropout, training=self.training)
        H = self.layers[-1](H)

        return H


def evaluate(model, A_norm, H, label, val_mask, test_mask):
    model.eval()
    logits = model(A_norm, H)
    pred = logits.argmax(dim=1)

    # Compute accuracy on validation/test set.
    val_acc = (pred[val_mask] == label[val_mask]).float().mean()
    test_acc = (pred[test_mask] == label[test_mask]).float().mean()
    return val_acc, test_acc


def train(model, g, A_norm, H):
    label = g.ndata["label"]
    train_mask = g.ndata["train_mask"]
    val_mask = g.ndata["val_mask"]
    test_mask = g.ndata["test_mask"]
    optimizer = Adam(model.parameters(), lr=0.01, weight_decay=5e-4)

    loss_fcn = nn.CrossEntropyLoss()

    for epoch in range(100):
        model.train()
        optimizer.zero_grad()

        # Forward.
        logits = model(A_norm, H)

        # Compute loss with nodes in the training set.
        loss = loss_fcn(logits[train_mask], label[train_mask])

        # Backward.
        loss.backward()
        optimizer.step()

        # Evaluate the prediction.
        val_acc, test_acc = evaluate(
            model, A_norm, H, label, val_mask, test_mask
        )
        if epoch % 5 == 0:
            print(
                f"In epoch {epoch}, loss: {loss:.3f}, val acc: {val_acc:.3f}"
                f", test acc: {test_acc:.3f}"
            )


if __name__ == "__main__":
    # If CUDA is available, use GPU to accelerate the training, use CPU
    # otherwise.
    dev = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

    # Load graph from the existing dataset.
    dataset = CoraGraphDataset()
    g = dataset[0].to(dev)
    num_classes = dataset.num_classes
    H = g.ndata["feat"]

    # Create the adjacency matrix of graph.
    src, dst = g.edges()
    N = g.num_nodes()
142
    A = dglsp.create_from_coo(dst, src, shape=(N, N))
paoxiaode's avatar
paoxiaode committed
143
144
145
146
147

    ############################################################################
    # (HIGHLIGHT) Compute the symmetrically normalized adjacency matrix with
    # Sparse Matrix API
    ############################################################################
148
    I = dglsp.identity(A.shape, device=dev)
paoxiaode's avatar
paoxiaode committed
149
    A_hat = A + I
150
    D_hat = dglsp.diag(A_hat.sum(1)) ** -0.5
paoxiaode's avatar
paoxiaode committed
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
    A_norm = D_hat @ A_hat @ D_hat

    # Create model.
    in_size = H.shape[1]
    out_size = num_classes
    model = GCNII(
        in_size,
        out_size,
        hidden_size=64,
        n_layers=64,
        lamda=0.5,
        alpha=0.2,
        dropout=0.5,
    ).to(dev)

    # Kick off training.
    train(model, g, A_norm, H)