1_gcn.py 7.51 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
.. _model-gcn:

Graph Convolutional Network
====================================

**Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang  <https://jermainewang.github.io/>`_,
Yu Gai, Quan Gan, Zheng Zhang

This is a gentle introduction of using DGL to implement Graph Convolutional
brett koonce's avatar
brett koonce committed
11
Networks (Kipf & Welling et al., `Semi-Supervised Classification with Graph
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We build upon
the :doc:`earlier tutorial <../../basics/3_pagerank>` on DGLGraph and demonstrate
how DGL combines graph with deep neural network and learn structural representations.
"""

###############################################################################
# Model Overview
# ------------------------------------------
# GCN from the perspective of message passing
# ```````````````````````````````````````````````
# We describe a layer of graph convolutional neural network from a message
# passing perspective; the math can be found `here <math_>`_.
# It boils down to the following step, for each node :math:`u`:
# 
# 1) Aggregate neighbors' representations :math:`h_{v}` to produce an
# intermediate representation :math:`\hat{h}_u`.  2) Transform the aggregated
# representation :math:`\hat{h}_{u}` with a linear projection followed by a
# non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.
# 
# We will implement step 1 with DGL message passing, and step 2 with the
# ``apply_nodes`` method, whose node UDF will be a PyTorch ``nn.Module``.
# 
# GCN implementation with DGL
# ``````````````````````````````````````````
# We first define the message and reduce function as usual.  Since the
# aggregation on a node :math:`u` only involves summing over the neighbors'
# representations :math:`h_v`, we can simply use builtin functions:

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')

###############################################################################
# We then define the node UDF for ``apply_nodes``, which is a fully-connected layer:

class NodeApplyModule(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(NodeApplyModule, self).__init__()
        self.linear = nn.Linear(in_feats, out_feats)
        self.activation = activation

    def forward(self, node):
        h = self.linear(node.data['h'])
Da Zheng's avatar
Da Zheng committed
61
62
        if self.activation is not None:
            h = self.activation(h)
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
        return {'h' : h}

###############################################################################
# We then proceed to define the GCN module. A GCN layer essentially performs
# message passing on all the nodes then applies the `NodeApplyModule`. Note
# that we omitted the dropout in the paper for simplicity.

class GCN(nn.Module):
    def __init__(self, in_feats, out_feats, activation):
        super(GCN, self).__init__()
        self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)

    def forward(self, g, feature):
        g.ndata['h'] = feature
        g.update_all(gcn_msg, gcn_reduce)
        g.apply_nodes(func=self.apply_mod)
        return g.ndata.pop('h')

###############################################################################
# The forward function is essentially the same as any other commonly seen NNs
# model in PyTorch.  We can initialize GCN like any ``nn.Module``. For example,
# let's define a simple neural network consisting of two GCN layers. Suppose we
# are training the classifier for the cora dataset (the input feature size is
Da Zheng's avatar
Da Zheng committed
86
87
# 1433 and the number of classes is 7). The last GCN layer computes node embeddings,
# so the last layer in general doesn't apply activation.
88
89
90
91
92

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
        self.gcn1 = GCN(1433, 16, F.relu)
Da Zheng's avatar
Da Zheng committed
93
        self.gcn2 = GCN(16, 7, None)
94
95
96
97
98
99
100
101
102
103
104
105
    
    def forward(self, g, features):
        x = self.gcn1(g, features)
        x = self.gcn2(g, x)
        return x
net = Net()
print(net)

###############################################################################
# We load the cora dataset using DGL's built-in data module.

from dgl.data import citation_graph as citegrh
106
import networkx as nx
107
108
109
110
def load_cora_data():
    data = citegrh.load_cora()
    features = th.FloatTensor(data.features)
    labels = th.LongTensor(data.labels)
Da Zheng's avatar
Da Zheng committed
111
112
    train_mask = th.ByteTensor(data.train_mask)
    test_mask = th.ByteTensor(data.test_mask)
Hao Zhang's avatar
Hao Zhang committed
113
114
    g = data.graph
    # add self loop
115
    g.remove_edges_from(nx.selfloop_edges(g))
Hao Zhang's avatar
Hao Zhang committed
116
117
    g = DGLGraph(g)
    g.add_edges(g.nodes(), g.nodes())
Da Zheng's avatar
Da Zheng committed
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    return g, features, labels, train_mask, test_mask

###############################################################################
# When a model is trained, we can use the following method to evaluate
# the performance of the model on the test dataset:

def evaluate(model, g, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)
133
134
135
136
137
138

###############################################################################
# We then train the network as follows:

import time
import numpy as np
Da Zheng's avatar
Da Zheng committed
139
g, features, labels, train_mask, test_mask = load_cora_data()
140
141
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
dur = []
Da Zheng's avatar
Da Zheng committed
142
for epoch in range(50):
143
144
    if epoch >=3:
        t0 = time.time()
Da Zheng's avatar
Da Zheng committed
145
146

    net.train()
147
148
    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
Da Zheng's avatar
Da Zheng committed
149
    loss = F.nll_loss(logp[train_mask], labels[train_mask])
150
151
152
153
154
155
156
157
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >=3:
        dur.append(time.time() - t0)
    
Da Zheng's avatar
Da Zheng committed
158
159
160
    acc = evaluate(net, g, features, labels, test_mask)
    print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), acc, np.mean(dur)))
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187

###############################################################################
# .. _math:
#
# GCN in one formula
# ------------------
# Mathematically, the GCN model follows this formula:
# 
# :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})`
# 
# Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,
# :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for
# this layer. :math:`D` and :math:`A`, as commonly seen, represent degree
# matrix and adjacency matrix, respectively. The ~ is a renormalization trick
# in which we add a self-connection to each node of the graph, and build the
# corresponding degree and adjacency matrix.  The shape of the input
# :math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes
# and :math:`D` is the number of input features. We can chain up multiple
# layers as such to produce a node-level representation output with shape
# :math`N \times F`, where :math:`F` is the dimension of the output node
# feature vector.
# 
# The equation can be efficiently implemented using sparse matrix
# multiplication kernels (such as Kipf's
# `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation
# in fact has already used this trick due to the use of builtin functions. To
# understand what is under the hood, please read our tutorial on :doc:`PageRank <../../basics/3_pagerank>`.