2_capsule.py 10.3 KB
Newer Older
1
2
3
"""
.. _model-capsule:

4
Capsule Network
5
6
===========================

7
**Author**: Jinjing Zhou, `Jake Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang, Jinyang Li
8

9
10
In this tutorial, you learn how to describe one of the more classical models in terms of graphs. The approach
offers a different perspective. The tutorial describes how to implement a Capsule model for the
11
`capsule network <http://arxiv.org/abs/1710.09829>`__.
12
13
14
15
16
17
18
19

.. warning::

    The tutorial aims at gaining insights into the paper, with code as a mean
    of explanation. The implementation thus is NOT optimized for running
    efficiency. For recommended implementation, please refer to the `official
    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.

20
21
22
23
24
"""
#######################################################################################
# Key ideas of Capsule
# --------------------
#
25
# The Capsule model offers two key ideas: Richer representation and dynamic routing.
26
#
27
# **Richer representation** -- In classic convolutional networks, a scalar
28
29
30
31
32
# value represents the activation of a given feature. By contrast, a
# capsule outputs a vector. The vector's length represents the probability
# of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture
# etc.).
33
#
34
# |image0|
35
#
36
# **Dynamic routing** -- The output of a capsule is sent to
37
38
# certain parents in the layer above based on how well the capsule's
# prediction agrees with that of a parent. Such dynamic
39
# routing-by-agreement generalizes the static routing of max-pooling.
40
#
41
42
43
# During training, routing is accomplished iteratively. Each iteration adjusts
# routing weights between capsules based on their observed agreements.
# It's a manner similar to a k-means algorithm or `competitive
44
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__.
45
#
46
47
# In this tutorial, you see how a capsule's dynamic routing algorithm can be
# naturally expressed as a graph algorithm. The implementation is adapted
48
# from `Cedric
49
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
50
# only the routing layer. This version achieves similar speed and accuracy.
51
#
52
# Model implementation
53
# ----------------------
54
# Step 1: Setup and graph initialization
55
56
57
58
59
60
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below.
#
# |image1|
61
#
62
63
64
65
66
# Each node :math:`j` is associated with feature :math:`v_j`,
# representing its capsule’s output. Each edge is associated with
# features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`.
67
#
68
69
# Here's how we set up the graph and initialize node and edge features.

70
71
import os
os.environ['DGLBACKEND'] = 'pytorch'
72
73
import matplotlib.pyplot as plt
import numpy as np
74
import torch as th
75
import torch.nn as nn
76
import torch.nn.functional as F
77

78
79
80
81
import dgl


def init_graph(in_nodes, out_nodes, f_size):
Minjie Wang's avatar
Minjie Wang committed
82
83
84
    u = np.repeat(np.arange(in_nodes), out_nodes)
    v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
    g = dgl.DGLGraph((u, v))
85
    # init states
86
87
    g.ndata["v"] = th.zeros(in_nodes + out_nodes, f_size)
    g.edata["b"] = th.zeros(in_nodes * out_nodes, 1)
88
89
90
91
92
93
94
    return g


#########################################################################################
# Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
95
# This is the pseudocode for Capsule's routing algorithm.
96
97
#
# |image2|
98
# Implement pseudocode lines 4-7 in the class `DGLRoutingLayer` as the following steps:
99
#
100
# 1. Calculate coupling coefficients.
101
#
102
#    -  Coefficients are the softmax over all out-edge of in-capsules.
103
#       :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
104
#
105
# 2. Calculate weighted sum over all in-capsules.
106
#
107
#    -  Output of a capsule is equal to the weighted sum of its in-capsules
108
109
#       :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
#
110
# 3. Squash outputs.
111
#
112
#    -  Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
113
114
#    -  :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
#
115
# 4. Update weights by the amount of agreement.
116
#
117
#    -  The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
118
#       :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
119

Minjie Wang's avatar
Minjie Wang committed
120
121
import dgl.function as fn

122

123
124
125
126
127
128
129
130
131
132
class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size):
        super(DGLRoutingLayer, self).__init__()
        self.g = init_graph(in_nodes, out_nodes, f_size)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))

    def forward(self, u_hat, routing_num=1):
133
        self.g.edata["u_hat"] = u_hat
134
135
136

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
137
138
139
            edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
            self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
            self.g.edata["c u_hat"] = self.g.edata["c"] * self.g.edata["u_hat"]
140
141

            # Execute step 1 & 2
142
            self.g.update_all(fn.copy_e("c u_hat", "m"), fn.sum("m", "s"))
143
144

            # step 3 (line 6)
145
146
147
            self.g.nodes[self.out_indx].data["v"] = self.squash(
                self.g.nodes[self.out_indx].data["s"], dim=1
            )
148
149

            # step 4 (line 7)
150
151
152
153
154
155
            v = th.cat(
                [self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
            )
            self.g.edata["b"] = self.g.edata["b"] + (
                self.g.edata["u_hat"] * v
            ).sum(dim=1, keepdim=True)
156
157
158

    @staticmethod
    def squash(s, dim=1):
159
        sq = th.sum(s**2, dim=dim, keepdim=True)
160
161
162
163
164
165
166
167
168
        s_norm = th.sqrt(sq)
        s = (sq / (1.0 + sq)) * (s / s_norm)
        return s


############################################################################################################
# Step 3: Testing
# ~~~~~~~~~~~~~~~
#
169
# Make a simple 20x10 capsule layer.
170
171
172
173
174
175
176
in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)

############################################################################################################
177
# You can visualize a Capsule network's behavior by monitoring the entropy
178
# of coupling coefficients. They should start high and then drop, as the
179
# weights gradually concentrate on fewer edges.
180
181
182
183
184
entropy_list = []
dist_list = []

for i in range(10):
    routing(u_hat)
185
    dist_matrix = routing.g.edata["c"].view(in_nodes, out_nodes)
186
187
188
189
190
191
    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
    entropy_list.append(entropy.data.numpy())
    dist_list.append(dist_matrix.data.numpy())

stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
192
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
193
194
195
196
197
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()
############################################################################################################
198
# |image3|
199
#
200
# Alternatively, we can also watch the evolution of histograms.
201
202

import matplotlib.animation as animation
203
import seaborn as sns
204
205
206
207
208
209
210
211
212
213
214
215
216

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()


def dist_animate(i):
    ax.cla()
    sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
    ax.set_xlabel("Weight Distribution Histogram")
    ax.set_title("Routing: %d" % (i))


217
218
219
ani = animation.FuncAnimation(
    fig, dist_animate, frames=len(entropy_list), interval=500
)
220
221
222
plt.close()

############################################################################################################
223
224
# |image4|
#
225
226
# You can monitor the how lower-level Capsules gradually attach to one of the
# higher level ones.
227
228
229
230
231
232
233
234
235
236
237
238
239
240
import networkx as nx
from networkx.algorithms import bipartite

g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()

fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
241
242
243
244
245
246
pos.update(
    (n, (i, 1)) for i, n in zip(height_in_y, X)
)  # put nodes from X at x=1
pos.update(
    (n, (i, 2)) for i, n in zip(height_out_y, Y)
)  # put nodes from Y at x=2
247
248
249
250


def weight_animate(i):
    ax.cla()
251
    ax.axis("off")
252
253
    ax.set_title("Routing: %d  " % i)
    dm = dist_list[i]
254
255
256
257
258
259
260
261
262
263
264
    nx.draw_networkx_nodes(
        g, pos, nodelist=range(in_nodes), node_color="r", node_size=100, ax=ax
    )
    nx.draw_networkx_nodes(
        g,
        pos,
        nodelist=range(in_nodes, in_nodes + out_nodes),
        node_color="b",
        node_size=100,
        ax=ax,
    )
265
    for edge in g.edges():
266
267
268
269
270
271
272
273
274
275
276
277
        nx.draw_networkx_edges(
            g,
            pos,
            edgelist=[edge],
            width=dm[edge[0], edge[1] - in_nodes] * 1.5,
            ax=ax,
        )


ani2 = animation.FuncAnimation(
    fig2, weight_animate, frames=len(dist_list), interval=500
)
278
279
280
plt.close()

############################################################################################################
281
# |image5|
282
#
283
284
285
# The full code of this visualization is provided on
# `GitHub <https://github.com/dmlc/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__. The complete
# code that trains on MNIST is also on `GitHub <https://github.com/dmlc/dgl/tree/tutorial/examples/pytorch/capsule>`__.
286
#
287
# .. |image0| image:: https://i.imgur.com/55Ovkdh.png
288
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png
289
290
291
292
# .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif