1_gcn.py 7.62 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
"""
.. _model-gcn:

Graph Convolutional Network
====================================

**Author:** `Qi Huang <https://github.com/HQ01>`_, `Minjie Wang  <https://jermainewang.github.io/>`_,
Yu Gai, Quan Gan, Zheng Zhang

This is a gentle introduction of using DGL to implement Graph Convolutional
brett koonce's avatar
brett koonce committed
11
Networks (Kipf & Welling et al., `Semi-Supervised Classification with Graph
Minjie Wang's avatar
Minjie Wang committed
12
13
14
15
16
17
18
19
Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We explain
what is under the hood of the :class:`~dgl.nn.pytorch.GraphConv` module.
The reader is expected to learn how to define a new GNN layer using DGL's
message passing APIs.

We build upon the :doc:`earlier tutorial <../../basics/3_pagerank>` on DGLGraph
and demonstrate how DGL combines graph with deep neural network and learn
structural representations.
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
"""

###############################################################################
# Model Overview
# ------------------------------------------
# GCN from the perspective of message passing
# ```````````````````````````````````````````````
# We describe a layer of graph convolutional neural network from a message
# passing perspective; the math can be found `here <math_>`_.
# It boils down to the following step, for each node :math:`u`:
# 
# 1) Aggregate neighbors' representations :math:`h_{v}` to produce an
# intermediate representation :math:`\hat{h}_u`.  2) Transform the aggregated
# representation :math:`\hat{h}_{u}` with a linear projection followed by a
# non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.
# 
Minjie Wang's avatar
Minjie Wang committed
36
37
# We will implement step 1 with DGL message passing, and step 2 by
# PyTorch ``nn.Module``.
38
39
40
41
42
43
44
45
46
47
48
49
50
51
# 
# GCN implementation with DGL
# ``````````````````````````````````````````
# We first define the message and reduce function as usual.  Since the
# aggregation on a node :math:`u` only involves summing over the neighbors'
# representations :math:`h_v`, we can simply use builtin functions:

import dgl
import dgl.function as fn
import torch as th
import torch.nn as nn
import torch.nn.functional as F
from dgl import DGLGraph

52
gcn_msg = fn.copy_u(u='h', out='m')
53
54
55
gcn_reduce = fn.sum(msg='m', out='h')

###############################################################################
Minjie Wang's avatar
Minjie Wang committed
56
57
# We then proceed to define the GCNLayer module. A GCNLayer essentially performs
# message passing on all the nodes then applies a fully-connected layer.
58
59
60
61
62
63
#
# .. note::
#
#    This is showing how to implement a GCN from scratch.  DGL provides a more
#    efficient :class:`builtin GCN layer module <dgl.nn.pytorch.conv.GraphConv>`.
#
64

Minjie Wang's avatar
Minjie Wang committed
65
66
67
class GCNLayer(nn.Module):
    def __init__(self, in_feats, out_feats):
        super(GCNLayer, self).__init__()
68
69
70
        self.linear = nn.Linear(in_feats, out_feats)

    def forward(self, g, feature):
Minjie Wang's avatar
Minjie Wang committed
71
72
73
74
75
76
77
78
        # Creating a local scope so that all the stored ndata and edata
        # (such as the `'h'` ndata below) are automatically popped out
        # when the scope exits.
        with g.local_scope():
            g.ndata['h'] = feature
            g.update_all(gcn_msg, gcn_reduce)
            h = g.ndata['h']
            return self.linear(h)
79
80
81
82
83
84

###############################################################################
# The forward function is essentially the same as any other commonly seen NNs
# model in PyTorch.  We can initialize GCN like any ``nn.Module``. For example,
# let's define a simple neural network consisting of two GCN layers. Suppose we
# are training the classifier for the cora dataset (the input feature size is
Da Zheng's avatar
Da Zheng committed
85
# 1433 and the number of classes is 7). The last GCN layer computes node embeddings,
Minjie Wang's avatar
Minjie Wang committed
86
# so the last layer in general does not apply activation.
87
88
89
90

class Net(nn.Module):
    def __init__(self):
        super(Net, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
91
92
        self.layer1 = GCNLayer(1433, 16)
        self.layer2 = GCNLayer(16, 7)
93
94
    
    def forward(self, g, features):
Minjie Wang's avatar
Minjie Wang committed
95
96
        x = F.relu(self.layer1(g, features))
        x = self.layer2(g, x)
97
98
99
100
101
102
103
        return x
net = Net()
print(net)

###############################################################################
# We load the cora dataset using DGL's built-in data module.

104
from dgl.data import CoraGraphDataset
105
def load_cora_data():
106
107
108
109
110
111
    dataset = CoraGraphDataset()
    g = dataset[0]
    features = g.ndata['feat']
    labels = g.ndata['label']
    train_mask = g.ndata['train_mask']
    test_mask = g.ndata['test_mask']
Da Zheng's avatar
Da Zheng committed
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
    return g, features, labels, train_mask, test_mask

###############################################################################
# When a model is trained, we can use the following method to evaluate
# the performance of the model on the test dataset:

def evaluate(model, g, features, labels, mask):
    model.eval()
    with th.no_grad():
        logits = model(g, features)
        logits = logits[mask]
        labels = labels[mask]
        _, indices = th.max(logits, dim=1)
        correct = th.sum(indices == labels)
        return correct.item() * 1.0 / len(labels)
127
128
129
130
131
132

###############################################################################
# We then train the network as follows:

import time
import numpy as np
Da Zheng's avatar
Da Zheng committed
133
g, features, labels, train_mask, test_mask = load_cora_data()
Mufei Li's avatar
Mufei Li committed
134
135
# Add edges between each node and itself to preserve old node representations
g.add_edges(g.nodes(), g.nodes())
Minjie Wang's avatar
Minjie Wang committed
136
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
137
dur = []
Da Zheng's avatar
Da Zheng committed
138
for epoch in range(50):
139
140
    if epoch >=3:
        t0 = time.time()
Da Zheng's avatar
Da Zheng committed
141
142

    net.train()
143
144
    logits = net(g, features)
    logp = F.log_softmax(logits, 1)
Da Zheng's avatar
Da Zheng committed
145
    loss = F.nll_loss(logp[train_mask], labels[train_mask])
146
147
148
149
150
151
152
153
    
    optimizer.zero_grad()
    loss.backward()
    optimizer.step()
    
    if epoch >=3:
        dur.append(time.time() - t0)
    
Da Zheng's avatar
Da Zheng committed
154
155
156
    acc = evaluate(net, g, features, labels, test_mask)
    print("Epoch {:05d} | Loss {:.4f} | Test Acc {:.4f} | Time(s) {:.4f}".format(
            epoch, loss.item(), acc, np.mean(dur)))
157
158
159
160
161
162
163
164
165
166
167
168

###############################################################################
# .. _math:
#
# GCN in one formula
# ------------------
# Mathematically, the GCN model follows this formula:
# 
# :math:`H^{(l+1)} = \sigma(\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}H^{(l)}W^{(l)})`
# 
# Here, :math:`H^{(l)}` denotes the :math:`l^{th}` layer in the network,
# :math:`\sigma` is the non-linearity, and :math:`W` is the weight matrix for
Mufei Li's avatar
Mufei Li committed
169
170
171
172
# this layer. :math:`\tilde{D}` and :math:`\tilde{A}` are separately the degree
# and adjacency matrices for the graph. With the superscript ~, we are referring
# to the variant where we add additional edges between each node and itself to
# preserve its old representation in graph convolutions. The shape of the input
173
174
175
# :math:`H^{(0)}` is :math:`N \times D`, where :math:`N` is the number of nodes
# and :math:`D` is the number of input features. We can chain up multiple
# layers as such to produce a node-level representation output with shape
Mufei Li's avatar
Mufei Li committed
176
# :math:`N \times F`, where :math:`F` is the dimension of the output node
177
# feature vector.
Mufei Li's avatar
Mufei Li committed
178
#
179
180
181
182
183
# The equation can be efficiently implemented using sparse matrix
# multiplication kernels (such as Kipf's
# `pygcn <https://github.com/tkipf/pygcn>`_ code). The above DGL implementation
# in fact has already used this trick due to the use of builtin functions. To
# understand what is under the hood, please read our tutorial on :doc:`PageRank <../../basics/3_pagerank>`.
Mufei Li's avatar
Mufei Li committed
184
185
186
187
188
#
# Note that the tutorial code implements a simplified version of GCN where we
# replace :math:`\tilde{D}^{-\frac{1}{2}}\tilde{A}\tilde{D}^{-\frac{1}{2}}` with
# :math:`\tilde{A}`. For a full implementation, see our example
# `here  <https://github.com/dmlc/dgl/tree/master/examples/pytorch/gcn>`_.