2_capsule.py 9.98 KB
Newer Older
1
2
3
"""
.. _model-capsule:

4
Capsule Network
5
6
===========================

7
**Author**: Jinjing Zhou, `Jake Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang, Jinyang Li
8

9
10
In this tutorial, you learn how to describe one of the more classical models in terms of graphs. The approach
offers a different perspective. The tutorial describes how to implement a Capsule model for the
11
`capsule network <http://arxiv.org/abs/1710.09829>`__.
12
13
14
15
16
17
18
19

.. warning::

    The tutorial aims at gaining insights into the paper, with code as a mean
    of explanation. The implementation thus is NOT optimized for running
    efficiency. For recommended implementation, please refer to the `official
    examples <https://github.com/dmlc/dgl/tree/master/examples>`_.

20
21
22
23
24
"""
#######################################################################################
# Key ideas of Capsule
# --------------------
#
25
# The Capsule model offers two key ideas: Richer representation and dynamic routing.
26
#
27
# **Richer representation** -- In classic convolutional networks, a scalar
28
29
30
31
32
# value represents the activation of a given feature. By contrast, a
# capsule outputs a vector. The vector's length represents the probability
# of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture
# etc.).
33
#
34
# |image0|
35
#
36
# **Dynamic routing** -- The output of a capsule is sent to
37
38
# certain parents in the layer above based on how well the capsule's
# prediction agrees with that of a parent. Such dynamic
39
# routing-by-agreement generalizes the static routing of max-pooling.
40
#
41
42
43
# During training, routing is accomplished iteratively. Each iteration adjusts
# routing weights between capsules based on their observed agreements.
# It's a manner similar to a k-means algorithm or `competitive
44
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__.
45
#
46
47
# In this tutorial, you see how a capsule's dynamic routing algorithm can be
# naturally expressed as a graph algorithm. The implementation is adapted
48
# from `Cedric
49
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
50
# only the routing layer. This version achieves similar speed and accuracy.
51
#
52
# Model implementation
53
# ----------------------
54
# Step 1: Setup and graph initialization
55
56
57
58
59
60
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below.
#
# |image1|
61
#
62
63
64
65
66
# Each node :math:`j` is associated with feature :math:`v_j`,
# representing its capsule’s output. Each edge is associated with
# features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`.
67
#
68
69
# Here's how we set up the graph and initialize node and edge features.

70
71
72
73
74
75
76
77
78
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl


def init_graph(in_nodes, out_nodes, f_size):
Minjie Wang's avatar
Minjie Wang committed
79
80
81
    u = np.repeat(np.arange(in_nodes), out_nodes)
    v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
    g = dgl.DGLGraph((u, v))
82
    # init states
Minjie Wang's avatar
Minjie Wang committed
83
    g.ndata['v'] = th.zeros(in_nodes + out_nodes, f_size)
84
85
86
87
88
89
90
91
    g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
    return g


#########################################################################################
# Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
92
# This is the pseudocode for Capsule's routing algorithm.
93
94
#
# |image2|
95
# Implement pseudocode lines 4-7 in the class `DGLRoutingLayer` as the following steps:
96
#
97
# 1. Calculate coupling coefficients.
98
#
99
#    -  Coefficients are the softmax over all out-edge of in-capsules.
100
#       :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
101
#
102
# 2. Calculate weighted sum over all in-capsules.
103
#
104
#    -  Output of a capsule is equal to the weighted sum of its in-capsules
105
106
#       :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
#
107
# 3. Squash outputs.
108
#
109
#    -  Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
110
111
#    -  :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
#
112
# 4. Update weights by the amount of agreement.
113
#
114
#    -  The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
115
#       :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
116

Minjie Wang's avatar
Minjie Wang committed
117
118
import dgl.function as fn

119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size):
        super(DGLRoutingLayer, self).__init__()
        self.g = init_graph(in_nodes, out_nodes, f_size)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))

    def forward(self, u_hat, routing_num=1):
        self.g.edata['u_hat'] = u_hat

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
            edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
            self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)
Minjie Wang's avatar
Minjie Wang committed
135
            self.g.edata['c u_hat'] = self.g.edata['c'] * self.g.edata['u_hat']
136
137

            # Execute step 1 & 2
Minjie Wang's avatar
Minjie Wang committed
138
            self.g.update_all(fn.copy_e('c u_hat', 'm'), fn.sum('m', 's'))
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158

            # step 3 (line 6)
            self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)

            # step 4 (line 7)
            v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
            self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)

    @staticmethod
    def squash(s, dim=1):
        sq = th.sum(s ** 2, dim=dim, keepdim=True)
        s_norm = th.sqrt(sq)
        s = (sq / (1.0 + sq)) * (s / s_norm)
        return s


############################################################################################################
# Step 3: Testing
# ~~~~~~~~~~~~~~~
#
159
# Make a simple 20x10 capsule layer.
160
161
162
163
164
165
166
in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)

############################################################################################################
167
# You can visualize a Capsule network's behavior by monitoring the entropy
168
# of coupling coefficients. They should start high and then drop, as the
169
# weights gradually concentrate on fewer edges.
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
entropy_list = []
dist_list = []

for i in range(10):
    routing(u_hat)
    dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
    entropy_list.append(entropy.data.numpy())
    dist_list.append(dist_matrix.data.numpy())

stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()
############################################################################################################
188
# |image3|
189
#
190
# Alternatively, we can also watch the evolution of histograms.
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210

import seaborn as sns
import matplotlib.animation as animation

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()


def dist_animate(i):
    ax.cla()
    sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
    ax.set_xlabel("Weight Distribution Histogram")
    ax.set_title("Routing: %d" % (i))


ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)
plt.close()

############################################################################################################
211
212
# |image4|
#
213
214
# You can monitor the how lower-level Capsules gradually attach to one of the
# higher level ones.
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
import networkx as nx
from networkx.algorithms import bipartite

g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()

fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X))  # put nodes from X at x=1
pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y))  # put nodes from Y at x=2


def weight_animate(i):
    ax.cla()
    ax.axis('off')
    ax.set_title("Routing: %d  " % i)
    dm = dist_list[i]
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)
    for edge in g.edges():
        nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)


ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)
plt.close()

############################################################################################################
248
# |image5|
249
#
250
251
252
# The full code of this visualization is provided on
# `GitHub <https://github.com/dmlc/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__. The complete
# code that trains on MNIST is also on `GitHub <https://github.com/dmlc/dgl/tree/tutorial/examples/pytorch/capsule>`__.
253
#
254
# .. |image0| image:: https://i.imgur.com/55Ovkdh.png
255
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png
256
257
258
259
260
# .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif