Commit 7cb50072 authored by VoVAllen's avatar VoVAllen Committed by Minjie Wang
Browse files

[Doc][Model] New Capsule Tutorial & Example (#143)

* new capsule tutorial

* capsule for new API

* fix deprecated API

* New tutorial and example

* investigate gc problem

* add viz code

* new capsule tutorial

* remove ipynb

* move u_hat

* add link

* add requirements.txt

* remove ani.save

* update ci to install requirements

* add graphviz
parent a95459e3
# install libraries for building c++ core on ubuntu # install libraries for building c++ core on ubuntu
apt update && apt install -y --no-install-recommends --force-yes \ apt update && apt install -y --no-install-recommends --force-yes \
apt-utils git build-essential make cmake wget unzip sudo \ apt-utils git build-essential make cmake wget unzip sudo \
libz-dev libxml2-dev libopenblas-dev libopencv-dev ca-certificates libz-dev libxml2-dev libopenblas-dev libopencv-dev \
libgraphviz-dev ca-certificates
...@@ -2,82 +2,39 @@ import dgl ...@@ -2,82 +2,39 @@ import dgl
import torch import torch
from torch import nn from torch import nn
from torch.nn import functional as F from torch.nn import functional as F
import dgl.function as fn
from DGLRoutingLayer import DGLRoutingLayer
class DGLDigitCapsuleLayer(nn.Module): class DGLDigitCapsuleLayer(nn.Module):
def __init__(self, input_capsule_dim=8, input_capsule_num=1152, output_capsule_num=10, output_capsule_dim=16, def __init__(self, in_nodes_dim=8, in_nodes=1152, out_nodes=10, out_nodes_dim=16, device='cpu'):
num_routing=3, device='cpu'
):
super(DGLDigitCapsuleLayer, self).__init__() super(DGLDigitCapsuleLayer, self).__init__()
self.device = device self.device = device
self.input_capsule_dim = input_capsule_dim self.in_nodes_dim, self.out_nodes_dim = in_nodes_dim, out_nodes_dim
self.input_capsule_num = input_capsule_num self.in_nodes, self.out_nodes = in_nodes, out_nodes
self.output_capsule_dim = output_capsule_dim self.weight = nn.Parameter(torch.randn(in_nodes, out_nodes, out_nodes_dim, in_nodes_dim))
self.output_capsule_num = output_capsule_num
self.num_routing = num_routing
self.weight = nn.Parameter(
torch.randn(input_capsule_num, output_capsule_num, output_capsule_dim, input_capsule_dim))
self.g, self.input_nodes, self.output_nodes = self.construct_graph()
def construct_graph(self):
g = dgl.DGLGraph()
g.add_nodes(self.input_capsule_num + self.output_capsule_num)
input_nodes = list(range(self.input_capsule_num))
output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))
u, v = [], []
for i in input_nodes:
for j in output_nodes:
u.append(i)
v.append(j)
g.add_edges(u, v)
return g, input_nodes, output_nodes
def forward(self, x): def forward(self, x):
self.batch_size = x.size(0) self.batch_size = x.size(0)
u_hat = self.compute_uhat(x)
routing = DGLRoutingLayer(self.in_nodes, self.out_nodes, self.out_nodes_dim, batch_size=self.batch_size,
device=self.device)
routing(u_hat, routing_num=3)
out_nodes_feature = routing.g.nodes[routing.out_indx].data['v']
routing.end()
# shape transformation is for further classification
return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1)
def compute_uhat(self, x):
# x is the input vextor with shape [batch_size, in_nodes_dim, in_nodes]
# Transpose x to [batch_size, in_nodes, in_nodes_dim]
x = x.transpose(1, 2) x = x.transpose(1, 2)
x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4) # Expand x to [batch_size, in_nodes, out_nodes, in_nodes_dim, 1]
x = torch.stack([x] * self.out_nodes, dim=2).unsqueeze(4)
# Expand W from [in_nodes, out_nodes, in_nodes_dim, out_nodes_dim]
# to [batch_size, in_nodes, out_nodes, out_nodes_dim, in_nodes_dim]
W = self.weight.expand(self.batch_size, *self.weight.size()) W = self.weight.expand(self.batch_size, *self.weight.size())
# u_hat's shape is [in_nodes, out_nodes, batch_size, out_nodes_dim]
u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous() u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()
return u_hat.view(-1, self.batch_size, self.out_nodes_dim)
b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device)
self.g.set_e_repr({'b_ij': b_ij.view(-1)})
self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.output_capsule_dim)})
node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,
self.output_capsule_dim).to(self.device)
self.g.set_n_repr({'h': node_features})
for i in range(self.num_routing):
self.g.update_all(self.capsule_msg, self.capsule_reduce, self.capsule_update)
self.g.update_edge(edge_func=self.update_edge)
this_layer_nodes_feature = self.g.get_n_repr()['h'][
self.input_capsule_num:self.input_capsule_num + self.output_capsule_num]
return this_layer_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1)
def update_edge(self, u, v, edge):
return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}
@staticmethod
def capsule_msg(src, edge):
return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}
@staticmethod
def capsule_reduce(node, msg):
b_ij_c, u_hat = msg['b_ij'], msg['u_hat']
c_i = F.softmax(b_ij_c, dim=0)
s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)
return {'h': s_j}
@staticmethod
def capsule_update(msg):
v_j = squash(msg['h'])
return {'h': v_j}
def squash(s, dim=2):
sq = torch.sum(s ** 2, dim=dim, keepdim=True)
s_std = torch.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_std)
return s
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import dgl
class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device='cpu'):
super(DGLRoutingLayer, self).__init__()
self.batch_size = batch_size
self.g = init_graph(in_nodes, out_nodes, f_size, device=device, batch_size=batch_size)
self.in_nodes = in_nodes
self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes))
self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
self.device = device
def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat
for r in range(routing_num):
# step 1 (line 4): normalize over out edges
in_edges = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
self.g.edata['c'] = F.softmax(in_edges, dim=1).view(-1, 1)
def cap_message(edges):
if self.batch_size:
return {'m': edges.data['c'].unsqueeze(1) * edges.data['u_hat']}
else:
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message)
# step 2 (line 5)
def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)}
self.g.register_reduce_func(cap_reduce)
# Execute step 1 & 2
self.g.update_all()
# step 3 (line 6)
if self.batch_size:
self.g.nodes[self.out_indx].data['v'] = squash(self.g.nodes[self.out_indx].data['s'], dim=2)
else:
self.g.nodes[self.out_indx].data['v'] = squash(self.g.nodes[self.out_indx].data['s'], dim=1)
# step 4 (line 7)
v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
if self.batch_size:
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).mean(dim=1).sum(dim=1, keepdim=True)
else:
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)
def end(self):
del self.g
# del self.g.edata['u_hat']
# del self.g.ndata['v']
# del self.g.ndata['s']
# del self.g.edata['b']
def squash(s, dim=1):
sq = th.sum(s ** 2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_norm)
return s
def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0):
g = dgl.DGLGraph()
all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes)
in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes))
# add edges use edge broadcasting
for u in in_indx:
g.add_edges(u, out_indx)
# init states
if batch_size:
g.ndata['v'] = th.zeros(all_nodes, batch_size, f_size).to(device)
else:
g.ndata['v'] = th.zeros(all_nodes, f_size).to(device)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g
import torch import torch
from torch import nn from torch import nn
from DGLDigitCapsule import DGLDigitCapsuleLayer, squash from DGLDigitCapsule import DGLDigitCapsuleLayer
from DGLRoutingLayer import squash
class Net(nn.Module): class Net(nn.Module):
......
import dgl
import torch as th
import torch.nn as nn
from torch.nn import functional as F
from DGLRoutingLayer import DGLRoutingLayer
g = dgl.DGLGraph()
g.graph_data = {}
in_nodes = 20
out_nodes = 10
g.graph_data['in_nodes']=in_nodes
g.graph_data['out_nodes']=out_nodes
all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes)
in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes))
g.graph_data['in_indx']=in_indx
g.graph_data['out_indx']=out_indx
# add edges use edge broadcasting
for u in out_indx:
g.add_edges(in_indx, u)
# init states
f_size = 4
g.ndata['v'] = th.zeros(all_nodes, f_size)
g.edata['u_hat'] = th.randn(in_nodes * out_nodes, f_size)
g.edata['b'] = th.randn(in_nodes * out_nodes, 1)
routing_layer = DGLRoutingLayer(g)
entropy_list=[]
for i in range(15):
routing_layer()
dist_matrix = g.edata['c'].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=0)
entropy_list.append(entropy.data.numpy())
std = dist_matrix.std(dim=0)
#!/bin/bash #!/bin/bash
# The working directory for this script will be "tests/scripts" # The working directory for this script will be "tests/scripts"
TUTORIAL_ROOT="../../tutorials"
function fail { function fail {
echo FAIL: $@ echo FAIL: $@
exit -1 exit -1
} }
export MPLBACKEND=Agg pushd ${TUTORIAL_ROOT} > /dev/null
# Install requirements
pip3 install -r requirements.txt || fail "installing requirements"
for f in $(find "../../tutorials" -name "*.py") # Test
export MPLBACKEND=Agg
for f in $(find . -name "*.py")
do do
echo "Running tutorial ${f} ..." echo "Running tutorial ${f} ..."
python3 $f || fail "run ${f}" python3 $f || fail "run ${f}"
done done
popd > /dev/null
""" """
.. _model-capsule: .. _model-capsule:
Capsule Network Capsule Network Tutorial
================ ===========================
**Author**: `Jinjing Zhou` **Author**: `Jinjing Zhou`, `Zheng Zhang`
This tutorial explains how to use DGL library and its language to implement the It is perhaps a little surprising that some of the more classical models can also be described in terms of graphs,
`capsule network <http://arxiv.org/abs/1710.09829>`__ proposed by Geoffrey offering a different perspective.
Hinton and his team. The algorithm aims to provide a better alternative to This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__.
current neural network structures. By using DGL library, users can implement
the algorithm in a more intuitive way.
""" """
#######################################################################################
# Key ideas of Capsule
# --------------------
#
# There are two key ideas that the Capsule model offers.
#
# **Richer representations** In classic convolutional network, a scalar
# value represents the activation of a given feature. Instead, a capsule
# outputs a vector, whose norm represents the probability of a feature,
# and the orientation its properties.
#
# .. figure:: https://i.imgur.com/55Ovkdh.png
# :alt:
#
# **Dynamic routing** To generalize max-pooling, there is another
# interesting proposed by the authors, as a representational more powerful
# way to construct higher level feature from its low levels. Consider a
# capsule :math:`u_i`. The way :math:`u_i` is integrated to the next level
# capsules take two steps:
#
# 1. :math:`u_i` projects differently to different higher level capsules
# via a linear transformation: :math:`\hat{u}_{j|i} = W_{ij}u_i`.
# 2. :math:`\hat{u}_{j|i}` routes to the higher level capsules by
# spreading itself with a weighted sum, and the weight is dynamically
# determined by iteratively modify the and checking against the
# "consistency" between :math:`\hat{u}_{j|i}` and :math:`v_j`, for any
# :math:`v_j`. Note that this is similar to a k-means algorithm or
# `competive
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__ in
# spirit. At the end of iterations, :math:`v_j` now integrates the
# lower level capsules.
#
# The full algorithm is the following: |image0|
#
# The dynamic routing step can be naturally expressed as a graph
# algorithm. This is the focus of this tutorial. Our implementation is
# adapted from `Cedric
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
# only the routing layer, and achieving similar speed and accuracy.
#
# Model Implementation
# -----------------------------------
# Step 1: Setup and Graph Initialiation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
#
# The below figure shows the directed bipartitie graph built for capsules
# network. We denote :math:`b_{ij}`, :math:`\hat{u}_{j|i}` as edge
# features and :math:`v_j` as node features. |image1|
#
import torch.nn as nn
import torch as th
import torch.nn.functional as F
import numpy as np
import matplotlib.pyplot as plt
import dgl
############################################################################## def init_graph(in_nodes, out_nodes, f_size):
# Model Overview
# ---------------
# Introduction
# ```````````````````
# Capsule Network were first introduced in 2011 by Geoffrey Hinton, et al., in
# paper `Transforming Autoencoders
# <https://www.cs.toronto.edu/~fritz/absps/transauto6.pdf>`__, but it was only
# a few months ago, in November 2017, that Sara Sabour, Nicholas Frosst, and
# Geoffrey Hinton published a paper called Dynamic Routing between Capsules,
# where they introduced a CapsNet architecture that reached state-of-the-art
# performance on MNIST.
#
# What's a capsule?
# ```````````````````
# In papers, author states that "A capsule is a group of neurons whose activity
# vector represents the instantiation parameters of a specific type of entity
# such as an object or an object part."
#
# Generally speaking, the idea of capsule is to encode all the information
# about the features into a vector form, by substituting scalars in traditional
# neural network with vectors. And use the norm of the vector to represents
# the meaning of original scalars.
#
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f1.png
#
# Dynamic Routing Algorithm
# `````````````````````````````
# Due to the different structure of network, capsules network has different
# operations to calculate results. This figure shows the comparison, drawn by
# `Max Pechyonkin
# <https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66O>`__
#
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f2.png
# :height: 250px
#
# The key idea is that the output of each capsule is the sum of weighted input vectors.
# We will go into details in the later section with code implementations.
#
# Model Implementations
# -------------------------
##############################################################################
# Algorithm Overview
# ```````````````````````````
#
# .. image:: https://raw.githubusercontent.com/VoVAllen/DGL_Capsule/master/algorithm.png
#
# The main step of routing algorithm is line 4 - 7. In ``DGLGraph`` structure, we consider these steps as a message passing
# procedure.
##############################################################################
# Consider capsule routing as a graph structure
# ````````````````````````````````````````````````````````````````````````````
# We can consider each capsule as a node in a graph, and connect all the nodes between layers.
#
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f3.png
# :height: 150px
#
def construct_graph(self):
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.add_nodes(self.input_capsule_num + self.output_capsule_num) all_nodes = in_nodes + out_nodes
input_nodes = list(range(self.input_capsule_num)) g.add_nodes(all_nodes)
output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))
u, v = [], [] in_indx = list(range(in_nodes))
for i in input_nodes: out_indx = list(range(in_nodes, in_nodes + out_nodes))
for j in output_nodes: # add edges use edge broadcasting
u.append(i) for u in in_indx:
v.append(j) g.add_edges(u, out_indx)
g.add_edges(u, v)
return g, input_nodes, output_nodes # init states
g.ndata['v'] = th.zeros(all_nodes, f_size)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
############################################################################## return g
# Write Message Passing Functions
# ``````````````````````````````````
# Reduce Functions (line 4 - 5) #########################################################################################
# ............................................. # Step 2: Define message passing functions
# # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f5.png # Recall the following steps, and they are implemented in the class
# # ``DGLRoutingLayer`` as the followings:
# At this stage, we need to define a reduce function to aggregate the node features #
# from layer :math:`l` and weighted sum them into layer :math:`(l+1)`'s node features. # 1. Normalize over out edges
# #
# .. note:: # - Softmax over all out-edge of in-capsules
# The softmax operation is over dimension :math:`j` instead of :math:`i`. # :math:`\textbf{c}_i = \text{softmax}(\textbf{b}_i)`.
def capsule_reduce(node, msg): #
b_ij_c, u_hat = msg['b_ij'], msg['u_hat'] # 2. Weighted sum over all in-capsules
# line 4 #
c_i = F.softmax(b_ij_c, dim=0) # - Out-capsules equals weighted sum of in-capsules
# line 5 # :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1) #
return {'h': s_j} # 3. Squash Operation
#
# - Squashing function is to ensure that short capsule vectors get
############################################################################## # shrunk to almost zero length while the long capsule vectors get
# Node Update Functions (line 6) # shrunk to a length slightly below 1. Its norm is expected to
# ...................................................... # represents probabilities at some levels.
# Squash the intermediate representations into node features :math:`v_j` # - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
# #
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/step6.png # 4. Update weights by agreement
# #
def capsule_update(msg): # - :math:`\hat{u}_{j|i}\cdot v_j` can be considered as agreement
v_j = squash(msg['h']) # between current capsule and updated capsule,
return {'h': v_j} # :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size):
############################################################################## super(DGLRoutingLayer, self).__init__()
# Edge Update Functions (line 7) self.g = init_graph(in_nodes, out_nodes, f_size)
# ........................................................................... self.in_nodes = in_nodes
# Update the routing parameters by updating edges in graph self.out_nodes = out_nodes
# self.in_indx = list(range(in_nodes))
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/step7.png self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
#
def update_edge(u, v, edge): def forward(self, u_hat, routing_num=1):
return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)} self.g.edata['u_hat'] = u_hat
for r in range(routing_num):
# step 1 (line 4): normalize over out edges
############################################################################## in_edges = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
# Call DGL function to execute algorithm self.g.edata['c'] = F.softmax(in_edges, dim=1).view(-1, 1)
# ````````````````````````````````````````````````````````````````````````````
# Call ``update_all`` and ``update_edge`` functions to execute the whole algorithms. def cap_message(edges):
# Message function is to define which attributes are needed in further computations return {'m': edges.data['c'] * edges.data['u_hat']}
# self.g.register_message_func(cap_message)
def routing(self):
def capsule_msg(src, edge): # step 2 (line 5)
return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']} def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)}
self.g.update_all(capsule_msg, capsule_reduce, capsule_update) self.g.register_reduce_func(cap_reduce)
self.g.update_edge(edge_func=update_edge)
# Execute step 1 & 2
self.g.update_all()
##############################################################################
# Forward Function # step 3 (line 6)
# ```````````````````````````````````````````````````````````````````````````` self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)
# This section shows the whole process of forward process of capsule routing algorithm.
def forward(self, x): # step 4 (line 7)
self.batch_size = x.size(0) v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
u_hat = self.compute_uhat(x) self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)
self.initialize_nodes_and_edges_features(u_hat)
for i in range(self.num_routing): @staticmethod
self.routing() def squash(s, dim=1):
this_layer_nodes_feature = self.g.get_n_repr()['h'][ sq = th.sum(s ** 2, dim=dim, keepdim=True)
self.input_capsule_num:self.input_capsule_num + self.output_capsule_num] s_norm = th.sqrt(sq)
return this_layer_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1) s = (sq / (1.0 + sq)) * (s / s_norm)
return s
##############################################################################
# Other Workaround ############################################################################################################
# ```````````````````````````````````````````````````````````````` # Step 3: Testing
# Initialization & Affine Transformation # ~~~~~~~~~~~~~~~
# .................................................. #
# This section implements the transformation operation in capsule networks, # Let's make a simple 20x10 capsule layer:
# which transform capsule into different dimensions. in_nodes = 20
# - Pre-compute :math:`\hat{u}_{j|i}`, initialize :math:`b_{ij}` and store them as edge attribute out_nodes = 10
# - Initialize node features as zero f_size = 4
# u_hat = th.randn(in_nodes * out_nodes, f_size)
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f4.png routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)
#
############################################################################################################
def compute_uhat(self, x): # We can visualize the behavior by monitoring the entropy of outgoing
# x is the input vextor with shape [batch_size, input_capsule_dim, input_num] # weights, they should start high and then drop, as the assignment
# Transpose x to [batch_size, input_num, input_capsule_dim] # gradually concentrate:
x = x.transpose(1, 2) entropy_list = []
# Expand x to [batch_size, input_num, output_num, input_capsule_dim, 1] dist_list = []
x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4)
# Expand W from [input_num, output_num, input_capsule_dim, output_capsule_dim] for i in range(10):
# to [batch_size, input_num, output_num, output_capsule_dim, input_capsule_dim] routing(u_hat)
W = self.weight.expand(self.batch_size, *self.weight.size()) dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
# u_hat's shape is [input_num, output_num, batch_size, output_capsule_dim] entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous() entropy_list.append(entropy.data.numpy())
return u_hat dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1)
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.ylabel("Entropy of Weight Distribution")
plt.xlabel("Number of Routing")
plt.xticks(np.arange(len(entropy_list)))
plt.close()
############################################################################################################
#
# .. figure:: https://i.imgur.com/dMvu7p3.png
# :alt:
import seaborn as sns
import matplotlib.animation as animation
def initialize_nodes_and_edges_features(self, u_hat): fig = plt.figure(dpi=150)
b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device) fig.clf()
self.g.set_e_repr({'b_ij': b_ij.view(-1)}) ax = fig.subplots()
self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.output_capsule_dim)})
# Initialize all node features as zero
node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,
self.output_capsule_dim).to(self.device)
self.g.set_n_repr({'h': node_features})
############################################################################## def dist_animate(i):
# Squash function ax.cla()
# .................. sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
# Squashing function is to ensure that short vectors get shrunk to almost zero ax.set_xlabel("Weight Distribution Histogram")
# length and long vectors get shrunk to a length slightly below 1. Its norm is ax.set_title("Routing: %d" % (i))
# expected to represents probabilities at some levels.
#
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/squash.png
# :height: 100px
#
def squash(s, dim=2):
sq = torch.sum(s ** 2, dim=dim, keepdim=True)
s_std = torch.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_std)
return s
##############################################################################
# General Setup
# .................
import dgl ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)
import torch plt.close()
import torch.nn.functional as F
from torch import nn ############################################################################################################
# Alternatively, we can also watch the evolution of histograms: |image2|
# Or monitor the how lower level capcules gradually attach to one of the higher level ones:
class DGLDigitCapsuleLayer(nn.Module): import networkx as nx
def __init__(self, from networkx.algorithms import bipartite
input_capsule_dim=8,
input_capsule_num=1152, g = routing.g.to_networkx()
output_capsule_num=10, X, Y = bipartite.sets(g)
output_capsule_dim=16, height_in = 10
num_routing=3, height_out = height_in * 0.8
device='cpu'): height_in_y = np.linspace(0, height_in, in_nodes)
super(DGLDigitCapsuleLayer, self).__init__() height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
self.device = device pos = dict()
self.input_capsule_dim = input_capsule_dim
self.input_capsule_num = input_capsule_num fig2 = plt.figure(figsize=(8, 3), dpi=150)
self.output_capsule_dim = output_capsule_dim fig2.clf()
self.output_capsule_num = output_capsule_num ax = fig2.subplots()
self.num_routing = num_routing pos.update((n, (i, 1)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1
self.weight = nn.Parameter( pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2
torch.randn(input_capsule_num, output_capsule_num, output_capsule_dim, input_capsule_dim))
self.g, self.input_nodes, self.output_nodes = self.construct_graph()
def weight_animate(i):
ax.cla()
# This section is for defining class in multiple cells. ax.axis('off')
DGLDigitCapsuleLayer.construct_graph = construct_graph ax.set_title("Routing: %d " % i)
DGLDigitCapsuleLayer.forward = forward dm = dist_list[i]
DGLDigitCapsuleLayer.routing = routing nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)
DGLDigitCapsuleLayer.compute_uhat = compute_uhat nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)
DGLDigitCapsuleLayer.initialize_nodes_and_edges_features = initialize_nodes_and_edges_features for edge in g.edges():
nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)
ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)
plt.close()
############################################################################################################
# |image3|
#
# The full code of this visulization is provided at
# `link <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__; the complete
# code that trains on MNIST is at `link <https://github.com/jermainewang/dgl/tree/tutorial/examples/pytorch/capsule>`__.
#
# .. |image0| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png
# .. |image2| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image3| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif
#
networkx
torch
numpy
seaborn
matplotlib
pygraphviz
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment