4_batch.py 9.57 KB
Newer Older
1
2
3
"""
.. currentmodule:: dgl

4
5
Graph Classification Tutorial
=============================
6
7
8
9
10

**Author**: `Mufei Li <https://github.com/mufeili>`_,
`Minjie Wang <https://jermainewang.github.io/>`_,
`Zheng Zhang <https://shanghai.nyu.edu/academics/faculty/directory/zheng-zhang>`_.

11
12
13
In this tutorial, you learn how to use DGL to batch multiple graphs of variable size and shape. The 
tutorial also demonstrates training a graph neural network for a simple graph classification task.

14
Graph classification is an important problem
15
16
17
with applications across many fields, such as bioinformatics, chemoinformatics, social
network analysis, urban computing, and cybersecurity. Applying graph neural
networks to this problem has been a popular approach recently. This can be seen in the following reserach references: 
18
19
20
21
22
23
`Ying et al., 2018 <https://arxiv.org/abs/1806.08804>`_,
`Cangea et al., 2018 <https://arxiv.org/abs/1811.01287>`_,
`Knyazev et al., 2018 <https://arxiv.org/abs/1811.09595>`_,
`Bianchi et al., 2019 <https://arxiv.org/abs/1901.01343>`_,
`Liao et al., 2019 <https://arxiv.org/abs/1901.01484>`_,
`Gao et al., 2019 <https://openreview.net/forum?id=HJePRoAct7>`_).
24
25
26
27

"""

###############################################################################
28
# Simple graph classification task
29
# --------------------------------
30
31
# In this tutorial, you learn how to perform batched graph classification
# with DGL. The example task objective is to classify eight types of topologies shown here.
32
#
Jinjing Zhou's avatar
Jinjing Zhou committed
33
# .. image:: https://data.dgl.ai/tutorial/batch/dataset_overview.png
34
35
#     :align: center
#
36
# Implement a synthetic dataset :class:`data.MiniGCDataset` in DGL. The dataset has eight 
37
38
# different types of graphs and each class has the same number of graph samples.

39
40
import dgl
import torch
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
from dgl.data import MiniGCDataset
import matplotlib.pyplot as plt
import networkx as nx
# A dataset with 80 samples, each graph is
# of size [10, 20]
dataset = MiniGCDataset(80, 10, 20)
graph, label = dataset[0]
fig, ax = plt.subplots()
nx.draw(graph.to_networkx(), ax=ax)
ax.set_title('Class: {:d}'.format(label))
plt.show()

###############################################################################
# Form a graph mini-batch
# -----------------------
56
# To train neural networks efficiently, a common practice is to batch
57
# multiple samples together to form a mini-batch. Batching fixed-shaped tensor
58
59
# inputs is common. For example, batching two images of size 28 x 28
# gives a tensor of shape 2 x 28 x 28. By contrast, batching graph inputs
60
61
62
# has two challenges:
#
# * Graphs are sparse.
63
# * Graphs can have various length. For example, number of nodes and edges.
64
#
65
66
67
# To address this, DGL provides a :func:`dgl.batch` API. It leverages the idea that
# a batch of graphs can be viewed as a large graph that has many disjointed 
# connected components. Below is a visualization that gives the general idea.
68
#
Jinjing Zhou's avatar
Jinjing Zhou committed
69
# .. image:: https://data.dgl.ai/tutorial/batch/batch.png
70
71
72
#     :width: 400pt
#     :align: center
#
73
74
# The return type of :func:`dgl.batch` is still a graph. In the same way, 
# a batch of tensors is still a tensor. This means that any code that works
75
# for one graph immediately works for a batch of graphs. More importantly,
76
# because DGL processes messages on all nodes and edges in parallel, this greatly
77
78
# improves efficiency.
#
79
# Graph classifier
80
# ----------------
81
# Graph classification proceeds as follows.
82
#
Jinjing Zhou's avatar
Jinjing Zhou committed
83
# .. image:: https://data.dgl.ai/tutorial/batch/graph_classifier.png
84
#
85
86
87
88
89
# From a batch of graphs, perform message passing and graph convolution
# for nodes to communicate with others. After message passing, compute a
# tensor for graph representation from node (and edge) attributes. This step might 
# be called readout or aggregation. Finally, the graph 
# representations are fed into a classifier :math:`g` to predict the graph labels.
90
#
Minjie Wang's avatar
Minjie Wang committed
91
# Graph convolution layer can be found in the ``dgl.nn.<backend>`` submodule.
92

Minjie Wang's avatar
Minjie Wang committed
93
from dgl.nn.pytorch import GraphConv
94
95

###############################################################################
96
# Readout and classification
97
# --------------------------
98
99
100
# For this demonstration, consider initial node features to be their degrees.
# After two rounds of graph convolution, perform a graph readout by averaging
# over all node features for each graph in the batch.
101
102
103
104
105
106
#
# .. math::
#
#    h_g=\frac{1}{|\mathcal{V}|}\sum_{v\in\mathcal{V}}h_{v}
#
# In DGL, :func:`dgl.mean_nodes` handles this task for a batch of
107
# graphs with variable size. You then feed the graph representations into a
108
# classifier with one linear layer to obtain pre-softmax logits.
109

Minjie Wang's avatar
Minjie Wang committed
110
import torch.nn as nn
111
112
113
114
115
import torch.nn.functional as F

class Classifier(nn.Module):
    def __init__(self, in_dim, hidden_dim, n_classes):
        super(Classifier, self).__init__()
Minjie Wang's avatar
Minjie Wang committed
116
117
        self.conv1 = GraphConv(in_dim, hidden_dim)
        self.conv2 = GraphConv(hidden_dim, hidden_dim)
118
119
120
        self.classify = nn.Linear(hidden_dim, n_classes)

    def forward(self, g):
Minjie Wang's avatar
Minjie Wang committed
121
122
        # Use node degree as the initial node feature. For undirected graphs, the in-degree
        # is the same as the out_degree.
123
        h = g.in_degrees().view(-1, 1).float()
Minjie Wang's avatar
Minjie Wang committed
124
125
126
        # Perform graph convolution and activation function.
        h = F.relu(self.conv1(g, h))
        h = F.relu(self.conv2(g, h))
127
        g.ndata['h'] = h
Minjie Wang's avatar
Minjie Wang committed
128
        # Calculate graph representation by averaging all the node representations.
129
130
131
132
        hg = dgl.mean_nodes(g, 'h')
        return self.classify(hg)

###############################################################################
133
# Setup and training
134
# ------------------
135
# Create a synthetic dataset of :math:`400` graphs with :math:`10` ~
136
137
138
139
# :math:`20` nodes. :math:`320` graphs constitute a training set and
# :math:`80` graphs constitute a test set.

import torch.optim as optim
140
from dgl.dataloading import GraphDataLoader
141
142
143
144

# Create training and test sets.
trainset = MiniGCDataset(320, 10, 20)
testset = MiniGCDataset(80, 10, 20)
145
146
147
# Use DGL's GraphDataLoader. It by default handles the 
# graph batching operation for every mini-batch.
data_loader = GraphDataLoader(trainset, batch_size=32, shuffle=True)
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169

# Create model
model = Classifier(1, 256, trainset.num_classes)
loss_func = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
model.train()

epoch_losses = []
for epoch in range(80):
    epoch_loss = 0
    for iter, (bg, label) in enumerate(data_loader):
        prediction = model(bg)
        loss = loss_func(prediction, label)
        optimizer.zero_grad()
        loss.backward()
        optimizer.step()
        epoch_loss += loss.detach().item()
    epoch_loss /= (iter + 1)
    print('Epoch {}, loss {:.4f}'.format(epoch, epoch_loss))
    epoch_losses.append(epoch_loss)

###############################################################################
170
# The learning curve of a run is presented below.
171
172
173
174
175
176

plt.title('cross entropy averaged over minibatches')
plt.plot(epoch_losses)
plt.show()

###############################################################################
177
178
# The trained model is evaluated on the test set created. To deploy
# the tutorial, restrict the running time to get a higher
179
180
181
# accuracy (:math:`80` % ~ :math:`90` %) than the ones printed below.

model.eval()
182
# Convert a list of tuples to two lists
183
184
185
186
187
188
189
190
191
192
193
194
test_X, test_Y = map(list, zip(*testset))
test_bg = dgl.batch(test_X)
test_Y = torch.tensor(test_Y).float().view(-1, 1)
probs_Y = torch.softmax(model(test_bg), 1)
sampled_Y = torch.multinomial(probs_Y, 1)
argmax_Y = torch.max(probs_Y, 1)[1].view(-1, 1)
print('Accuracy of sampled predictions on the test set: {:.4f}%'.format(
    (test_Y == sampled_Y.float()).sum().item() / len(test_Y) * 100))
print('Accuracy of argmax predictions on the test set: {:4f}%'.format(
    (test_Y == argmax_Y.float()).sum().item() / len(test_Y) * 100))

###############################################################################
195
# The animation here plots the probability that a trained model predicts the correct graph type.
196
#
Jinjing Zhou's avatar
Jinjing Zhou committed
197
# .. image:: https://data.dgl.ai/tutorial/batch/test_eval4.gif
198
#
199
# To understand the node and graph representations that a trained model learned,
200
201
202
# we use `t-SNE, <https://lvdmaaten.github.io/tsne/>`_ for dimensionality reduction
# and visualization.
#
Jinjing Zhou's avatar
Jinjing Zhou committed
203
# .. image:: https://data.dgl.ai/tutorial/batch/tsne_node2.png
204
205
#     :align: center
#
Jinjing Zhou's avatar
Jinjing Zhou committed
206
# .. image:: https://data.dgl.ai/tutorial/batch/tsne_graph2.png
207
208
#     :align: center
#
209
210
# The two small figures on the top separately visualize node representations after one and two
# layers of graph convolution. The figure on the bottom visualizes
211
# the pre-softmax logits for graphs as graph representations.
212
213
#
# While the visualization does suggest some clustering effects of the node features,
214
215
# you would not expect a perfect result. Node degrees are deterministic for
# these node features. The graph features are improved when separated.
216
#
217
# What's next?
218
# ------------
219
220
221
222
223
224
# Graph classification with graph neural networks is still a new field.
# It's waiting for people to bring more exciting discoveries. The work requires 
# mapping different graphs to different embeddings, while preserving
# their structural similarity in the embedding space. To learn more about it, see 
# `How Powerful Are Graph Neural Networks? <https://arxiv.org/abs/1810.00826>`_ a research paper  
# published for the International Conference on Learning Representations 2019.
225
#
226
# For more examples about batched graph processing, see the following:
227
#
228
229
# * Tutorials for `Tree LSTM <https://docs.dgl.ai/tutorials/models/2_small_graph/3_tree-lstm.html>`_ and `Deep Generative Models of Graphs <https://docs.dgl.ai/tutorials/models/3_generative_model/5_dgmg.html>`_
# * An example implementation of `Junction Tree VAE <https://github.com/dmlc/dgl/tree/master/examples/pytorch/jtnn>`_