Commit 5fc7eb6d authored by John Andrilla's avatar John Andrilla Committed by Minjie Wang
Browse files

[Doc] Generative models, edit pass (#1035)



* [Doc] Generative models, edit pass

Edit for grammar and style. Can you clarify "Instead of and/or"?

* Update tutorials/models/3_generative_model/5_dgmg.py

* Update tutorials/models/3_generative_model/5_dgmg.py
Co-Authored-By: default avatarAaron Markham <markhama@amazon.com>

* Update tutorials/models/3_generative_model/5_dgmg.py
Co-Authored-By: default avatarAaron Markham <markhama@amazon.com>

* Update tutorials/models/3_generative_model/5_dgmg.py
Co-Authored-By: default avatarAaron Markham <markhama@amazon.com>

* Update tutorials/models/3_generative_model/5_dgmg.py
Co-Authored-By: default avatarAaron Markham <markhama@amazon.com>
parent f45178c3
""" """
.. _model-dgmg: .. _model-dgmg:
Tutorial for Generative Models of Graphs Tutorial: Generative models of graphs
=========================================== ===========================================
**Author**: `Mufei Li <https://github.com/mufeili>`_, **Author**: `Mufei Li <https://github.com/mufeili>`_,
...@@ -10,31 +10,36 @@ Tutorial for Generative Models of Graphs ...@@ -10,31 +10,36 @@ Tutorial for Generative Models of Graphs
############################################################################## ##############################################################################
# #
# In earlier tutorials we have seen how learned embedding of a graph and/or # In this tutorial, you learn how to train and generate one graph at
# a node allow applications such as `semi-supervised classification for nodes # a time. You also explore parallelism within the graph embedding operation, which is an
# essential building block. The tutorial ends with a simple optimization that
# delivers double the speed by batching across graphs.
#
# Earlier tutorials showed how embedding a graph or
# a node enables you to work on tasks such as `semi-supervised classification for nodes
# <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__ # <http://docs.dgl.ai/tutorials/models/1_gcn.html#sphx-glr-tutorials-models-1-gcn-py>`__
# or `sentiment analysis # or `sentiment analysis
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__. # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__.
# Wouldn't it be interesting to predict the future evolution of the graph and # Wouldn't it be interesting to predict the future evolution of the graph and
# perform the analysis iteratively? # perform the analysis iteratively?
# #
# We will need to generate a variety of graph samples, in other words, we need # To address the evolution of the graphs, you generate a variety of graph samples. In other words, you need
# **generative models** of graphs. Instead of and/or in addition to learning # **generative models** of graphs. In-addition to learning
# node and edge features, we want to model the distribution of arbitrary graphs. # node and edge features, you would need to model the distribution of arbitrary graphs.
# While general generative models can model the density function explicitly and # While general generative models can model the density function explicitly and
# implicitly and generate samples at once or sequentially, we will only focus # implicitly and generate samples at once or sequentially, you only focus
# on explicit generative models for sequential generation here. Typical applications # on explicit generative models for sequential generation here. Typical applications
# include drug/material discovery, chemical processes, proteomics, etc. # include drug or materials discovery, chemical processes, or proteomics.
# #
# Introduction # Introduction
# -------------------- # --------------------
# The primitive actions of mutating a graph in DGL are nothing more than ``add_nodes`` # The primitive actions of mutating a graph in Deep Graph Library (DGL) are nothing more than ``add_nodes``
# and ``add_edges``. That is, if we were to draw a circle of 3 nodes, # and ``add_edges``. That is, if you were to draw a circle of three nodes,
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48313438-78baf000-e5f7-11e8-931e-cd00ab34fa50.gif
# :alt: # :alt:
# #
# we can simply write the code as: # you can write the code as follows.
# #
import dgl import dgl
...@@ -44,7 +49,7 @@ g.add_nodes(1) # Add node 0 ...@@ -44,7 +49,7 @@ g.add_nodes(1) # Add node 0
g.add_nodes(1) # Add node 1 g.add_nodes(1) # Add node 1
# Edges in DGLGraph are directed by default. # Edges in DGLGraph are directed by default.
# For undirected edges, we add edges for both directions. # For undirected edges, add edges for both directions.
g.add_edges([1, 0], [0, 1]) # Add edges (1, 0), (0, 1) g.add_edges([1, 0], [0, 1]) # Add edges (1, 0), (0, 1)
g.add_nodes(1) # Add node 2 g.add_nodes(1) # Add node 2
g.add_edges([2, 1], [1, 2]) # Add edges (2, 1), (1, 2) g.add_edges([2, 1], [1, 2]) # Add edges (2, 1), (1, 2)
...@@ -54,55 +59,51 @@ g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2) ...@@ -54,55 +59,51 @@ g.add_edges([2, 0], [0, 2]) # Add edges (2, 0), (0, 2)
# Real-world graphs are much more complex. There are many families of graphs, # Real-world graphs are much more complex. There are many families of graphs,
# with different sizes, topologies, node types, edge types, and the possibility # with different sizes, topologies, node types, edge types, and the possibility
# of multigraphs. Besides, a same graph can be generated in many different # of multigraphs. Besides, a same graph can be generated in many different
# orders. Regardless, the generative process entails a few steps: # orders. Regardless, the generative process entails a few steps.
# #
# - Encode a changing graph, # - Encode a changing graph.
# - Perform actions stochastically, # - Perform actions stochastically.
# - Collect error signals and optimize the model parameters (If we are training) # - If you are training, collect error signals and optimize the model parameters.
# #
# When it comes to implementation, another important aspect is speed: how do we # When it comes to implementation, another important aspect is speed. How do you
# parallelize the computation given that generating a graph is fundamentally a # parallelize the computation, given that generating a graph is fundamentally a
# sequential process? # sequential process?
# #
# .. note:: # .. note::
# #
# To be sure, this is not necessarily a hard constraint, one can imagine # To be sure, this is not necessarily a hard constraint. Subgraphs can be
# that subgraphs can be built in parallel and then get assembled. But we # built in parallel and then get assembled. But we
# will restrict ourselves to the sequential processes for this tutorial. # will restrict ourselves to the sequential processes for this tutorial.
# #
# In tutorial, we will first focus on how to train and generate one graph at
# a time, exploring parallelism within the graph embedding operation, an
# essential building block. We will end with a simple optimization that
# delivers a 2x speedup by batching across graphs.
# #
# DGMG: the main flow # DGMG: The main flow
# -------------------- # --------------------
# We pick DGMG ( # For this tutorial, you use
# `Learning Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__ # `Deep Generative Models of Graphs <https://arxiv.org/abs/1803.03324>`__
# ) as an exercise to implement a graph generative model using DGL, primarily # ) (DGMG) to implement a graph generative model using DGL. Its algorithmic
# because its algorithmic framework is general but also challenging to parallelize. # framework is general but also challenging to parallelize.
# #
# .. note:: # .. note::
# #
# While it's possible for DGMG to handle complex graphs with typed nodes, # While it's possible for DGMG to handle complex graphs with typed nodes,
# typed edges and multigraphs, we only present a simplified version of it # typed edges, and multigraphs, here you use a simplified version of it
# for generating graph topologies. # for generating graph topologies.
# #
# DGMG generates a graph by following a state machine, which is basically a # DGMG generates a graph by following a state machine, which is basically a
# two-level loop: generate one node at a time, and connect it to a subset of # two-level loop. Generate one node at a time and connect it to a subset of
# the existing nodes, one at a time. This is similar to language modeling: the # the existing nodes, one at a time. This is similar to language modeling. The
# generative process is an iterative one that emits one word/character/sentence # generative process is an iterative one that emits one word or character or sentence
# at a time, conditioned on the sequence generated so far. # at a time, conditioned on the sequence generated so far.
# #
# At each time step, we either # At each time step, you either:
# - add a new node to the graph, or # - Add a new node to the graph
# - select two existing nodes and add an edge between them # - Select two existing nodes and add an edge between them
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png # .. figure:: https://user-images.githubusercontent.com/19576924/48605003-7f11e900-e9b6-11e8-8880-87362348e154.png
# :alt: # :alt:
# #
# The Python code will look as follows; in fact, this is *exactly* how inference # The Python code will look as follows. In fact, this is *exactly* how inference
# with DGMG is implemented in DGL: # with DGMG is implemented in DGL.
# #
def forward_inference(self): def forward_inference(self):
...@@ -119,9 +120,9 @@ def forward_inference(self): ...@@ -119,9 +120,9 @@ def forward_inference(self):
return self.g return self.g
####################################################################################### #######################################################################################
# Assume we have a pre-trained model for generating cycles of nodes 10 - 20, let's see # Assume you have a pre-trained model for generating cycles of nodes 10-20.
# how it generates a cycle on the fly during inference. You can also use the code below # How does it generate a cycle on-the-fly during inference? Use the code below
# for creating animation with your own model. # to create an animation with your own model.
# #
# :: # ::
# #
...@@ -171,10 +172,10 @@ def forward_inference(self): ...@@ -171,10 +172,10 @@ def forward_inference(self):
# .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48928548-2644d200-ef1b-11e8-8591-da93345382ad.gif
# :alt: # :alt:
# #
# DGMG: optimization objective # DGMG: Optimization objective
# ------------------------------ # ------------------------------
# Similar to language modeling, DGMG trains the model with *behavior cloning*, # Similar to language modeling, DGMG trains the model with *behavior cloning*,
# or *teacher forcing*. Let's assume for each graph there exists a sequence of # or *teacher forcing*. Assume for each graph there exists a sequence of
# *oracle actions* :math:`a_{1},\cdots,a_{T}` that generates it. What the model # *oracle actions* :math:`a_{1},\cdots,a_{T}` that generates it. What the model
# does is to follow these actions, compute the joint probabilities of such # does is to follow these actions, compute the joint probabilities of such
# action sequences, and maximize them. # action sequences, and maximize them.
...@@ -219,14 +220,14 @@ def forward_train(self, actions): ...@@ -219,14 +220,14 @@ def forward_train(self, actions):
####################################################################################### #######################################################################################
# The key difference between ``forward_train`` and ``forward_inference`` is # The key difference between ``forward_train`` and ``forward_inference`` is
# that the training process takes oracle actions as input, and returns log # that the training process takes oracle actions as input and returns log
# probabilities for evaluating the loss. # probabilities for evaluating the loss.
# #
# DGMG: the implementation # DGMG: The implementation
# -------------------------- # --------------------------
# The ``DGMG`` class # The ``DGMG`` class
# `````````````````````````` # ``````````````````````````
# Below one can find the skeleton code for the model. We will gradually # Below you can find the skeleton code for the model. You gradually
# fill in the details for each function. # fill in the details for each function.
# #
...@@ -271,7 +272,7 @@ class DGMGSkeleton(nn.Module): ...@@ -271,7 +272,7 @@ class DGMGSkeleton(nn.Module):
return NotImplementedError return NotImplementedError
def forward(self, actions=None): def forward(self, actions=None):
# The graph we will work on # The graph you will work on
self.g = dgl.DGLGraph() self.g = dgl.DGLGraph()
# If there are some features for nodes and edges, # If there are some features for nodes and edges,
...@@ -288,11 +289,11 @@ class DGMGSkeleton(nn.Module): ...@@ -288,11 +289,11 @@ class DGMGSkeleton(nn.Module):
# Encoding a dynamic graph # Encoding a dynamic graph
# `````````````````````````` # ``````````````````````````
# All the actions generating a graph are sampled from probability # All the actions generating a graph are sampled from probability
# distributions. In order to do that, we must project the structured data, # distributions. In order to do that, you project the structured data,
# namely the graph, onto an Euclidean space. The challenge is that such # namely the graph, onto an Euclidean space. The challenge is that such
# process, called *embedding*, needs to be repeated as the graphs mutate. # process, called *embedding*, needs to be repeated as the graphs mutate.
# #
# Graph Embedding # Graph embedding
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
# Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an # Let :math:`G=(V,E)` be an arbitrary graph. Each node :math:`v` has an
# embedding vector :math:`\textbf{h}_{v} \in \mathbb{R}^{n}`. Similarly, # embedding vector :math:`\textbf{h}_{v} \in \mathbb{R}^{n}`. Similarly,
...@@ -308,11 +309,11 @@ class DGMGSkeleton(nn.Module): ...@@ -308,11 +309,11 @@ class DGMGSkeleton(nn.Module):
# \textbf{h}_{G} =\sum_{v\in V}\text{Sigmoid}(g_m(\textbf{h}_{v}))f_{m}(\textbf{h}_{v}),\\ # \textbf{h}_{G} =\sum_{v\in V}\text{Sigmoid}(g_m(\textbf{h}_{v}))f_{m}(\textbf{h}_{v}),\\
# #
# The first term, :math:`\text{Sigmoid}(g_m(\textbf{h}_{v}))`, computes a # The first term, :math:`\text{Sigmoid}(g_m(\textbf{h}_{v}))`, computes a
# gating function and can be thought as how much the overall graph embedding # gating function and can be thought of as how much the overall graph embedding
# attends on each node. The second term :math:`f_{m}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{k}` # attends on each node. The second term :math:`f_{m}:\mathbb{R}^{n}\rightarrow\mathbb{R}^{k}`
# maps the node embeddings to the space of graph embeddings. # maps the node embeddings to the space of graph embeddings.
# #
# We implement graph embedding as a ``GraphEmbed`` class: # Implement graph embedding as a ``GraphEmbed`` class.
# #
import torch import torch
...@@ -344,7 +345,7 @@ class GraphEmbed(nn.Module): ...@@ -344,7 +345,7 @@ class GraphEmbed(nn.Module):
####################################################################################### #######################################################################################
# Update node embeddings via graph propagation # Update node embeddings via graph propagation
# '''''''''''''''''''''''''''''''''''''''''''' # '''''''''''''''''''''''''''''''''''''''''''''
# #
# The mechanism of updating node embeddings in DGMG is similar to that for # The mechanism of updating node embeddings in DGMG is similar to that for
# graph convolutional networks. For a node :math:`v` in the graph, its # graph convolutional networks. For a node :math:`v` in the graph, its
...@@ -372,11 +373,11 @@ class GraphEmbed(nn.Module): ...@@ -372,11 +373,11 @@ class GraphEmbed(nn.Module):
# #
# Performing all the operations above once for all nodes synchronously is # Performing all the operations above once for all nodes synchronously is
# called one round of graph propagation. The more rounds of graph propagation # called one round of graph propagation. The more rounds of graph propagation
# we perform, the longer distance messages travel throughout the graph. # you perform, the longer distance messages travel throughout the graph.
# #
# With dgl, we implement graph propagation with ``g.update_all``. Note that # With DGL, you implement graph propagation with ``g.update_all``.
# the message notation here can be a bit confusing. While the authors refer # The message notation here can be a bit confusing. Researchers can refer
# to :math:`\textbf{m}_{u\rightarrow v}` as messages, our message function # to :math:`\textbf{m}_{u\rightarrow v}` as messages, however the message function
# below only passes :math:`\text{concat}([\textbf{h}_{u}, \textbf{x}_{u, v}])`. # below only passes :math:`\text{concat}([\textbf{h}_{u}, \textbf{x}_{u, v}])`.
# The operation :math:`\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m}` # The operation :math:`\textbf{W}_{m}\text{concat}([\textbf{h}_{v}, \textbf{h}_{u}, \textbf{x}_{u, v}]) + \textbf{b}_{m}`
# is then performed across all edges at once for efficiency consideration. # is then performed across all edges at once for efficiency consideration.
...@@ -436,13 +437,13 @@ class GraphProp(nn.Module): ...@@ -436,13 +437,13 @@ class GraphProp(nn.Module):
####################################################################################### #######################################################################################
# Actions # Actions
# `````````````````````````` # ``````````````````````````
# All actions are sampled from distributions parameterized using neural nets # All actions are sampled from distributions parameterized using neural networks
# and we introduce them in turn. # and here they are in turn.
# #
# Action 1: add nodes # Action 1: Add nodes
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
# #
# Given the graph embedding vector :math:`\textbf{h}_{G}`, we evaluate # Given the graph embedding vector :math:`\textbf{h}_{G}`, evaluate
# #
# .. math:: # .. math::
# #
...@@ -451,7 +452,7 @@ class GraphProp(nn.Module): ...@@ -451,7 +452,7 @@ class GraphProp(nn.Module):
# which is then used to parametrize a Bernoulli distribution for deciding whether # which is then used to parametrize a Bernoulli distribution for deciding whether
# to add a new node. # to add a new node.
# #
# If a new node is to be added, we initialize its feature with # If a new node is to be added, initialize its feature with
# #
# .. math:: # .. math::
# #
...@@ -524,12 +525,12 @@ class AddNode(nn.Module): ...@@ -524,12 +525,12 @@ class AddNode(nn.Module):
return stop return stop
####################################################################################### #######################################################################################
# Action 2: add edges # Action 2: Add edges
# '''''''''''''''''''''''''' # ''''''''''''''''''''''''''
# #
# Given the graph embedding vector :math:`\textbf{h}_{G}` and the node # Given the graph embedding vector :math:`\textbf{h}_{G}` and the node
# embedding vector :math:`\textbf{h}_{v}` for the latest node :math:`v`, # embedding vector :math:`\textbf{h}_{v}` for the latest node :math:`v`,
# we evaluate # you evaluate
# #
# .. math:: # .. math::
# #
...@@ -568,10 +569,10 @@ class AddEdge(nn.Module): ...@@ -568,10 +569,10 @@ class AddEdge(nn.Module):
return to_add_edge return to_add_edge
####################################################################################### #######################################################################################
# Action 3: choosing destination # Action 3: Choose a destination
# ''''''''''''''''''''''''''''''''' # '''''''''''''''''''''''''''''''''
# #
# When action 2 returns True, we need to choose a destination for the # When action 2 returns `True`, choose a destination for the
# latest node :math:`v`. # latest node :math:`v`.
# #
# For each possible destination :math:`u\in\{0, \cdots, v-1\}`, the # For each possible destination :math:`u\in\{0, \cdots, v-1\}`, the
...@@ -592,8 +593,8 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -592,8 +593,8 @@ class ChooseDestAndUpdate(nn.Module):
self.choose_dest = nn.Linear(2 * node_hidden_size, 1) self.choose_dest = nn.Linear(2 * node_hidden_size, 1)
def _initialize_edge_repr(self, g, src_list, dest_list): def _initialize_edge_repr(self, g, src_list, dest_list):
# For untyped edges, we only add 1 to indicate its existence. # For untyped edges, only add 1 to indicate its existence.
# For multiple edge types, we can use a one hot representation # For multiple edge types, use a one-hot representation
# or an embedding module. # or an embedding module.
edge_repr = torch.ones(len(src_list), 1) edge_repr = torch.ones(len(src_list), 1)
g.edges[src_list, dest_list].data['he'] = edge_repr g.edges[src_list, dest_list].data['he'] = edge_repr
...@@ -617,8 +618,8 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -617,8 +618,8 @@ class ChooseDestAndUpdate(nn.Module):
dest = Categorical(dests_probs).sample().item() dest = Categorical(dests_probs).sample().item()
if not g.has_edge_between(src, dest): if not g.has_edge_between(src, dest):
# For undirected graphs, we add edges for both directions # For undirected graphs, add edges for both directions
# so that we can perform graph propagation. # so that you can perform graph propagation.
src_list = [src, dest] src_list = [src, dest]
dest_list = [dest, src] dest_list = [dest, src]
...@@ -636,7 +637,7 @@ class ChooseDestAndUpdate(nn.Module): ...@@ -636,7 +637,7 @@ class ChooseDestAndUpdate(nn.Module):
# Putting it together # Putting it together
# `````````````````````````` # ``````````````````````````
# #
# We are now ready to have a complete implementation of the model class. # You are now ready to have a complete implementation of the model class.
# #
class DGMG(DGMGSkeleton): class DGMG(DGMGSkeleton):
...@@ -702,13 +703,13 @@ class DGMG(DGMGSkeleton): ...@@ -702,13 +703,13 @@ class DGMG(DGMGSkeleton):
####################################################################################### #######################################################################################
# Below is an animation where a graph is generated on the fly # Below is an animation where a graph is generated on the fly
# after every 10 batches of training for the first 400 batches. One # after every 10 batches of training for the first 400 batches. You
# can see how our model improves over time and begins generating cycles. # can see how the model improves over time and begins generating cycles.
# #
# .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif # .. figure:: https://user-images.githubusercontent.com/19576924/48929291-60fe3880-ef22-11e8-832a-fbe56656559a.gif
# :alt: # :alt:
# #
# For generative models, we can evaluate its performance by checking the percentage # For generative models, you can evaluate performance by checking the percentage
# of valid graphs among the graphs it generates on the fly. # of valid graphs among the graphs it generates on the fly.
import torch.utils.model_zoo as model_zoo import torch.utils.model_zoo as model_zoo
...@@ -761,18 +762,17 @@ del model ...@@ -761,18 +762,17 @@ del model
print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
####################################################################################### #######################################################################################
# For the complete implementation, see `dgl DGMG example # For the complete implementation, see the `DGL DGMG example
# <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__. # <https://github.com/dmlc/dgl/tree/master/examples/pytorch/dgmg>`__.
# #
# Batched Graph Generation # Batched graph generation
# --------------------------- # ---------------------------
# #
# Speeding up DGMG is hard since each graph can be generated with a # Speeding up DGMG is hard because each graph can be generated with a
# unique sequence of actions. One way to explore parallelism is to adopt # unique sequence of actions. One way to explore parallelism is to adopt
# asynchronous gradient descent with multiple processes. Each of them # asynchronous gradient descent with multiple processes. Each of them
# works on one graph at a time and the processes are loosely coordinated # works on one graph at a time and the processes are loosely coordinated
# by a parameter server. This is the approach that the authors adopted # by a parameter server.
# and we can also use.
# #
# DGL explores parallelism in the message-passing framework, on top of # DGL explores parallelism in the message-passing framework, on top of
# the framework-provided tensor operation. The earlier tutorial already # the framework-provided tensor operation. The earlier tutorial already
...@@ -784,16 +784,16 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -784,16 +784,16 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# for g in g_list: # for g in g_list:
# self.graph_prop(g) # self.graph_prop(g)
# #
# We can modify the code to work on a batch of graphs at once by replacing # Modify the code to work on a batch of graphs at once by replacing
# these lines with the following. On CPU with a Mac machine, we instantly # these lines with the following. On CPU with a macOS, you instantly
# enjoy a 6~7x reduction for the graph propagation part. # enjoy a six to seven-time reduction for the graph propagation part.
# :: # ::
# #
# bg = dgl.batch(g_list) # bg = dgl.batch(g_list)
# self.graph_prop(bg) # self.graph_prop(bg)
# g_list = dgl.unbatch(bg) # g_list = dgl.unbatch(bg)
# #
# We have already used this trick of calling ``dgl.batch`` in the # You have already used this trick of calling ``dgl.batch`` in the
# `Tree-LSTM tutorial # `Tree-LSTM tutorial
# <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__ # <http://docs.dgl.ai/tutorials/models/3_tree-lstm.html#sphx-glr-tutorials-models-3-tree-lstm-py>`__
# , and it is worth explaining one more time why this is so. # , and it is worth explaining one more time why this is so.
...@@ -802,7 +802,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -802,7 +802,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing # graph (``BatchedDGLGraph``) over which ``update_all`` propels message-passing
# on all the edges and nodes. # on all the edges and nodes.
# #
# With ``dgl.batch``, we merge ``g_{1}, ..., g_{N}`` into one single giant # With ``dgl.batch``, you merge ``g_{1}, ..., g_{N}`` into one single giant
# graph consisting of :math:`N` isolated small graphs. For example, if we # graph consisting of :math:`N` isolated small graphs. For example, if we
# have two graphs with adjacency matrices # have two graphs with adjacency matrices
# #
...@@ -828,7 +828,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -828,7 +828,7 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# In DGL, the message function is defined on the edges, thus batching scales # In DGL, the message function is defined on the edges, thus batching scales
# the processing of edge user-defined functions (UDFs) linearly. # the processing of edge user-defined functions (UDFs) linearly.
# #
# The reduce UDFs (i.e ``dgmg_reduce``) works on nodes, and each of them may # The reduce UDFs or ``dgmg_reduce``, work on nodes. Each of them may
# have different numbers of incoming edges. Using ``degree bucketing``, DGL # have different numbers of incoming edges. Using ``degree bucketing``, DGL
# internally groups nodes with the same in-degrees and calls reduce UDF once # internally groups nodes with the same in-degrees and calls reduce UDF once
# for each group. Thus, batching also reduces number of calls to these UDFs. # for each group. Thus, batching also reduces number of calls to these UDFs.
...@@ -839,5 +839,4 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid)) ...@@ -839,5 +839,4 @@ print('Among 100 graphs generated, {}% are valid.'.format(num_valid))
# ``g_list = dgl.unbatch(bg)``. # ``g_list = dgl.unbatch(bg)``.
# #
# The complete code to the batched version can also be found in the example. # The complete code to the batched version can also be found in the example.
# On our testbed, we get roughly 2x speed up comparing to the previous implementation # On a testbed, you get roughly double the speed when compared to the previous implementation.
#
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment