2_capsule.py 9.96 KB
Newer Older
1
2
3
"""
.. _model-capsule:

4
Capsule network tutorial
5
6
===========================

7
**Author**: Jinjing Zhou, `Jake Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang, Jinyang Li
8

9
10
In this tutorial, you learn how to describe one of the more classical models in terms of graphs. The approach
offers a different perspective. The tutorial describes how to implement a Capsule model for the
11
`capsule network <http://arxiv.org/abs/1710.09829>`__.
12
13
14
15
16
"""
#######################################################################################
# Key ideas of Capsule
# --------------------
#
17
# The Capsule model offers two key ideas: Richer representation and dynamic routing.
18
#
19
# **Richer representation** -- In classic convolutional networks, a scalar
20
21
22
23
24
# value represents the activation of a given feature. By contrast, a
# capsule outputs a vector. The vector's length represents the probability
# of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture
# etc.).
25
#
26
# |image0|
27
#
28
# **Dynamic routing** -- The output of a capsule is sent to
29
30
# certain parents in the layer above based on how well the capsule's
# prediction agrees with that of a parent. Such dynamic
31
# routing-by-agreement generalizes the static routing of max-pooling.
32
#
33
34
35
# During training, routing is accomplished iteratively. Each iteration adjusts
# routing weights between capsules based on their observed agreements.
# It's a manner similar to a k-means algorithm or `competitive
36
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__.
37
#
38
39
# In this tutorial, you see how a capsule's dynamic routing algorithm can be
# naturally expressed as a graph algorithm. The implementation is adapted
40
# from `Cedric
41
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
42
# only the routing layer. This version achieves similar speed and accuracy.
43
#
44
# Model implementation
45
# ----------------------
46
# Step 1: Setup and graph initialization
47
48
49
50
51
52
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below.
#
# |image1|
53
#
54
55
56
57
58
# Each node :math:`j` is associated with feature :math:`v_j`,
# representing its capsule’s output. Each edge is associated with
# features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`.
59
#
60
61
# Here's how we set up the graph and initialize node and edge features.

62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl


def init_graph(in_nodes, out_nodes, f_size):
    g = dgl.DGLGraph()
    all_nodes = in_nodes + out_nodes
    g.add_nodes(all_nodes)

    in_indx = list(range(in_nodes))
    out_indx = list(range(in_nodes, in_nodes + out_nodes))
    # add edges use edge broadcasting
    for u in in_indx:
        g.add_edges(u, out_indx)

    # init states
    g.ndata['v'] = th.zeros(all_nodes, f_size)
    g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
    return g


#########################################################################################
# Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
91
# This is the pseudocode for Capsule's routing algorithm.
92
93
#
# |image2|
94
# Implement pseudocode lines 4-7 in the class `DGLRoutingLayer` as the following steps:
95
#
96
# 1. Calculate coupling coefficients.
97
#
98
#    -  Coefficients are the softmax over all out-edge of in-capsules.
99
#       :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
100
#
101
# 2. Calculate weighted sum over all in-capsules.
102
#
103
#    -  Output of a capsule is equal to the weighted sum of its in-capsules
104
105
#       :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
#
106
# 3. Squash outputs.
107
#
108
#    -  Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
109
110
#    -  :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
#
111
# 4. Update weights by the amount of agreement.
112
#
113
#    -  The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
114
#       :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
class DGLRoutingLayer(nn.Module):
    def __init__(self, in_nodes, out_nodes, f_size):
        super(DGLRoutingLayer, self).__init__()
        self.g = init_graph(in_nodes, out_nodes, f_size)
        self.in_nodes = in_nodes
        self.out_nodes = out_nodes
        self.in_indx = list(range(in_nodes))
        self.out_indx = list(range(in_nodes, in_nodes + out_nodes))

    def forward(self, u_hat, routing_num=1):
        self.g.edata['u_hat'] = u_hat

        # step 2 (line 5)
        def cap_message(edges):
            return {'m': edges.data['c'] * edges.data['u_hat']}

        self.g.register_message_func(cap_message)

        def cap_reduce(nodes):
            return {'s': th.sum(nodes.mailbox['m'], dim=1)}

        self.g.register_reduce_func(cap_reduce)

        for r in range(routing_num):
            # step 1 (line 4): normalize over out edges
            edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
            self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)

            # Execute step 1 & 2
            self.g.update_all()

            # step 3 (line 6)
            self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)

            # step 4 (line 7)
            v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
            self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)

    @staticmethod
    def squash(s, dim=1):
        sq = th.sum(s ** 2, dim=dim, keepdim=True)
        s_norm = th.sqrt(sq)
        s = (sq / (1.0 + sq)) * (s / s_norm)
        return s


############################################################################################################
# Step 3: Testing
# ~~~~~~~~~~~~~~~
#
166
# Make a simple 20x10 capsule layer.
167
168
169
170
171
172
173
in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)

############################################################################################################
174
# You can visualize a Capsule network's behavior by monitoring the entropy
175
# of coupling coefficients. They should start high and then drop, as the
176
# weights gradually concentrate on fewer edges.
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
entropy_list = []
dist_list = []

for i in range(10):
    routing(u_hat)
    dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
    entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
    entropy_list.append(entropy.data.numpy())
    dist_list.append(dist_matrix.data.numpy())

stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()
############################################################################################################
195
# |image3|
196
#
197
# Alternatively, we can also watch the evolution of histograms.
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217

import seaborn as sns
import matplotlib.animation as animation

fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()


def dist_animate(i):
    ax.cla()
    sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
    ax.set_xlabel("Weight Distribution Histogram")
    ax.set_title("Routing: %d" % (i))


ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)
plt.close()

############################################################################################################
218
219
# |image4|
#
220
221
# You can monitor the how lower-level Capsules gradually attach to one of the
# higher level ones.
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
import networkx as nx
from networkx.algorithms import bipartite

g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()

fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X))  # put nodes from X at x=1
pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y))  # put nodes from Y at x=2


def weight_animate(i):
    ax.cla()
    ax.axis('off')
    ax.set_title("Routing: %d  " % i)
    dm = dist_list[i]
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)
    nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)
    for edge in g.edges():
        nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)


ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)
plt.close()

############################################################################################################
255
# |image5|
256
#
257
258
259
# The full code of this visualization is provided on
# `GitHub <https://github.com/dmlc/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__. The complete
# code that trains on MNIST is also on `GitHub <https://github.com/dmlc/dgl/tree/tutorial/examples/pytorch/capsule>`__.
260
#
261
# .. |image0| image:: https://i.imgur.com/55Ovkdh.png
262
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png
263
264
265
266
267
# .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif