2_capsule.py 9.96 KB
Newer Older
1
2
3
4
5
6
"""
.. _model-capsule:

Capsule Network Tutorial
===========================

7
**Author**: Jinjing Zhou, `Jake Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang, Jinyang Li
8

9
10
11
12
It is perhaps a little surprising that some of the more classical models
can also be described in terms of graphs, offering a different
perspective. This tutorial describes how this can be done for the
`capsule network <http://arxiv.org/abs/1710.09829>`__.
13
14
15
16
17
"""
#######################################################################################
# Key ideas of Capsule
# --------------------
#
18
# The Capsule model offers two key ideas.
19
#
20
21
22
23
24
25
# **Richer representation** In classic convolutional networks, a scalar
# value represents the activation of a given feature. By contrast, a
# capsule outputs a vector. The vector's length represents the probability
# of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture
# etc.).
26
#
27
# |image0|
28
#
29
30
31
32
# **Dynamic routing** The output of a capsule is preferentially sent to
# certain parents in the layer above based on how well the capsule's
# prediction agrees with that of a parent. Such dynamic
# "routing-by-agreement" generalizes the static routing of max-pooling.
33
#
34
35
36
37
# During training, routing is done iteratively; each iteration adjusts
# "routing weights" between capsules based on their observed agreements,
# in a manner similar to a k-means algorithm or `competitive
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__.
38
#
39
40
41
# In this tutorial, we show how capsule's dynamic routing algorithm can be
# naturally expressed as a graph algorithm. Our implementation is adapted
# from `Cedric
42
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
43
# only the routing layer. Our version achieves similar speed and accuracy.
44
45
#
# Model Implementation
46
47
48
49
50
51
52
53
# ----------------------
# Step 1: Setup and Graph Initialization
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below.
#
# |image1|
54
#
55
56
57
58
59
# Each node :math:`j` is associated with feature :math:`v_j`,
# representing its capsule’s output. Each edge is associated with
# features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`.
60
#
61
62
# Here's how we set up the graph and initialize node and edge features.

63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl


def init_graph(in_nodes, out_nodes, f_size):
    g = dgl.DGLGraph()
    all_nodes = in_nodes + out_nodes
    g.add_nodes(all_nodes)

    in_indx = list(range(in_nodes))
    out_indx = list(range(in_nodes, in_nodes + out_nodes))
    # add edges use edge broadcasting
    for u in in_indx:
        g.add_edges(u, out_indx)

    # init states
    g.ndata['v'] = th.zeros(all_nodes, f_size)
    g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
    return g


#########################################################################################
# Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
92
93
94
95
96
97
98
# This is the pseudo code for Capsule's routing algorithm as given in the
# paper:
#
# |image2|
# We implement pseudo code lines 4-7 in the class `DGLRoutingLayer` as the following steps:
#
# 1. Calculate coupling coefficients:
99
#
100
101
#    -  Coefficients are the softmax over all out-edge of in-capsules:
#       :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
102
#
103
# 2. Calculate weighted sum over all in-capsules:
104
#
105
#    -  Output of a capsule is equal to the weighted sum of its in-capsules
106
107
#       :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
#
108
# 3. Squash outputs:
109
#
110
#    -  Squash the length of a capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
111
112
#    -  :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
#
113
# 4. Update weights by the amount of agreement:
114
#
115
#    -  The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
116
#       :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
117

118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size):
        super(DGLRoutingLayer, self).__init__()
        self.g = init_graph(in_nodes, out_nodes, f_size)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))

    def forward(self, u_hat, routing_num=1):
        self.g.edata['u_hat'] = u_hat

        # step 2 (line 5)
        def cap_message(edges):
            return {'m': edges.data['c'] * edges.data['u_hat']}

        self.g.register_message_func(cap_message)

        def cap_reduce(nodes):
            return {'s': th.sum(nodes.mailbox['m'], dim=1)}

        self.g.register_reduce_func(cap_reduce)

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
            edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
            self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)

            # Execute step 1 & 2
            self.g.update_all()

            # step 3 (line 6)
            self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)

            # step 4 (line 7)
            v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
            self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)

    @staticmethod
    def squash(s, dim=1):
        sq = th.sum(s ** 2, dim=dim, keepdim=True)
        s_norm = th.sqrt(sq)
        s = (sq / (1.0 + sq)) * (s / s_norm)
        return s


############################################################################################################
# Step 3: Testing
# ~~~~~~~~~~~~~~~
#
# Let's make a simple 20x10 capsule layer:
in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)

############################################################################################################
176
177
178
# We can visualize a capsule network's behavior by monitoring the entropy
# of coupling coefficients. They should start high and then drop, as the
# weights gradually concentrate on fewer edges:
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
entropy_list = []
dist_list = []

for i in range(10):
    routing(u_hat)
    dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
    entropy_list.append(entropy.data.numpy())
    dist_list.append(dist_matrix.data.numpy())

stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()
############################################################################################################
197
# |image3|
198
#
199
# Alternatively, we can also watch the evolution of histograms:
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219

import seaborn as sns
import matplotlib.animation as animation

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()


def dist_animate(i):
    ax.cla()
    sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
    ax.set_xlabel("Weight Distribution Histogram")
    ax.set_title("Routing: %d" % (i))


ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)
plt.close()

############################################################################################################
220
221
222
223
# |image4|
#
# Or monitor the how lower level capsules gradually attach to one of the
# higher level ones:
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
import networkx as nx
from networkx.algorithms import bipartite

g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()

fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X))  # put nodes from X at x=1
pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y))  # put nodes from Y at x=2


def weight_animate(i):
    ax.cla()
    ax.axis('off')
    ax.set_title("Routing: %d  " % i)
    dm = dist_list[i]
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)
    for edge in g.edges():
        nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)


ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)
plt.close()

############################################################################################################
257
# |image5|
258
#
259
# The full code of this visualization is provided at
260
261
262
# `link <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__; the complete
# code that trains on MNIST is at `link <https://github.com/jermainewang/dgl/tree/tutorial/examples/pytorch/capsule>`__.
#
263
# .. |image0| image:: https://i.imgur.com/55Ovkdh.png
264
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png
265
266
267
268
269
# .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif