"""
.. _model-capsule:
Capsule Network
===========================
**Author**: Jinjing Zhou, `Jake Zhao `_, Zheng Zhang, Jinyang Li
In this tutorial, you learn how to describe one of the more classical models in terms of graphs. The approach
offers a different perspective. The tutorial describes how to implement a Capsule model for the
`capsule network `__.
.. warning::
The tutorial aims at gaining insights into the paper, with code as a mean
of explanation. The implementation thus is NOT optimized for running
efficiency. For recommended implementation, please refer to the `official
examples `_.
"""
#######################################################################################
# Key ideas of Capsule
# --------------------
#
# The Capsule model offers two key ideas: Richer representation and dynamic routing.
#
# **Richer representation** -- In classic convolutional networks, a scalar
# value represents the activation of a given feature. By contrast, a
# capsule outputs a vector. The vector's length represents the probability
# of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture
# etc.).
#
# |image0|
#
# **Dynamic routing** -- The output of a capsule is sent to
# certain parents in the layer above based on how well the capsule's
# prediction agrees with that of a parent. Such dynamic
# routing-by-agreement generalizes the static routing of max-pooling.
#
# During training, routing is accomplished iteratively. Each iteration adjusts
# routing weights between capsules based on their observed agreements.
# It's a manner similar to a k-means algorithm or `competitive
# learning `__.
#
# In this tutorial, you see how a capsule's dynamic routing algorithm can be
# naturally expressed as a graph algorithm. The implementation is adapted
# from `Cedric
# Chee `__, replacing
# only the routing layer. This version achieves similar speed and accuracy.
#
# Model implementation
# ----------------------
# Step 1: Setup and graph initialization
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below.
#
# |image1|
#
# Each node :math:`j` is associated with feature :math:`v_j`,
# representing its capsuleās output. Each edge is associated with
# features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`.
#
# Here's how we set up the graph and initialize node and edge features.
import os
os.environ["DGLBACKEND"] = "pytorch"
import dgl
import matplotlib.pyplot as plt
import numpy as np
import torch as th
import torch.nn as nn
import torch.nn.functional as F
def init_graph(in_nodes, out_nodes, f_size):
u = np.repeat(np.arange(in_nodes), out_nodes)
v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
g = dgl.DGLGraph((u, v))
# init states
g.ndata["v"] = th.zeros(in_nodes + out_nodes, f_size)
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1)
return g
#########################################################################################
# Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# This is the pseudocode for Capsule's routing algorithm.
#
# |image2|
# Implement pseudocode lines 4-7 in the class `DGLRoutingLayer` as the following steps:
#
# 1. Calculate coupling coefficients.
#
# - Coefficients are the softmax over all out-edge of in-capsules.
# :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
#
# 2. Calculate weighted sum over all in-capsules.
#
# - Output of a capsule is equal to the weighted sum of its in-capsules
# :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
#
# 3. Squash outputs.
#
# - Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
# - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
#
# 4. Update weights by the amount of agreement.
#
# - The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
# :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
import dgl.function as fn
class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size):
super(DGLRoutingLayer, self).__init__()
self.g = init_graph(in_nodes, out_nodes, f_size)
self.in_nodes = in_nodes
self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes))
self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
def forward(self, u_hat, routing_num=1):
self.g.edata["u_hat"] = u_hat
for r in range(routing_num):
# step 1 (line 4): normalize over out edges
edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
self.g.edata["c u_hat"] = self.g.edata["c"] * self.g.edata["u_hat"]
# Execute step 1 & 2
self.g.update_all(fn.copy_e("c u_hat", "m"), fn.sum("m", "s"))
# step 3 (line 6)
self.g.nodes[self.out_indx].data["v"] = self.squash(
self.g.nodes[self.out_indx].data["s"], dim=1
)
# step 4 (line 7)
v = th.cat(
[self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
)
self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["u_hat"] * v
).sum(dim=1, keepdim=True)
@staticmethod
def squash(s, dim=1):
sq = th.sum(s**2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_norm)
return s
############################################################################################################
# Step 3: Testing
# ~~~~~~~~~~~~~~~
#
# Make a simple 20x10 capsule layer.
in_nodes = 20
out_nodes = 10
f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)
############################################################################################################
# You can visualize a Capsule network's behavior by monitoring the entropy
# of coupling coefficients. They should start high and then drop, as the
# weights gradually concentrate on fewer edges.
entropy_list = []
dist_list = []
for i in range(10):
routing(u_hat)
dist_matrix = routing.g.edata["c"].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
entropy_list.append(entropy.data.numpy())
dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()
############################################################################################################
# |image3|
#
# Alternatively, we can also watch the evolution of histograms.
import matplotlib.animation as animation
import seaborn as sns
fig = plt.figure(dpi=150)
fig.clf()
ax = fig.subplots()
def dist_animate(i):
ax.cla()
sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
ax.set_xlabel("Weight Distribution Histogram")
ax.set_title("Routing: %d" % (i))
ani = animation.FuncAnimation(
fig, dist_animate, frames=len(entropy_list), interval=500
)
plt.close()
############################################################################################################
# |image4|
#
# You can monitor the how lower-level Capsules gradually attach to one of the
# higher level ones.
import networkx as nx
from networkx.algorithms import bipartite
g = routing.g.to_networkx()
X, Y = bipartite.sets(g)
height_in = 10
height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
pos = dict()
fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2.clf()
ax = fig2.subplots()
pos.update(
(n, (i, 1)) for i, n in zip(height_in_y, X)
) # put nodes from X at x=1
pos.update(
(n, (i, 2)) for i, n in zip(height_out_y, Y)
) # put nodes from Y at x=2
def weight_animate(i):
ax.cla()
ax.axis("off")
ax.set_title("Routing: %d " % i)
dm = dist_list[i]
nx.draw_networkx_nodes(
g, pos, nodelist=range(in_nodes), node_color="r", node_size=100, ax=ax
)
nx.draw_networkx_nodes(
g,
pos,
nodelist=range(in_nodes, in_nodes + out_nodes),
node_color="b",
node_size=100,
ax=ax,
)
for edge in g.edges():
nx.draw_networkx_edges(
g,
pos,
edgelist=[edge],
width=dm[edge[0], edge[1] - in_nodes] * 1.5,
ax=ax,
)
ani2 = animation.FuncAnimation(
fig2, weight_animate, frames=len(dist_list), interval=500
)
plt.close()
############################################################################################################
# |image5|
#
# The full code of this visualization is provided on
# `GitHub `__. The complete
# code that trains on MNIST is also on `GitHub `__.
#
# .. |image0| image:: https://i.imgur.com/55Ovkdh.png
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png
# .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif