Commit e17c41c0 authored by VoVAllen's avatar VoVAllen Committed by Minjie Wang
Browse files

[Model][Tutorial] Fix capsule memory leak (#185)

* fix memory leak & Remove unnecessary initializer

* change confused name

* fix name

* Move func outside loop

* fix name inconsistency
parent 79fe09d3
...@@ -22,7 +22,6 @@ class DGLDigitCapsuleLayer(nn.Module): ...@@ -22,7 +22,6 @@ class DGLDigitCapsuleLayer(nn.Module):
device=self.device) device=self.device)
routing(u_hat, routing_num=3) routing(u_hat, routing_num=3)
out_nodes_feature = routing.g.nodes[routing.out_indx].data['v'] out_nodes_feature = routing.g.nodes[routing.out_indx].data['v']
routing.end()
# shape transformation is for further classification # shape transformation is for further classification
return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1) return out_nodes_feature.transpose(0, 1).unsqueeze(1).unsqueeze(4).squeeze(1)
......
...@@ -8,7 +8,7 @@ class DGLRoutingLayer(nn.Module): ...@@ -8,7 +8,7 @@ class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device='cpu'): def __init__(self, in_nodes, out_nodes, f_size, batch_size=0, device='cpu'):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
self.batch_size = batch_size self.batch_size = batch_size
self.g = init_graph(in_nodes, out_nodes, f_size, device=device, batch_size=batch_size) self.g = init_graph(in_nodes, out_nodes, f_size, device=device)
self.in_nodes = in_nodes self.in_nodes = in_nodes
self.out_nodes = out_nodes self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes)) self.in_indx = list(range(in_nodes))
...@@ -17,22 +17,26 @@ class DGLRoutingLayer(nn.Module): ...@@ -17,22 +17,26 @@ class DGLRoutingLayer(nn.Module):
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat self.g.edata['u_hat'] = u_hat
for r in range(routing_num): batch_size = self.batch_size
# step 1 (line 4): normalize over out edges
in_edges = self.g.edata['b'].view(self.in_nodes, self.out_nodes) # step 2 (line 5)
self.g.edata['c'] = F.softmax(in_edges, dim=1).view(-1, 1) def cap_message(edges):
if batch_size:
return {'m': edges.data['c'].unsqueeze(1) * edges.data['u_hat']}
else:
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message)
def cap_message(edges): def cap_reduce(nodes):
if self.batch_size: return {'s': th.sum(nodes.mailbox['m'], dim=1)}
return {'m': edges.data['c'].unsqueeze(1) * edges.data['u_hat']}
else:
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message)
# step 2 (line 5) self.g.register_reduce_func(cap_reduce)
def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)} for r in range(routing_num):
self.g.register_reduce_func(cap_reduce) # step 1 (line 4): normalize over out edges
edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)
# Execute step 1 & 2 # Execute step 1 & 2
self.g.update_all() self.g.update_all()
...@@ -50,13 +54,6 @@ class DGLRoutingLayer(nn.Module): ...@@ -50,13 +54,6 @@ class DGLRoutingLayer(nn.Module):
else: else:
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True) self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)
def end(self):
del self.g
# del self.g.edata['u_hat']
# del self.g.ndata['v']
# del self.g.ndata['s']
# del self.g.edata['b']
def squash(s, dim=1): def squash(s, dim=1):
sq = th.sum(s ** 2, dim=dim, keepdim=True) sq = th.sum(s ** 2, dim=dim, keepdim=True)
...@@ -65,8 +62,9 @@ def squash(s, dim=1): ...@@ -65,8 +62,9 @@ def squash(s, dim=1):
return s return s
def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0): def init_graph(in_nodes, out_nodes, f_size, device='cpu'):
g = dgl.DGLGraph() g = dgl.DGLGraph()
g.set_n_initializer(dgl.frame.zero_initializer)
all_nodes = in_nodes + out_nodes all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes) g.add_nodes(all_nodes)
in_indx = list(range(in_nodes)) in_indx = list(range(in_nodes))
...@@ -75,10 +73,5 @@ def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0): ...@@ -75,10 +73,5 @@ def init_graph(in_nodes, out_nodes, f_size, device='cpu', batch_size=0):
for u in in_indx: for u in in_indx:
g.add_edges(u, out_indx) g.add_edges(u, out_indx)
# init states
if batch_size:
g.ndata['v'] = th.zeros(all_nodes, batch_size, f_size).to(device)
else:
g.ndata['v'] = th.zeros(all_nodes, f_size).to(device)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1).to(device) g.edata['b'] = th.zeros(in_nodes * out_nodes, 1).to(device)
return g return g
""" """
.. _model-capsule: .. _model-capsule:
Capsule Network Tutorial Capsule Network Tutorial
=========================== ===========================
**Author**: Jinjing Zhou, `Jake **Author**: Jinjing Zhou, `Jake
Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang Zhao <https://cs.nyu.edu/~jakezhao/>`_, Zheng Zhang
It is perhaps a little surprising that some of the more classical models can It is perhaps a little surprising that some of the more classical models can
also be described in terms of graphs, offering a different perspective. also be described in terms of graphs, offering a different perspective.
This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__. This tutorial describes how this is done for the `capsule network <http://arxiv.org/abs/1710.09829>`__.
""" """
####################################################################################### #######################################################################################
# Key ideas of Capsule # Key ideas of Capsule
# -------------------- # --------------------
# #
# There are two key ideas that the Capsule model offers. # There are two key ideas that the Capsule model offers.
# #
# **Richer representations** In classic convolutional network, a scalar # **Richer representations** In classic convolutional network, a scalar
# value represents the activation of a given feature. Instead, a capsule # value represents the activation of a given feature. Instead, a capsule
# outputs a vector, whose norm represents the probability of a feature, # outputs a vector, whose norm represents the probability of a feature,
# and the orientation its properties. # and the orientation its properties.
# #
# .. figure:: https://i.imgur.com/55Ovkdh.png # .. figure:: https://i.imgur.com/55Ovkdh.png
# :alt: # :alt:
# #
# **Dynamic routing** To generalize max-pooling, there is another # **Dynamic routing** To generalize max-pooling, there is another
# interesting proposed by the authors, as a representational more powerful # interesting proposed by the authors, as a representational more powerful
# way to construct higher level feature from its low levels. Consider a # way to construct higher level feature from its low levels. Consider a
# capsule :math:`u_i`. The way :math:`u_i` is integrated to the next level # capsule :math:`u_i`. The way :math:`u_i` is integrated to the next level
# capsules take two steps: # capsules take two steps:
# #
# 1. :math:`u_i` projects differently to different higher level capsules # 1. :math:`u_i` projects differently to different higher level capsules
# via a linear transformation: :math:`\hat{u}_{j|i} = W_{ij}u_i`. # via a linear transformation: :math:`\hat{u}_{j|i} = W_{ij}u_i`.
# 2. :math:`\hat{u}_{j|i}` routes to the higher level capsules by # 2. :math:`\hat{u}_{j|i}` routes to the higher level capsules by
# spreading itself with a weighted sum, and the weight is dynamically # spreading itself with a weighted sum, and the weight is dynamically
# determined by iteratively modify the and checking against the # determined by iteratively modify the and checking against the
# "consistency" between :math:`\hat{u}_{j|i}` and :math:`v_j`, for any # "consistency" between :math:`\hat{u}_{j|i}` and :math:`v_j`, for any
# :math:`v_j`. Note that this is similar to a k-means algorithm or # :math:`v_j`. Note that this is similar to a k-means algorithm or
# `competive # `competive
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__ in # learning <https://en.wikipedia.org/wiki/Competitive_learning>`__ in
# spirit. At the end of iterations, :math:`v_j` now integrates the # spirit. At the end of iterations, :math:`v_j` now integrates the
# lower level capsules. # lower level capsules.
# #
# The full algorithm is the following: |image0| # The full algorithm is the following: |image0|
# #
# The dynamic routing step can be naturally expressed as a graph # The dynamic routing step can be naturally expressed as a graph
# algorithm. This is the focus of this tutorial. Our implementation is # algorithm. This is the focus of this tutorial. Our implementation is
# adapted from `Cedric # adapted from `Cedric
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing # Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
# only the routing layer, and achieving similar speed and accuracy. # only the routing layer, and achieving similar speed and accuracy.
# #
# Model Implementation # Model Implementation
# ----------------------------------- # -----------------------------------
# Step 1: Setup and Graph Initialiation # Step 1: Setup and Graph Initialiation
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# The below figure shows the directed bipartitie graph built for capsules # The below figure shows the directed bipartitie graph built for capsules
# network. We denote :math:`b_{ij}`, :math:`\hat{u}_{j|i}` as edge # network. We denote :math:`b_{ij}`, :math:`\hat{u}_{j|i}` as edge
# features and :math:`v_j` as node features. |image1| # features and :math:`v_j` as node features. |image1|
# #
import torch.nn as nn import torch.nn as nn
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import numpy as np import numpy as np
import matplotlib.pyplot as plt import matplotlib.pyplot as plt
import dgl import dgl
def init_graph(in_nodes, out_nodes, f_size): def init_graph(in_nodes, out_nodes, f_size):
g = dgl.DGLGraph() g = dgl.DGLGraph()
all_nodes = in_nodes + out_nodes all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes) g.add_nodes(all_nodes)
in_indx = list(range(in_nodes)) in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes)) out_indx = list(range(in_nodes, in_nodes + out_nodes))
# add edges use edge broadcasting # add edges use edge broadcasting
for u in in_indx: for u in in_indx:
g.add_edges(u, out_indx) g.add_edges(u, out_indx)
# init states # init states
g.ndata['v'] = th.zeros(all_nodes, f_size) g.ndata['v'] = th.zeros(all_nodes, f_size)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1) g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
return g return g
######################################################################################### #########################################################################################
# Step 2: Define message passing functions # Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# Recall the following steps, and they are implemented in the class # Recall the following steps, and they are implemented in the class
# ``DGLRoutingLayer`` as the followings: # ``DGLRoutingLayer`` as the followings:
# #
# 1. Normalize over out edges # 1. Normalize over out edges
# #
# - Softmax over all out-edge of in-capsules # - Softmax over all out-edge of in-capsules
# :math:`\textbf{c}_i = \text{softmax}(\textbf{b}_i)`. # :math:`\textbf{c}_i = \text{softmax}(\textbf{b}_i)`.
# #
# 2. Weighted sum over all in-capsules # 2. Weighted sum over all in-capsules
# #
# - Out-capsules equals weighted sum of in-capsules # - Out-capsules equals weighted sum of in-capsules
# :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}` # :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
# #
# 3. Squash Operation # 3. Squash Operation
# #
# - Squashing function is to ensure that short capsule vectors get # - Squashing function is to ensure that short capsule vectors get
# shrunk to almost zero length while the long capsule vectors get # shrunk to almost zero length while the long capsule vectors get
# shrunk to a length slightly below 1. Its norm is expected to # shrunk to a length slightly below 1. Its norm is expected to
# represents probabilities at some levels. # represents probabilities at some levels.
# - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}` # - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
# #
# 4. Update weights by agreement # 4. Update weights by agreement
# #
# - :math:`\hat{u}_{j|i}\cdot v_j` can be considered as agreement # - :math:`\hat{u}_{j|i}\cdot v_j` can be considered as agreement
# between current capsule and updated capsule, # between current capsule and updated capsule,
# :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j` # :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
class DGLRoutingLayer(nn.Module): class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size): def __init__(self, in_nodes, out_nodes, f_size):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
self.g = init_graph(in_nodes, out_nodes, f_size) self.g = init_graph(in_nodes, out_nodes, f_size)
self.in_nodes = in_nodes self.in_nodes = in_nodes
self.out_nodes = out_nodes self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes)) self.in_indx = list(range(in_nodes))
self.out_indx = list(range(in_nodes, in_nodes + out_nodes)) self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat self.g.edata['u_hat'] = u_hat
for r in range(routing_num):
# step 1 (line 4): normalize over out edges # step 2 (line 5)
in_edges = self.g.edata['b'].view(self.in_nodes, self.out_nodes) def cap_message(edges):
self.g.edata['c'] = F.softmax(in_edges, dim=1).view(-1, 1) return {'m': edges.data['c'] * edges.data['u_hat']}
def cap_message(edges): self.g.register_message_func(cap_message)
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message) def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)}
# step 2 (line 5)
def cap_reduce(nodes): self.g.register_reduce_func(cap_reduce)
return {'s': th.sum(nodes.mailbox['m'], dim=1)}
self.g.register_reduce_func(cap_reduce) for r in range(routing_num):
# step 1 (line 4): normalize over out edges
# Execute step 1 & 2 edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
self.g.update_all() self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)
# step 3 (line 6) # Execute step 1 & 2
self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1) self.g.update_all()
# step 4 (line 7) # step 3 (line 6)
v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0) self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)
self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)
# step 4 (line 7)
@staticmethod v = th.cat([self.g.nodes[self.out_indx].data['v']] * self.in_nodes, dim=0)
def squash(s, dim=1): self.g.edata['b'] = self.g.edata['b'] + (self.g.edata['u_hat'] * v).sum(dim=1, keepdim=True)
sq = th.sum(s ** 2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq) @staticmethod
s = (sq / (1.0 + sq)) * (s / s_norm) def squash(s, dim=1):
return s sq = th.sum(s ** 2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_norm)
############################################################################################################ return s
# Step 3: Testing
# ~~~~~~~~~~~~~~~
# ############################################################################################################
# Let's make a simple 20x10 capsule layer: # Step 3: Testing
in_nodes = 20 # ~~~~~~~~~~~~~~~
out_nodes = 10 #
f_size = 4 # Let's make a simple 20x10 capsule layer:
u_hat = th.randn(in_nodes * out_nodes, f_size) in_nodes = 20
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size) out_nodes = 10
f_size = 4
############################################################################################################ u_hat = th.randn(in_nodes * out_nodes, f_size)
# We can visualize the behavior by monitoring the entropy of outgoing routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)
# weights, they should start high and then drop, as the assignment
# gradually concentrate: ############################################################################################################
entropy_list = [] # We can visualize the behavior by monitoring the entropy of outgoing
dist_list = [] # weights, they should start high and then drop, as the assignment
# gradually concentrate:
for i in range(10): entropy_list = []
routing(u_hat) dist_list = []
dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1) for i in range(10):
entropy_list.append(entropy.data.numpy()) routing(u_hat)
dist_list.append(dist_matrix.data.numpy()) dist_matrix = routing.g.edata['c'].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
stds = np.std(entropy_list, axis=1) entropy_list.append(entropy.data.numpy())
means = np.mean(entropy_list, axis=1) dist_list.append(dist_matrix.data.numpy())
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.ylabel("Entropy of Weight Distribution") stds = np.std(entropy_list, axis=1)
plt.xlabel("Number of Routing") means = np.mean(entropy_list, axis=1)
plt.xticks(np.arange(len(entropy_list))) plt.errorbar(np.arange(len(entropy_list)), means, stds, marker='o')
plt.close() plt.ylabel("Entropy of Weight Distribution")
############################################################################################################ plt.xlabel("Number of Routing")
# plt.xticks(np.arange(len(entropy_list)))
# .. figure:: https://i.imgur.com/dMvu7p3.png plt.close()
# :alt: ############################################################################################################
#
import seaborn as sns # .. figure:: https://i.imgur.com/dMvu7p3.png
import matplotlib.animation as animation # :alt:
fig = plt.figure(dpi=150) import seaborn as sns
fig.clf() import matplotlib.animation as animation
ax = fig.subplots()
fig = plt.figure(dpi=150)
fig.clf()
def dist_animate(i): ax = fig.subplots()
ax.cla()
sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
ax.set_xlabel("Weight Distribution Histogram") def dist_animate(i):
ax.set_title("Routing: %d" % (i)) ax.cla()
sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
ax.set_xlabel("Weight Distribution Histogram")
ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500) ax.set_title("Routing: %d" % (i))
plt.close()
############################################################################################################ ani = animation.FuncAnimation(fig, dist_animate, frames=len(entropy_list), interval=500)
# Alternatively, we can also watch the evolution of histograms: |image2| plt.close()
# Or monitor the how lower level capcules gradually attach to one of the higher level ones:
import networkx as nx ############################################################################################################
from networkx.algorithms import bipartite # Alternatively, we can also watch the evolution of histograms: |image2|
# Or monitor the how lower level capcules gradually attach to one of the higher level ones:
g = routing.g.to_networkx() import networkx as nx
X, Y = bipartite.sets(g) from networkx.algorithms import bipartite
height_in = 10
height_out = height_in * 0.8 g = routing.g.to_networkx()
height_in_y = np.linspace(0, height_in, in_nodes) X, Y = bipartite.sets(g)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes) height_in = 10
pos = dict() height_out = height_in * 0.8
height_in_y = np.linspace(0, height_in, in_nodes)
fig2 = plt.figure(figsize=(8, 3), dpi=150) height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
fig2.clf() pos = dict()
ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1 fig2 = plt.figure(figsize=(8, 3), dpi=150)
pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2 fig2.clf()
ax = fig2.subplots()
pos.update((n, (i, 1)) for i, n in zip(height_in_y, X)) # put nodes from X at x=1
def weight_animate(i): pos.update((n, (i, 2)) for i, n in zip(height_out_y, Y)) # put nodes from Y at x=2
ax.cla()
ax.axis('off')
ax.set_title("Routing: %d " % i) def weight_animate(i):
dm = dist_list[i] ax.cla()
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax) ax.axis('off')
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax) ax.set_title("Routing: %d " % i)
for edge in g.edges(): dm = dist_list[i]
nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax) nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes), node_color='r', node_size=100, ax=ax)
nx.draw_networkx_nodes(g, pos, nodelist=range(in_nodes, in_nodes + out_nodes), node_color='b', node_size=100, ax=ax)
for edge in g.edges():
ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500) nx.draw_networkx_edges(g, pos, edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax)
plt.close()
############################################################################################################ ani2 = animation.FuncAnimation(fig2, weight_animate, frames=len(dist_list), interval=500)
# |image3| plt.close()
#
# The full code of this visulization is provided at ############################################################################################################
# `link <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__; the complete # |image3|
# code that trains on MNIST is at `link <https://github.com/jermainewang/dgl/tree/tutorial/examples/pytorch/capsule>`__. #
# # The full code of this visulization is provided at
# .. |image0| image:: https://i.imgur.com/mv1W9Rv.png # `link <https://github.com/jermainewang/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__; the complete
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png # code that trains on MNIST is at `link <https://github.com/jermainewang/dgl/tree/tutorial/examples/pytorch/capsule>`__.
# .. |image2| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif #
# .. |image3| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif # .. |image0| image:: https://i.imgur.com/mv1W9Rv.png
# # .. |image1| image:: https://i.imgur.com/9tc6GLl.png
# .. |image2| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image3| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif
#
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment