Unverified Commit dce89919 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Auto-reformat multiple python folders. (#5325)



* auto-reformat

* lintrunner

---------
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-28-63.ap-northeast-1.compute.internal>
parent ab812179
...@@ -15,376 +15,408 @@ Tree-LSTM in DGL ...@@ -15,376 +15,408 @@ Tree-LSTM in DGL
efficiency. For recommended implementation, please refer to the `official efficiency. For recommended implementation, please refer to the `official
examples <https://github.com/dmlc/dgl/tree/master/examples>`_. examples <https://github.com/dmlc/dgl/tree/master/examples>`_.
""" """
############################################################################## import os
#
# In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis. ##############################################################################
# The Tree-LSTM is a generalization of long short-term memory (LSTM) networks to tree-structured network topologies. #
# # In this tutorial, you learn to use Tree-LSTM networks for sentiment analysis.
# The Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015 # The Tree-LSTM is a generalization of long short-term memory (LSTM) networks to tree-structured network topologies.
# paper: `Improved Semantic Representations From Tree-Structured Long #
# Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__. # The Tree-LSTM structure was first introduced by Kai et. al in an ACL 2015
# The core idea is to introduce syntactic information for language tasks by # paper: `Improved Semantic Representations From Tree-Structured Long
# extending the chain-structured LSTM to a tree-structured LSTM. The dependency # Short-Term Memory Networks <https://arxiv.org/pdf/1503.00075.pdf>`__.
# tree and constituency tree techniques are leveraged to obtain a ''latent tree''. # The core idea is to introduce syntactic information for language tasks by
# # extending the chain-structured LSTM to a tree-structured LSTM. The dependency
# The challenge in training Tree-LSTMs is batching --- a standard # tree and constituency tree techniques are leveraged to obtain a ''latent tree''.
# technique in machine learning to accelerate optimization. However, since trees #
# generally have different shapes by nature, parallization is non-trivial. # The challenge in training Tree-LSTMs is batching --- a standard
# DGL offers an alternative. Pool all the trees into one single graph then # technique in machine learning to accelerate optimization. However, since trees
# induce the message passing over them, guided by the structure of each tree. # generally have different shapes by nature, parallization is non-trivial.
# # DGL offers an alternative. Pool all the trees into one single graph then
# The task and the dataset # induce the message passing over them, guided by the structure of each tree.
# ------------------------ #
# # The task and the dataset
# The steps here use the # ------------------------
# `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in #
# ``dgl.data``. The dataset provides a fine-grained, tree-level sentiment # The steps here use the
# annotation. There are five classes: Very negative, negative, neutral, positive, and # `Stanford Sentiment Treebank <https://nlp.stanford.edu/sentiment/>`__ in
# very positive, which indicate the sentiment in the current subtree. Non-leaf # ``dgl.data``. The dataset provides a fine-grained, tree-level sentiment
# nodes in a constituency tree do not contain words, so use a special # annotation. There are five classes: Very negative, negative, neutral, positive, and
# ``PAD_WORD`` token to denote them. During training and inference # very positive, which indicate the sentiment in the current subtree. Non-leaf
# their embeddings would be masked to all-zero. # nodes in a constituency tree do not contain words, so use a special
# # ``PAD_WORD`` token to denote them. During training and inference
# .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png # their embeddings would be masked to all-zero.
# :alt: #
# # .. figure:: https://i.loli.net/2018/11/08/5be3d4bfe031b.png
# The figure displays one sample of the SST dataset, which is a # :alt:
# constituency parse tree with their nodes labeled with sentiment. To #
# speed up things, build a tiny set with five sentences and take a look # The figure displays one sample of the SST dataset, which is a
# at the first one. # constituency parse tree with their nodes labeled with sentiment. To
# # speed up things, build a tiny set with five sentences and take a look
# at the first one.
from collections import namedtuple #
import os from collections import namedtuple
os.environ['DGLBACKEND'] = 'pytorch'
import dgl os.environ["DGLBACKEND"] = "pytorch"
from dgl.data.tree import SSTDataset import dgl
from dgl.data.tree import SSTDataset
SSTBatch = namedtuple('SSTBatch', ['graph', 'mask', 'wordid', 'label'])
SSTBatch = namedtuple("SSTBatch", ["graph", "mask", "wordid", "label"])
# Each sample in the dataset is a constituency tree. The leaf nodes
# represent words. The word is an int value stored in the "x" field. # Each sample in the dataset is a constituency tree. The leaf nodes
# The non-leaf nodes have a special word PAD_WORD. The sentiment # represent words. The word is an int value stored in the "x" field.
# label is stored in the "y" feature field. # The non-leaf nodes have a special word PAD_WORD. The sentiment
trainset = SSTDataset(mode='tiny') # the "tiny" set has only five trees # label is stored in the "y" feature field.
tiny_sst = [tr for tr in trainset] trainset = SSTDataset(mode="tiny") # the "tiny" set has only five trees
num_vocabs = trainset.vocab_size tiny_sst = [tr for tr in trainset]
num_classes = trainset.num_classes num_vocabs = trainset.vocab_size
num_classes = trainset.num_classes
vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {v: k for k, v in vocab.items()} # inverted vocabulary dict: id -> word vocab = trainset.vocab # vocabulary dict: key -> id
inv_vocab = {
a_tree = tiny_sst[0] v: k for k, v in vocab.items()
for token in a_tree.ndata['x'].tolist(): } # inverted vocabulary dict: id -> word
if token != trainset.PAD_WORD:
print(inv_vocab[token], end=" ") a_tree = tiny_sst[0]
for token in a_tree.ndata["x"].tolist():
############################################################################## if token != trainset.PAD_WORD:
# Step 1: Batching print(inv_vocab[token], end=" ")
# ---------------- import matplotlib.pyplot as plt
#
# Add all the trees to one graph, using ##############################################################################
# the :func:`~dgl.batched_graph.batch` API. # Step 1: Batching
# # ----------------
#
import networkx as nx # Add all the trees to one graph, using
import matplotlib.pyplot as plt # the :func:`~dgl.batched_graph.batch` API.
#
graph = dgl.batch(tiny_sst)
def plot_tree(g): import networkx as nx
# this plot requires pygraphviz package
pos = nx.nx_agraph.graphviz_layout(g, prog='dot') graph = dgl.batch(tiny_sst)
nx.draw(g, pos, with_labels=False, node_size=10,
node_color=[[.5, .5, .5]], arrowsize=4)
plt.show() def plot_tree(g):
# this plot requires pygraphviz package
plot_tree(graph.to_networkx()) pos = nx.nx_agraph.graphviz_layout(g, prog="dot")
nx.draw(
################################################################################# g,
# You can read more about the definition of :func:`~dgl.batch`, or pos,
# skip ahead to the next step: with_labels=False,
# .. note:: node_size=10,
# node_color=[[0.5, 0.5, 0.5]],
# **Definition**: :func:`~dgl.batch` unions a list of :math:`B` arrowsize=4,
# :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch )
# size :math:`B`. plt.show()
#
# - The union includes all the nodes,
# edges, and their features. The order of nodes, edges, and features are plot_tree(graph.to_networkx())
# preserved.
# #################################################################################
# - Given that you have :math:`V_i` nodes for graph # You can read more about the definition of :func:`~dgl.batch`, or
# :math:`\mathcal{G}_i`, the node ID :math:`j` in graph # skip ahead to the next step:
# :math:`\mathcal{G}_i` correspond to node ID # .. note::
# :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph. #
# # **Definition**: :func:`~dgl.batch` unions a list of :math:`B`
# - Therefore, performing feature transformation and message passing on # :class:`~dgl.DGLGraph`\ s and returns a :class:`~dgl.DGLGraph` of batch
# the batched graph is equivalent to doing those # size :math:`B`.
# on all ``DGLGraph`` constituents in parallel. #
# # - The union includes all the nodes,
# - Duplicate references to the same graph are # edges, and their features. The order of nodes, edges, and features are
# treated as deep copies; the nodes, edges, and features are duplicated, # preserved.
# and mutation on one reference does not affect the other. #
# - The batched graph keeps track of the meta # - Given that you have :math:`V_i` nodes for graph
# information of the constituents so it can be # :math:`\mathcal{G}_i`, the node ID :math:`j` in graph
# :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s. # :math:`\mathcal{G}_i` correspond to node ID
# # :math:`j + \sum_{k=1}^{i-1} V_k` in the batched graph.
# Step 2: Tree-LSTM cell with message-passing APIs #
# ------------------------------------------------ # - Therefore, performing feature transformation and message passing on
# # the batched graph is equivalent to doing those
# Researchers have proposed two types of Tree-LSTMs: Child-Sum # on all ``DGLGraph`` constituents in parallel.
# Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial you focus #
# on applying *Binary* Tree-LSTM to binarized constituency trees. This # - Duplicate references to the same graph are
# application is also known as *Constituency Tree-LSTM*. Use PyTorch # treated as deep copies; the nodes, edges, and features are duplicated,
# as a backend framework to set up the network. # and mutation on one reference does not affect the other.
# # - The batched graph keeps track of the meta
# In `N`-ary Tree-LSTM, each unit at node :math:`j` maintains a hidden # information of the constituents so it can be
# representation :math:`h_j` and a memory cell :math:`c_j`. The unit # :func:`~dgl.batched_graph.unbatch`\ ed to list of ``DGLGraph``\ s.
# :math:`j` takes the input vector :math:`x_j` and the hidden #
# representations of the child units: :math:`h_{jl}, 1\leq l\leq N` as # Step 2: Tree-LSTM cell with message-passing APIs
# input, then update its new hidden representation :math:`h_j` and memory # ------------------------------------------------
# cell :math:`c_j` by: #
# # Researchers have proposed two types of Tree-LSTMs: Child-Sum
# .. math:: # Tree-LSTMs, and :math:`N`-ary Tree-LSTMs. In this tutorial you focus
# # on applying *Binary* Tree-LSTM to binarized constituency trees. This
# i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\ # application is also known as *Constituency Tree-LSTM*. Use PyTorch
# f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\ # as a backend framework to set up the network.
# o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\ #
# u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\ # In `N`-ary Tree-LSTM, each unit at node :math:`j` maintains a hidden
# c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\ # representation :math:`h_j` and a memory cell :math:`c_j`. The unit
# h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\ # :math:`j` takes the input vector :math:`x_j` and the hidden
# # representations of the child units: :math:`h_{jl}, 1\leq l\leq N` as
# It can be decomposed into three phases: ``message_func``, # input, then update its new hidden representation :math:`h_j` and memory
# ``reduce_func`` and ``apply_node_func``. # cell :math:`c_j` by:
# #
# .. note:: # .. math::
# ``apply_node_func`` is a new node UDF that has not been introduced before. In #
# ``apply_node_func``, a user specifies what to do with node features, # i_j & = & \sigma\left(W^{(i)}x_j + \sum_{l=1}^{N}U^{(i)}_l h_{jl} + b^{(i)}\right), & (1)\\
# without considering edge features and messages. In a Tree-LSTM case, # f_{jk} & = & \sigma\left(W^{(f)}x_j + \sum_{l=1}^{N}U_{kl}^{(f)} h_{jl} + b^{(f)} \right), & (2)\\
# ``apply_node_func`` is a must, since there exists (leaf) nodes with # o_j & = & \sigma\left(W^{(o)}x_j + \sum_{l=1}^{N}U_{l}^{(o)} h_{jl} + b^{(o)} \right), & (3) \\
# :math:`0` incoming edges, which would not be updated with # u_j & = & \textrm{tanh}\left(W^{(u)}x_j + \sum_{l=1}^{N} U_l^{(u)}h_{jl} + b^{(u)} \right), & (4)\\
# ``reduce_func``. # c_j & = & i_j \odot u_j + \sum_{l=1}^{N} f_{jl} \odot c_{jl}, &(5) \\
# # h_j & = & o_j \cdot \textrm{tanh}(c_j), &(6) \\
#
import torch as th # It can be decomposed into three phases: ``message_func``,
import torch.nn as nn # ``reduce_func`` and ``apply_node_func``.
#
class TreeLSTMCell(nn.Module): # .. note::
def __init__(self, x_size, h_size): # ``apply_node_func`` is a new node UDF that has not been introduced before. In
super(TreeLSTMCell, self).__init__() # ``apply_node_func``, a user specifies what to do with node features,
self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False) # without considering edge features and messages. In a Tree-LSTM case,
self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False) # ``apply_node_func`` is a must, since there exists (leaf) nodes with
self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size)) # :math:`0` incoming edges, which would not be updated with
self.U_f = nn.Linear(2 * h_size, 2 * h_size) # ``reduce_func``.
#
def message_func(self, edges):
return {'h': edges.src['h'], 'c': edges.src['c']} import torch as th
import torch.nn as nn
def reduce_func(self, nodes):
# concatenate h_jl for equation (1), (2), (3), (4)
h_cat = nodes.mailbox['h'].view(nodes.mailbox['h'].size(0), -1) class TreeLSTMCell(nn.Module):
# equation (2) def __init__(self, x_size, h_size):
f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox['h'].size()) super(TreeLSTMCell, self).__init__()
# second term of equation (5) self.W_iou = nn.Linear(x_size, 3 * h_size, bias=False)
c = th.sum(f * nodes.mailbox['c'], 1) self.U_iou = nn.Linear(2 * h_size, 3 * h_size, bias=False)
return {'iou': self.U_iou(h_cat), 'c': c} self.b_iou = nn.Parameter(th.zeros(1, 3 * h_size))
self.U_f = nn.Linear(2 * h_size, 2 * h_size)
def apply_node_func(self, nodes):
# equation (1), (3), (4) def message_func(self, edges):
iou = nodes.data['iou'] + self.b_iou return {"h": edges.src["h"], "c": edges.src["c"]}
i, o, u = th.chunk(iou, 3, 1)
i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u) def reduce_func(self, nodes):
# equation (5) # concatenate h_jl for equation (1), (2), (3), (4)
c = i * u + nodes.data['c'] h_cat = nodes.mailbox["h"].view(nodes.mailbox["h"].size(0), -1)
# equation (6) # equation (2)
h = o * th.tanh(c) f = th.sigmoid(self.U_f(h_cat)).view(*nodes.mailbox["h"].size())
return {'h' : h, 'c' : c} # second term of equation (5)
c = th.sum(f * nodes.mailbox["c"], 1)
############################################################################## return {"iou": self.U_iou(h_cat), "c": c}
# Step 3: Define traversal
# ------------------------ def apply_node_func(self, nodes):
# # equation (1), (3), (4)
# After you define the message-passing functions, induce the iou = nodes.data["iou"] + self.b_iou
# right order to trigger them. This is a significant departure from models i, o, u = th.chunk(iou, 3, 1)
# such as GCN, where all nodes are pulling messages from upstream ones i, o, u = th.sigmoid(i), th.sigmoid(o), th.tanh(u)
# *simultaneously*. # equation (5)
# c = i * u + nodes.data["c"]
# In the case of Tree-LSTM, messages start from leaves of the tree, and # equation (6)
# propagate/processed upwards until they reach the roots. A visualization h = o * th.tanh(c)
# is as follows: return {"h": h, "c": c}
#
# .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
# :alt: ##############################################################################
# # Step 3: Define traversal
# DGL defines a generator to perform the topological sort, each item is a # ------------------------
# tensor recording the nodes from bottom level to the roots. One can #
# appreciate the degree of parallelism by inspecting the difference of the # After you define the message-passing functions, induce the
# followings: # right order to trigger them. This is a significant departure from models
# # such as GCN, where all nodes are pulling messages from upstream ones
# *simultaneously*.
# to heterogenous graph #
trv_a_tree = dgl.graph(a_tree.edges()) # In the case of Tree-LSTM, messages start from leaves of the tree, and
print('Traversing one tree:') # propagate/processed upwards until they reach the roots. A visualization
print(dgl.topological_nodes_generator(trv_a_tree)) # is as follows:
#
# to heterogenous graph # .. figure:: https://i.loli.net/2018/11/09/5be4b5d2df54d.gif
trv_graph = dgl.graph(graph.edges()) # :alt:
print('Traversing many trees at the same time:') #
print(dgl.topological_nodes_generator(trv_graph)) # DGL defines a generator to perform the topological sort, each item is a
# tensor recording the nodes from bottom level to the roots. One can
############################################################################## # appreciate the degree of parallelism by inspecting the difference of the
# Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing: # followings:
#
import dgl.function as fn
import torch as th # to heterogenous graph
trv_a_tree = dgl.graph(a_tree.edges())
trv_graph.ndata['a'] = th.ones(graph.number_of_nodes(), 1) print("Traversing one tree:")
traversal_order = dgl.topological_nodes_generator(trv_graph) print(dgl.topological_nodes_generator(trv_a_tree))
trv_graph.prop_nodes(traversal_order,
message_func=fn.copy_u('a', 'a'), # to heterogenous graph
reduce_func=fn.sum('a', 'a')) trv_graph = dgl.graph(graph.edges())
print("Traversing many trees at the same time:")
# the following is a syntax sugar that does the same print(dgl.topological_nodes_generator(trv_graph))
# dgl.prop_nodes_topo(graph)
##############################################################################
############################################################################## # Call :meth:`~dgl.DGLGraph.prop_nodes` to trigger the message passing:
# .. note::
# import dgl.function as fn
# Before you call :meth:`~dgl.DGLGraph.prop_nodes`, specify a import torch as th
# `message_func` and `reduce_func` in advance. In the example, you can see built-in
# copy-from-source and sum functions as message functions, and a reduce trv_graph.ndata["a"] = th.ones(graph.number_of_nodes(), 1)
# function for demonstration. traversal_order = dgl.topological_nodes_generator(trv_graph)
# trv_graph.prop_nodes(
# Putting it together traversal_order,
# ------------------- message_func=fn.copy_u("a", "a"),
# reduce_func=fn.sum("a", "a"),
# Here is the complete code that specifies the ``Tree-LSTM`` class. )
#
# the following is a syntax sugar that does the same
class TreeLSTM(nn.Module): # dgl.prop_nodes_topo(graph)
def __init__(self,
num_vocabs, ##############################################################################
x_size, # .. note::
h_size, #
num_classes, # Before you call :meth:`~dgl.DGLGraph.prop_nodes`, specify a
dropout, # `message_func` and `reduce_func` in advance. In the example, you can see built-in
pretrained_emb=None): # copy-from-source and sum functions as message functions, and a reduce
super(TreeLSTM, self).__init__() # function for demonstration.
self.x_size = x_size #
self.embedding = nn.Embedding(num_vocabs, x_size) # Putting it together
if pretrained_emb is not None: # -------------------
print('Using glove') #
self.embedding.weight.data.copy_(pretrained_emb) # Here is the complete code that specifies the ``Tree-LSTM`` class.
self.embedding.weight.requires_grad = True #
self.dropout = nn.Dropout(dropout)
self.linear = nn.Linear(h_size, num_classes)
self.cell = TreeLSTMCell(x_size, h_size) class TreeLSTM(nn.Module):
def __init__(
def forward(self, batch, h, c): self,
"""Compute tree-lstm prediction given a batch. num_vocabs,
x_size,
Parameters h_size,
---------- num_classes,
batch : dgl.data.SSTBatch dropout,
The data batch. pretrained_emb=None,
h : Tensor ):
Initial hidden state. super(TreeLSTM, self).__init__()
c : Tensor self.x_size = x_size
Initial cell state. self.embedding = nn.Embedding(num_vocabs, x_size)
if pretrained_emb is not None:
Returns print("Using glove")
------- self.embedding.weight.data.copy_(pretrained_emb)
logits : Tensor self.embedding.weight.requires_grad = True
The prediction of each node. self.dropout = nn.Dropout(dropout)
""" self.linear = nn.Linear(h_size, num_classes)
g = batch.graph self.cell = TreeLSTMCell(x_size, h_size)
# to heterogenous graph
g = dgl.graph(g.edges()) def forward(self, batch, h, c):
# feed embedding """Compute tree-lstm prediction given a batch.
embeds = self.embedding(batch.wordid * batch.mask)
g.ndata['iou'] = self.cell.W_iou(self.dropout(embeds)) * batch.mask.float().unsqueeze(-1) Parameters
g.ndata['h'] = h ----------
g.ndata['c'] = c batch : dgl.data.SSTBatch
# propagate The data batch.
dgl.prop_nodes_topo(g, h : Tensor
message_func=self.cell.message_func, Initial hidden state.
reduce_func=self.cell.reduce_func, c : Tensor
apply_node_func=self.cell.apply_node_func) Initial cell state.
# compute logits
h = self.dropout(g.ndata.pop('h')) Returns
logits = self.linear(h) -------
return logits logits : Tensor
The prediction of each node.
############################################################################## """
# Main Loop g = batch.graph
# --------- # to heterogenous graph
# g = dgl.graph(g.edges())
# Finally, you could write a training paradigm in PyTorch. # feed embedding
# embeds = self.embedding(batch.wordid * batch.mask)
g.ndata["iou"] = self.cell.W_iou(
from torch.utils.data import DataLoader self.dropout(embeds)
import torch.nn.functional as F ) * batch.mask.float().unsqueeze(-1)
g.ndata["h"] = h
device = th.device('cpu') g.ndata["c"] = c
# hyper parameters # propagate
x_size = 256 dgl.prop_nodes_topo(
h_size = 256 g,
dropout = 0.5 message_func=self.cell.message_func,
lr = 0.05 reduce_func=self.cell.reduce_func,
weight_decay = 1e-4 apply_node_func=self.cell.apply_node_func,
epochs = 10 )
# compute logits
# create the model h = self.dropout(g.ndata.pop("h"))
model = TreeLSTM(trainset.vocab_size, logits = self.linear(h)
x_size, return logits
h_size,
trainset.num_classes,
dropout) import torch.nn.functional as F
print(model)
##############################################################################
# create the optimizer # Main Loop
optimizer = th.optim.Adagrad(model.parameters(), # ---------
lr=lr, #
weight_decay=weight_decay) # Finally, you could write a training paradigm in PyTorch.
#
def batcher(dev):
def batcher_dev(batch): from torch.utils.data import DataLoader
batch_trees = dgl.batch(batch)
return SSTBatch(graph=batch_trees, device = th.device("cpu")
mask=batch_trees.ndata['mask'].to(device), # hyper parameters
wordid=batch_trees.ndata['x'].to(device), x_size = 256
label=batch_trees.ndata['y'].to(device)) h_size = 256
return batcher_dev dropout = 0.5
lr = 0.05
train_loader = DataLoader(dataset=tiny_sst, weight_decay = 1e-4
batch_size=5, epochs = 10
collate_fn=batcher(device),
shuffle=False, # create the model
num_workers=0) model = TreeLSTM(
trainset.vocab_size, x_size, h_size, trainset.num_classes, dropout
# training loop )
for epoch in range(epochs): print(model)
for step, batch in enumerate(train_loader):
g = batch.graph # create the optimizer
n = g.number_of_nodes() optimizer = th.optim.Adagrad(
h = th.zeros((n, h_size)) model.parameters(), lr=lr, weight_decay=weight_decay
c = th.zeros((n, h_size)) )
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction='sum') def batcher(dev):
optimizer.zero_grad() def batcher_dev(batch):
loss.backward() batch_trees = dgl.batch(batch)
optimizer.step() return SSTBatch(
pred = th.argmax(logits, 1) graph=batch_trees,
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label) mask=batch_trees.ndata["mask"].to(device),
print("Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format( wordid=batch_trees.ndata["x"].to(device),
epoch, step, loss.item(), acc)) label=batch_trees.ndata["y"].to(device),
)
##############################################################################
# To train the model on a full dataset with different settings (such as CPU or GPU), return batcher_dev
# refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# There is also an implementation of the Child-Sum Tree-LSTM.
train_loader = DataLoader(
dataset=tiny_sst,
batch_size=5,
collate_fn=batcher(device),
shuffle=False,
num_workers=0,
)
# training loop
for epoch in range(epochs):
for step, batch in enumerate(train_loader):
g = batch.graph
n = g.number_of_nodes()
h = th.zeros((n, h_size))
c = th.zeros((n, h_size))
logits = model(batch, h, c)
logp = F.log_softmax(logits, 1)
loss = F.nll_loss(logp, batch.label, reduction="sum")
optimizer.zero_grad()
loss.backward()
optimizer.step()
pred = th.argmax(logits, 1)
acc = float(th.sum(th.eq(batch.label, pred))) / len(batch.label)
print(
"Epoch {:05d} | Step {:05d} | Loss {:.4f} | Acc {:.4f} |".format(
epoch, step, loss.item(), acc
)
)
##############################################################################
# To train the model on a full dataset with different settings (such as CPU or GPU),
# refer to the `PyTorch example <https://github.com/dmlc/dgl/tree/master/examples/pytorch/tree_lstm>`__.
# There is also an implementation of the Child-Sum Tree-LSTM.
...@@ -14,764 +14,780 @@ Generative Models of Graphs ...@@ -14,764 +14,780 @@ Generative Models of Graphs
efficiency. For recommended implementation, please refer to the `official efficiency. For recommended implementation, please refer to the `official
examples <https://github.com/dmlc/dgl/tree/master/examples>`_. examples <https://github.com/dmlc/dgl/tree/master/examples>`_.
""" """
############################################################################## ##############################################################################
# #
# In this tutorial, you learn how to train and generate one graph at # In this tutorial, you learn how to train and generate one graph at
# a time. You also explore parallelism within the graph embedding operation, which is an # a time. You also explore parallelism within the graph embedding operation, which is an
# essential building block. The tutorial ends with a simple optimization that # essential building block. The tutorial ends with a simple optimization that
# delivers double the speed by batching across graphs. # delivers double the speed by batching across graphs.
# #
# Earlier tutorials showed how embedding a graph or # Earlier tutorials showed how embedding a graph or
# a node enables you to work on tasks such as `semi-supervised classification for nodes # a node enables you to work on tasks such as `semi-supervised classification for nodes
# <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__ # <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__
# or `sentiment analysis # or `sentiment analysis
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__. # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__.
# Wouldn't it be interesting to predict the future evolution of the graph and # Wouldn't it be interesting to predict the future evolution of the graph and
# perform the analysis iteratively? # perform the analysis iteratively?
# #
# To address the evolution of the graphs, you generate a variety of graph samples. In other words, you need # To address the evolution of the graphs, you generate a variety of graph samples. In other words, you need
# **generative models** of graphs. In-addition to learning # **generative models** of graphs. In-addition to learning
# node and edge features, you would need to model the distribution of arbitrary graphs. # node and edge features, you would need to model the distribution of arbitrary graphs.
# While general generative models can model the density function explicitly and # While general generative models can model the density function explicitly and
# implicitly and generate samples at once or sequentially, you only focus # implicitly and generate samples at once or sequentially, you only focus
# on explicit generative models for sequential generation here. Typical applications # on explicit generative models for sequential generation here. Typical applications
# include drug or materials discovery, chemical processes, or proteomics. # include drug or materials discovery, chemical processes, or proteomics.
# #
# Introduction # Introduction
# -------------------- # --------------------
# The primitive actions of mutating a graph in Deep Graph Library (DGL) are nothing more than ``add_nodes`` # The primitive actions of mutating a graph in Deep Graph Library (DGL) are nothing more than ``add_nodes``
# and ``add_edges``. That is, if you were to draw a circle of three nodes, # and ``add_edges``. That is, if you were to draw a circle of three nodes,
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif
# :alt: # :alt:
# #
# you can write the code as follows. # you can write the code as follows.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import dgl os.environ["DGLBACKEND"] = "pytorch"
import dgl
g = dgl.DGLGraph()
g.add_nodes(1) # Add node 0 g = dgl.DGLGraph()
g.add_nodes(1) # Add node 1 g.add_nodes(1) # Add node 0
g.add_nodes(1) # Add node 1
# Edges in DGLGraph are directed by default.
# For undirected edges, add edges for both directions. # Edges in DGLGraph are directed by default.
g.add_edges([1, 0], [0, 1]) # Add edges (1, 0), (0, 1) # For undirected edges, add edges for both directions.
g.add_nodes(1) # Add node 2 g.add_edges([1, 0], [0, 1]) # Add edges (1, 0), (0, 1)
g.add_edges([2, 1], [1, 2]) # Add edges (2, 1), (1, 2) g.add_nodes(1) # Add node 2
g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2) g.add_edges([2, 1], [1, 2]) # Add edges (2, 1), (1, 2)
g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2)
#######################################################################################
# Real-world graphs are much more complex. There are many families of graphs, #######################################################################################
# with different sizes, topologies, node types, edge types, and the possibility # Real-world graphs are much more complex. There are many families of graphs,
# of multigraphs. Besides, a same graph can be generated in many different # with different sizes, topologies, node types, edge types, and the possibility
# orders. Regardless, the generative process entails a few steps. # of multigraphs. Besides, a same graph can be generated in many different
# # orders. Regardless, the generative process entails a few steps.
# - Encode a changing graph. #
# - Perform actions stochastically. # - Encode a changing graph.
# - If you are training, collect error signals and optimize the model parameters. # - Perform actions stochastically.
# # - If you are training, collect error signals and optimize the model parameters.
# When it comes to implementation, another important aspect is speed. How do you #
# parallelize the computation, given that generating a graph is fundamentally a # When it comes to implementation, another important aspect is speed. How do you
# sequential process? # parallelize the computation, given that generating a graph is fundamentally a
# # sequential process?
# .. note:: #
# # .. note::
# To be sure, this is not necessarily a hard constraint. Subgraphs can be #
# built in parallel and then get assembled. But we # To be sure, this is not necessarily a hard constraint. Subgraphs can be
# will restrict ourselves to the sequential processes for this tutorial. # built in parallel and then get assembled. But we
# # will restrict ourselves to the sequential processes for this tutorial.
# #
# DGMG: The main flow #
# -------------------- # DGMG: The main flow
# For this tutorial, you use # --------------------
# `Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__ # For this tutorial, you use
# ) (DGMG) to implement a graph generative model using DGL. Its algorithmic # `Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
# framework is general but also challenging to parallelize. # ) (DGMG) to implement a graph generative model using DGL. Its algorithmic
# # framework is general but also challenging to parallelize.
# .. note:: #
# # .. note::
# While it's possible for DGMG to handle complex graphs with typed nodes, #
# typed edges, and multigraphs, here you use a simplified version of it # While it's possible for DGMG to handle complex graphs with typed nodes,
# for generating graph topologies. # typed edges, and multigraphs, here you use a simplified version of it
# # for generating graph topologies.
# DGMG generates a graph by following a state machine, which is basically a #
# two-level loop. Generate one node at a time and connect it to a subset of # DGMG generates a graph by following a state machine, which is basically a
# the existing nodes, one at a time. This is similar to language modeling. The # two-level loop. Generate one node at a time and connect it to a subset of
# generative process is an iterative one that emits one word or character or sentence # the existing nodes, one at a time. This is similar to language modeling. The
# at a time, conditioned on the sequence generated so far. # generative process is an iterative one that emits one word or character or sentence
# # at a time, conditioned on the sequence generated so far.
# At each time step, you either: #
# - Add a new node to the graph # At each time step, you either:
# - Select two existing nodes and add an edge between them # - Add a new node to the graph
# # - Select two existing nodes and add an edge between them
# .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png #
# :alt: # .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png
# # :alt:
# The Python code will look as follows. In fact, this is *exactly* how inference #
# with DGMG is implemented in DGL. # The Python code will look as follows. In fact, this is *exactly* how inference
# # with DGMG is implemented in DGL.
#
def forward_inference(self):
stop = self.add_node_and_update()
while (not stop) and (self.g.number_of_nodes() < self.v_max + 1): def forward_inference(self):
num_trials = 0 stop = self.add_node_and_update()
to_add_edge = self.add_edge_or_not() while (not stop) and (self.g.number_of_nodes() < self.v_max + 1):
while to_add_edge and (num_trials < self.g.number_of_nodes() - 1): num_trials = 0
self.choose_dest_and_update() to_add_edge = self.add_edge_or_not()
num_trials += 1 while to_add_edge and (num_trials < self.g.number_of_nodes() - 1):
to_add_edge = self.add_edge_or_not() self.choose_dest_and_update()
stop = self.add_node_and_update() num_trials += 1
to_add_edge = self.add_edge_or_not()
return self.g stop = self.add_node_and_update()
return self.g
#######################################################################################
# Assume you have a pre-trained model for generating cycles of nodes 10-20.
# How does it generate a cycle on-the-fly during inference? Use the code below #######################################################################################
# to create an animation with your own model. # Assume you have a pre-trained model for generating cycles of nodes 10-20.
# # How does it generate a cycle on-the-fly during inference? Use the code below
# :: # to create an animation with your own model.
# #
# import torch # ::
# import matplotlib.animation as animation #
# import matplotlib.pyplot as plt # import torch
# import networkx as nx # import matplotlib.animation as animation
# from copy import deepcopy # import matplotlib.pyplot as plt
# # import networkx as nx
# if __name__ == '__main__': # from copy import deepcopy
# # pre-trained model saved with path ./model.pth #
# model = torch.load('./model.pth') # if __name__ == '__main__':
# model.eval() # # pre-trained model saved with path ./model.pth
# g = model() # model = torch.load('./model.pth')
# # model.eval()
# src_list = g.edges()[1] # g = model()
# dest_list = g.edges()[0] #
# # src_list = g.edges()[1]
# evolution = [] # dest_list = g.edges()[0]
# #
# nx_g = nx.Graph() # evolution = []
# evolution.append(deepcopy(nx_g)) #
# # nx_g = nx.Graph()
# for i in range(0, len(src_list), 2): # evolution.append(deepcopy(nx_g))
# src = src_list[i].item() #
# dest = dest_list[i].item() # for i in range(0, len(src_list), 2):
# if src not in nx_g.nodes(): # src = src_list[i].item()
# nx_g.add_node(src) # dest = dest_list[i].item()
# evolution.append(deepcopy(nx_g)) # if src not in nx_g.nodes():
# if dest not in nx_g.nodes(): # nx_g.add_node(src)
# nx_g.add_node(dest) # evolution.append(deepcopy(nx_g))
# evolution.append(deepcopy(nx_g)) # if dest not in nx_g.nodes():
# nx_g.add_edges_from([(src, dest), (dest, src)]) # nx_g.add_node(dest)
# evolution.append(deepcopy(nx_g)) # evolution.append(deepcopy(nx_g))
# # nx_g.add_edges_from([(src, dest), (dest, src)])
# def animate(i): # evolution.append(deepcopy(nx_g))
# ax.cla() #
# g_t = evolution[i] # def animate(i):
# nx.draw_circular(g_t, with_labels=True, ax=ax, # ax.cla()
# node_color=['#FEBD69'] * g_t.number_of_nodes()) # g_t = evolution[i]
# # nx.draw_circular(g_t, with_labels=True, ax=ax,
# fig, ax = plt.subplots() # node_color=['#FEBD69'] * g_t.number_of_nodes())
# ani = animation.FuncAnimation(fig, animate, #
# frames=len(evolution), # fig, ax = plt.subplots()
# interval=600) # ani = animation.FuncAnimation(fig, animate,
# # frames=len(evolution),
# .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif # interval=600)
# :alt: #
# # .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif
# DGMG: Optimization objective # :alt:
# ------------------------------ #
# Similar to language modeling, DGMG trains the model with *behavior cloning*, # DGMG: Optimization objective
# or *teacher forcing*. Assume for each graph there exists a sequence of # ------------------------------
# *oracle actions* :math:`a_{1},\cdots,a_{T}` that generates it. What the model # Similar to language modeling, DGMG trains the model with *behavior cloning*,
# does is to follow these actions, compute the joint probabilities of such # or *teacher forcing*. Assume for each graph there exists a sequence of
# action sequences, and maximize them. # *oracle actions* :math:`a_{1},\cdots,a_{T}` that generates it. What the model
# # does is to follow these actions, compute the joint probabilities of such
# By chain rule, the probability of taking :math:`a_{1},\cdots,a_{T}` is: # action sequences, and maximize them.
# #
# .. math:: # By chain rule, the probability of taking :math:`a_{1},\cdots,a_{T}` is:
# #
# p(a_{1},\cdots, a_{T}) = p(a_{1})p(a_{2}|a_{1})\cdots p(a_{T}|a_{1},\cdots,a_{T-1}).\\ # .. math::
# #
# The optimization objective is then simply the typical MLE loss: # p(a_{1},\cdots, a_{T}) = p(a_{1})p(a_{2}|a_{1})\cdots p(a_{T}|a_{1},\cdots,a_{T-1}).\\
# #
# .. math:: # The optimization objective is then simply the typical MLE loss:
# #
# -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\ # .. math::
# #
# -\log p(a_{1},\cdots,a_{T})=-\sum_{t=1}^{T}\log p(a_{t}|a_{1},\cdots, a_{t-1}).\\
def forward_train(self, actions): #
"""
- actions: list
- Contains a_1, ..., a_T described above def forward_train(self, actions):
- self.prepare_for_train() """
- Initializes self.action_step to be 0, which will get - actions: list
incremented by 1 every time it is called. - Contains a_1, ..., a_T described above
- Initializes objects recording log p(a_t|a_1,...a_{t-1}) - self.prepare_for_train()
- Initializes self.action_step to be 0, which will get
Returns incremented by 1 every time it is called.
------- - Initializes objects recording log p(a_t|a_1,...a_{t-1})
- self.get_log_prob(): log p(a_1, ..., a_T)
""" Returns
self.prepare_for_train() -------
- self.get_log_prob(): log p(a_1, ..., a_T)
stop = self.add_node_and_update(a=actions[self.action_step]) """
while not stop: self.prepare_for_train()
to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
while to_add_edge: stop = self.add_node_and_update(a=actions[self.action_step])
self.choose_dest_and_update(a=actions[self.action_step]) while not stop:
to_add_edge = self.add_edge_or_not(a=actions[self.action_step]) to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
stop = self.add_node_and_update(a=actions[self.action_step]) while to_add_edge:
self.choose_dest_and_update(a=actions[self.action_step])
return self.get_log_prob() to_add_edge = self.add_edge_or_not(a=actions[self.action_step])
stop = self.add_node_and_update(a=actions[self.action_step])
####################################################################################### return self.get_log_prob()
# The key difference between ``forward_train`` and ``forward_inference`` is
# that the training process takes oracle actions as input and returns log
# probabilities for evaluating the loss. #######################################################################################
# # The key difference between ``forward_train`` and ``forward_inference`` is
# DGMG: The implementation # that the training process takes oracle actions as input and returns log
# -------------------------- # probabilities for evaluating the loss.
# The ``DGMG`` class #
# `````````````````````````` # DGMG: The implementation
# Below you can find the skeleton code for the model. You gradually # --------------------------
# fill in the details for each function. # The ``DGMG`` class
# # ``````````````````````````
# Below you can find the skeleton code for the model. You gradually
import torch.nn as nn # fill in the details for each function.
#
class DGMGSkeleton(nn.Module): import torch.nn as nn
def __init__(self, v_max):
"""
Parameters class DGMGSkeleton(nn.Module):
---------- def __init__(self, v_max):
v_max: int """
Max number of nodes considered Parameters
""" ----------
super(DGMGSkeleton, self).__init__() v_max: int
Max number of nodes considered
# Graph configuration """
self.v_max = v_max super(DGMGSkeleton, self).__init__()
def add_node_and_update(self, a=None): # Graph configuration
"""Decide if to add a new node. self.v_max = v_max
If a new node should be added, update the graph."""
return NotImplementedError def add_node_and_update(self, a=None):
"""Decide if to add a new node.
def add_edge_or_not(self, a=None): If a new node should be added, update the graph."""
"""Decide if a new edge should be added.""" return NotImplementedError
return NotImplementedError
def add_edge_or_not(self, a=None):
def choose_dest_and_update(self, a=None): """Decide if a new edge should be added."""
"""Choose destination and connect it to the latest node. return NotImplementedError
Add edges for both directions and update the graph."""
return NotImplementedError def choose_dest_and_update(self, a=None):
"""Choose destination and connect it to the latest node.
def forward_train(self, actions): Add edges for both directions and update the graph."""
"""Forward at training time. It records the probability return NotImplementedError
of generating a ground truth graph following the actions."""
return NotImplementedError def forward_train(self, actions):
"""Forward at training time. It records the probability
def forward_inference(self): of generating a ground truth graph following the actions."""
"""Forward at inference time. return NotImplementedError
It generates graphs on the fly."""
return NotImplementedError def forward_inference(self):
"""Forward at inference time.
def forward(self, actions=None): It generates graphs on the fly."""
# The graph you will work on return NotImplementedError
self.g = dgl.DGLGraph()
def forward(self, actions=None):
# If there are some features for nodes and edges, # The graph you will work on
# zero tensors will be set for those of new nodes and edges. self.g = dgl.DGLGraph()
self.g.set_n_initializer(dgl.frame.zero_initializer)
self.g.set_e_initializer(dgl.frame.zero_initializer) # If there are some features for nodes and edges,
# zero tensors will be set for those of new nodes and edges.
if self.training: self.g.set_n_initializer(dgl.frame.zero_initializer)
return self.forward_train(actions=actions) self.g.set_e_initializer(dgl.frame.zero_initializer)
else:
return self.forward_inference() if self.training:
return self.forward_train(actions=actions)
####################################################################################### else:
# Encoding a dynamic graph return self.forward_inference()
# ``````````````````````````
# All the actions generating a graph are sampled from probability
# distributions. In order to do that, you project the structured data, #######################################################################################
# namely the graph, onto an Euclidean space. The challenge is that such # Encoding a dynamic graph
# process, called *embedding*, needs to be repeated as the graphs mutate. # ``````````````````````````
# # All the actions generating a graph are sampled from probability
# Graph embedding # distributions. In order to do that, you project the structured data,
# '''''''''''''''''''''''''' # namely the graph, onto an Euclidean space. The challenge is that such
# Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an # process, called *embedding*, needs to be repeated as the graphs mutate.
# embedding vector :math:`\textbf{h}_{v} \in \mathbb{R}^{n}`. Similarly, #
# the graph has an embedding vector :math:`\textbf{h}_{G} \in \mathbb{R}^{k}`. # Graph embedding
# Typically, :math:`k > n` since a graph contains more information than # ''''''''''''''''''''''''''
# an individual node. # Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an
# # embedding vector :math:`\textbf{h}_{v} \in \mathbb{R}^{n}`. Similarly,
# The graph embedding is a weighted sum of node embeddings under a linear # the graph has an embedding vector :math:`\textbf{h}_{G} \in \mathbb{R}^{k}`.
# transformation: # Typically, :math:`k > n` since a graph contains more information than
# # an individual node.
# .. math:: #
# # The graph embedding is a weighted sum of node embeddings under a linear
# \textbf{h}_{G} =\sum_{v\in V}\text{Sigmoid}(g_m(\textbf{h}_{v}))f_{m}(\textbf{h}_{v}),\\ # transformation:
# #
# The first term, :math:`\text{Sigmoid}(g_m(\textbf{h}_{v}))`, computes a # .. math::
# gating function and can be thought of as how much the overall graph embedding #
# attends on each node. The second term :math:`f_{m}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{k}` # \textbf{h}_{G} =\sum_{v\in V}\text{Sigmoid}(g_m(\textbf{h}_{v}))f_{m}(\textbf{h}_{v}),\\
# maps the node embeddings to the space of graph embeddings. #
# # The first term, :math:`\text{Sigmoid}(g_m(\textbf{h}_{v}))`, computes a
# Implement graph embedding as a ``GraphEmbed`` class. # gating function and can be thought of as how much the overall graph embedding
# # attends on each node. The second term :math:`f_{m}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{k}`
# maps the node embeddings to the space of graph embeddings.
import torch #
# Implement graph embedding as a ``GraphEmbed`` class.
#
class GraphEmbed(nn.Module):
def __init__(self, node_hidden_size): import torch
super(GraphEmbed, self).__init__()
# Setting from the paper class GraphEmbed(nn.Module):
self.graph_hidden_size = 2 * node_hidden_size def __init__(self, node_hidden_size):
super(GraphEmbed, self).__init__()
# Embed graphs
self.node_gating = nn.Sequential( # Setting from the paper
nn.Linear(node_hidden_size, 1), self.graph_hidden_size = 2 * node_hidden_size
nn.Sigmoid()
) # Embed graphs
self.node_to_graph = nn.Linear(node_hidden_size, self.node_gating = nn.Sequential(
self.graph_hidden_size) nn.Linear(node_hidden_size, 1), nn.Sigmoid()
)
def forward(self, g): self.node_to_graph = nn.Linear(node_hidden_size, self.graph_hidden_size)
if g.number_of_nodes() == 0:
return torch.zeros(1, self.graph_hidden_size) def forward(self, g):
else: if g.number_of_nodes() == 0:
# Node features are stored as hv in ndata. return torch.zeros(1, self.graph_hidden_size)
hvs = g.ndata['hv'] else:
return (self.node_gating(hvs) * # Node features are stored as hv in ndata.
self.node_to_graph(hvs)).sum(0, keepdim=True) hvs = g.ndata["hv"]
return (self.node_gating(hvs) * self.node_to_graph(hvs)).sum(
####################################################################################### 0, keepdim=True
# Update node embeddings via graph propagation )
# '''''''''''''''''''''''''''''''''''''''''''''
#
# The mechanism of updating node embeddings in DGMG is similar to that for #######################################################################################
# graph convolutional networks. For a node :math:`v` in the graph, its # Update node embeddings via graph propagation
# neighbor :math:`u` sends a message to it with # '''''''''''''''''''''''''''''''''''''''''''''
# #
# .. math:: # The mechanism of updating node embeddings in DGMG is similar to that for
# # graph convolutional networks. For a node :math:`v` in the graph, its
# \textbf{m}_{u\rightarrow v}=\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m},\\ # neighbor :math:`u` sends a message to it with
# #
# where :math:`\textbf{x}_{u,v}` is the embedding of the edge between # .. math::
# :math:`u` and :math:`v`. #
# # \textbf{m}_{u\rightarrow v}=\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m},\\
# After receiving messages from all its neighbors, :math:`v` summarizes them #
# with a node activation vector # where :math:`\textbf{x}_{u,v}` is the embedding of the edge between
# # :math:`u` and :math:`v`.
# .. math:: #
# # After receiving messages from all its neighbors, :math:`v` summarizes them
# \textbf{a}_{v} = \sum_{u: (u, v)\in E}\textbf{m}_{u\rightarrow v}\\ # with a node activation vector
# #
# and use this information to update its own feature: # .. math::
# #
# .. math:: # \textbf{a}_{v} = \sum_{u: (u, v)\in E}\textbf{m}_{u\rightarrow v}\\
# #
# \textbf{h}'_{v} = \textbf{GRU}(\textbf{h}_{v}, \textbf{a}_{v}).\\ # and use this information to update its own feature:
# #
# Performing all the operations above once for all nodes synchronously is # .. math::
# called one round of graph propagation. The more rounds of graph propagation #
# you perform, the longer distance messages travel throughout the graph. # \textbf{h}'_{v} = \textbf{GRU}(\textbf{h}_{v}, \textbf{a}_{v}).\\
# #
# With DGL, you implement graph propagation with ``g.update_all``. # Performing all the operations above once for all nodes synchronously is
# The message notation here can be a bit confusing. Researchers can refer # called one round of graph propagation. The more rounds of graph propagation
# to :math:`\textbf{m}_{u\rightarrow v}` as messages, however the message function # you perform, the longer distance messages travel throughout the graph.
# below only passes :math:`\text{concat}([\textbf{h}_{u}, \textbf{x}_{u, v}])`. #
# The operation :math:`\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m}` # With DGL, you implement graph propagation with ``g.update_all``.
# is then performed across all edges at once for efficiency consideration. # The message notation here can be a bit confusing. Researchers can refer
# # to :math:`\textbf{m}_{u\rightarrow v}` as messages, however the message function
# below only passes :math:`\text{concat}([\textbf{h}_{u}, \textbf{x}_{u, v}])`.
from functools import partial # The operation :math:`\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m}`
# is then performed across all edges at once for efficiency consideration.
class GraphProp(nn.Module): #
def __init__(self, num_prop_rounds, node_hidden_size):
super(GraphProp, self).__init__() from functools import partial
self.num_prop_rounds = num_prop_rounds
class GraphProp(nn.Module):
# Setting from the paper def __init__(self, num_prop_rounds, node_hidden_size):
self.node_activation_hidden_size = 2 * node_hidden_size super(GraphProp, self).__init__()
message_funcs = [] self.num_prop_rounds = num_prop_rounds
node_update_funcs = []
self.reduce_funcs = [] # Setting from the paper
self.node_activation_hidden_size = 2 * node_hidden_size
for t in range(num_prop_rounds):
# input being [hv, hu, xuv] message_funcs = []
message_funcs.append(nn.Linear(2 * node_hidden_size + 1, node_update_funcs = []
self.node_activation_hidden_size)) self.reduce_funcs = []
self.reduce_funcs.append(partial(self.dgmg_reduce, round=t)) for t in range(num_prop_rounds):
node_update_funcs.append( # input being [hv, hu, xuv]
nn.GRUCell(self.node_activation_hidden_size, message_funcs.append(
node_hidden_size)) nn.Linear(
2 * node_hidden_size + 1, self.node_activation_hidden_size
self.message_funcs = nn.ModuleList(message_funcs) )
self.node_update_funcs = nn.ModuleList(node_update_funcs) )
def dgmg_msg(self, edges): self.reduce_funcs.append(partial(self.dgmg_reduce, round=t))
"""For an edge u->v, return concat([h_u, x_uv])""" node_update_funcs.append(
return {'m': torch.cat([edges.src['hv'], nn.GRUCell(self.node_activation_hidden_size, node_hidden_size)
edges.data['he']], )
dim=1)} self.message_funcs = nn.ModuleList(message_funcs)
self.node_update_funcs = nn.ModuleList(node_update_funcs)
def dgmg_reduce(self, nodes, round):
hv_old = nodes.data['hv'] def dgmg_msg(self, edges):
m = nodes.mailbox['m'] """For an edge u->v, return concat([h_u, x_uv])"""
message = torch.cat([ return {"m": torch.cat([edges.src["hv"], edges.data["he"]], dim=1)}
hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2)
node_activation = (self.message_funcs[round](message)).sum(1) def dgmg_reduce(self, nodes, round):
hv_old = nodes.data["hv"]
return {'a': node_activation} m = nodes.mailbox["m"]
message = torch.cat(
def forward(self, g): [hv_old.unsqueeze(1).expand(-1, m.size(1), -1), m], dim=2
if g.number_of_edges() > 0: )
for t in range(self.num_prop_rounds): node_activation = (self.message_funcs[round](message)).sum(1)
g.update_all(message_func=self.dgmg_msg,
reduce_func=self.reduce_funcs[t]) return {"a": node_activation}
g.ndata['hv'] = self.node_update_funcs[t](
g.ndata['a'], g.ndata['hv']) def forward(self, g):
if g.number_of_edges() > 0:
####################################################################################### for t in range(self.num_prop_rounds):
# Actions g.update_all(
# `````````````````````````` message_func=self.dgmg_msg, reduce_func=self.reduce_funcs[t]
# All actions are sampled from distributions parameterized using neural networks )
# and here they are in turn. g.ndata["hv"] = self.node_update_funcs[t](
# g.ndata["a"], g.ndata["hv"]
# Action 1: Add nodes )
# ''''''''''''''''''''''''''
#
# Given the graph embedding vector :math:`\textbf{h}_{G}`, evaluate #######################################################################################
# # Actions
# .. math:: # ``````````````````````````
# # All actions are sampled from distributions parameterized using neural networks
# \text{Sigmoid}(\textbf{W}_{\text{add node}}\textbf{h}_{G}+b_{\text{add node}}),\\ # and here they are in turn.
# #
# which is then used to parametrize a Bernoulli distribution for deciding whether # Action 1: Add nodes
# to add a new node. # ''''''''''''''''''''''''''
# #
# If a new node is to be added, initialize its feature with # Given the graph embedding vector :math:`\textbf{h}_{G}`, evaluate
# #
# .. math:: # .. math::
# #
# \textbf{W}_{\text{init}}\text{concat}([\textbf{h}_{\text{init}} , \textbf{h}_{G}])+\textbf{b}_{\text{init}},\\ # \text{Sigmoid}(\textbf{W}_{\text{add node}}\textbf{h}_{G}+b_{\text{add node}}),\\
# #
# where :math:`\textbf{h}_{\text{init}}` is a learnable embedding module for # which is then used to parametrize a Bernoulli distribution for deciding whether
# untyped nodes. # to add a new node.
# #
# If a new node is to be added, initialize its feature with
import torch.nn.functional as F #
from torch.distributions import Bernoulli # .. math::
#
def bernoulli_action_log_prob(logit, action): # \textbf{W}_{\text{init}}\text{concat}([\textbf{h}_{\text{init}} , \textbf{h}_{G}])+\textbf{b}_{\text{init}},\\
"""Calculate the log p of an action with respect to a Bernoulli #
distribution. Use logit rather than prob for numerical stability.""" # where :math:`\textbf{h}_{\text{init}}` is a learnable embedding module for
if action == 0: # untyped nodes.
return F.logsigmoid(-logit) #
else:
return F.logsigmoid(logit) import torch.nn.functional as F
from torch.distributions import Bernoulli
class AddNode(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
super(AddNode, self).__init__() def bernoulli_action_log_prob(logit, action):
"""Calculate the log p of an action with respect to a Bernoulli
self.graph_op = {'embed': graph_embed_func} distribution. Use logit rather than prob for numerical stability."""
if action == 0:
self.stop = 1 return F.logsigmoid(-logit)
self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1) else:
return F.logsigmoid(logit)
# If to add a node, initialize its hv
self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.initialize_hv = nn.Linear(node_hidden_size + \ class AddNode(nn.Module):
graph_embed_func.graph_hidden_size, def __init__(self, graph_embed_func, node_hidden_size):
node_hidden_size) super(AddNode, self).__init__()
self.init_node_activation = torch.zeros(1, 2 * node_hidden_size) self.graph_op = {"embed": graph_embed_func}
def _initialize_node_repr(self, g, node_type, graph_embed): self.stop = 1
"""Whenver a node is added, initialize its representation.""" self.add_node = nn.Linear(graph_embed_func.graph_hidden_size, 1)
num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv( # If to add a node, initialize its hv
torch.cat([ self.node_type_embed = nn.Embedding(1, node_hidden_size)
self.node_type_embed(torch.LongTensor([node_type])), self.initialize_hv = nn.Linear(
graph_embed], dim=1)) node_hidden_size + graph_embed_func.graph_hidden_size,
g.nodes[num_nodes - 1].data['hv'] = hv_init node_hidden_size,
g.nodes[num_nodes - 1].data['a'] = self.init_node_activation )
def prepare_training(self): self.init_node_activation = torch.zeros(1, 2 * node_hidden_size)
self.log_prob = []
def _initialize_node_repr(self, g, node_type, graph_embed):
def forward(self, g, action=None): """Whenver a node is added, initialize its representation."""
graph_embed = self.graph_op['embed'](g) num_nodes = g.number_of_nodes()
hv_init = self.initialize_hv(
logit = self.add_node(graph_embed) torch.cat(
prob = torch.sigmoid(logit) [
self.node_type_embed(torch.LongTensor([node_type])),
if not self.training: graph_embed,
action = Bernoulli(prob).sample().item() ],
stop = bool(action == self.stop) dim=1,
)
if not stop: )
g.add_nodes(1) g.nodes[num_nodes - 1].data["hv"] = hv_init
self._initialize_node_repr(g, action, graph_embed) g.nodes[num_nodes - 1].data["a"] = self.init_node_activation
if self.training: def prepare_training(self):
sample_log_prob = bernoulli_action_log_prob(logit, action) self.log_prob = []
self.log_prob.append(sample_log_prob) def forward(self, g, action=None):
graph_embed = self.graph_op["embed"](g)
return stop
logit = self.add_node(graph_embed)
####################################################################################### prob = torch.sigmoid(logit)
# Action 2: Add edges
# '''''''''''''''''''''''''' if not self.training:
# action = Bernoulli(prob).sample().item()
# Given the graph embedding vector :math:`\textbf{h}_{G}` and the node stop = bool(action == self.stop)
# embedding vector :math:`\textbf{h}_{v}` for the latest node :math:`v`,
# you evaluate if not stop:
# g.add_nodes(1)
# .. math:: self._initialize_node_repr(g, action, graph_embed)
# if self.training:
# \text{Sigmoid}(\textbf{W}_{\text{add edge}}\text{concat}([\textbf{h}_{G}, \textbf{h}_{v}])+b_{\text{add edge}}),\\ sample_log_prob = bernoulli_action_log_prob(logit, action)
#
# which is then used to parametrize a Bernoulli distribution for deciding self.log_prob.append(sample_log_prob)
# whether to add a new edge starting from :math:`v`. return stop
#
class AddEdge(nn.Module): #######################################################################################
def __init__(self, graph_embed_func, node_hidden_size): # Action 2: Add edges
super(AddEdge, self).__init__() # ''''''''''''''''''''''''''
#
self.graph_op = {'embed': graph_embed_func} # Given the graph embedding vector :math:`\textbf{h}_{G}` and the node
self.add_edge = nn.Linear(graph_embed_func.graph_hidden_size + \ # embedding vector :math:`\textbf{h}_{v}` for the latest node :math:`v`,
node_hidden_size, 1) # you evaluate
#
def prepare_training(self): # .. math::
self.log_prob = [] #
# \text{Sigmoid}(\textbf{W}_{\text{add edge}}\text{concat}([\textbf{h}_{G}, \textbf{h}_{v}])+b_{\text{add edge}}),\\
def forward(self, g, action=None): #
graph_embed = self.graph_op['embed'](g) # which is then used to parametrize a Bernoulli distribution for deciding
src_embed = g.nodes[g.number_of_nodes() - 1].data['hv'] # whether to add a new edge starting from :math:`v`.
#
logit = self.add_edge(torch.cat(
[graph_embed, src_embed], dim=1))
prob = torch.sigmoid(logit) class AddEdge(nn.Module):
def __init__(self, graph_embed_func, node_hidden_size):
if self.training: super(AddEdge, self).__init__()
sample_log_prob = bernoulli_action_log_prob(logit, action)
self.log_prob.append(sample_log_prob) self.graph_op = {"embed": graph_embed_func}
else: self.add_edge = nn.Linear(
action = Bernoulli(prob).sample().item() graph_embed_func.graph_hidden_size + node_hidden_size, 1
)
to_add_edge = bool(action == 0)
return to_add_edge def prepare_training(self):
self.log_prob = []
#######################################################################################
# Action 3: Choose a destination def forward(self, g, action=None):
# ''''''''''''''''''''''''''''''''' graph_embed = self.graph_op["embed"](g)
# src_embed = g.nodes[g.number_of_nodes() - 1].data["hv"]
# When action 2 returns `True`, choose a destination for the
# latest node :math:`v`. logit = self.add_edge(torch.cat([graph_embed, src_embed], dim=1))
# prob = torch.sigmoid(logit)
# For each possible destination :math:`u\in\{0, \cdots, v-1\}`, the
# probability of choosing it is given by if self.training:
# sample_log_prob = bernoulli_action_log_prob(logit, action)
# .. math:: self.log_prob.append(sample_log_prob)
# else:
# \frac{\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{u}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}{\sum_{i=0}^{v-1}\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{i}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}\\ action = Bernoulli(prob).sample().item()
# to_add_edge = bool(action == 0)
return to_add_edge
from torch.distributions import Categorical
class ChooseDestAndUpdate(nn.Module): #######################################################################################
def __init__(self, graph_prop_func, node_hidden_size): # Action 3: Choose a destination
super(ChooseDestAndUpdate, self).__init__() # '''''''''''''''''''''''''''''''''
#
self.graph_op = {'prop': graph_prop_func} # When action 2 returns `True`, choose a destination for the
self.choose_dest = nn.Linear(2 * node_hidden_size, 1) # latest node :math:`v`.
#
def _initialize_edge_repr(self, g, src_list, dest_list): # For each possible destination :math:`u\in\{0, \cdots, v-1\}`, the
# For untyped edges, only add 1 to indicate its existence. # probability of choosing it is given by
# For multiple edge types, use a one-hot representation #
# or an embedding module. # .. math::
edge_repr = torch.ones(len(src_list), 1) #
g.edges[src_list, dest_list].data['he'] = edge_repr # \frac{\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{u}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}{\sum_{i=0}^{v-1}\text{exp}(\textbf{W}_{\text{dest}}\text{concat}([\textbf{h}_{i}, \textbf{h}_{v}])+\textbf{b}_{\text{dest}})}\\
#
def prepare_training(self):
self.log_prob = [] from torch.distributions import Categorical
def forward(self, g, dest):
src = g.number_of_nodes() - 1 class ChooseDestAndUpdate(nn.Module):
possible_dests = range(src) def __init__(self, graph_prop_func, node_hidden_size):
super(ChooseDestAndUpdate, self).__init__()
src_embed_expand = g.nodes[src].data['hv'].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data['hv'] self.graph_op = {"prop": graph_prop_func}
self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
dests_scores = self.choose_dest(
torch.cat([possible_dests_embed, def _initialize_edge_repr(self, g, src_list, dest_list):
src_embed_expand], dim=1)).view(1, -1) # For untyped edges, only add 1 to indicate its existence.
dests_probs = F.softmax(dests_scores, dim=1) # For multiple edge types, use a one-hot representation
# or an embedding module.
if not self.training: edge_repr = torch.ones(len(src_list), 1)
dest = Categorical(dests_probs).sample().item() g.edges[src_list, dest_list].data["he"] = edge_repr
if not g.has_edges_between(src, dest): def prepare_training(self):
# For undirected graphs, add edges for both directions self.log_prob = []
# so that you can perform graph propagation.
src_list = [src, dest] def forward(self, g, dest):
dest_list = [dest, src] src = g.number_of_nodes() - 1
possible_dests = range(src)
g.add_edges(src_list, dest_list)
self._initialize_edge_repr(g, src_list, dest_list) src_embed_expand = g.nodes[src].data["hv"].expand(src, -1)
possible_dests_embed = g.nodes[possible_dests].data["hv"]
self.graph_op['prop'](g)
dests_scores = self.choose_dest(
if self.training: torch.cat([possible_dests_embed, src_embed_expand], dim=1)
if dests_probs.nelement() > 1: ).view(1, -1)
self.log_prob.append( dests_probs = F.softmax(dests_scores, dim=1)
F.log_softmax(dests_scores, dim=1)[:, dest: dest + 1])
if not self.training:
####################################################################################### dest = Categorical(dests_probs).sample().item()
# Putting it together if not g.has_edges_between(src, dest):
# `````````````````````````` # For undirected graphs, add edges for both directions
# # so that you can perform graph propagation.
# You are now ready to have a complete implementation of the model class. src_list = [src, dest]
# dest_list = [dest, src]
class DGMG(DGMGSkeleton): g.add_edges(src_list, dest_list)
def __init__(self, v_max, node_hidden_size, self._initialize_edge_repr(g, src_list, dest_list)
num_prop_rounds):
super(DGMG, self).__init__(v_max) self.graph_op["prop"](g)
if self.training:
# Graph embedding module if dests_probs.nelement() > 1:
self.graph_embed = GraphEmbed(node_hidden_size) self.log_prob.append(
F.log_softmax(dests_scores, dim=1)[:, dest : dest + 1]
# Graph propagation module )
self.graph_prop = GraphProp(num_prop_rounds,
node_hidden_size)
#######################################################################################
# Actions # Putting it together
self.add_node_agent = AddNode( # ``````````````````````````
self.graph_embed, node_hidden_size) #
self.add_edge_agent = AddEdge( # You are now ready to have a complete implementation of the model class.
self.graph_embed, node_hidden_size) #
self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size)
class DGMG(DGMGSkeleton):
# Forward functions def __init__(self, v_max, node_hidden_size, num_prop_rounds):
self.forward_train = partial(forward_train, self=self) super(DGMG, self).__init__(v_max)
self.forward_inference = partial(forward_inference, self=self)
# Graph embedding module
@property self.graph_embed = GraphEmbed(node_hidden_size)
def action_step(self):
old_step_count = self.step_count # Graph propagation module
self.step_count += 1 self.graph_prop = GraphProp(num_prop_rounds, node_hidden_size)
return old_step_count # Actions
self.add_node_agent = AddNode(self.graph_embed, node_hidden_size)
def prepare_for_train(self): self.add_edge_agent = AddEdge(self.graph_embed, node_hidden_size)
self.step_count = 0 self.choose_dest_agent = ChooseDestAndUpdate(
self.graph_prop, node_hidden_size
self.add_node_agent.prepare_training() )
self.add_edge_agent.prepare_training()
self.choose_dest_agent.prepare_training() # Forward functions
self.forward_train = partial(forward_train, self=self)
def add_node_and_update(self, a=None): self.forward_inference = partial(forward_inference, self=self)
"""Decide if to add a new node.
If a new node should be added, update the graph.""" @property
def action_step(self):
return self.add_node_agent(self.g, a) old_step_count = self.step_count
self.step_count += 1
def add_edge_or_not(self, a=None):
"""Decide if a new edge should be added.""" return old_step_count
return self.add_edge_agent(self.g, a) def prepare_for_train(self):
self.step_count = 0
def choose_dest_and_update(self, a=None):
"""Choose destination and connect it to the latest node. self.add_node_agent.prepare_training()
Add edges for both directions and update the graph.""" self.add_edge_agent.prepare_training()
self.choose_dest_agent.prepare_training()
self.choose_dest_agent(self.g, a)
def add_node_and_update(self, a=None):
def get_log_prob(self): """Decide if to add a new node.
add_node_log_p = torch.cat(self.add_node_agent.log_prob).sum() If a new node should be added, update the graph."""
add_edge_log_p = torch.cat(self.add_edge_agent.log_prob).sum()
choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum() return self.add_node_agent(self.g, a)
return add_node_log_p + add_edge_log_p + choose_dest_log_p
def add_edge_or_not(self, a=None):
####################################################################################### """Decide if a new edge should be added."""
# Below is an animation where a graph is generated on the fly
# after every 10 batches of training for the first 400 batches. You return self.add_edge_agent(self.g, a)
# can see how the model improves over time and begins generating cycles.
# def choose_dest_and_update(self, a=None):
# .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif """Choose destination and connect it to the latest node.
# :alt: Add edges for both directions and update the graph."""
#
# For generative models, you can evaluate performance by checking the percentage self.choose_dest_agent(self.g, a)
# of valid graphs among the graphs it generates on the fly.
def get_log_prob(self):
import torch.utils.model_zoo as model_zoo add_node_log_p = torch.cat(self.add_node_agent.log_prob).sum()
add_edge_log_p = torch.cat(self.add_edge_agent.log_prob).sum()
# Download a pre-trained model state dict for generating cycles with 10-20 nodes. choose_dest_log_p = torch.cat(self.choose_dest_agent.log_prob).sum()
state_dict = model_zoo.load_url('https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth') return add_node_log_p + add_edge_log_p + choose_dest_log_p
model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2)
model.load_state_dict(state_dict)
model.eval() #######################################################################################
# Below is an animation where a graph is generated on the fly
def is_valid(g): # after every 10 batches of training for the first 400 batches. You
# Check if g is a cycle having 10-20 nodes. # can see how the model improves over time and begins generating cycles.
def _get_previous(i, v_max): #
if i == 0: # .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif
return v_max # :alt:
else: #
return i - 1 # For generative models, you can evaluate performance by checking the percentage
# of valid graphs among the graphs it generates on the fly.
def _get_next(i, v_max):
if i == v_max: import torch.utils.model_zoo as model_zoo
return 0
else: # Download a pre-trained model state dict for generating cycles with 10-20 nodes.
return i + 1 state_dict = model_zoo.load_url(
"https://data.dgl.ai/model/dgmg_cycles-5a0c40be.pth"
size = g.number_of_nodes() )
model = DGMG(v_max=20, node_hidden_size=16, num_prop_rounds=2)
if size < 10 or size > 20: model.load_state_dict(state_dict)
return False model.eval()
for node in range(size):
neighbors = g.successors(node) def is_valid(g):
# Check if g is a cycle having 10-20 nodes.
if len(neighbors) != 2: def _get_previous(i, v_max):
return False if i == 0:
return v_max
if _get_previous(node, size - 1) not in neighbors: else:
return False return i - 1
if _get_next(node, size - 1) not in neighbors: def _get_next(i, v_max):
return False if i == v_max:
return 0
return True else:
return i + 1
num_valid = 0
for i in range(100): size = g.number_of_nodes()
g = model()
num_valid += is_valid(g) if size < 10 or size > 20:
return False
del model for node in range(size):
print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) neighbors = g.successors(node)
####################################################################################### if len(neighbors) != 2:
# For the complete implementation, see the `DGL DGMG example return False
# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__. if _get_previous(node, size - 1) not in neighbors:
# return False
if _get_next(node, size - 1) not in neighbors:
return False
return True
num_valid = 0
for i in range(100):
g = model()
num_valid += is_valid(g)
del model
print("Among 100 graphs generated, {}% are valid.".format(num_valid))
#######################################################################################
# For the complete implementation, see the `DGL DGMG example
# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__.
#
...@@ -17,276 +17,275 @@ offers a different perspective. The tutorial describes how to implement a Capsul ...@@ -17,276 +17,275 @@ offers a different perspective. The tutorial describes how to implement a Capsul
efficiency. For recommended implementation, please refer to the `official efficiency. For recommended implementation, please refer to the `official
examples <https://github.com/dmlc/dgl/tree/master/examples>`_. examples <https://github.com/dmlc/dgl/tree/master/examples>`_.
""" """
####################################################################################### #######################################################################################
# Key ideas of Capsule # Key ideas of Capsule
# -------------------- # --------------------
# #
# The Capsule model offers two key ideas: Richer representation and dynamic routing. # The Capsule model offers two key ideas: Richer representation and dynamic routing.
# #
# **Richer representation** -- In classic convolutional networks, a scalar # **Richer representation** -- In classic convolutional networks, a scalar
# value represents the activation of a given feature. By contrast, a # value represents the activation of a given feature. By contrast, a
# capsule outputs a vector. The vector's length represents the probability # capsule outputs a vector. The vector's length represents the probability
# of a feature being present. The vector's orientation represents the # of a feature being present. The vector's orientation represents the
# various properties of the feature (such as pose, deformation, texture # various properties of the feature (such as pose, deformation, texture
# etc.). # etc.).
# #
# |image0| # |image0|
# #
# **Dynamic routing** -- The output of a capsule is sent to # **Dynamic routing** -- The output of a capsule is sent to
# certain parents in the layer above based on how well the capsule's # certain parents in the layer above based on how well the capsule's
# prediction agrees with that of a parent. Such dynamic # prediction agrees with that of a parent. Such dynamic
# routing-by-agreement generalizes the static routing of max-pooling. # routing-by-agreement generalizes the static routing of max-pooling.
# #
# During training, routing is accomplished iteratively. Each iteration adjusts # During training, routing is accomplished iteratively. Each iteration adjusts
# routing weights between capsules based on their observed agreements. # routing weights between capsules based on their observed agreements.
# It's a manner similar to a k-means algorithm or `competitive # It's a manner similar to a k-means algorithm or `competitive
# learning <https://en.wikipedia.org/wiki/Competitive_learning>`__. # learning <https://en.wikipedia.org/wiki/Competitive_learning>`__.
# #
# In this tutorial, you see how a capsule's dynamic routing algorithm can be # In this tutorial, you see how a capsule's dynamic routing algorithm can be
# naturally expressed as a graph algorithm. The implementation is adapted # naturally expressed as a graph algorithm. The implementation is adapted
# from `Cedric # from `Cedric
# Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing # Chee <https://github.com/cedrickchee/capsule-net-pytorch>`__, replacing
# only the routing layer. This version achieves similar speed and accuracy. # only the routing layer. This version achieves similar speed and accuracy.
# #
# Model implementation # Model implementation
# ---------------------- # ----------------------
# Step 1: Setup and graph initialization # Step 1: Setup and graph initialization
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# The connectivity between two layers of capsules form a directed, # The connectivity between two layers of capsules form a directed,
# bipartite graph, as shown in the Figure below. # bipartite graph, as shown in the Figure below.
# #
# |image1| # |image1|
# #
# Each node :math:`j` is associated with feature :math:`v_j`, # Each node :math:`j` is associated with feature :math:`v_j`,
# representing its capsule’s output. Each edge is associated with # representing its capsule’s output. Each edge is associated with
# features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}` # features :math:`b_{ij}` and :math:`\hat{u}_{j|i}`. :math:`b_{ij}`
# determines routing weights, and :math:`\hat{u}_{j|i}` represents the # determines routing weights, and :math:`\hat{u}_{j|i}` represents the
# prediction of capsule :math:`i` for :math:`j`. # prediction of capsule :math:`i` for :math:`j`.
# #
# Here's how we set up the graph and initialize node and edge features. # Here's how we set up the graph and initialize node and edge features.
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
import matplotlib.pyplot as plt os.environ["DGLBACKEND"] = "pytorch"
import numpy as np import dgl
import torch as th import matplotlib.pyplot as plt
import torch.nn as nn import numpy as np
import torch.nn.functional as F import torch as th
import torch.nn as nn
import dgl import torch.nn.functional as F
def init_graph(in_nodes, out_nodes, f_size): def init_graph(in_nodes, out_nodes, f_size):
u = np.repeat(np.arange(in_nodes), out_nodes) u = np.repeat(np.arange(in_nodes), out_nodes)
v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes) v = np.tile(np.arange(in_nodes, in_nodes + out_nodes), in_nodes)
g = dgl.DGLGraph((u, v)) g = dgl.DGLGraph((u, v))
# init states # init states
g.ndata["v"] = th.zeros(in_nodes + out_nodes, f_size) g.ndata["v"] = th.zeros(in_nodes + out_nodes, f_size)
g.edata["b"] = th.zeros(in_nodes * out_nodes, 1) g.edata["b"] = th.zeros(in_nodes * out_nodes, 1)
return g return g
######################################################################################### #########################################################################################
# Step 2: Define message passing functions # Step 2: Define message passing functions
# ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
# #
# This is the pseudocode for Capsule's routing algorithm. # This is the pseudocode for Capsule's routing algorithm.
# #
# |image2| # |image2|
# Implement pseudocode lines 4-7 in the class `DGLRoutingLayer` as the following steps: # Implement pseudocode lines 4-7 in the class `DGLRoutingLayer` as the following steps:
# #
# 1. Calculate coupling coefficients. # 1. Calculate coupling coefficients.
# #
# - Coefficients are the softmax over all out-edge of in-capsules. # - Coefficients are the softmax over all out-edge of in-capsules.
# :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`. # :math:`\textbf{c}_{i,j} = \text{softmax}(\textbf{b}_{i,j})`.
# #
# 2. Calculate weighted sum over all in-capsules. # 2. Calculate weighted sum over all in-capsules.
# #
# - Output of a capsule is equal to the weighted sum of its in-capsules # - Output of a capsule is equal to the weighted sum of its in-capsules
# :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}` # :math:`s_j=\sum_i c_{ij}\hat{u}_{j|i}`
# #
# 3. Squash outputs. # 3. Squash outputs.
# #
# - Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present). # - Squash the length of a Capsule's output vector to range (0,1), so it can represent the probability (of some feature being present).
# - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}` # - :math:`v_j=\text{squash}(s_j)=\frac{||s_j||^2}{1+||s_j||^2}\frac{s_j}{||s_j||}`
# #
# 4. Update weights by the amount of agreement. # 4. Update weights by the amount of agreement.
# #
# - The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update # - The scalar product :math:`\hat{u}_{j|i}\cdot v_j` can be considered as how well capsule :math:`i` agrees with :math:`j`. It is used to update
# :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j` # :math:`b_{ij}=b_{ij}+\hat{u}_{j|i}\cdot v_j`
import dgl.function as fn import dgl.function as fn
class DGLRoutingLayer(nn.Module): class DGLRoutingLayer(nn.Module):
def __init__(self, in_nodes, out_nodes, f_size): def __init__(self, in_nodes, out_nodes, f_size):
super(DGLRoutingLayer, self).__init__() super(DGLRoutingLayer, self).__init__()
self.g = init_graph(in_nodes, out_nodes, f_size) self.g = init_graph(in_nodes, out_nodes, f_size)
self.in_nodes = in_nodes self.in_nodes = in_nodes
self.out_nodes = out_nodes self.out_nodes = out_nodes
self.in_indx = list(range(in_nodes)) self.in_indx = list(range(in_nodes))
self.out_indx = list(range(in_nodes, in_nodes + out_nodes)) self.out_indx = list(range(in_nodes, in_nodes + out_nodes))
def forward(self, u_hat, routing_num=1): def forward(self, u_hat, routing_num=1):
self.g.edata["u_hat"] = u_hat self.g.edata["u_hat"] = u_hat
for r in range(routing_num): for r in range(routing_num):
# step 1 (line 4): normalize over out edges # step 1 (line 4): normalize over out edges
edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes) edges_b = self.g.edata["b"].view(self.in_nodes, self.out_nodes)
self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1) self.g.edata["c"] = F.softmax(edges_b, dim=1).view(-1, 1)
self.g.edata["c u_hat"] = self.g.edata["c"] * self.g.edata["u_hat"] self.g.edata["c u_hat"] = self.g.edata["c"] * self.g.edata["u_hat"]
# Execute step 1 & 2 # Execute step 1 & 2
self.g.update_all(fn.copy_e("c u_hat", "m"), fn.sum("m", "s")) self.g.update_all(fn.copy_e("c u_hat", "m"), fn.sum("m", "s"))
# step 3 (line 6) # step 3 (line 6)
self.g.nodes[self.out_indx].data["v"] = self.squash( self.g.nodes[self.out_indx].data["v"] = self.squash(
self.g.nodes[self.out_indx].data["s"], dim=1 self.g.nodes[self.out_indx].data["s"], dim=1
) )
# step 4 (line 7) # step 4 (line 7)
v = th.cat( v = th.cat(
[self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0 [self.g.nodes[self.out_indx].data["v"]] * self.in_nodes, dim=0
) )
self.g.edata["b"] = self.g.edata["b"] + ( self.g.edata["b"] = self.g.edata["b"] + (
self.g.edata["u_hat"] * v self.g.edata["u_hat"] * v
).sum(dim=1, keepdim=True) ).sum(dim=1, keepdim=True)
@staticmethod @staticmethod
def squash(s, dim=1): def squash(s, dim=1):
sq = th.sum(s**2, dim=dim, keepdim=True) sq = th.sum(s**2, dim=dim, keepdim=True)
s_norm = th.sqrt(sq) s_norm = th.sqrt(sq)
s = (sq / (1.0 + sq)) * (s / s_norm) s = (sq / (1.0 + sq)) * (s / s_norm)
return s return s
############################################################################################################ ############################################################################################################
# Step 3: Testing # Step 3: Testing
# ~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~
# #
# Make a simple 20x10 capsule layer. # Make a simple 20x10 capsule layer.
in_nodes = 20 in_nodes = 20
out_nodes = 10 out_nodes = 10
f_size = 4 f_size = 4
u_hat = th.randn(in_nodes * out_nodes, f_size) u_hat = th.randn(in_nodes * out_nodes, f_size)
routing = DGLRoutingLayer(in_nodes, out_nodes, f_size) routing = DGLRoutingLayer(in_nodes, out_nodes, f_size)
############################################################################################################ ############################################################################################################
# You can visualize a Capsule network's behavior by monitoring the entropy # You can visualize a Capsule network's behavior by monitoring the entropy
# of coupling coefficients. They should start high and then drop, as the # of coupling coefficients. They should start high and then drop, as the
# weights gradually concentrate on fewer edges. # weights gradually concentrate on fewer edges.
entropy_list = [] entropy_list = []
dist_list = [] dist_list = []
for i in range(10): for i in range(10):
routing(u_hat) routing(u_hat)
dist_matrix = routing.g.edata["c"].view(in_nodes, out_nodes) dist_matrix = routing.g.edata["c"].view(in_nodes, out_nodes)
entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1) entropy = (-dist_matrix * th.log(dist_matrix)).sum(dim=1)
entropy_list.append(entropy.data.numpy()) entropy_list.append(entropy.data.numpy())
dist_list.append(dist_matrix.data.numpy()) dist_list.append(dist_matrix.data.numpy())
stds = np.std(entropy_list, axis=1)
stds = np.std(entropy_list, axis=1) means = np.mean(entropy_list, axis=1)
means = np.mean(entropy_list, axis=1) plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o")
plt.errorbar(np.arange(len(entropy_list)), means, stds, marker="o") plt.ylabel("Entropy of Weight Distribution")
plt.ylabel("Entropy of Weight Distribution") plt.xlabel("Number of Routing")
plt.xlabel("Number of Routing") plt.xticks(np.arange(len(entropy_list)))
plt.xticks(np.arange(len(entropy_list))) plt.close()
plt.close() ############################################################################################################
############################################################################################################ # |image3|
# |image3| #
# # Alternatively, we can also watch the evolution of histograms.
# Alternatively, we can also watch the evolution of histograms.
import matplotlib.animation as animation
import matplotlib.animation as animation import seaborn as sns
import seaborn as sns
fig = plt.figure(dpi=150)
fig = plt.figure(dpi=150) fig.clf()
fig.clf() ax = fig.subplots()
ax = fig.subplots()
def dist_animate(i):
def dist_animate(i): ax.cla()
ax.cla() sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax)
sns.distplot(dist_list[i].reshape(-1), kde=False, ax=ax) ax.set_xlabel("Weight Distribution Histogram")
ax.set_xlabel("Weight Distribution Histogram") ax.set_title("Routing: %d" % (i))
ax.set_title("Routing: %d" % (i))
ani = animation.FuncAnimation(
ani = animation.FuncAnimation( fig, dist_animate, frames=len(entropy_list), interval=500
fig, dist_animate, frames=len(entropy_list), interval=500 )
) plt.close()
plt.close()
############################################################################################################
############################################################################################################ # |image4|
# |image4| #
# # You can monitor the how lower-level Capsules gradually attach to one of the
# You can monitor the how lower-level Capsules gradually attach to one of the # higher level ones.
# higher level ones. import networkx as nx
import networkx as nx from networkx.algorithms import bipartite
from networkx.algorithms import bipartite
g = routing.g.to_networkx()
g = routing.g.to_networkx() X, Y = bipartite.sets(g)
X, Y = bipartite.sets(g) height_in = 10
height_in = 10 height_out = height_in * 0.8
height_out = height_in * 0.8 height_in_y = np.linspace(0, height_in, in_nodes)
height_in_y = np.linspace(0, height_in, in_nodes) height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes)
height_out_y = np.linspace((height_in - height_out) / 2, height_out, out_nodes) pos = dict()
pos = dict()
fig2 = plt.figure(figsize=(8, 3), dpi=150)
fig2 = plt.figure(figsize=(8, 3), dpi=150) fig2.clf()
fig2.clf() ax = fig2.subplots()
ax = fig2.subplots() pos.update(
pos.update( (n, (i, 1)) for i, n in zip(height_in_y, X)
(n, (i, 1)) for i, n in zip(height_in_y, X) ) # put nodes from X at x=1
) # put nodes from X at x=1 pos.update(
pos.update( (n, (i, 2)) for i, n in zip(height_out_y, Y)
(n, (i, 2)) for i, n in zip(height_out_y, Y) ) # put nodes from Y at x=2
) # put nodes from Y at x=2
def weight_animate(i):
def weight_animate(i): ax.cla()
ax.cla() ax.axis("off")
ax.axis("off") ax.set_title("Routing: %d " % i)
ax.set_title("Routing: %d " % i) dm = dist_list[i]
dm = dist_list[i] nx.draw_networkx_nodes(
nx.draw_networkx_nodes( g, pos, nodelist=range(in_nodes), node_color="r", node_size=100, ax=ax
g, pos, nodelist=range(in_nodes), node_color="r", node_size=100, ax=ax )
) nx.draw_networkx_nodes(
nx.draw_networkx_nodes( g,
g, pos,
pos, nodelist=range(in_nodes, in_nodes + out_nodes),
nodelist=range(in_nodes, in_nodes + out_nodes), node_color="b",
node_color="b", node_size=100,
node_size=100, ax=ax,
ax=ax, )
) for edge in g.edges():
for edge in g.edges(): nx.draw_networkx_edges(
nx.draw_networkx_edges( g,
g, pos,
pos, edgelist=[edge],
edgelist=[edge], width=dm[edge[0], edge[1] - in_nodes] * 1.5,
width=dm[edge[0], edge[1] - in_nodes] * 1.5, ax=ax,
ax=ax, )
)
ani2 = animation.FuncAnimation(
ani2 = animation.FuncAnimation( fig2, weight_animate, frames=len(dist_list), interval=500
fig2, weight_animate, frames=len(dist_list), interval=500 )
) plt.close()
plt.close()
############################################################################################################
############################################################################################################ # |image5|
# |image5| #
# # The full code of this visualization is provided on
# The full code of this visualization is provided on # `GitHub <https://github.com/dmlc/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__. The complete
# `GitHub <https://github.com/dmlc/dgl/blob/master/examples/pytorch/capsule/simple_routing.py>`__. The complete # code that trains on MNIST is also on `GitHub <https://github.com/dmlc/dgl/tree/tutorial/examples/pytorch/capsule>`__.
# code that trains on MNIST is also on `GitHub <https://github.com/dmlc/dgl/tree/tutorial/examples/pytorch/capsule>`__. #
# # .. |image0| image:: https://i.imgur.com/55Ovkdh.png
# .. |image0| image:: https://i.imgur.com/55Ovkdh.png # .. |image1| image:: https://i.imgur.com/9tc6GLl.png
# .. |image1| image:: https://i.imgur.com/9tc6GLl.png # .. |image2| image:: https://i.imgur.com/mv1W9Rv.png
# .. |image2| image:: https://i.imgur.com/mv1W9Rv.png # .. |image3| image:: https://i.imgur.com/dMvu7p3.png
# .. |image3| image:: https://i.imgur.com/dMvu7p3.png # .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif
# .. |image4| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_dist.gif # .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif
# .. |image5| image:: https://github.com/VoVAllen/DGL_Capsule/raw/master/routing_vis.gif
...@@ -104,7 +104,7 @@ Transformer as a Graph Neural Network ...@@ -104,7 +104,7 @@ Transformer as a Graph Neural Network
# - ``get_o`` maps the updated value after attention to the output # - ``get_o`` maps the updated value after attention to the output
# :math:`o` for post-processing. # :math:`o` for post-processing.
# #
# .. code:: # .. code::
# #
# class MultiHeadAttention(nn.Module): # class MultiHeadAttention(nn.Module):
# "Multi-Head Attention" # "Multi-Head Attention"
...@@ -146,14 +146,14 @@ Transformer as a Graph Neural Network ...@@ -146,14 +146,14 @@ Transformer as a Graph Neural Network
# #
# Construct the graph by mapping tokens of the source and target # Construct the graph by mapping tokens of the source and target
# sentence to nodes. The complete Transformer graph is made up of three # sentence to nodes. The complete Transformer graph is made up of three
# subgraphs: # subgraphs:
# #
# **Source language graph**. This is a complete graph, each # **Source language graph**. This is a complete graph, each
# token :math:`s_i` can attend to any other token :math:`s_j` (including # token :math:`s_i` can attend to any other token :math:`s_j` (including
# self-loops). |image0| # self-loops). |image0|
# **Target language graph**. The graph is # **Target language graph**. The graph is
# half-complete, in that :math:`t_i` attends only to :math:`t_j` if # half-complete, in that :math:`t_i` attends only to :math:`t_j` if
# :math:`i > j` (an output token can not depend on future words). |image1| # :math:`i > j` (an output token can not depend on future words). |image1|
# **Cross-language graph**. This is a bi-partitie graph, where there is # **Cross-language graph**. This is a bi-partitie graph, where there is
# an edge from every source token :math:`s_i` to every target token # an edge from every source token :math:`s_i` to every target token
# :math:`t_j`, meaning every target token can attend on source tokens. # :math:`t_j`, meaning every target token can attend on source tokens.
...@@ -191,7 +191,7 @@ Transformer as a Graph Neural Network ...@@ -191,7 +191,7 @@ Transformer as a Graph Neural Network
# #
# Compute ``score`` and send source node’s ``v`` to destination’s mailbox # Compute ``score`` and send source node’s ``v`` to destination’s mailbox
# #
# .. code:: # .. code::
# #
# def message_func(edges): # def message_func(edges):
# return {'score': ((edges.src['k'] * edges.dst['q']) # return {'score': ((edges.src['k'] * edges.dst['q'])
...@@ -203,7 +203,7 @@ Transformer as a Graph Neural Network ...@@ -203,7 +203,7 @@ Transformer as a Graph Neural Network
# #
# Normalize over all in-edges and weighted sum to get output # Normalize over all in-edges and weighted sum to get output
# #
# .. code:: # .. code::
# #
# import torch as th # import torch as th
# import torch.nn.functional as F # import torch.nn.functional as F
...@@ -216,7 +216,7 @@ Transformer as a Graph Neural Network ...@@ -216,7 +216,7 @@ Transformer as a Graph Neural Network
# Execute on specific edges # Execute on specific edges
# ''''''''''''''''''''''''' # '''''''''''''''''''''''''
# #
# .. code:: # .. code::
# #
# import functools.partial as partial # import functools.partial as partial
# def naive_propagate_attention(self, g, eids): # def naive_propagate_attention(self, g, eids):
...@@ -269,7 +269,7 @@ Transformer as a Graph Neural Network ...@@ -269,7 +269,7 @@ Transformer as a Graph Neural Network
# #
# The normalization of :math:`\textrm{wv}` is left to post processing. # The normalization of :math:`\textrm{wv}` is left to post processing.
# #
# .. code:: # .. code::
# #
# def src_dot_dst(src_field, dst_field, out_field): # def src_dot_dst(src_field, dst_field, out_field):
# def func(edges): # def func(edges):
...@@ -338,7 +338,7 @@ Transformer as a Graph Neural Network ...@@ -338,7 +338,7 @@ Transformer as a Graph Neural Network
# #
# where :math:`\textrm{FFN}` refers to the feed forward function. # where :math:`\textrm{FFN}` refers to the feed forward function.
# #
# .. code:: # .. code::
# #
# class Encoder(nn.Module): # class Encoder(nn.Module):
# def __init__(self, layer, N): # def __init__(self, layer, N):
...@@ -501,7 +501,7 @@ Transformer as a Graph Neural Network ...@@ -501,7 +501,7 @@ Transformer as a Graph Neural Network
# Task and the dataset # Task and the dataset
# ~~~~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~~~~
# #
# The Transformer is a general framework for a variety of NLP tasks. This tutorial focuses # The Transformer is a general framework for a variety of NLP tasks. This tutorial focuses
# on the sequence to sequence learning: it’s a typical case to illustrate how it works. # on the sequence to sequence learning: it’s a typical case to illustrate how it works.
# #
# As for the dataset, there are two example tasks: copy and sort, together # As for the dataset, there are two example tasks: copy and sort, together
...@@ -729,7 +729,7 @@ Transformer as a Graph Neural Network ...@@ -729,7 +729,7 @@ Transformer as a Graph Neural Network
# with this nodes. The following code shows the Universal Transformer # with this nodes. The following code shows the Universal Transformer
# class in DGL: # class in DGL:
# #
# .. code:: # .. code::
# #
# class UTransformer(nn.Module): # class UTransformer(nn.Module):
# "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)." # "Universal Transformer(https://arxiv.org/pdf/1807.03819.pdf) with ACT(https://arxiv.org/pdf/1603.08983.pdf)."
...@@ -849,10 +849,10 @@ Transformer as a Graph Neural Network ...@@ -849,10 +849,10 @@ Transformer as a Graph Neural Network
# that are still active: # that are still active:
# #
# .. note:: # .. note::
# #
# - :func:`~dgl.DGLGraph.filter_nodes` takes a predicate and a node # - :func:`~dgl.DGLGraph.filter_nodes` takes a predicate and a node
# ID list/tensor as input, then returns a tensor of node IDs that satisfy # ID list/tensor as input, then returns a tensor of node IDs that satisfy
# the given predicate. # the given predicate.
# - :func:`~dgl.DGLGraph.filter_edges` takes a predicate # - :func:`~dgl.DGLGraph.filter_edges` takes a predicate
# and an edge ID list/tensor as input, then returns a tensor of edge IDs # and an edge ID list/tensor as input, then returns a tensor of edge IDs
# that satisfy the given predicate. # that satisfy the given predicate.
...@@ -883,6 +883,6 @@ Transformer as a Graph Neural Network ...@@ -883,6 +883,6 @@ Transformer as a Graph Neural Network
# #
# .. note:: # .. note::
# The notebook itself is not executable due to many dependencies. # The notebook itself is not executable due to many dependencies.
# Download `7_transformer.py <https://data.dgl.ai/tutorial/7_transformer.py>`__, # Download `7_transformer.py <https://data.dgl.ai/tutorial/7_transformer.py>`__,
# and copy the python script to directory ``examples/pytorch/transformer`` # and copy the python script to directory ``examples/pytorch/transformer``
# then run ``python 7_transformer.py`` to see how it works. # then run ``python 7_transformer.py`` to see how it works.
...@@ -71,7 +71,8 @@ process ID, which should be an integer from `0` to `world_size - 1`. ...@@ -71,7 +71,8 @@ process ID, which should be an integer from `0` to `world_size - 1`.
""" """
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import torch.distributed as dist import torch.distributed as dist
......
...@@ -26,63 +26,65 @@ models with multi-GPU with ``DistributedDataParallel``. ...@@ -26,63 +26,65 @@ models with multi-GPU with ``DistributedDataParallel``.
###################################################################### ######################################################################
# Loading Dataset # Loading Dataset
# --------------- # ---------------
# #
# OGB already prepared the data as a ``DGLGraph`` object. The following code is # OGB already prepared the data as a ``DGLGraph`` object. The following code is
# copy-pasted from the :doc:`Training GNN with Neighbor Sampling for Node # copy-pasted from the :doc:`Training GNN with Neighbor Sampling for Node
# Classification <../large/L1_large_node_classification>` # Classification <../large/L1_large_node_classification>`
# tutorial. # tutorial.
# #
import os import os
os.environ['DGLBACKEND'] = 'pytorch'
os.environ["DGLBACKEND"] = "pytorch"
import dgl import dgl
import torch
import numpy as np import numpy as np
import sklearn.metrics
import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import tqdm
from dgl.nn import SAGEConv from dgl.nn import SAGEConv
from ogb.nodeproppred import DglNodePropPredDataset from ogb.nodeproppred import DglNodePropPredDataset
import tqdm
import sklearn.metrics
dataset = DglNodePropPredDataset('ogbn-arxiv') dataset = DglNodePropPredDataset("ogbn-arxiv")
graph, node_labels = dataset[0] graph, node_labels = dataset[0]
# Add reverse edges since ogbn-arxiv is unidirectional. # Add reverse edges since ogbn-arxiv is unidirectional.
graph = dgl.add_reverse_edges(graph) graph = dgl.add_reverse_edges(graph)
graph.ndata['label'] = node_labels[:, 0] graph.ndata["label"] = node_labels[:, 0]
node_features = graph.ndata['feat'] node_features = graph.ndata["feat"]
num_features = node_features.shape[1] num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item() num_classes = (node_labels.max() + 1).item()
idx_split = dataset.get_idx_split() idx_split = dataset.get_idx_split()
train_nids = idx_split['train'] train_nids = idx_split["train"]
valid_nids = idx_split['valid'] valid_nids = idx_split["valid"]
test_nids = idx_split['test'] # Test node IDs, not used in the tutorial though. test_nids = idx_split["test"] # Test node IDs, not used in the tutorial though.
###################################################################### ######################################################################
# Defining Model # Defining Model
# -------------- # --------------
# #
# The model will be again identical to the :doc:`Training GNN with Neighbor # The model will be again identical to the :doc:`Training GNN with Neighbor
# Sampling for Node Classification <../large/L1_large_node_classification>` # Sampling for Node Classification <../large/L1_large_node_classification>`
# tutorial. # tutorial.
# #
class Model(nn.Module): class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes): def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__() super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean') self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type="mean")
self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean') self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type="mean")
self.h_feats = h_feats self.h_feats = h_feats
def forward(self, mfgs, x): def forward(self, mfgs, x):
h_dst = x[:mfgs[0].num_dst_nodes()] h_dst = x[: mfgs[0].num_dst_nodes()]
h = self.conv1(mfgs[0], (x, h_dst)) h = self.conv1(mfgs[0], (x, h_dst))
h = F.relu(h) h = F.relu(h)
h_dst = h[:mfgs[1].num_dst_nodes()] h_dst = h[: mfgs[1].num_dst_nodes()]
h = self.conv2(mfgs[1], (h, h_dst)) h = self.conv2(mfgs[1], (h, h_dst))
return h return h
...@@ -90,7 +92,7 @@ class Model(nn.Module): ...@@ -90,7 +92,7 @@ class Model(nn.Module):
###################################################################### ######################################################################
# Defining Training Procedure # Defining Training Procedure
# --------------------------- # ---------------------------
# #
# The training procedure will be slightly different from what you saw # The training procedure will be slightly different from what you saw
# previously, in the sense that you will need to # previously, in the sense that you will need to
# #
...@@ -98,45 +100,58 @@ class Model(nn.Module): ...@@ -98,45 +100,58 @@ class Model(nn.Module):
# * Wrap your model with ``torch.nn.parallel.DistributedDataParallel``. # * Wrap your model with ``torch.nn.parallel.DistributedDataParallel``.
# * Add a ``use_ddp=True`` argument to the DGL dataloader you wish to run # * Add a ``use_ddp=True`` argument to the DGL dataloader you wish to run
# together with DDP. # together with DDP.
# #
# You will also need to wrap the training loop inside a function so that # You will also need to wrap the training loop inside a function so that
# you can spawn subprocesses to run it. # you can spawn subprocesses to run it.
# #
def run(proc_id, devices): def run(proc_id, devices):
# Initialize distributed training context. # Initialize distributed training context.
dev_id = devices[proc_id] dev_id = devices[proc_id]
dist_init_method = 'tcp://{master_ip}:{master_port}'.format(master_ip='127.0.0.1', master_port='12345') dist_init_method = "tcp://{master_ip}:{master_port}".format(
master_ip="127.0.0.1", master_port="12345"
)
if torch.cuda.device_count() < 1: if torch.cuda.device_count() < 1:
device = torch.device('cpu') device = torch.device("cpu")
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend='gloo', init_method=dist_init_method, world_size=len(devices), rank=proc_id) backend="gloo",
init_method=dist_init_method,
world_size=len(devices),
rank=proc_id,
)
else: else:
torch.cuda.set_device(dev_id) torch.cuda.set_device(dev_id)
device = torch.device('cuda:' + str(dev_id)) device = torch.device("cuda:" + str(dev_id))
torch.distributed.init_process_group( torch.distributed.init_process_group(
backend='nccl', init_method=dist_init_method, world_size=len(devices), rank=proc_id) backend="nccl",
init_method=dist_init_method,
world_size=len(devices),
rank=proc_id,
)
# Define training and validation dataloader, copied from the previous tutorial # Define training and validation dataloader, copied from the previous tutorial
# but with one line of difference: use_ddp to enable distributed data parallel # but with one line of difference: use_ddp to enable distributed data parallel
# data loading. # data loading.
sampler = dgl.dataloading.NeighborSampler([4, 4]) sampler = dgl.dataloading.NeighborSampler([4, 4])
train_dataloader = dgl.dataloading.DataLoader( train_dataloader = dgl.dataloading.DataLoader(
# The following arguments are specific to DataLoader. # The following arguments are specific to DataLoader.
graph, # The graph graph, # The graph
train_nids, # The node IDs to iterate over in minibatches train_nids, # The node IDs to iterate over in minibatches
sampler, # The neighbor sampler sampler, # The neighbor sampler
device=device, # Put the sampled MFGs on CPU or GPU device=device, # Put the sampled MFGs on CPU or GPU
use_ddp=True, # Make it work with distributed data parallel use_ddp=True, # Make it work with distributed data parallel
# The following arguments are inherited from PyTorch DataLoader. # The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Per-device batch size. batch_size=1024, # Per-device batch size.
# The effective batch size is this number times the number of GPUs. # The effective batch size is this number times the number of GPUs.
shuffle=True, # Whether to shuffle the nodes for every epoch shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes num_workers=0, # Number of sampler processes
) )
valid_dataloader = dgl.dataloading.DataLoader( valid_dataloader = dgl.dataloading.DataLoader(
graph, valid_nids, sampler, graph,
valid_nids,
sampler,
device=device, device=device,
use_ddp=False, use_ddp=False,
batch_size=1024, batch_size=1024,
...@@ -144,20 +159,24 @@ def run(proc_id, devices): ...@@ -144,20 +159,24 @@ def run(proc_id, devices):
drop_last=False, drop_last=False,
num_workers=0, num_workers=0,
) )
model = Model(num_features, 128, num_classes).to(device) model = Model(num_features, 128, num_classes).to(device)
# Wrap the model with distributed data parallel module. # Wrap the model with distributed data parallel module.
if device == torch.device('cpu'): if device == torch.device("cpu"):
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=None, output_device=None) model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=None, output_device=None
)
else: else:
model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[device], output_device=device) model = torch.nn.parallel.DistributedDataParallel(
model, device_ids=[device], output_device=device
)
# Define optimizer # Define optimizer
opt = torch.optim.Adam(model.parameters()) opt = torch.optim.Adam(model.parameters())
best_accuracy = 0 best_accuracy = 0
best_model_path = './model.pt' best_model_path = "./model.pt"
# Copied from previous tutorial with changes highlighted. # Copied from previous tutorial with changes highlighted.
for epoch in range(10): for epoch in range(10):
model.train() model.train()
...@@ -165,8 +184,8 @@ def run(proc_id, devices): ...@@ -165,8 +184,8 @@ def run(proc_id, devices):
with tqdm.tqdm(train_dataloader) as tq: with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, mfgs) in enumerate(tq): for step, (input_nodes, output_nodes, mfgs) in enumerate(tq):
# feature copy from CPU to GPU takes place here # feature copy from CPU to GPU takes place here
inputs = mfgs[0].srcdata['feat'] inputs = mfgs[0].srcdata["feat"]
labels = mfgs[-1].dstdata['label'] labels = mfgs[-1].dstdata["label"]
predictions = model(mfgs, inputs) predictions = model(mfgs, inputs)
...@@ -175,9 +194,15 @@ def run(proc_id, devices): ...@@ -175,9 +194,15 @@ def run(proc_id, devices):
loss.backward() loss.backward()
opt.step() opt.step()
accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy()) accuracy = sklearn.metrics.accuracy_score(
labels.cpu().numpy(),
predictions.argmax(1).detach().cpu().numpy(),
)
tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False) tq.set_postfix(
{"loss": "%.03f" % loss.item(), "acc": "%.03f" % accuracy},
refresh=False,
)
model.eval() model.eval()
...@@ -187,13 +212,15 @@ def run(proc_id, devices): ...@@ -187,13 +212,15 @@ def run(proc_id, devices):
labels = [] labels = []
with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad(): with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
for input_nodes, output_nodes, mfgs in tq: for input_nodes, output_nodes, mfgs in tq:
inputs = mfgs[0].srcdata['feat'] inputs = mfgs[0].srcdata["feat"]
labels.append(mfgs[-1].dstdata['label'].cpu().numpy()) labels.append(mfgs[-1].dstdata["label"].cpu().numpy())
predictions.append(model(mfgs, inputs).argmax(1).cpu().numpy()) predictions.append(
model(mfgs, inputs).argmax(1).cpu().numpy()
)
predictions = np.concatenate(predictions) predictions = np.concatenate(predictions)
labels = np.concatenate(labels) labels = np.concatenate(labels)
accuracy = sklearn.metrics.accuracy_score(labels, predictions) accuracy = sklearn.metrics.accuracy_score(labels, predictions)
print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy)) print("Epoch {} Validation Accuracy {}".format(epoch, accuracy))
if best_accuracy < accuracy: if best_accuracy < accuracy:
best_accuracy = accuracy best_accuracy = accuracy
torch.save(model.state_dict(), best_model_path) torch.save(model.state_dict(), best_model_path)
...@@ -205,7 +232,7 @@ def run(proc_id, devices): ...@@ -205,7 +232,7 @@ def run(proc_id, devices):
###################################################################### ######################################################################
# Spawning Trainer Processes # Spawning Trainer Processes
# -------------------------- # --------------------------
# #
# A typical scenario for multi-GPU training with DDP is to replicate the # A typical scenario for multi-GPU training with DDP is to replicate the
# model once per GPU, and spawn one trainer process per GPU. # model once per GPU, and spawn one trainer process per GPU.
# #
...@@ -219,15 +246,15 @@ def run(proc_id, devices): ...@@ -219,15 +246,15 @@ def run(proc_id, devices):
# or ``out_degrees`` is called. To avoid this, you need to create # or ``out_degrees`` is called. To avoid this, you need to create
# all sparse matrix representations beforehand using the ``create_formats_`` # all sparse matrix representations beforehand using the ``create_formats_``
# method: # method:
# #
graph.create_formats_() graph.create_formats_()
###################################################################### ######################################################################
# Then you can spawn the subprocesses to train with multiple GPUs. # Then you can spawn the subprocesses to train with multiple GPUs.
# #
# #
# .. code:: python # .. code:: python
# #
# # Say you have four GPUs. # # Say you have four GPUs.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment