{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "Capsule Network\n", "================\n", "**Author**: `Jinjing Zhou`\n", "\n", "This tutorial explains how to use DGL library and its language to implement the [capsule network](http://arxiv.org/abs/1710.09829) proposed by Geoffrey Hinton and his team. The algorithm aims to provide a better alternative to current neural network structures. By using DGL library, users can implement the algorithm in a more intuitive way." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Overview\n", "\n", "### Introduction\n", "Capsule Network is \n", "\n", "### What's a capsule?\n", "> A capsule is a group of neurons whose activity vector represents the instantiation parameters of a specific type of entity such as an object or an object part. \n", "\n", "Generally Speaking, the idea of capsule is to encode all the information about the features in a vector form, by substituting scalars in traditional neural network with vectors. And use the norm of the vector to represents the meaning of original scalars. \n", "![figure_1](./capsule_f1.png)\n", "\n", "### Dynamic Routing Algorithm\n", "" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Model Implementations\n", "\n", "### 1. Consider capsule routing as a graph structure\n", "\n", "We can consider each capsule as a node in a graph, and connect the nodes between layers.\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def construct_graph(self):\n", " g = dgl.DGLGraph()\n", " g.add_nodes(self.in_channel + self.num_unit)\n", " self.in_channel_nodes = list(range(self.in_channel))\n", " self.capsule_nodes = list(range(self.in_channel, self.in_channel + self.num_unit))\n", " u, v = [], []\n", " for i in self.in_channel_nodes:\n", " for j in self.capsule_nodes:\n", " u.append(i)\n", " v.append(j)\n", " g.add_edges(u, v)\n", " return g" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 2. Pre-compute $\\hat{u}_{j|i}$, initialize $b_{ij}$ and store them as edge attribute\n", "" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "x = x.transpose(1, 2)\n", "x = torch.stack([x] * self.num_unit, dim=2).unsqueeze(4)\n", "W = self.weight.expand(self.batch_size, *self.weight.shape)\n", "u_hat = torch.matmul(W, x).permute(1, 2, 0, 3, 4).squeeze().contiguous()\n", "self.g.set_e_repr({'b_ij': edge_features.view(-1)})\n", "self.g.set_e_repr({'u_hat': u_hat.view(-1, self.batch_size, self.unit_size)})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 3. Initialize node features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "node_features = torch.zeros(self.in_channel + self.num_unit, self.batch_size, self.unit_size).to(device)\n", "self.g.set_n_repr({'h': node_features})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Write message passing functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "@staticmethod\n", "def capsule_msg(src, edge):\n", " return {'b_ij': edge['b_ij'], 'h': src['h'], 'u_hat': edge['u_hat']}\n", "\n", "@staticmethod\n", "def capsule_reduce(node, msg):\n", " b_ij_c, u_hat = msg['b_ij'], msg['u_hat']\n", " # line 4\n", " c_i = F.softmax(b_ij_c, dim=0)\n", " # line 5\n", " s_j = (c_i.unsqueeze(2).unsqueeze(3) * u_hat).sum(dim=1)\n", " return {'h': s_j}\n", "\n", "def capsule_update(self, msg):\n", " # line 6\n", " v_j = self.squash(msg['h'])\n", " return {'h': v_j}\n", "\n", "def update_edge(self, u, v, edge):\n", " # line 7\n", " return {'b_ij': edge['b_ij'] + (v['h'] * edge['u_hat']).mean(dim=1).sum(dim=1)}\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### 4. Executing algorithm" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "for i in range(self.num_routing):\n", " self.g.update_all(self.capsule_msg, self.capsule_reduce, self.capsule_update)\n", " self.g.update_edge(edge_func=self.update_edge)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3.6", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.1" } }, "nbformat": 4, "nbformat_minor": 2 }