capsule.py 8.83 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
"""
Capsule Network
================

**Author**: `Jinjing Zhou`
 
This tutorial explains how to use DGL library and its language to implement the
`capsule network <http://arxiv.org/abs/1710.09829>`__ proposed by Geoffrey Hinton and his team.
The algorithm aims to provide a better alternative to current neural network structures.
By using DGL library, users can implement the algorithm in a more intuitive way.
"""

##############################################################################
# Model Overview
# ---------------
# Introduction
# ```````````````````
# Capsule Network were first introduced in 2011 by Geoffrey Hinton, et al.,
# in paper `Transforming Autoencoders <https://www.cs.toronto.edu/~fritz/absps/transauto6.pdf>`__,
# but it was only a few months ago, in November 2017, that Sara Sabour, Nicholas Frosst,
# and Geoffrey Hinton published a paper called Dynamic Routing between Capsules, where they
# introduced a CapsNet architecture that reached state-of-the-art performance on MNIST.
#  
# What's a capsule?
# ```````````````````
# In papers, author states that "A capsule is a group of neurons whose activity vector
# represents the instantiation parameters of a specific type of entity such as an object
# or an object part."    
# Generally Speaking, the idea of capsule is to encode all the information about the
# features into a vector form, by substituting scalars in traditional neural network with vectors.
# And use the norm of the vector to represents the meaning of original scalars. 
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f1.png
# 
# Dynamic Routing Algorithm
# `````````````````````````````
# Due to the different structure of network, capsules network has different operations to
# calculate results. This figure shows the comparison, drawn by
# `Max Pechyonkin <https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66O>`__
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f2.png
#    :height: 250px
# 
# The key idea is that the output of each capsule is the sum of weighted input vectors.
# We will go into details in the later section with code implementations.
# 
# Model Implementations
# -------------------------
# Setup
# ```````````````````````````

import dgl
import torch
import torch.nn.functional as F
from torch import nn

class DGLBatchCapsuleLayer(nn.Module):
    def __init__(self, input_capsule_dim, input_capsule_num, output_capsule_num, output_capsule_dim, num_routing,
                 cuda_enabled):
        super(DGLBatchCapsuleLayer, self).__init__()
        self.device = "cuda" if cuda_enabled else "cpu"
        self.input_capsule_dim = input_capsule_dim
        self.input_capsule_num = input_capsule_num
        self.output_capsule_dim = output_capsule_dim
        self.output_capsule_num = output_capsule_num
        self.num_routing = num_routing
        self.weight = nn.Parameter(
            torch.randn(input_capsule_num, output_capsule_num, output_capsule_dim, input_capsule_dim))
        self.g, self.input_nodes, self.output_nodes = self.construct_graph()

##############################################################################
# Consider capsule routing  as a graph structure
# ````````````````````````````````````````````````````````````````````````````
# We can consider each capsule as a node in a graph, and connect all the nodes between layers.  
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f3.png
#    :height: 200px
# 
def construct_graph(self):
    g = dgl.DGLGraph()
    g.add_nodes(self.input_capsule_num + self.output_capsule_num)
    input_nodes = list(range(self.input_capsule_num))
    output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))
    u, v = [], []
    for i in input_nodes:
        for j in output_nodes:
            u.append(i)
            v.append(j)
    g.add_edges(u, v)
    return g, input_nodes, output_nodes
DGLBatchCapsuleLayer.construct_graph = construct_graph  # This line is for defining class in multiple cells.

##############################################################################
# Initialization & Affine Transformation
# ````````````````````````````````````````````````````````````````````````````
# - Pre-compute :math:`\hat{u}_{j|i}`, initialize :math:`b_{ij}` and store them as edge attribute
# - Initialize node features as zero
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f4.png
# 
def forward(self, x):
    self.batch_size = x.size(0)
    # x is the input vextor with shape [batch_size, input_capsule_dim, input_num]
    # Transpose x to [batch_size, input_num, input_capsule_dim]   
    x = x.transpose(1, 2)
    # Expand x to [batch_size, input_num, output_num, input_capsule_dim, 1]
    x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4)
    # Expand W from [input_num, output_num, input_capsule_dim, output_capsule_dim] 
    # to [batch_size, input_num, output_num, output_capsule_dim, input_capsule_dim]
    W = self.weight.expand(self.batch_size, *self.weight.size())
    # u_hat's shape is [input_num, output_num, batch_size, output_capsule_dim]
    u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()

    b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device)

    self.g.set_e_repr({'b_ij': b_ij.view(-1)})
    self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.output_capsule_dim)})
    
    self.routing()
    
    # Initialize all node features as zero
    node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,
                                self.output_capsule_dim).to(self.device)
    self.g.set_n_repr({'h': node_features})
DGLBatchCapsuleLayer.forward = forward

##############################################################################
# Write Message Passing functions and Squash function
# ````````````````````````````````````````````````````````````````````````````
# Squash function
# ..................
# Squashing function is to ensure that short vectors get shrunk to almost zero length and
# long vectors get shrunk to a length slightly below 1.
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/squash.png
#    :height: 100px
# 
def squash(s):
    mag_sq = torch.sum(s ** 2, dim=2, keepdim=True)
    mag = torch.sqrt(mag_sq)
    s = (mag_sq / (1.0 + mag_sq)) * (s / mag)
    return s


##############################################################################
# Message Functions
# ..................
# At first stage, we need to define a message function to get all the attributes we need
# in the further computations.
def capsule_msg(src, edge):
    return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}

##############################################################################
# Reduce Functions
# ..................
# At this stage, we need to define a reduce function to aggregate all the information we
# get from message function into node features.
# This step implements the line 4 and line 5 in routing algorithms, which softmax over
# :math:`b_{ij}` and calculate weighted sum of input features.
# 
# .. note::
#    The softmax operation is over dimension :math:`j` instead of :math:`i`. 
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/capsule_f5.png
# 
def capsule_reduce(node, msg):
    b_ij_c, u_hat = msg['b_ij'], msg['u_hat']
    # line 4
    c_i = F.softmax(b_ij_c, dim=0)
    # line 5
    s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)
    return {'h': s_j}

##############################################################################
# Node Update Functions
# ...........................
# Squash the intermidiate representations into node features :math:`v_j`
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/step6.png
# 
def capsule_update(msg):
    v_j = squash(msg['h'])
    return {'h': v_j}

##############################################################################
# Edge Update Functions
# ..........................
# Update the routing parameters
# 
# .. image:: https://raw.githubusercontent.com/dmlc/web-data/master/dgl/tutorials/capsule/step7.png
# 
def update_edge(u, v, edge):
    return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}

##############################################################################
# Executing algorithm
# .....................
# Call `update_all` and `update_edge` functions to execute the algorithms
def routing(self):
    for i in range(self.num_routing):
        self.g.update_all(capsule_msg, capsule_reduce, capsule_update)
        self.g.update_edge(edge_func=update_edge)
DGLBatchCapsuleLayer.routing = routing