Unverified Commit a50bbe58 authored by VoVAllen's avatar VoVAllen Committed by GitHub
Browse files

Merge branch 'master' into master

parents 20469802 4d7b3ba8
This diff is collapsed.
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Capsule Network\n",
"================\n",
"**Author**: `Jinjing Zhou`\n",
"\n",
"This tutorial explains how to use DGL library and its language to implement the [capsule network](http://arxiv.org/abs/1710.09829) proposed by Geoffrey Hinton and his team. The algorithm aims to provide a better alternative to current neural network structures. By using DGL library, users can implement the algorithm in a more intuitive way."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Overview\n",
"\n",
"### Introduction\n",
"Capsule Network were first introduced in 2011 by Geoffrey Hinton, et al., in a paper called [Transforming Autoencoders](https://www.cs.toronto.edu/~fritz/absps/transauto6.pdf), but it was only a few months ago, in November 2017, that Sara Sabour, Nicholas Frosst, and Geoffrey Hinton published a paper called Dynamic Routing between Capsules, where they introduced a CapsNet architecture that reached state-of-the-art performance on MNIST.\n",
"\n",
"### What's a capsule?\n",
"> A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or an object part. \n",
"\n",
"Generally Speaking, the idea of capsule is to encode all the information about the features into a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n",
"![figure_1](./capsule_f1.png)\n",
"\n",
"### Dynamic Routing Algorithm\n",
"Due to the different structure of network, capsules network has different operations to calculate results. This figure shows the comparison, drawn by [Max Pechyonkin](https://medium.com/ai%C2%B3-theory-practice-business/understanding-hintons-capsule-networks-part-ii-how-capsules-work-153b6ade9f66O). \n",
"<img src=\"./capsule_f2.png\" style=\"height:250px;\"/><br/>\n",
"\n",
"The key idea is that the output of each capsule is the sum of weighted input vectors. We will go into details in the later section with code implementations.\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Implementations\n",
"\n",
"### 1. Consider capsule routing as a graph structure\n",
"\n",
"We can consider each capsule as a node in a graph, and connect all the nodes between layers.\n",
"<img src=\"./capsule_f3.png\" style=\"height:200px;\"/>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def construct_graph(self):\n",
" g = dgl.DGLGraph()\n",
" g.add_nodes(self.input_capsule_num + self.output_capsule_num)\n",
" input_nodes = list(range(self.input_capsule_num))\n",
" output_nodes = list(range(self.input_capsule_num, self.input_capsule_num + self.output_capsule_num))\n",
" u, v = [], []\n",
" for i in input_nodes:\n",
" for j in output_nodes:\n",
" u.append(i)\n",
" v.append(j)\n",
" g.add_edges(u, v)\n",
" return g, input_nodes, output_nodes"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Initialization & Affine Transformation\n",
"- Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n",
"- Initialize node features as zero\n",
"<img src=\"./capsule_f4.png\" style=\"height:200px;\"/>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# x is the input vextor with shape [batch_size, input_capsule_dim, input_num]\n",
"# Transpose x to [batch_size, input_num, input_capsule_dim] \n",
"x = x.transpose(1, 2)\n",
"# Expand x to [batch_size, input_num, output_num, input_capsule_dim, 1]\n",
"x = torch.stack([x] * self.output_capsule_num, dim=2).unsqueeze(4)\n",
"# Expand W from [input_num, output_num, input_capsule_dim, output_capsule_dim] \n",
"# to [batch_size, input_num, output_num, output_capsule_dim, input_capsule_dim] \n",
"W = self.weight.expand(self.batch_size, *self.weight.size())\n",
"# u_hat's shape is [input_num, output_num, batch_size, output_capsule_dim]\n",
"u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()\n",
"\n",
"b_ij = torch.zeros(self.input_capsule_num, self.output_capsule_num).to(self.device)\n",
"\n",
"self.g.set_e_repr({'b_ij': b_ij.view(-1)})\n",
"self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})\n",
"\n",
"# Initialize all node features as zero\n",
"node_features = torch.zeros(self.input_capsule_num + self.output_capsule_num, self.batch_size,\n",
" self.output_capsule_dim).to(self.device)\n",
"self.g.set_n_repr({'h': node_features})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Write Message Passing functions and Squash function"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.1 Squash function\n",
"Squashing function is to ensure that short vectors get shrunk to almost zero length and long vectors get shrunk to a length slightly below 1.\n",
"![squash](./squash.png)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
"def squash(s):\n",
" msg_sq = torch.sum(s ** 2, dim=2, keepdim=True)\n",
" msg = torch.sqrt(msg_sq)\n",
" s = (msg_sq / (1.0 + msg_sq)) * (s / msg)\n",
" return s"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.2 Message Functions\n",
"At first stage, we need to define a message function to get all the attributes we need in the further computations."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def capsule_msg(src, edge):\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.3 Reduce Functions\n",
"At this stage, we need to define a reduce function to aggregate all the information we get from message function into node features.\n",
"This step implements the line 4 and line 5 in routing algorithms, which softmax over $b_{ij}$ and calculate weighted sum of input features. Note that softmax operation is over dimension $j$ instead of $i$. \n",
"<img src=\"./capsule_f5.png\" style=\"height:300px\">"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
"def capsule_reduce(node, msg):\n",
" b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n",
" # line 4\n",
" c_i = F.softmax(b_ij_c, dim=0)\n",
" # line 5\n",
" s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n",
" return {'h': s_j}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.4 Node Update Functions\n",
"Squash the intermidiate representations into node features $v_j$\n",
"![step6](./step6.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def capsule_update(msg):\n",
" # line 6\n",
" v_j = squash(msg['h'])\n",
" return {'h': v_j}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"#### 3.5 Edge Update Functions\n",
"Update the routing parameters\n",
"![step7](./step7.png)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def update_edge(self, u, v, edge):\n",
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Executing algorithm\n",
"Call `update_all` and `update_edge` functions to execute the algorithms"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(self.num_routing):\n",
" self.g.update_all(self.capsule_msg, self.capsule_reduce, self.capsule_update)\n",
" self.g.update_edge(edge_func=self.update_edge)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
This diff is collapsed.
This diff is collapsed.
# Changelog
All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).
## [0.4.0] - 2018-01-30
### Added
- Supports and works with CIFAR10 dataset.
### Changed
- Upgrade to PyTorch 0.3.0.
- Supports CUDA 9.
- Drop our custom softmax function and switch to PyTorch softmax function.
- Modify the save_image utils function to handle 3-channel (RGB) image.
### Fixed
- Compatibilities with PyTorch 0.3.0.
## [0.3.0] - 2017-11-27
### Added
- Decoder network PyTorch module.
- Reconstruct image with Decoder network during testing.
- Save the original and recontructed images into file system.
- Log the original and reconstructed images using TensorBoard.
### Changed
- Refactor reconstruction loss function and decoder network.
- Remove image reconstruction from training.
## [0.2.0] - 2017-11-26
### Added
- New dependencies for TensorBoard and tqdm.
- Logging losses and accuracies with TensorBoard.
- New utils functions for:
- computing accuracy
- convert values of the model parameters to numpy.array.
- parsing boolean values with argparse
- Softmax function that takes a dimension.
- More detailed code comments.
- Show margin loss and reconstruction loss in logs.
- Show accuracy in train logs.
### Changed
- Refactor loss functions.
- Clean codes.
### Fixed
- Runtime error during pip install requirements.txt
- Bug in routing algorithm.
## [0.1.0] - 2017-11-12
### Added
- Implemented reconstruction loss.
- Saving reconstructed image as file.
- Improve training speed by using PyTorch DataParallel to wrap our model.
- PyTorch will parallelized the model and data over multiple GPUs.
- Supports training:
- on CPU (tested with macOS Sierra)
- on one GPU (tested with 1 Tesla K80 GPU)
- on multiple GPU (tested with 8 GPUs)
- with or without CUDA (tested with CUDA version 8.0.61)
- cuDNN 5 (tested with cuDNN 5.1.3)
### Changed
- More intuitive variable naming.
### Fixed
- Resolve Pylint warnings and reformat code.
- Missing square in equation 4 for margin (class) loss.
## 0.0.1 - 2017-11-04
### Added
- Initial release. The first beta version. API is stable. The code runs. So, I think it's safe to use for development but not ready for general production usage.
[Unreleased]: https://github.com/cedrickchee/capsule-net-pytorch/compare/v1.0.0...HEAD
[0.1.0]: https://github.com/cedrickchee/capsule-net-pytorch/compare/v0.0.1...v0.1.0
[0.2.0]: https://github.com/cedrickchee/capsule-net-pytorch/compare/v0.1.0...v0.2.0
[0.3.0]: https://github.com/cedrickchee/capsule-net-pytorch/compare/v0.2.0...v0.3.0
[0.4.0]: https://github.com/cedrickchee/capsule-net-pytorch/compare/v0.3.0...v0.4.0
COPYRIGHT
All contributions by Cedric Chee:
Copyright (c) 2017, Cedric Chee.
All rights reserved.
All other contributions:
Copyright (c) 2017, the respective contributors.
All rights reserved.
Each contributor holds copyright over their respective contributions.
The project versioning (Git) records all such contribution source information.
LICENSE
The MIT License (MIT)
Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
in the Software without restriction, including without limitation the rights
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
copies of the Software, and to permit persons to whom the Software is
furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all
copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
SOFTWARE.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment