6_line_graph.py 25.5 KB
Newer Older
1
2
3
"""
.. _model-line-graph:

4
Line Graph Neural Network
5
6
7
8
=========================

**Author**: `Qi Huang <https://github.com/HQ01>`_, Yu Gai,
`Minjie Wang <https://jermainewang.github.io/>`_, Zheng Zhang
9
10
11
12
13
14
15
16

.. warning::

    The tutorial aims at gaining insights into the paper, with code as a mean
    of explanation. The implementation thus is NOT optimized for running
    efficiency. For recommended implementation, please refer to the `official
    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.

17
18
19
20
"""

###########################################################################################
# 
21
22
23
# In this tutorial, you learn how to solve community detection tasks by implementing a line
# graph neural network (LGNN). Community detection, or graph clustering, consists of partitioning
# the vertices in a graph into clusters in which nodes are more similar to
24
# one another.
25
26
27
28
29
30
31
# 
# In the :doc:`Graph convolutinal network tutorial <1_gcn>`, you learned how to classify the nodes of an input
# graph in a semi-supervised setting. You used a graph convolutional neural network (GCN)
# as an embedding mechanism for graph features.
# 
# To generalize a graph neural network (GNN) into supervised community detection, a line-graph based 
# variation of GNN is introduced in the research paper 
32
# `Supervised Community Detection with Line Graph Neural Networks <https://arxiv.org/abs/1705.08415>`__. 
33
34
35
# One of the highlights of the model is
# to augment the straightforward GNN architecture so that it operates on
# a line graph of edge adjacencies, defined with a non-backtracking operator.
36
#
37
38
# A line graph neural network (LGNN) shows how DGL can implement an advanced graph algorithm by 
# mixing basic tensor operations, sparse-matrix multiplication, and message-
39
40
# passing APIs.
#
41
42
# In the following sections, you learn about community detection, line
# graphs, LGNN, and its implementation.
43
#
44
# Supervised community detection task with the Cora dataset
45
# --------------------------------------------
46
# Community detection
47
# ~~~~~~~~~~~~~~~~~~~~
48
49
50
# In a community detection task, you cluster similar nodes instead of
# labeling them. The node similarity is typically described as having higher inner
# density within each cluster.
51
52
53
54
#
# What's the difference between community detection and node classification?
# Comparing to node classification, community detection focuses on retrieving
# cluster information in the graph, rather than assigning a specific label to
brett koonce's avatar
brett koonce committed
55
# a node. For example, as long as a node is clustered with its community
56
57
58
59
60
61
62
63
64
65
# members, it doesn't matter whether the node is assigned as "community A",
# or "community B", while assigning all "great movies" to label "bad movies"
# will be a disaster in a movie network classification task.
#
# What's the difference then, between a community detection algorithm and
# other clustering algorithm such as k-means? Community detection algorithm operates on
# graph-structured data. Comparing to k-means, community detection leverages
# graph structure, instead of simply clustering nodes based on their
# features.
#
66
# Cora dataset
67
# ~~~~~
68
69
70
71
72
# To be consistent with the GCN tutorial, 
# you use the `Cora dataset <https://linqs.soe.ucsc.edu/data>`__ 
# to illustrate a simple community detection task. Cora is a scientific publication dataset, 
# with 2708 papers belonging to seven  
# different machine learning fields. Here, you formulate Cora as a 
73
74
# directed graph, with each node being a paper, and each edge being a 
# citation link (A->B means A cites B). Here is a visualization of the whole 
75
# Cora dataset.
76
77
78
79
80
81
82
#
# .. figure:: https://i.imgur.com/X404Byc.png
#    :alt: cora
#    :height: 400px
#    :width: 500px
#    :align: center
#
83
84
# Cora naturally contains seven classes, and statistics below show that each
# class does satisfy our assumption of community, i.e. nodes of same class
85
86
# class have higher connection probability among them than with nodes of different class.
# The following code snippet verifies that there are more intra-class edges
87
# than inter-class.
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102

import torch
import torch as th
import torch.nn as nn
import torch.nn.functional as F

import dgl
from dgl.data import citation_graph as citegrh

data = citegrh.load_cora()

G = dgl.DGLGraph(data.graph)
labels = th.tensor(data.labels)

# find all the nodes labeled with class 0
103
label0_nodes = th.nonzero(labels == 0, as_tuple=False).squeeze()
104
105
106
107
# find all the edges pointing to class 0 nodes
src, _ = G.in_edges(label0_nodes)
src_labels = labels[src]
# find all the edges whose both endpoints are in class 0
108
intra_src = th.nonzero(src_labels == 0, as_tuple=False)
109
110
111
print('Intra-class edges percent: %.4f' % (len(intra_src) / len(src_labels)))

###########################################################################################
112
# Binary community subgraph from Cora with a test dataset
113
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
114
# Without loss of generality, in this tutorial you limit the scope of the
115
116
117
118
# task to binary community detection.
# 
# .. note::
#
119
120
#    To create a practice binary-community dataset from Cora, first extract
#    all two-class pairs from the original Cora seven classes. For each pair, you
121
#    treat each class as one community, and find the largest subgraph that
122
123
#    at least contains one cross-community edge as the training example. As
#    a result, there are a total of 21 training samples in this small dataset.
124
#
125
# With the following code, you can visualize one of the training samples and its community structure.
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

import networkx as nx
import matplotlib.pyplot as plt

train_set = dgl.data.CoraBinary()
G1, pmpd1, label1 = train_set[1]
nx_G1 = G1.to_networkx()

def visualize(labels, g):
    pos = nx.spring_layout(g, seed=1)
    plt.figure(figsize=(8, 8))
    plt.axis('off')
    nx.draw_networkx(g, pos=pos, node_size=50, cmap=plt.get_cmap('coolwarm'),
                     node_color=labels, edge_color='k',
                     arrows=False, width=0.5, style='dotted', with_labels=False)
visualize(label1, nx_G1)

###########################################################################################
144
145
# To learn more, go the original research paper to see how to generalize
# to multiple communities case.
146
#
147
# Community detection in a supervised setting
148
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
149
150
151
# The community detection problem could be tackled with both supervised and
# unsupervised approaches. You can formulate
# community detection in a supervised setting as follows:
152
153
154
155
156
157
158
159
160
161
162
163
#
# - Each training example consists of :math:`(G, L)`, where :math:`G` is a
#   directed graph :math:`(V, E)`. For each node :math:`v` in :math:`V`, we
#   assign a ground truth community label :math:`z_v \in \{0,1\}`.
# - The parameterized model :math:`f(G, \theta)` predicts a label set
#   :math:`\tilde{Z} = f(G)` for nodes :math:`V`.
# - For each example :math:`(G,L)`, the model learns to minimize a specially
#   designed loss function (equivariant loss) :math:`L_{equivariant} =
#   (\tilde{Z},Z)`
#
# .. note::
#
164
#    In this supervised setting, the model naturally predicts a label for
165
#    each community. However, community assignment should be equivariant to
brett koonce's avatar
brett koonce committed
166
167
#    label permutations. To achieve this, in each forward process, we take
#    the minimum among losses calculated from all possible permutations of
168
169
170
171
172
173
174
175
#    labels.
#
#    Mathematically, this means
#    :math:`L_{equivariant} = \underset{\pi \in S_c} {min}-\log(\hat{\pi}, \pi)`,
#    where :math:`S_c` is the set of all permutations of labels, and
#    :math:`\hat{\pi}` is the set of predicted labels,
#    :math:`- \log(\hat{\pi},\pi)` denotes negative log likelihood.
#
176
#    For instance, for a sample graph with node :math:`\{1,2,3,4\}` and
177
178
179
180
#    community assignment :math:`\{A, A, A, B\}`, with each node's label
#    :math:`l \in \{0,1\}`,The group of all possible permutations
#    :math:`S_c = \{\{0,0,0,1\}, \{1,1,1,0\}\}`.
# 
181
# Line graph neural network key ideas
182
# ------------------------------------
183
# An key innovation in this topic is the use of a line graph.
184
# Unlike models in previous tutorials, message passing happens not only on the
185
186
# original graph, e.g. the binary community subgraph from Cora, but also on the
# line graph associated with the original graph.
187
#
188
# What is a line-graph?
189
190
# ~~~~~~~~~~~~~~~~~~~~~
# In graph theory, line graph is a graph representation that encodes the
brett koonce's avatar
brett koonce committed
191
# edge adjacency structure in the original graph.
192
193
194
#
# Specifically, a line-graph :math:`L(G)` turns an edge of the original graph `G`
# into a node. This is illustrated with the graph below (taken from the
195
# research paper).
196
197
198
199
200
201
202
203
204
205
# 
# .. figure:: https://i.imgur.com/4WO5jEm.png
#    :alt: lg
#    :align: center
#
# Here, :math:`e_{A}:= (i\rightarrow j)` and :math:`e_{B}:= (j\rightarrow k)`
# are two edges in the original graph :math:`G`. In line graph :math:`G_L`,
# they correspond to nodes :math:`v^{l}_{A}, v^{l}_{B}`.
#
# The next natural question is, how to connect nodes in line-graph? How to
206
# connect two edges? Here, we use the following connection rule:
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
#
# Two nodes :math:`v^{l}_{A}`, :math:`v^{l}_{B}` in `lg` are connected if
# the corresponding two edges :math:`e_{A}, e_{B}` in `g` share one and only 
# one node:
# :math:`e_{A}`'s destination node is :math:`e_{B}`'s source node
# (:math:`j`).
# 
# .. note::
#
#    Mathematically, this definition corresponds to a notion called non-backtracking
#    operator:
#    :math:`B_{(i \rightarrow j), (\hat{i} \rightarrow \hat{j})}`
#    :math:`= \begin{cases}
#    1 \text{ if } j = \hat{i}, \hat{j} \neq i\\
#    0 \text{ otherwise} \end{cases}`
#    where an edge is formed if :math:`B_{node1, node2} = 1`.
#
#
225
# One layer in LGNN, algorithm structure
226
227
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
228
229
230
# LGNN chains together a series of line graph neural network layers. The graph
# representation :math:`x` and its line graph companion :math:`y` evolve with
# the dataflow as follows.
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
# 
# .. figure:: https://i.imgur.com/bZGGIGp.png
#    :alt: alg
#    :align: center
#
# At the :math:`k`-th layer, the :math:`i`-th neuron of the :math:`l`-th
# channel updates its embedding :math:`x^{(k+1)}_{i,l}` with:
#
# .. math::
#    \begin{split}
#    x^{(k+1)}_{i,l} ={}&\rho[x^{(k)}_{i}\theta^{(k)}_{1,l}
#    +(Dx^{(k)})_{i}\theta^{(k)}_{2,l} \\
#    &+\sum^{J-1}_{j=0}(A^{2^{j}}x^{k})_{i}\theta^{(k)}_{3+j,l}\\
#    &+[\{\text{Pm},\text{Pd}\}y^{(k)}]_{i}\theta^{(k)}_{3+J,l}] \\
#    &+\text{skip-connection}
#    \qquad i \in V, l = 1,2,3, ... b_{k+1}/2
#    \end{split}
#
# Then, the line-graph representation :math:`y^{(k+1)}_{i,l}` with,
#
# .. math::
#
#    \begin{split}
#    y^{(k+1)}_{i',l^{'}} = {}&\rho[y^{(k)}_{i^{'}}\gamma^{(k)}_{1,l^{'}}+
#    (D_{L(G)}y^{(k)})_{i^{'}}\gamma^{(k)}_{2,l^{'}}\\
#    &+\sum^{J-1}_{j=0}(A_{L(G)}^{2^{j}}y^{k})_{i}\gamma^{(k)}_{3+j,l^{'}}\\
#    &+[\{\text{Pm},\text{Pd}\}^{T}x^{(k+1)}]_{i^{'}}\gamma^{(k)}_{3+J,l^{'}}]\\
#    &+\text{skip-connection}
#    \qquad i^{'} \in V_{l}, l^{'} = 1,2,3, ... b^{'}_{k+1}/2
#    \end{split}
#
# Where :math:`\text{skip-connection}` refers to performing the same operation without the non-linearity
# :math:`\rho`, and with linear projection :math:`\theta_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}`
# and :math:`\gamma_\{\frac{b_{k+1}}{2} + 1, ..., b_{k+1}-1, b_{k+1}\}`.
#
# Implement LGNN in DGL
# ---------------------
268
269
270
271
272
273
274
275
276
277
278
279
280
281
# Even though the equations in the previous section might seem intimidating, 
# it helps to understand the following information before you implement the LGNN.
# 
# The two equations are symmetric and can be implemented as two instances
# of the same class with different parameters.
# The first equation operates on graph representation :math:`x`,
# whereas the second operates on line-graph
# representation :math:`y`. Let us denote this abstraction as :math:`f`. Then
# the first is :math:`f(x,y; \theta_x)`, and the second
# is :math:`f(y,x, \theta_y)`. That is, they are parameterized to compute
# representations of the original graph and its
# companion line graph, respectively.
#
# Each equation consists of four terms. Take the first one as an example, which follows.
282
283
284
285
286
287
288
289
290
291
#
#   - :math:`x^{(k)}\theta^{(k)}_{1,l}`, a linear projection of previous
#     layer's output :math:`x^{(k)}`, denote as :math:`\text{prev}(x)`.
#   - :math:`(Dx^{(k)})\theta^{(k)}_{2,l}`, a linear projection of degree
#     operator on :math:`x^{(k)}`, denote as :math:`\text{deg}(x)`.
#   - :math:`\sum^{J-1}_{j=0}(A^{2^{j}}x^{(k)})\theta^{(k)}_{3+j,l}`,
#     a summation of :math:`2^{j}` adjacency operator on :math:`x^{(k)}`,
#     denote as :math:`\text{radius}(x)`
#   - :math:`[\{Pm,Pd\}y^{(k)}]\theta^{(k)}_{3+J,l}`, fusing another
#     graph's embedding information using incidence matrix
brett koonce's avatar
brett koonce committed
292
#     :math:`\{Pm, Pd\}`, followed with a linear projection,
293
294
#     denote as :math:`\text{fuse}(y)`.
#
295
296
297
# Each of the terms are performed again with different
# parameters, and without the nonlinearity after the sum.
# Therefore, :math:`f` could be written as:
298
299
300
301
302
303
304
305
# 
#   .. math::
#      \begin{split}
#      f(x^{(k)},y^{(k)}) = {}\rho[&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1})
#      +\text{fuse}(y^{(k)})]\\
#      +&\text{prev}(x^{(k-1)}) + \text{deg}(x^{(k-1)}) +\text{radius}(x^{k-1}) +\text{fuse}(y^{(k)})
#      \end{split}
#
306
# Two equations are chained-up in the following order:
307
308
309
310
311
312
313
# 
#   .. math::
#      \begin{split}
#      x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
#      y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)})
#      \end{split}
# 
314
315
# Keep in mind the listed observations in this overview and proceed to implementation.
# An important point is that you use different strategies for the noted terms.
316
317
# 
# .. note::
318
319
320
321
322
323
324
#    You can understand :math:`\{Pm, Pd\}` more thoroughly with this explanation. 
#    Roughly speaking, there is a relationship between how :math:`g` and
#    :math:`lg` (the line graph) work together with loopy brief propagation.
#    Here, you implement :math:`\{Pm, Pd\}` as a SciPy COO sparse matrix in the dataset,
#    and stack them as tensors when batching. Another batching solution is to
#    treat :math:`\{Pm, Pd\}` as the adjacency matrix of a bipartite graph, which maps
#    line graph's feature to graph's, and vice versa.
325
326
327
#
# Implementing :math:`\text{prev}` and :math:`\text{deg}` as tensor operation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
328
329
# Linear projection and degree operation are both simply matrix
# multiplication. Write them as PyTorch tensor operations.
330
#
331
# In ``__init__``, you define the projection variables.
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
# 
# ::
# 
#    self.linear_prev = nn.Linear(in_feats, out_feats)
#    self.linear_deg = nn.Linear(in_feats, out_feats)
# 
#
# In ``forward()``, :math:`\text{prev}` and :math:`\text{deg}` are the same
# as any other PyTorch tensor operations.
# 
# ::
# 
#    prev_proj = self.linear_prev(feat_a)
#    deg_proj = self.linear_deg(deg * feat_a)
# 
# Implementing :math:`\text{radius}` as message passing in DGL
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
349
350
# As discussed in GCN tutorial, you can formulate one adjacency operator as
# doing one-step message passing. As a generalization, :math:`2^j` adjacency
351
352
# operations can be formulated as performing :math:`2^j` step of message
# passing. Therefore, the summation is equivalent to summing nodes'
brett koonce's avatar
brett koonce committed
353
# representation of :math:`2^j, j=0, 1, 2..` step message passing, i.e.
354
# gathering information in :math:`2^{j}` neighborhood of each node.
355
#
356
357
# In ``__init__``, define the projection variables used in each
# :math:`2^j` steps of message passing.
358
359
360
361
362
363
# 
# ::
# 
#   self.linear_radius = nn.ModuleList(
#           [nn.Linear(in_feats, out_feats) for i in range(radius)])
#
364
365
366
# In ``__forward__``, use following function ``aggregate_radius()`` to
# gather data from multiple hops. This can be seen in the following code. 
# Note that the ``update_all`` is called multiple times.
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387

# Return a list containing features gathered from multiple radius.
import dgl.function as fn
def aggregate_radius(radius, g, z):
    # initializing list to collect message passing result
    z_list = []
    g.ndata['z'] = z
    # pulling message from 1-hop neighbourhood
    g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
    z_list.append(g.ndata['z'])
    for i in range(radius - 1):
        for j in range(2 ** i):
            #pulling message from 2^j neighborhood
            g.update_all(fn.copy_src(src='z', out='m'), fn.sum(msg='m', out='z'))
        z_list.append(g.ndata['z'])
    return z_list

#########################################################################
# Implementing :math:`\text{fuse}` as sparse matrix multiplication
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# :math:`\{Pm, Pd\}` is a sparse matrix with only two non-zero entries on
388
# each column. Therefore, you construct it as a sparse matrix in the dataset,
389
390
391
392
393
394
395
396
397
398
# and implement :math:`\text{fuse}` as a sparse matrix multiplication.
#
# in ``__forward__``:
# 
# ::
# 
#   fuse = self.linear_fuse(th.mm(pm_pd, feat_b))
#
# Completing :math:`f(x, y)`
# ~~~~~~~~~~~~~~~~~~~~~~~~~~
399
400
# Finally, the following shows how to sum up all the terms together, pass it to skip connection, and
# batch norm.
401
402
403
404
405
# 
# ::
#
#   result = prev_proj + deg_proj + radius_proj + fuse
# 
406
# Pass result to skip connection. 
407
408
409
410
411
# 
# ::
# 
#   result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
# 
412
# Then pass the result to batch norm.
413
414
415
416
417
418
# 
# ::
# 
#   result = self.bn(result) #Batch Normalization.
# 
#
419
# Here is the complete code for one LGNN layer's abstraction :math:`f(x,y)`
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
class LGNNCore(nn.Module):
    def __init__(self, in_feats, out_feats, radius):
        super(LGNNCore, self).__init__()
        self.out_feats = out_feats
        self.radius = radius

        self.linear_prev = nn.Linear(in_feats, out_feats)
        self.linear_deg = nn.Linear(in_feats, out_feats)
        self.linear_radius = nn.ModuleList(
                [nn.Linear(in_feats, out_feats) for i in range(radius)])
        self.linear_fuse = nn.Linear(in_feats, out_feats)
        self.bn = nn.BatchNorm1d(out_feats)

    def forward(self, g, feat_a, feat_b, deg, pm_pd):
        # term "prev"
        prev_proj = self.linear_prev(feat_a)
        # term "deg"
        deg_proj = self.linear_deg(deg * feat_a)

        # term "radius"
        # aggregate 2^j-hop features
        hop2j_list = aggregate_radius(self.radius, g, feat_a)
        # apply linear transformation
        hop2j_list = [linear(x) for linear, x in zip(self.linear_radius, hop2j_list)]
        radius_proj = sum(hop2j_list)

        # term "fuse"
        fuse = self.linear_fuse(th.mm(pm_pd, feat_b))

        # sum them together
        result = prev_proj + deg_proj + radius_proj + fuse

        # skip connection and batch norm
        n = self.out_feats // 2
        result = th.cat([result[:, :n], F.relu(result[:, n:])], 1)
        result = self.bn(result)

        return result

##############################################################################################################
460
# Chain-up LGNN abstractions as an LGNN layer
461
462
463
464
465
466
467
468
469
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# To implement:
# 
# .. math::
#    \begin{split}
#    x^{(k+1)} = {}& f(x^{(k)}, y^{(k)})\\
#    y^{(k+1)} = {}& f(y^{(k)}, x^{(k+1)})
#    \end{split}
#
470
# Chain-up two ``LGNNCore`` instances, as in the example code, with different parameters in the forward pass.
471
472
473
474
475
476
477
478
479
480
481
482
483
class LGNNLayer(nn.Module):
    def __init__(self, in_feats, out_feats, radius):
        super(LGNNLayer, self).__init__()
        self.g_layer = LGNNCore(in_feats, out_feats, radius)
        self.lg_layer = LGNNCore(in_feats, out_feats, radius)

    def forward(self, g, lg, x, lg_x, deg_g, deg_lg, pm_pd):
        next_x = self.g_layer(g, x, lg_x, deg_g, pm_pd)
        pm_pd_y = th.transpose(pm_pd, 0, 1)
        next_lg_x = self.lg_layer(lg, lg_x, x, deg_lg, pm_pd_y)
        return next_x, next_lg_x

########################################################################################
484
# Chain-up LGNN layers
485
# ~~~~~~~~~~~~~~~~~~~~
486
# Define an LGNN with three hidden layers, as in the following example.
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
class LGNN(nn.Module):
    def __init__(self, radius):
        super(LGNN, self).__init__()
        self.layer1 = LGNNLayer(1, 16, radius)  # input is scalar feature
        self.layer2 = LGNNLayer(16, 16, radius)  # hidden size is 16
        self.layer3 = LGNNLayer(16, 16, radius)
        self.linear = nn.Linear(16, 2)  # predice two classes

    def forward(self, g, lg, pm_pd):
        # compute the degrees
        deg_g = g.in_degrees().float().unsqueeze(1)
        deg_lg = lg.in_degrees().float().unsqueeze(1)
        # use degree as the input feature
        x, lg_x = deg_g, deg_lg
        x, lg_x = self.layer1(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
        x, lg_x = self.layer2(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
        x, lg_x = self.layer3(g, lg, x, lg_x, deg_g, deg_lg, pm_pd)
        return self.linear(x)
#########################################################################################
506
# Training and inference
507
# -----------------------
508
# First load the data.
509
510
511
512
513
514
515
from torch.utils.data import DataLoader
training_loader = DataLoader(train_set,
                             batch_size=1,
                             collate_fn=train_set.collate_fn,
                             drop_last=True)

#######################################################################################
516
517
518
# Next, define the main training loop. Note that each training sample contains
# three objects: A :class:`~dgl.DGLGraph`, a SciPy sparse matrix ``pmpd``, and a label
# array in ``numpy.ndarray``. Generate the line graph by using this command:
519
520
521
522
523
524
#
# ::
# 
#   lg = g.line_graph(backtracking=False)
#
# Note that ``backtracking=False`` is required to correctly simulate non-backtracking
525
# operation. We also define a utility function to convert the SciPy sparse matrix to
526
527
# torch sparse tensor.

528
# Create the model
529
530
531
532
model = LGNN(radius=3)
# define the optimizer
optimizer = th.optim.Adam(model.parameters(), lr=1e-2)

533
# A utility function to convert a scipy.coo_matrix to torch.SparseFloat
534
535
536
537
538
539
def sparse2th(mat):
    value = mat.data
    indices = th.LongTensor([mat.row, mat.col])
    tensor = th.sparse.FloatTensor(indices, th.from_numpy(value).float(), mat.shape)
    return tensor

540
# Train for 20 epochs
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
for i in range(20):
    all_loss = []
    all_acc = []
    for [g, pmpd, label] in training_loader:
        # Generate the line graph.
        lg = g.line_graph(backtracking=False)
        # Create torch tensors
        pmpd = sparse2th(pmpd)
        label = th.from_numpy(label)
        
        # Forward
        z = model(g, lg, pmpd)

        # Calculate loss:
        # Since there are only two communities, there are only two permutations
        #  of the community labels.
        loss_perm1 = F.cross_entropy(z, label)
        loss_perm2 = F.cross_entropy(z, 1 - label)
        loss = th.min(loss_perm1, loss_perm2)

        # Calculate accuracy:
        _, pred = th.max(z, 1)
        acc_perm1 = (pred == label).float().mean()
        acc_perm2 = (pred == 1 - label).float().mean()
        acc = th.max(acc_perm1, acc_perm2)
        all_loss.append(loss.item())
        all_acc.append(acc.item())

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    niters = len(all_loss)
    print("Epoch %d | loss %.4f | accuracy %.4f" % (i,
        sum(all_loss) / niters, sum(all_acc) / niters))

#######################################################################################
# Visualize training progress
# -----------------------------
580
581
# You can visualize the network's community prediction on one training example,
# together with the ground truth. Start this with the following code example.
582
583
584
585
586
587
588
589
590

pmpd1 = sparse2th(pmpd1)
LG1 = G1.line_graph(backtracking=False)
z = model(G1, LG1, pmpd1)
_, pred = th.max(z, 1)
visualize(pred, nx_G1)

#######################################################################################
# Compared with the ground truth. Note that the color might be reversed for the
591
# two communities because the model is for correctly predicting the partitioning.
592
593
594
595
596
597
598
599
visualize(label1, nx_G1)

#########################################
# Here is an animation to better understand the process. (40 epochs)
#
# .. figure:: https://i.imgur.com/KDUyE1S.gif 
#    :alt: lgnn-anim
#
600
601
# Batching graphs for parallelism
# --------------------------------
602
603
#
# LGNN takes a collection of different graphs.
604
# You might consider whether batching can be used for parallelism.
605
#
606
607
608
# Batching has been into the data loader itself.
# In the ``collate_fn`` for PyTorch data loader, graphs are batched using DGL's
# batched_graph API. DGL batches graphs by merging them
609
# into a large graph, with each smaller graph's adjacency matrix being a block
610
# along the diagonal of the large graph's adjacency matrix.  Concatenate
brett koonce's avatar
brett koonce committed
611
# :math`\{Pm,Pd\}` as block diagonal matrix in correspondence to DGL batched
612
613
614
615
616
617
618
619
620
621
# graph API.

def collate_fn(batch):
    graphs, pmpds, labels = zip(*batch)
    batched_graphs = dgl.batch(graphs)
    batched_pmpds = sp.block_diag(pmpds)
    batched_labels = np.concatenate(labels, axis=0)
    return batched_graphs, batched_pmpds, batched_labels

######################################################################################
622
623
# You can find the complete code on Github at 
# `Community Detection with Graph Neural Networks (CDGNN) <https://github.com/dmlc/dgl/tree/master/examples/pytorch/line_graph>`_.