gcn.py 6.74 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
Ziyue Huang's avatar
Ziyue Huang committed
5
GCN with SPMV specialization.
Minjie Wang's avatar
Minjie Wang committed
6
"""
Ziyue Huang's avatar
Ziyue Huang committed
7
import argparse, time, math
Minjie Wang's avatar
Minjie Wang committed
8
9
10
11
12
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph
13
from dgl.data import register_data_args, load_data
Minjie Wang's avatar
Minjie Wang committed
14
15


Ziyue Huang's avatar
Ziyue Huang committed
16
17
18
19
20
21
22
23
24
def gcn_msg(edge):
    msg = edge.src['h'] * edge.src['norm']
    return {'m': msg}


def gcn_reduce(node):
    accum = torch.sum(node.mailbox['m'], 1) * node.data['norm']
    return {'h': accum}

Minjie Wang's avatar
Minjie Wang committed
25

26
class NodeApplyModule(nn.Module):
Ziyue Huang's avatar
Ziyue Huang committed
27
    def __init__(self, out_feats, activation=None, bias=True):
28
        super(NodeApplyModule, self).__init__()
Ziyue Huang's avatar
Ziyue Huang committed
29
30
31
32
        if bias:
            self.bias = nn.Parameter(torch.Tensor(out_feats))
        else:
            self.bias = None
Minjie Wang's avatar
Minjie Wang committed
33
        self.activation = activation
Ziyue Huang's avatar
Ziyue Huang committed
34
35
36
37
38
39
        self.reset_parameters()

    def reset_parameters(self):
        if self.bias is not None:
            stdv = 1. / math.sqrt(self.bias.size(0))
            self.bias.data.uniform_(-stdv, stdv)
Minjie Wang's avatar
Minjie Wang committed
40

41
    def forward(self, nodes):
Ziyue Huang's avatar
Ziyue Huang committed
42
43
44
        h = nodes.data['h']
        if self.bias is not None:
            h = h + self.bias
Minjie Wang's avatar
Minjie Wang committed
45
46
        if self.activation:
            h = self.activation(h)
Ziyue Huang's avatar
Ziyue Huang committed
47
        return {'h': h}
Minjie Wang's avatar
Minjie Wang committed
48

Ziyue Huang's avatar
Ziyue Huang committed
49
50

class GCNLayer(nn.Module):
Minjie Wang's avatar
Minjie Wang committed
51
    def __init__(self,
Minjie Wang's avatar
Minjie Wang committed
52
                 g,
Minjie Wang's avatar
Minjie Wang committed
53
                 in_feats,
Ziyue Huang's avatar
Ziyue Huang committed
54
                 out_feats,
Minjie Wang's avatar
Minjie Wang committed
55
                 activation,
Ziyue Huang's avatar
Ziyue Huang committed
56
57
58
                 dropout,
                 bias=True):
        super(GCNLayer, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
59
        self.g = g
Ziyue Huang's avatar
Ziyue Huang committed
60
        self.weight = nn.Parameter(torch.Tensor(in_feats, out_feats))
Mufei Li's avatar
Mufei Li committed
61
62
63
64
        if dropout:
            self.dropout = nn.Dropout(p=dropout)
        else:
            self.dropout = 0.
Ziyue Huang's avatar
Ziyue Huang committed
65
66
        self.node_update = NodeApplyModule(out_feats, activation, bias)
        self.reset_parameters()
Mufei Li's avatar
Mufei Li committed
67

Ziyue Huang's avatar
Ziyue Huang committed
68
69
70
    def reset_parameters(self):
        stdv = 1. / math.sqrt(self.weight.size(1))
        self.weight.data.uniform_(-stdv, stdv)
71

Ziyue Huang's avatar
Ziyue Huang committed
72
73
74
75
76
77
78
    def forward(self, h):
        if self.dropout:
            h = self.dropout(h)
        self.g.ndata['h'] = torch.mm(h, self.weight)
        self.g.update_all(gcn_msg, gcn_reduce, self.node_update)
        h = self.g.ndata.pop('h')
        return h
Mufei Li's avatar
Mufei Li committed
79

Ziyue Huang's avatar
Ziyue Huang committed
80
81
82
83
84
85
86
87
88
89
90
91
92
class GCN(nn.Module):
    def __init__(self,
                 g,
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super(GCN, self).__init__()
        self.layers = nn.ModuleList()
        # input layer
        self.layers.append(GCNLayer(g, in_feats, n_hidden, activation, dropout))
Minjie Wang's avatar
Minjie Wang committed
93
94
        # hidden layers
        for i in range(n_layers - 1):
Ziyue Huang's avatar
Ziyue Huang committed
95
            self.layers.append(GCNLayer(g, n_hidden, n_hidden, activation, dropout))
Minjie Wang's avatar
Minjie Wang committed
96
        # output layer
Ziyue Huang's avatar
Ziyue Huang committed
97
        self.layers.append(GCNLayer(g, n_hidden, n_classes, None, dropout))
Minjie Wang's avatar
Minjie Wang committed
98

Minjie Wang's avatar
Minjie Wang committed
99
    def forward(self, features):
Ziyue Huang's avatar
Ziyue Huang committed
100
101
102
103
        h = features
        for layer in self.layers:
            h = layer(h)
        return h
Minjie Wang's avatar
Minjie Wang committed
104

105
106
107
108
109
110
111
112
113
114
def evaluate(model, features, labels, mask):
    model.eval()
    with torch.no_grad():
        logits = model(features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = torch.max(logits, dim=1)
        correct = torch.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)

Minjie Wang's avatar
Minjie Wang committed
115
116
def main(args):
    # load and preprocess dataset
117
    data = load_data(args)
Minjie Wang's avatar
Minjie Wang committed
118
119
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
120
121
122
    train_mask = torch.ByteTensor(data.train_mask)
    val_mask = torch.ByteTensor(data.val_mask)
    test_mask = torch.ByteTensor(data.test_mask)
Minjie Wang's avatar
Minjie Wang committed
123
    in_feats = features.shape[1]
Minjie Wang's avatar
Minjie Wang committed
124
125
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()
Ziyue Huang's avatar
Ziyue Huang committed
126
127
128
129
130
131
132
133
134
135
    print("""----Data statistics------'
      #Edges %d
      #Classes %d
      #Train samples %d
      #Val samples %d
      #Test samples %d""" %
          (n_edges, n_classes,
              train_mask.sum().item(),
              val_mask.sum().item(),
              test_mask.sum().item()))
Minjie Wang's avatar
Minjie Wang committed
136
137
138
139
140
141

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
Minjie Wang's avatar
Minjie Wang committed
142
        features = features.cuda()
Minjie Wang's avatar
Minjie Wang committed
143
        labels = labels.cuda()
144
145
146
        train_mask = train_mask.cuda()
        val_mask = val_mask.cuda()
        test_mask = test_mask.cuda()
Minjie Wang's avatar
Minjie Wang committed
147

148
    # graph preprocess and calculate normalization factor
Minjie Wang's avatar
Minjie Wang committed
149
    g = DGLGraph(data.graph)
150
151
152
153
154
155
156
157
158
159
160
161
    n_edges = g.number_of_edges()
    # add self loop
    g.add_edges(g.nodes(), g.nodes())
    # normalization
    degs = g.in_degrees().float()
    norm = torch.pow(degs, -0.5)
    norm[torch.isinf(norm)] = 0
    if cuda:
        norm = norm.cuda()
    g.ndata['norm'] = norm.unsqueeze(1)

    # create GCN model
Minjie Wang's avatar
Minjie Wang committed
162
    model = GCN(g,
Minjie Wang's avatar
Minjie Wang committed
163
164
165
166
167
168
169
170
171
                in_feats,
                args.n_hidden,
                n_classes,
                args.n_layers,
                F.relu,
                args.dropout)

    if cuda:
        model.cuda()
Ziyue Huang's avatar
Ziyue Huang committed
172
    loss_fcn = torch.nn.CrossEntropyLoss()
Minjie Wang's avatar
Minjie Wang committed
173
174

    # use optimizer
175
176
177
    optimizer = torch.optim.Adam(model.parameters(),
                                 lr=args.lr,
                                 weight_decay=args.weight_decay)
Minjie Wang's avatar
Minjie Wang committed
178
179
180
181

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
182
        model.train()
Minjie Wang's avatar
Minjie Wang committed
183
184
185
        if epoch >= 3:
            t0 = time.time()
        # forward
Minjie Wang's avatar
Minjie Wang committed
186
        logits = model(features)
Ziyue Huang's avatar
Ziyue Huang committed
187
        loss = loss_fcn(logits[train_mask], labels[train_mask])
Minjie Wang's avatar
Minjie Wang committed
188
189
190
191
192
193
194
195

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

196
197
        acc = evaluate(model, features, labels, val_mask)
        print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
Ziyue Huang's avatar
Ziyue Huang committed
198
199
              "ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
                                             acc, n_edges / np.mean(dur) / 1000))
200
201
202
203
204

    print()
    acc = evaluate(model, features, labels, test_mask)
    print("Test Accuracy {:.4f}".format(acc))

Minjie Wang's avatar
Minjie Wang committed
205
206
207

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
208
    register_data_args(parser)
209
    parser.add_argument("--dropout", type=float, default=0.5,
Minjie Wang's avatar
Minjie Wang committed
210
211
212
            help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
213
    parser.add_argument("--lr", type=float, default=1e-2,
Minjie Wang's avatar
Minjie Wang committed
214
            help="learning rate")
215
    parser.add_argument("--n-epochs", type=int, default=200,
Minjie Wang's avatar
Minjie Wang committed
216
217
218
219
220
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
            help="number of hidden gcn layers")
221
222
    parser.add_argument("--weight-decay", type=float, default=5e-4,
            help="Weight for L2 loss")
Minjie Wang's avatar
Minjie Wang committed
223
224
225
226
    args = parser.parse_args()
    print(args)

    main(args)