gcn.py 4.08 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
"""
Semi-Supervised Classification with Graph Convolutional Networks
Paper: https://arxiv.org/abs/1609.02907
Code: https://github.com/tkipf/gcn
Minjie Wang's avatar
Minjie Wang committed
5
6

GCN with batch processing
Minjie Wang's avatar
Minjie Wang committed
7
8
9
10
11
12
13
"""
import argparse
import numpy as np
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
Minjie Wang's avatar
Minjie Wang committed
14
import dgl
Minjie Wang's avatar
Minjie Wang committed
15
from dgl import DGLGraph
16
from dgl.data import register_data_args, load_data
Minjie Wang's avatar
Minjie Wang committed
17
18

def gcn_msg(src, edge):
Minjie Wang's avatar
Minjie Wang committed
19
    return {'m' : src['h']}
Minjie Wang's avatar
Minjie Wang committed
20
21

def gcn_reduce(node, msgs):
Minjie Wang's avatar
Minjie Wang committed
22
    return {'h' : torch.sum(msgs['m'], 1)}
Minjie Wang's avatar
Minjie Wang committed
23

24
class NodeApplyModule(nn.Module):
Minjie Wang's avatar
Minjie Wang committed
25
    def __init__(self, in_feats, out_feats, activation=None):
26
        super(NodeApplyModule, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
27
28
29
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

30
    def forward(self, node):
Minjie Wang's avatar
Minjie Wang committed
31
        h = self.linear(node['h'])
Minjie Wang's avatar
Minjie Wang committed
32
33
        if self.activation:
            h = self.activation(h)
Minjie Wang's avatar
Minjie Wang committed
34
        return {'h' : h}
Minjie Wang's avatar
Minjie Wang committed
35
36
37

class GCN(nn.Module):
    def __init__(self,
Minjie Wang's avatar
Minjie Wang committed
38
                 g,
Minjie Wang's avatar
Minjie Wang committed
39
40
41
42
43
44
45
                 in_feats,
                 n_hidden,
                 n_classes,
                 n_layers,
                 activation,
                 dropout):
        super(GCN, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
46
        self.g = g
Minjie Wang's avatar
Minjie Wang committed
47
48
        self.dropout = dropout
        # input layer
49
        self.layers = nn.ModuleList([NodeApplyModule(in_feats, n_hidden, activation)])
Minjie Wang's avatar
Minjie Wang committed
50
51
        # hidden layers
        for i in range(n_layers - 1):
52
            self.layers.append(NodeApplyModule(n_hidden, n_hidden, activation))
Minjie Wang's avatar
Minjie Wang committed
53
        # output layer
54
        self.layers.append(NodeApplyModule(n_hidden, n_classes))
Minjie Wang's avatar
Minjie Wang committed
55

Minjie Wang's avatar
Minjie Wang committed
56
    def forward(self, features):
Minjie Wang's avatar
Minjie Wang committed
57
        self.g.set_n_repr({'h' : features})
Minjie Wang's avatar
Minjie Wang committed
58
59
60
        for layer in self.layers:
            # apply dropout
            if self.dropout:
Minjie Wang's avatar
Minjie Wang committed
61
62
                g.apply_nodes(apply_node_func=
                        lambda node: F.dropout(node['h'], p=self.dropout))
Minjie Wang's avatar
Minjie Wang committed
63
            self.g.update_all(gcn_msg, gcn_reduce, layer)
Minjie Wang's avatar
Minjie Wang committed
64
        return self.g.pop_n_repr('h')
Minjie Wang's avatar
Minjie Wang committed
65
66
67

def main(args):
    # load and preprocess dataset
68
    data = load_data(args)
Minjie Wang's avatar
Minjie Wang committed
69

Minjie Wang's avatar
Minjie Wang committed
70
71
72
73
    features = torch.FloatTensor(data.features)
    labels = torch.LongTensor(data.labels)
    mask = torch.ByteTensor(data.train_mask)
    in_feats = features.shape[1]
Minjie Wang's avatar
Minjie Wang committed
74
75
76
77
78
79
80
81
    n_classes = data.num_labels
    n_edges = data.graph.number_of_edges()

    if args.gpu < 0:
        cuda = False
    else:
        cuda = True
        torch.cuda.set_device(args.gpu)
Minjie Wang's avatar
Minjie Wang committed
82
        features = features.cuda()
Minjie Wang's avatar
Minjie Wang committed
83
        labels = labels.cuda()
Minjie Wang's avatar
Minjie Wang committed
84
        mask = mask.cuda()
Minjie Wang's avatar
Minjie Wang committed
85
86

    # create GCN model
Minjie Wang's avatar
Minjie Wang committed
87
88
    g = DGLGraph(data.graph)
    model = GCN(g,
Minjie Wang's avatar
Minjie Wang committed
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
                in_feats,
                args.n_hidden,
                n_classes,
                args.n_layers,
                F.relu,
                args.dropout)

    if cuda:
        model.cuda()

    # use optimizer
    optimizer = torch.optim.Adam(model.parameters(), lr=args.lr)

    # initialize graph
    dur = []
    for epoch in range(args.n_epochs):
        if epoch >= 3:
            t0 = time.time()
        # forward
Minjie Wang's avatar
Minjie Wang committed
108
        logits = model(features)
Minjie Wang's avatar
Minjie Wang committed
109
        logp = F.log_softmax(logits, 1)
Minjie Wang's avatar
Minjie Wang committed
110
        loss = F.nll_loss(logp[mask], labels[mask])
Minjie Wang's avatar
Minjie Wang committed
111
112
113
114
115
116
117
118
119
120
121
122
123

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

        if epoch >= 3:
            dur.append(time.time() - t0)

        print("Epoch {:05d} | Loss {:.4f} | Time(s) {:.4f} | ETputs(KTEPS) {:.2f}".format(
            epoch, loss.item(), np.mean(dur), n_edges / np.mean(dur) / 1000))

if __name__ == '__main__':
    parser = argparse.ArgumentParser(description='GCN')
124
    register_data_args(parser)
Minjie Wang's avatar
Minjie Wang committed
125
126
127
128
129
130
    parser.add_argument("--dropout", type=float, default=0,
            help="dropout probability")
    parser.add_argument("--gpu", type=int, default=-1,
            help="gpu")
    parser.add_argument("--lr", type=float, default=1e-3,
            help="learning rate")
Minjie Wang's avatar
Minjie Wang committed
131
    parser.add_argument("--n-epochs", type=int, default=20,
Minjie Wang's avatar
Minjie Wang committed
132
133
134
135
136
137
138
139
140
            help="number of training epochs")
    parser.add_argument("--n-hidden", type=int, default=16,
            help="number of hidden gcn units")
    parser.add_argument("--n-layers", type=int, default=1,
            help="number of hidden gcn layers")
    args = parser.parse_args()
    print(args)

    main(args)