"examples/pytorch/graphsage/node_classification.py" did not exist on "3bd5a9b6d11a74df6035ecdbdf5f71088eb2e901"
Unverified Commit 2cdc4d3c authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Doc] Patch tutorial (#1380)

* patched 1_first

* done 2_basics

* done 4_batch

* done 1_gcn, 9_gat, 2_capsule

* 4_rgcn.py

* revert

* more fix
parent 0f40c6e4
......@@ -45,32 +45,26 @@ At the end of this tutorial, we hope you get a brief feeling of how DGL works.
# Create the graph for Zachary's karate club as follows:
import dgl
import numpy as np
def build_karate_club_graph():
g = dgl.DGLGraph()
# add 34 nodes into the graph; nodes are labeled from 0~33
g.add_nodes(34)
# all 78 edges as a list of tuples
edge_list = [(1, 0), (2, 0), (2, 1), (3, 0), (3, 1), (3, 2),
(4, 0), (5, 0), (6, 0), (6, 4), (6, 5), (7, 0), (7, 1),
(7, 2), (7, 3), (8, 0), (8, 2), (9, 2), (10, 0), (10, 4),
(10, 5), (11, 0), (12, 0), (12, 3), (13, 0), (13, 1), (13, 2),
(13, 3), (16, 5), (16, 6), (17, 0), (17, 1), (19, 0), (19, 1),
(21, 0), (21, 1), (25, 23), (25, 24), (27, 2), (27, 23),
(27, 24), (28, 2), (29, 23), (29, 26), (30, 1), (30, 8),
(31, 0), (31, 24), (31, 25), (31, 28), (32, 2), (32, 8),
(32, 14), (32, 15), (32, 18), (32, 20), (32, 22), (32, 23),
(32, 29), (32, 30), (32, 31), (33, 8), (33, 9), (33, 13),
(33, 14), (33, 15), (33, 18), (33, 19), (33, 20), (33, 22),
(33, 23), (33, 26), (33, 27), (33, 28), (33, 29), (33, 30),
(33, 31), (33, 32)]
# add edges two lists of nodes: src and dst
src, dst = tuple(zip(*edge_list))
g.add_edges(src, dst)
# edges are directional in DGL; make them bi-directional
g.add_edges(dst, src)
return g
# All 78 edges are stored in two numpy arrays. One for source endpoints
# while the other for destination endpoints.
src = np.array([1, 2, 2, 3, 3, 3, 4, 5, 6, 6, 6, 7, 7, 7, 7, 8, 8, 9, 10, 10,
10, 11, 12, 12, 13, 13, 13, 13, 16, 16, 17, 17, 19, 19, 21, 21,
25, 25, 27, 27, 27, 28, 29, 29, 30, 30, 31, 31, 31, 31, 32, 32,
32, 32, 32, 32, 32, 32, 32, 32, 32, 33, 33, 33, 33, 33, 33, 33,
33, 33, 33, 33, 33, 33, 33, 33, 33, 33])
dst = np.array([0, 0, 1, 0, 1, 2, 0, 0, 0, 4, 5, 0, 1, 2, 3, 0, 2, 2, 0, 4,
5, 0, 0, 3, 0, 1, 2, 3, 5, 6, 0, 1, 0, 1, 0, 1, 23, 24, 2, 23,
24, 2, 23, 26, 1, 8, 0, 24, 25, 28, 2, 8, 14, 15, 18, 20, 22, 23,
29, 30, 31, 8, 9, 13, 14, 15, 18, 19, 20, 22, 23, 26, 27, 28, 29, 30,
31, 32])
# Edges are directional in DGL; Make them bi-directional.
u = np.concatenate([src, dst])
v = np.concatenate([dst, src])
# Construct a DGLGraph
return dgl.DGLGraph((u, v))
###############################################################################
# Print out the number of nodes and edges in our newly constructed graph:
......@@ -95,27 +89,28 @@ nx.draw(nx_G, pos, with_labels=True, node_color=[[.7, .7, .7]])
# Step 2: Assign features to nodes or edges
# --------------------------------------------
# Graph neural networks associate features with nodes and edges for training.
# For our classification example, we assign each node an input feature as a one-hot vector:
# node :math:`v_i`'s feature vector is :math:`[0,\ldots,1,\dots,0]`,
# where the :math:`i^{th}` position is one.
#
# For our classification example, since there is no input feature, we assign each node
# with a learnable embedding vector.
# In DGL, you can add features for all nodes at once, using a feature tensor that
# batches node features along the first dimension. The code below adds the one-hot
# feature for all nodes:
# batches node features along the first dimension. The code below adds the learnable
# embeddings for all nodes:
import torch
import torch.nn as nn
import torch.nn.functional as F
G.ndata['feat'] = torch.eye(34)
embed = nn.Embedding(34, 5) # 34 nodes with embedding dim equal to 5
G.ndata['feat'] = embed.weight
###############################################################################
# Print out the node features to verify:
# print out node 2's input feature
print(G.nodes[2].data['feat'])
print(G.ndata['feat'][2])
# print out node 10 and 11's input features
print(G.nodes[[10, 11]].data['feat'])
print(G.ndata['feat'][[10, 11]])
###############################################################################
# Step 3: Define a Graph Convolutional Network (GCN)
......@@ -139,74 +134,41 @@ print(G.nodes[[10, 11]].data['feat'])
# :alt: mailbox
# :align: center
#
# Now, we show that the GCN layer can be easily implemented in DGL.
import torch.nn as nn
import torch.nn.functional as F
# Define the message and reduce function
# NOTE: We ignore the GCN's normalization constant c_ij for this tutorial.
def gcn_message(edges):
# The argument is a batch of edges.
# This computes a (batch of) message called 'msg' using the source node's feature 'h'.
return {'msg' : edges.src['h']}
# In DGL, we provide implementations of popular Graph Neural Network layers under
# the `dgl.<backend>.nn` subpackage. The :class:`~dgl.nn.pytorch.GraphConv` module
# implements one Graph Convolutional layer.
def gcn_reduce(nodes):
# The argument is a batch of nodes.
# This computes the new 'h' features by summing received 'msg' in each node's mailbox.
return {'h' : torch.sum(nodes.mailbox['msg'], dim=1)}
# Define the GCNLayer module
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
def forward(self, g, inputs):
# g is the graph and the inputs is the input node features
# first set the node features
g.ndata['h'] = inputs
# trigger message passing on all edges
g.send(g.edges(), gcn_message)
# trigger aggregation at all nodes
g.recv(g.nodes(), gcn_reduce)
# get the result node features
h = g.ndata.pop('h')
# perform linear transformation
return self.linear(h)
from dgl.nn.pytorch import GraphConv
###############################################################################
# In general, the nodes send information computed via the *message functions*,
# and aggregate incoming information with the *reduce functions*.
#
# Define a deeper GCN model that contains two GCN layers:
# Define a 2-layer GCN model
class GCN(nn.Module):
def __init__(self, in_feats, hidden_size, num_classes):
super(GCN, self).__init__()
self.gcn1 = GCNLayer(in_feats, hidden_size)
self.gcn2 = GCNLayer(hidden_size, num_classes)
self.conv1 = GraphConv(in_feats, hidden_size)
self.conv2 = GraphConv(hidden_size, num_classes)
def forward(self, g, inputs):
h = self.gcn1(g, inputs)
h = self.conv1(g, inputs)
h = torch.relu(h)
h = self.gcn2(g, h)
h = self.conv2(g, h)
return h
# The first layer transforms input features of size of 34 to a hidden size of 5.
# The first layer transforms input features of size of 5 to a hidden size of 5.
# The second layer transforms the hidden layer and produces output features of
# size 2, corresponding to the two groups of the karate club.
net = GCN(34, 5, 2)
net = GCN(5, 5, 2)
###############################################################################
# Step 4: Data preparation and initialization
# -------------------------------------------
#
# We use one-hot vectors to initialize the node features. Since this is a
# We use learnable embeddings to initialize the node features. Since this is a
# semi-supervised setting, only the instructor (node 0) and the club president
# (node 33) are assigned labels. The implementation is available as follow.
inputs = torch.eye(34)
inputs = embed.weight
labeled_nodes = torch.tensor([0, 33]) # only the instructor and the president nodes are labeled
labels = torch.tensor([0, 1]) # their labels are different
......@@ -216,10 +178,11 @@ labels = torch.tensor([0, 1]) # their labels are different
# The training loop is exactly the same as other PyTorch models.
# We (1) create an optimizer, (2) feed the inputs to the model,
# (3) calculate the loss and (4) use autograd to optimize the model.
import itertools
optimizer = torch.optim.Adam(net.parameters(), lr=0.01)
optimizer = torch.optim.Adam(itertools.chain(net.parameters(), embed.parameters()), lr=0.01)
all_logits = []
for epoch in range(30):
for epoch in range(50):
logits = net(G, inputs)
# we save the logits for visualization later
all_logits.append(logits.detach())
......
......@@ -33,18 +33,50 @@ plt.show()
###############################################################################
# The examples here show the same graph, except that :class:`DGLGraph` is always directional.
# There are many ways to construct a :class:`DGLGraph`. Below are the allowed
# data types ordered by our recommendataion.
#
# You can also create a graph by calling the DGL interface.
# * A pair of arrays ``(u, v)`` storing the source and destination nodes respectively.
# They can be numpy arrays or tensor objects from the backend framework.
# * ``scipy`` sparse matrix representing the adjacency matrix of the graph to be
# constructed.
# * ``networkx`` graph object.
# * A list of edges in the form of integer pairs.
#
# In the next example, you build a star graph. :class:`DGLGraph` nodes are a consecutive range of
# integers between 0 and :func:`number_of_nodes() <DGLGraph.number_of_nodes>`
# and can grow by calling :func:`add_nodes <DGLGraph.add_nodes>`.
# The examples below construct the same star graph via different methods.
#
# :class:`DGLGraph` nodes are a consecutive range of integers between 0 and
# :func:`number_of_nodes() <DGLGraph.number_of_nodes>`.
# :class:`DGLGraph` edges are in order of their additions. Note that
# edges are accessed in much the same way as nodes, with one extra feature: *edge broadcasting*.
# edges are accessed in much the same way as nodes, with one extra feature:
# *edge broadcasting*.
import dgl
import torch as th
import numpy as np
import scipy.sparse as spp
# Create a star graph from a pair of arrays (using ``numpy.array`` works too).
u = th.tensor([0, 0, 0, 0, 0])
v = th.tensor([1, 2, 3, 4, 5])
star1 = dgl.DGLGraph((u, v))
# Create the same graph in one go! Essentially, if one of the arrays is a scalar,
# the value is automatically broadcasted to match the length of the other array
# -- a feature called *edge broadcasting*.
start2 = dgl.DGLGraph((0, v))
# Create the same graph from a scipy sparse matrix (using ``scipy.sparse.csr_matrix`` works too).
adj = spp.coo_matrix((np.ones(len(u)), (u.numpy(), v.numpy())))
star3 = dgl.DGLGraph(adj)
# Create the same graph from a list of integer pairs.
elist = [(0, 1), (0, 2), (0, 3), (0, 4), (0, 5)]
star4 = dgl.DGLGraph(elist)
###############################################################################
# You can also create a graph by progressively adding more nodes and edges.
# Although it is not as efficient as the above constructors, it is suitable
# for applications where the graph cannot be constructed in one shot.
g = dgl.DGLGraph()
g.add_nodes(10)
......@@ -63,12 +95,10 @@ g.clear(); g.add_nodes(10)
src = th.tensor(list(range(1, 10)));
g.add_edges(src, 0)
import networkx as nx
import matplotlib.pyplot as plt
# Visualize the graph.
nx.draw(g.to_networkx(), with_labels=True)
plt.show()
###############################################################################
# Assigning a feature
# -------------------
......@@ -89,19 +119,14 @@ import torch as th
x = th.randn(10, 3)
g.ndata['x'] = x
###############################################################################
# :func:`ndata <DGLGraph.ndata>` is a syntax sugar to access the state of all nodes.
# States are stored
# in a container ``data`` that hosts a user-defined dictionary.
print(g.ndata['x'] == g.nodes[:].data['x'])
# Access node set with integer, list, or integer tensor
g.nodes[0].data['x'] = th.zeros(1, 3)
g.nodes[[0, 1, 2]].data['x'] = th.zeros(3, 3)
g.nodes[th.tensor([0, 1, 2])].data['x'] = th.zeros(3, 3)
# :func:`ndata <DGLGraph.ndata>` is a syntax sugar to access the feature
# data of all nodes. To get the features of some particular nodes, slice out
# the corresponding rows.
g.ndata['x'][0] = th.zeros(1, 3)
g.ndata['x'][[0, 1, 2]] = th.zeros(3, 3)
g.ndata['x'][th.tensor([0, 1, 2])] = th.randn((3, 3))
###############################################################################
# Assigning edge features is similar to that of node features,
......@@ -110,14 +135,15 @@ g.nodes[th.tensor([0, 1, 2])].data['x'] = th.zeros(3, 3)
g.edata['w'] = th.randn(9, 2)
# Access edge set with IDs in integer, list, or integer tensor
g.edges[1].data['w'] = th.randn(1, 2)
g.edges[[0, 1, 2]].data['w'] = th.zeros(3, 2)
g.edges[th.tensor([0, 1, 2])].data['w'] = th.zeros(3, 2)
# You can also access the edges by giving endpoints
g.edges[1, 0].data['w'] = th.ones(1, 2) # edge 1 -> 0
g.edges[[1, 2, 3], [0, 0, 0]].data['w'] = th.ones(3, 2) # edges [1, 2, 3] -> 0
g.edata['w'][1] = th.randn(1, 2)
g.edata['w'][[0, 1, 2]] = th.zeros(3, 2)
g.edata['w'][th.tensor([0, 1, 2])] = th.zeros(3, 2)
# You can get the edge ids by giving endpoints, which are useful for accessing the features.
g.edata['w'][g.edge_id(1, 0)] = th.ones(1, 2) # edge 1 -> 0
g.edata['w'][g.edge_ids([1, 2, 3], [0, 0, 0])] = th.ones(3, 2) # edges [1, 2, 3] -> 0
# Use edge broadcasting whenever applicable.
g.edata['w'][g.edge_ids([1, 2, 3], 0)] = th.ones(3, 2) # edges [1, 2, 3] -> 0
###############################################################################
# After assignments, each node or edge field will be associated with a scheme
......@@ -170,7 +196,6 @@ print(g_multi.edata['w'])
# * Updating a feature of different schemes raises the risk of error on individual nodes (or
# node subset).
###############################################################################
# Next steps
# ----------
......
......@@ -72,6 +72,7 @@ plt.show()
# list of graph and label pairs.
import dgl
import torch
def collate(samples):
# The input `samples` is a list of pairs
......@@ -99,57 +100,9 @@ def collate(samples):
# be called readout or aggregation. Finally, the graph
# representations are fed into a classifier :math:`g` to predict the graph labels.
#
# Graph convolution
# -----------------
# The graph convolution operation is basically the same as that for graph convolutional network (GCN). To learn more,
# see the GCN `tutorial <https://docs.dgl.ai/tutorials/models/1_gnn/1_gcn.html>`_). The only difference is
# that we replace :math:`h_{v}^{(l+1)} = \text{ReLU}\left(b^{(l)}+\sum_{u\in\mathcal{N}(v)}h_{u}^{(l)}W^{(l)}\right)`
# by
# :math:`h_{v}^{(l+1)} = \text{ReLU}\left(b^{(l)}+\frac{1}{|\mathcal{N}(v)|}\sum_{u\in\mathcal{N}(v)}h_{u}^{(l)}W^{(l)}\right)`
#
# The replacement of summation by average is to balance nodes with different
# degrees. This gives a better performance for this experiment.
#
# The self edges added in the dataset initialization allows you to
# include the original node feature :math:`h_{v}^{(l)}` when taking the average.
import dgl.function as fn
import torch
import torch.nn as nn
# Sends a message of node feature h.
msg = fn.copy_src(src='h', out='m')
# Graph convolution layer can be found in the ``dgl.nn.<backend>`` submodule.
def reduce(nodes):
"""Take an average over all neighbor node features hu and use it to
overwrite the original node feature."""
accum = torch.mean(nodes.mailbox['m'], 1)
return {'h': accum}
class NodeApplyModule(nn.Module):
"""Update the node feature hv with ReLU(Whv+b)."""
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
h = self.activation(h)
return {'h' : h}
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
# Initialize the node features with h.
g.ndata['h'] = feature
g.update_all(msg, reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
from dgl.nn.pytorch import GraphConv
###############################################################################
# Readout and classification
......@@ -166,25 +119,25 @@ class GCN(nn.Module):
# graphs with variable size. You then feed the graph representations into a
# classifier with one linear layer to obtain pre-softmax logits.
import torch.nn as nn
import torch.nn.functional as F
class Classifier(nn.Module):
def __init__(self, in_dim, hidden_dim, n_classes):
super(Classifier, self).__init__()
self.layers = nn.ModuleList([
GCN(in_dim, hidden_dim, F.relu),
GCN(hidden_dim, hidden_dim, F.relu)])
self.conv1 = GraphConv(in_dim, hidden_dim)
self.conv2 = GraphConv(hidden_dim, hidden_dim)
self.classify = nn.Linear(hidden_dim, n_classes)
def forward(self, g):
# For undirected graphs, in_degree is the same as
# out_degree.
# Use node degree as the initial node feature. For undirected graphs, the in-degree
# is the same as the out_degree.
h = g.in_degrees().view(-1, 1).float()
for conv in self.layers:
h = conv(g, h)
# Perform graph convolution and activation function.
h = F.relu(self.conv1(g, h))
h = F.relu(self.conv2(g, h))
g.ndata['h'] = h
# Calculate graph representation by averaging all the node representations.
hg = dgl.mean_nodes(g, 'h')
return self.classify(hg)
......
......@@ -9,9 +9,14 @@ Yu Gai, Quan Gan, Zheng Zhang
This is a gentle introduction of using DGL to implement Graph Convolutional
Networks (Kipf & Welling et al., `Semi-Supervised Classification with Graph
Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We build upon
the :doc:`earlier tutorial <../../basics/3_pagerank>` on DGLGraph and demonstrate
how DGL combines graph with deep neural network and learn structural representations.
Convolutional Networks <https://arxiv.org/pdf/1609.02907.pdf>`_). We explain
what is under the hood of the :class:`~dgl.nn.pytorch.GraphConv` module.
The reader is expected to learn how to define a new GNN layer using DGL's
message passing APIs.
We build upon the :doc:`earlier tutorial <../../basics/3_pagerank>` on DGLGraph
and demonstrate how DGL combines graph with deep neural network and learn
structural representations.
"""
###############################################################################
......@@ -28,8 +33,8 @@ how DGL combines graph with deep neural network and learn structural representat
# representation :math:`\hat{h}_{u}` with a linear projection followed by a
# non-linearity: :math:`h_{u} = f(W_{u} \hat{h}_u)`.
#
# We will implement step 1 with DGL message passing, and step 2 with the
# ``apply_nodes`` method, whose node UDF will be a PyTorch ``nn.Module``.
# We will implement step 1 with DGL message passing, and step 2 by
# PyTorch ``nn.Module``.
#
# GCN implementation with DGL
# ``````````````````````````````````````````
......@@ -48,35 +53,23 @@ gcn_msg = fn.copy_src(src='h', out='m')
gcn_reduce = fn.sum(msg='m', out='h')
###############################################################################
# We then define the node UDF for ``apply_nodes``, which is a fully-connected layer:
# We then proceed to define the GCNLayer module. A GCNLayer essentially performs
# message passing on all the nodes then applies a fully-connected layer.
class NodeApplyModule(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(NodeApplyModule, self).__init__()
class GCNLayer(nn.Module):
def __init__(self, in_feats, out_feats):
super(GCNLayer, self).__init__()
self.linear = nn.Linear(in_feats, out_feats)
self.activation = activation
def forward(self, node):
h = self.linear(node.data['h'])
if self.activation is not None:
h = self.activation(h)
return {'h' : h}
###############################################################################
# We then proceed to define the GCN module. A GCN layer essentially performs
# message passing on all the nodes then applies the `NodeApplyModule`. Note
# that we omitted the dropout in the paper for simplicity.
class GCN(nn.Module):
def __init__(self, in_feats, out_feats, activation):
super(GCN, self).__init__()
self.apply_mod = NodeApplyModule(in_feats, out_feats, activation)
def forward(self, g, feature):
# Creating a local scope so that all the stored ndata and edata
# (such as the `'h'` ndata below) are automatically popped out
# when the scope exits.
with g.local_scope():
g.ndata['h'] = feature
g.update_all(gcn_msg, gcn_reduce)
g.apply_nodes(func=self.apply_mod)
return g.ndata.pop('h')
h = g.ndata['h']
return self.linear(h)
###############################################################################
# The forward function is essentially the same as any other commonly seen NNs
......@@ -84,17 +77,17 @@ class GCN(nn.Module):
# let's define a simple neural network consisting of two GCN layers. Suppose we
# are training the classifier for the cora dataset (the input feature size is
# 1433 and the number of classes is 7). The last GCN layer computes node embeddings,
# so the last layer in general doesn't apply activation.
# so the last layer in general does not apply activation.
class Net(nn.Module):
def __init__(self):
super(Net, self).__init__()
self.gcn1 = GCN(1433, 16, F.relu)
self.gcn2 = GCN(16, 7, None)
self.layer1 = GCNLayer(1433, 16)
self.layer2 = GCNLayer(16, 7)
def forward(self, g, features):
x = self.gcn1(g, features)
x = self.gcn2(g, x)
x = F.relu(self.layer1(g, features))
x = self.layer2(g, x)
return x
net = Net()
print(net)
......@@ -110,11 +103,7 @@ def load_cora_data():
labels = th.LongTensor(data.labels)
train_mask = th.BoolTensor(data.train_mask)
test_mask = th.BoolTensor(data.test_mask)
g = data.graph
# add self loop
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
g = DGLGraph(data.graph)
return g, features, labels, train_mask, test_mask
###############################################################################
......@@ -137,7 +126,7 @@ def evaluate(model, g, features, labels, mask):
import time
import numpy as np
g, features, labels, train_mask, test_mask = load_cora_data()
optimizer = th.optim.Adam(net.parameters(), lr=1e-3)
optimizer = th.optim.Adam(net.parameters(), lr=1e-2)
dur = []
for epoch in range(50):
if epoch >=3:
......
......@@ -298,9 +298,7 @@ lr = 0.01 # learning rate
l2norm = 0 # L2 norm coefficient
# create graph
g = DGLGraph()
g.add_nodes(num_nodes)
g.add_edges(data.edge_src, data.edge_dst)
g = DGLGraph((data.edge_src, data.edge_dst))
g.edata.update({'rel_type': edge_type, 'norm': edge_norm})
# create model
......
......@@ -94,6 +94,15 @@ structure-free normalization, in the style of attention.
# GAT in DGL
# ----------
#
# DGL provides an off-the-shelf implementation of the GAT layer under the ``dgl.nn.<backend>``
# subpackage. Simply import the ``GATConv`` as the follows.
from dgl.nn.pytorch import GATConv
###############################################################
# Readers can skip the following step-by-step explanation of the implementation and
# jump to the `Put everything together`_ for training and visualization results.
#
# To begin, you can get an overall impression about how a ``GATLayer`` module is
# implemented in DGL. In this section, the four equations above are broken down
# one at a time.
......@@ -277,11 +286,7 @@ def load_cora_data():
features = torch.FloatTensor(data.features)
labels = torch.LongTensor(data.labels)
mask = torch.BoolTensor(data.train_mask)
g = data.graph
# add self loop
g.remove_edges_from(nx.selfloop_edges(g))
g = DGLGraph(g)
g.add_edges(g.nodes(), g.nodes())
g = DGLGraph(data.graph)
return g, features, labels, mask
##############################################################################
......
......@@ -68,18 +68,11 @@ import dgl
def init_graph(in_nodes, out_nodes, f_size):
g = dgl.DGLGraph()
all_nodes = in_nodes + out_nodes
g.add_nodes(all_nodes)
in_indx = list(range(in_nodes))
out_indx = list(range(in_nodes, in_nodes + out_nodes))
# add edges use edge broadcasting
for u in in_indx:
g.add_edges(u, out_indx)
u = np.repeat(np.arange(in_nodes), out_nodes)
v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
g = dgl.DGLGraph((u, v))
# init states
g.ndata['v'] = th.zeros(all_nodes, f_size)
g.ndata['v'] = th.zeros(in_nodes + out_nodes, f_size)
g.edata['b'] = th.zeros(in_nodes * out_nodes, 1)
return g
......@@ -113,6 +106,8 @@ def init_graph(in_nodes, out_nodes, f_size):
# - The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
# :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
import dgl.function as fn
class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size):
super(DGLRoutingLayer, self).__init__()
......@@ -125,24 +120,14 @@ class DGLRoutingLayer(nn.Module):
def forward(self, u_hat, routing_num=1):
self.g.edata['u_hat'] = u_hat
# step 2 (line 5)
def cap_message(edges):
return {'m': edges.data['c'] * edges.data['u_hat']}
self.g.register_message_func(cap_message)
def cap_reduce(nodes):
return {'s': th.sum(nodes.mailbox['m'], dim=1)}
self.g.register_reduce_func(cap_reduce)
for r in range(routing_num):
# step 1 (line 4): normalize over out edges
edges_b = self.g.edata['b'].view(self.in_nodes, self.out_nodes)
self.g.edata['c'] = F.softmax(edges_b, dim=1).view(-1, 1)
self.g.edata['c u_hat'] = self.g.edata['c'] * self.g.edata['u_hat']
# Execute step 1 & 2
self.g.update_all()
self.g.update_all(fn.copy_e('c u_hat', 'm'), fn.sum('m', 's'))
# step 3 (line 6)
self.g.nodes[self.out_indx].data['v'] = self.squash(self.g.nodes[self.out_indx].data['s'], dim=1)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment