"tools/git@developer.sourcefind.cn:OpenDAS/dlib.git" did not exist on "4a8e882f35e512de62c715cebf777b8475e30c61"
Commit 8e2e1f98 authored by VoVAllen's avatar VoVAllen
Browse files

[WIP]update tutorial

parent d87fea82
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Capsule Network\n",
"================\n",
"**Author**: `Jinjing Zhou`\n",
"\n",
"This tutorial explains how to use DGL library and its language to implement the [capsule network](http://arxiv.org/abs/1710.09829) proposed by Geoffrey Hinton and his team. The algorithm aims to provide a better alternative to current neural network structures. By using DGL library, users can implement the algorithm in a more intuitive way."
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Overview\n",
"\n",
"### Introduction\n",
"Capsule Network is \n",
"\n",
"### What's a capsule?\n",
"> A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or an object part. \n",
"\n",
"Generally Speaking, the idea of capsule is to encode all the information about the features in a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n",
"![figure_1](./capsule_f1.png)\n",
"\n",
"### Dynamic Routing Algorithm\n",
"<img src=\"./capsule_f2.png\" style=\"height:300px;\"/>"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Model Implementations\n",
"\n",
"### 1. Consider capsule routing as a graph structure\n",
"\n",
"We can consider each capsule as a node in a graph, and connect the nodes between layers.\n",
"<img src=\"./capsule_f3.png\" style=\"height:200px;\"/>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def construct_graph(self):\n",
" g = dgl.DGLGraph()\n",
" g.add_nodes(self.in_channel + self.num_unit)\n",
" self.in_channel_nodes = list(range(self.in_channel))\n",
" self.capsule_nodes = list(range(self.in_channel, self.in_channel + self.num_unit))\n",
" u, v = [], []\n",
" for i in self.in_channel_nodes:\n",
" for j in self.capsule_nodes:\n",
" u.append(i)\n",
" v.append(j)\n",
" g.add_edges(u, v)\n",
" return g"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 2. Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n",
"<img src=\"./capsule_f4.png\" style=\"height:200px;\"/>"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"x = x.transpose(1, 2)\n",
"x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4)\n",
"W = self.weight.expand(self.batch_size, *self.weight.shape)\n",
"u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()\n",
"self.g.set_e_repr({'b_ij': edge_features.view(-1)})\n",
"self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 3. Initialize node features"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"node_features = torch.zeros(self.in_channel + self.num_unit, self.batch_size, self.unit_size).to(device)\n",
"self.g.set_n_repr({'h': node_features})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Write message passing functions"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"@staticmethod\n",
"def capsule_msg(src, edge):\n",
" return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n",
"\n",
"@staticmethod\n",
"def capsule_reduce(node, msg):\n",
" b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n",
" # line 4\n",
" c_i = F.softmax(b_ij_c, dim=0)\n",
" # line 5\n",
" s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n",
" return {'h': s_j}\n",
"\n",
"def capsule_update(self, msg):\n",
" # line 6\n",
" v_j = self.squash(msg['h'])\n",
" return {'h': v_j}\n",
"\n",
"def update_edge(self, u, v, edge):\n",
" # line 7\n",
" return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}\n"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### 4. Executing algorithm"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"for i in range(self.num_routing):\n",
" self.g.update_all(self.capsule_msg, self.capsule_reduce, self.capsule_update)\n",
" self.g.update_edge(edge_func=self.update_edge)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Capsule Network\n",
"================\n",
"**Author**: `Jinjing Zhou`\n",
"\n",
"This Tutorial is for blablabla"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Introduction\n",
"\n",
"Capsule Network is "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### What's a capsule?\n",
"> A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or an object part. \n",
"-- <cite>Geoffrey E. Hinton</cite>\n",
"\n",
"Generally Speaking, "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Dynamic Routing Algorithm\n",
"Dynamic routing is the algorithm calculates the capsules.\n",
"\n",
"$\\textbf{for}$ $r$ $\\text{iterations}$ $\\textbf{do}$ \n",
"for all capsule $i$ in layer $l$$\\hat{u}$"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.6",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment