Commit 026d35c5 authored by VoVAllen's avatar VoVAllen
Browse files

update capsule tutorials

parent d0313326
......@@ -18,16 +18,19 @@
"## Model Overview\n",
"\n",
"### Introduction\n",
"Capsule Network is \n",
"Capsule Network were first introduced in 2011 by Geoffrey Hinton, et al., in a paper called [Transforming Autoencoders](https://www.cs.toronto.edu/~fritz/absps/transauto6.pdf), but it was only a few months ago, in November 2017, that Sara Sabour, Nicholas Frosst, and Geoffrey Hinton published a paper called Dynamic Routing between Capsules, where they introduced a CapsNet architecture that reached state-of-the-art performance on MNIST.\n",
"\n",
"### What's a capsule?\n",
"> A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or an object part. \n",
"\n",
"Generally Speaking, the idea of capsule is to encode all the information about the features in a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n",
"Generally Speaking, the idea of capsule is to encode all the information about the features into a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n",
"![figure_1](./capsule_f1.png)\n",
"\n",
"### Dynamic Routing Algorithm\n",
"<img src=\"./capsule_f2.png\" style=\"height:300px;\"/>"
"Due to the different structure of network, capsules network has different operations to calculate results. This figure shows the comparison, drawn by [Max Pechyonkin](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66O). \n",
"<img src=\"./capsule_f2.png\" style=\"height:250px;\"/><br/>\n",
"\n",
"The key idea is that the output of each capsule is the sum of weighted input vectors. We will go into details in the later section with code implementations.\n"
]
},
{
......@@ -38,7 +41,7 @@
"\n",
"### 1. Consider capsule routing as a graph structure\n",
"\n",
"We can consider each capsule as a node in a graph, and connect the nodes between layers.\n",
"We can consider each capsule as a node in a graph, and connect all the nodes between layers.\n",
"<img src=\"./capsule_f3.png\" style=\"height:200px;\"/>"
]
},
......@@ -50,23 +53,25 @@
"source": [
"def construct_graph(self):\n",
" g = dgl.DGLGraph()\n",
" g.add_nodes(self.in_channel + self.num_unit)\n",
" self.in_channel_nodes = list(range(self.in_channel))\n",
" self.capsule_nodes = list(range(self.in_channel, self.in_channel + self.num_unit))\n",
" g.add_nodes(self.input_capsule_num + self.output_capsule_num)\n",
" input_nodes = list(range(self.input_capsule_num))\n",
" output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))\n",
" u, v = [], []\n",
" for i in self.in_channel_nodes:\n",
" for j in self.capsule_nodes:\n",
" for i in input_nodes:\n",
" for j in output_nodes:\n",
" u.append(i)\n",
" v.append(j)\n",
" g.add_edges(u, v)\n",
" return g"
" return g, input_nodes, output_nodes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n",
"### 2. Initialization & Affine Transformation\n",
"- Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n",
"- Initialize node features as zero\n",
"<img src=\"./capsule_f4.png\" style=\"height:200px;\"/>"
]
},
......@@ -76,36 +81,63 @@
"metadata": {},
"outputs": [],
"source": [
"# x is the input vextor with shape [batch_size, input_capsule_dim, input_num]\n",
"# Transpose x to [batch_size, input_num, input_capsule_dim] \n",
"x = x.transpose(1, 2)\n",
"x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4)\n",
"W = self.weight.expand(self.batch_size, *self.weight.shape)\n",
"# Expand x to [batch_size, input_num, output_num, input_capsule_dim, 1]\n",
"x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4)\n",
"# Expand W from [input_num, output_num, input_capsule_dim, output_capsule_dim] \n",
"# to [batch_size, input_num, output_num, output_capsule_dim, input_capsule_dim] \n",
"W = self.weight.expand(self.batch_size, *self.weight.size())\n",
"# u_hat's shape is [input_num, output_num, batch_size, output_capsule_dim]\n",
"u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()\n",
"self.g.set_e_repr({'b_ij': edge_features.view(-1)})\n",
"self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})"
"\n",
"b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device)\n",
"\n",
"self.g.set_e_repr({'b_ij': b_ij.view(-1)})\n",
"self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})\n",
"\n",
"# Initialize all node features as zero\n",
"node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,\n",
" self.output_capsule_dim).to(self.device)\n",
"self.g.set_n_repr({'h': node_features})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Initialize node features"
"### 3. Write Message Passing functions and Squash function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.1 Squash function\n",
"Squashing function is to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1.\n",
"![squash](./squash.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"node_features = torch.zeros(self.in_channel + self.num_unit, self.batch_size, self.unit_size).to(device)\n",
"self.g.set_n_repr({'h': node_features})"
"def squash(s):\n",
" msg_sq = torch.sum(s ** 2, dim=2, keepdim=True)\n",
" msg = torch.sqrt(msg_sq)\n",
" s = (msg_sq / (1.0 + msg_sq)) * (s / msg)\n",
" return s"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Write message passing functions"
"#### 3.2 Message Functions\n",
"At first stage, we need to define a message function to get all the attributes we need in the further computations."
]
},
{
......@@ -114,34 +146,81 @@
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def capsule_msg(src, edge):\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n",
"\n",
"@staticmethod\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.3 Reduce Functions\n",
"At this stage, we need to define a reduce function to aggregate all the information we get from message function into node features.\n",
"This step implements the line 4 and line 5 in routing algorithms, which softmax over $b_{ij}$ and calculate weighted sum of input features. Note that softmax operation is over dimension $j$ instead of $i$. \n",
"<img src=\"./capsule_f5.png\" style=\"height:300px\">"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def capsule_reduce(node, msg):\n",
" b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n",
" # line 4\n",
" c_i = F.softmax(b_ij_c, dim=0)\n",
" # line 5\n",
" s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n",
" return {'h': s_j}\n",
"\n",
"def capsule_update(self, msg):\n",
" return {'h': s_j}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.4 Node Update Functions\n",
"Squash the intermidiate representations into node features $v_j$\n",
"![step6](./step6.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def capsule_update(msg):\n",
" # line 6\n",
" v_j = self.squash(msg['h'])\n",
" return {'h': v_j}\n",
"\n",
" v_j = squash(msg['h'])\n",
" return {'h': v_j}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.5 Edge Update Functions\n",
"Update the routing parameters\n",
"![step7](./step7.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def update_edge(self, u, v, edge):\n",
" # line 7\n",
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}\n"
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Executing algorithm"
"### 4. Executing algorithm\n",
"Call `update_all` and `update_edge` functions to execute the algorithms"
]
},
{
......
tutorial/capsule/capsule_f4.png

14.7 KB | W: | H:

tutorial/capsule/capsule_f4.png

23 KB | W: | H:

tutorial/capsule/capsule_f4.png
tutorial/capsule/capsule_f4.png
tutorial/capsule/capsule_f4.png
tutorial/capsule/capsule_f4.png
  • 2-up
  • Swipe
  • Onion skin
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment