Commit a2b8d8e4 authored by John Andrilla's avatar John Andrilla Committed by Minjie Wang
Browse files

[Doc] Giant graph tutorial edit for grammar and style (#1020)

* Edit for readability

* giant_graph_readme edit for grammar

* NodeFlow and Sampling edit pass for grammar

* Update tutorials/models/1_gnn/9_gat.py

* Update tutorials/models/1_gnn/9_gat.py
parent dd09f15f
""" """
.. _model-gat: .. _model-gat:
Understand Graph Attention Network Graph attention network
================================== ==================================
**Authors:** `Hao Zhang <https://github.com/sufeidechabei/>`_, `Mufei Li **Authors:** `Hao Zhang <https://github.com/sufeidechabei/>`_, `Mufei Li
...@@ -9,32 +9,29 @@ Understand Graph Attention Network ...@@ -9,32 +9,29 @@ Understand Graph Attention Network
<https://jermainewang.github.io/>`_ `Zheng Zhang <https://jermainewang.github.io/>`_ `Zheng Zhang
<https://shanghai.nyu.edu/academics/faculty/directory/zheng-zhang>`_ <https://shanghai.nyu.edu/academics/faculty/directory/zheng-zhang>`_
From `Graph Convolutional Network (GCN) <https://arxiv.org/abs/1609.02907>`_, In this tutorial, you learn about a graph attention network (GAT) and how it can be
we learned that combining local graph structure and node-level features yields implemented in PyTorch. You can also learn to visualize and understand what the attention
good performance on node classification task. However, the way GCN aggregates mechanism has learned.
is structure-dependent, which may hurt its generalizability.
One workaround is to simply average over all neighbor node features as in The research described in the paper `Graph Convolutional Network (GCN) <https://arxiv.org/abs/1609.02907>`_,
`GraphSAGE indicates that combining local graph structure and node-level features yields
<https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`_. good performance on node classification tasks. However, the way GCN aggregates
`Graph Attention Network <https://arxiv.org/abs/1710.10903>`_ proposes an is structure-dependent, which can hurt its generalizability.
alternative way by weighting neighbor features with feature dependent and
structure free normalization, in the style of attention.
The goal of this tutorial:
* Explain what is Graph Attention Network. One workaround is to simply average over all neighbor node features as described in
* Demonstrate how it can be implemented in DGL. the research paper `GraphSAGE
* Understand the attentions learnt. <https://www-cs-faculty.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`_.
* Introduce to inductive learning. However, `Graph Attention Network <https://arxiv.org/abs/1710.10903>`_ proposes a
different type of aggregation. GAN uses weighting neighbor features with feature dependent and
structure-free normalization, in the style of attention.
""" """
############################################################### ###############################################################
# Introducing Attention to GCN # Introducing attention to GCN
# ---------------------------- # ----------------------------
# #
# The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated. # The key difference between GAT and GCN is how the information from the one-hop neighborhood is aggregated.
# #
# For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors: # For GCN, a graph convolution operation produces the normalized sum of the node features of neighbors.
# #
# #
# .. math:: # .. math::
...@@ -56,7 +53,7 @@ The goal of this tutorial: ...@@ -56,7 +53,7 @@ The goal of this tutorial:
# GAT introduces the attention mechanism as a substitute for the statically # GAT introduces the attention mechanism as a substitute for the statically
# normalized convolution operation. Below are the equations to compute the node # normalized convolution operation. Below are the equations to compute the node
# embedding :math:`h_i^{(l+1)}` of layer :math:`l+1` from the embeddings of # embedding :math:`h_i^{(l+1)}` of layer :math:`l+1` from the embeddings of
# layer :math:`l`: # layer :math:`l`.
# #
# .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/gat.png # .. image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/gat.png
# :width: 450px # :width: 450px
...@@ -77,30 +74,29 @@ The goal of this tutorial: ...@@ -77,30 +74,29 @@ The goal of this tutorial:
# #
# * Equation (1) is a linear transformation of the lower layer embedding :math:`h_i^{(l)}` # * Equation (1) is a linear transformation of the lower layer embedding :math:`h_i^{(l)}`
# and :math:`W^{(l)}` is its learnable weight matrix. # and :math:`W^{(l)}` is its learnable weight matrix.
# * Equation (2) computes a pair-wise *unnormalized* attention score between two neighbors. # * Equation (2) computes a pair-wise *un-normalized* attention score between two neighbors.
# Here, it first concatenates the :math:`z` embeddings of the two nodes, where :math:`||` # Here, it first concatenates the :math:`z` embeddings of the two nodes, where :math:`||`
# denotes concatenation, then takes a dot product of it and a learnable weight vector # denotes concatenation, then takes a dot product of it and a learnable weight vector
# :math:`\vec a^{(l)}`, and applies a LeakyReLU in the end. This form of attention is # :math:`\vec a^{(l)}`, and applies a LeakyReLU in the end. This form of attention is
# usually called *additive attention*, contrast with the dot-product attention in the # usually called *additive attention*, contrast with the dot-product attention in the
# Transformer model. # Transformer model.
# * Equation (3) applies a softmax to normalize the attention scores on each node's # * Equation (3) applies a softmax to normalize the attention scores on each node's
# in-coming edges. # incoming edges.
# * Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together, # * Equation (4) is similar to GCN. The embeddings from neighbors are aggregated together,
# scaled by the attention scores. # scaled by the attention scores.
# #
# There are other details from the paper, such as dropout and skip connections. # There are other details from the paper, such as dropout and skip connections.
# For the purpose of simplicity, we omit them in this tutorial and leave the # For the purpose of simplicity, those details are left out of this tutorial. To see more details,
# link to the full example at the end for interested readers. # download the `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
#
# In its essence, GAT is just a different aggregation function with attention # In its essence, GAT is just a different aggregation function with attention
# over features of neighbors, instead of a simple mean aggregation. # over features of neighbors, instead of a simple mean aggregation.
# #
# GAT in DGL # GAT in DGL
# ---------- # ----------
# #
# Let's first have an overall impression about how a ``GATLayer`` module is # To begin, you can get an overall impression about how a ``GATLayer`` module is
# implemented in DGL. Don't worry, we will break down the four equations above # implemented in DGL. In this section, the four equations above are broken down
# one-by-one. # one at a time.
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -152,7 +148,7 @@ class GATLayer(nn.Module): ...@@ -152,7 +148,7 @@ class GATLayer(nn.Module):
# #
# z_i^{(l)}=W^{(l)}h_i^{(l)},(1) # z_i^{(l)}=W^{(l)}h_i^{(l)},(1)
# #
# The first one is simple. Linear transformation is very common and can be # The first one shows linear transformation. It's common and can be
# easily implemented in Pytorch using ``torch.nn.Linear``. # easily implemented in Pytorch using ``torch.nn.Linear``.
# #
# Equation (2) # Equation (2)
...@@ -162,9 +158,9 @@ class GATLayer(nn.Module): ...@@ -162,9 +158,9 @@ class GATLayer(nn.Module):
# #
# e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2) # e_{ij}^{(l)}=\text{LeakyReLU}(\vec a^{(l)^T}(z_i^{(l)}|z_j^{(l)})),(2)
# #
# The unnormalized attention score :math:`e_{ij}` is calculated using the # The un-normalized attention score :math:`e_{ij}` is calculated using the
# embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the # embeddings of adjacent nodes :math:`i` and :math:`j`. This suggests that the
# attention scores can be viewed as edge data which can be calculated by the # attention scores can be viewed as edge data, which can be calculated by the
# ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**, # ``apply_edges`` API. The argument to the ``apply_edges`` is an **Edge UDF**,
# which is defined as below: # which is defined as below:
...@@ -176,7 +172,7 @@ def edge_attention(self, edges): ...@@ -176,7 +172,7 @@ def edge_attention(self, edges):
########################################################################3 ########################################################################3
# Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}` # Here, the dot product with the learnable weight vector :math:`\vec{a^{(l)}}`
# is implemented again using pytorch's linear transformation ``attn_fc``. Note # is implemented again using PyTorch's linear transformation ``attn_fc``. Note
# that ``apply_edges`` will **batch** all the edge data in one tensor, so the # that ``apply_edges`` will **batch** all the edge data in one tensor, so the
# ``cat``, ``attn_fc`` here are applied on all the edges in parallel. # ``cat``, ``attn_fc`` here are applied on all the edges in parallel.
# #
...@@ -192,7 +188,7 @@ def edge_attention(self, edges): ...@@ -192,7 +188,7 @@ def edge_attention(self, edges):
# #
# Similar to GCN, ``update_all`` API is used to trigger message passing on all # Similar to GCN, ``update_all`` API is used to trigger message passing on all
# the nodes. The message function sends out two tensors: the transformed ``z`` # the nodes. The message function sends out two tensors: the transformed ``z``
# embedding of the source node and the unnormalized attention score ``e`` on # embedding of the source node and the un-normalized attention score ``e`` on
# each edge. The reduce function then performs two tasks: # each edge. The reduce function then performs two tasks:
# #
# #
...@@ -211,7 +207,7 @@ def reduce_func(self, nodes): ...@@ -211,7 +207,7 @@ def reduce_func(self, nodes):
return {'h' : h} return {'h' : h}
##################################################################### #####################################################################
# Multi-head Attention # Multi-head attention
# ^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^
# #
# Analogous to multiple channels in ConvNet, GAT introduces **multi-head # Analogous to multiple channels in ConvNet, GAT introduces **multi-head
...@@ -225,10 +221,10 @@ def reduce_func(self, nodes): ...@@ -225,10 +221,10 @@ def reduce_func(self, nodes):
# #
# .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right) # .. math:: \text{average}: h_{i}^{(l+1)}=\sigma\left(\frac{1}{K}\sum_{k=1}^{K}\sum_{j\in\mathcal{N}(i)}\alpha_{ij}^{k}W^{k}h^{(l)}_{j}\right)
# #
# where :math:`K` is the number of heads. The authors suggest using # where :math:`K` is the number of heads. You can use
# concatenation for intermediary layers and average for the final layer. # concatenation for intermediary layers and average for the final layer.
# #
# We can use the above defined single-head ``GATLayer`` as the building block # Use the above defined single-head ``GATLayer`` as the building block
# for the ``MultiHeadGATLayer`` below: # for the ``MultiHeadGATLayer`` below:
class MultiHeadGATLayer(nn.Module): class MultiHeadGATLayer(nn.Module):
...@@ -252,7 +248,7 @@ class MultiHeadGATLayer(nn.Module): ...@@ -252,7 +248,7 @@ class MultiHeadGATLayer(nn.Module):
# Put everything together # Put everything together
# ^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^
# #
# Now, we can define a two-layer GAT model: # Now, you can define a two-layer GAT model.
class GAT(nn.Module): class GAT(nn.Module):
def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads): def __init__(self, g, in_dim, hidden_dim, out_dim, num_heads):
...@@ -270,7 +266,7 @@ class GAT(nn.Module): ...@@ -270,7 +266,7 @@ class GAT(nn.Module):
return h return h
############################################################################# #############################################################################
# We then load the cora dataset using DGL's built-in data module. # We then load the Cora dataset using DGL's built-in data module.
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import citation_graph as citegrh from dgl.data import citation_graph as citegrh
...@@ -327,14 +323,14 @@ for epoch in range(30): ...@@ -327,14 +323,14 @@ for epoch in range(30):
epoch, loss.item(), np.mean(dur))) epoch, loss.item(), np.mean(dur)))
######################################################################### #########################################################################
# Visualizing and Understanding Attention Learnt # Visualizing and understanding attention learned
# ---------------------------------------------- # ----------------------------------------------
# #
# Cora # Cora
# ^^^^ # ^^^^
# #
# The following table summarizes the model performances on Cora reported in # The following table summarizes the model performance on Cora that is reported in
# `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ and obtained with dgl # `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ and obtained with DGL
# implementations. # implementations.
# #
# .. list-table:: # .. list-table::
...@@ -351,10 +347,10 @@ for epoch in range(30): ...@@ -351,10 +347,10 @@ for epoch in range(30):
# * - GAT (dgl) # * - GAT (dgl)
# - :math:`83.69\pm 0.529%` # - :math:`83.69\pm 0.529%`
# #
# *What kind of attention distribution has our model learnt?* # *What kind of attention distribution has our model learned?*
# #
# Because the attention weight :math:`a_{ij}` is associated with edges, we can # Because the attention weight :math:`a_{ij}` is associated with edges, you can
# visualize it by coloring edges. Below we pick a subgraph of Cora and plot the # visualize it by coloring edges. Below you can pick a subgraph of Cora and plot the
# attention weights of the last ``GATLayer``. The nodes are colored according # attention weights of the last ``GATLayer``. The nodes are colored according
# to their labels, whereas the edges are colored according to the magnitude of # to their labels, whereas the edges are colored according to the magnitude of
# the attention weights, which can be referred with the colorbar on the right. # the attention weights, which can be referred with the colorbar on the right.
...@@ -363,8 +359,8 @@ for epoch in range(30): ...@@ -363,8 +359,8 @@ for epoch in range(30):
# :width: 600px # :width: 600px
# :align: center # :align: center
# #
# You can that the model seems to learn different attention weights. To # You can see that the model seems to learn different attention weights. To
# understand the distribution more thoroughly, we measure the `entropy # understand the distribution more thoroughly, measure the `entropy
# <https://en.wikipedia.org/wiki/Entropy_(information_theory>`_) of the # <https://en.wikipedia.org/wiki/Entropy_(information_theory>`_) of the
# attention distribution. For any node :math:`i`, # attention distribution. For any node :math:`i`,
# :math:`\{\alpha_{ij}\}_{j\in\mathcal{N}(i)}` forms a discrete probability # :math:`\{\alpha_{ij}\}_{j\in\mathcal{N}(i)}` forms a discrete probability
...@@ -372,14 +368,14 @@ for epoch in range(30): ...@@ -372,14 +368,14 @@ for epoch in range(30):
# #
# .. math:: H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij} # .. math:: H({\alpha_{ij}}_{j\in\mathcal{N}(i)})=-\sum_{j\in\mathcal{N}(i)} \alpha_{ij}\log\alpha_{ij}
# #
# Intuitively, a low entropy means a high degree of concentration, and vice # A low entropy means a high degree of concentration, and vice
# versa; an entropy of 0 means all attention is on one source node. The uniform # versa. An entropy of 0 means all attention is on one source node. The uniform
# distribution has the highest entropy of :math:`\log(\mathcal{N}(i))`. # distribution has the highest entropy of :math:`\log(\mathcal{N}(i))`.
# Ideally, we want to see the model learns a distribution of lower entropy # Ideally, you want to see the model learns a distribution of lower entropy
# (i.e, one or two neighbors are much more important than the others). # (i.e, one or two neighbors are much more important than the others).
# #
# Note that since nodes can have different degrees, the maximum entropy will # Note that since nodes can have different degrees, the maximum entropy will
# also be different. Therefore, we plot the aggregated histogram of entropy # also be different. Therefore, you plot the aggregated histogram of entropy
# values of all nodes in the entire graph. Below are the attention histogram of # values of all nodes in the entire graph. Below are the attention histogram of
# learned by each attention head. # learned by each attention head.
# #
...@@ -396,13 +392,13 @@ for epoch in range(30): ...@@ -396,13 +392,13 @@ for epoch in range(30):
# explains why the performance of GAT is close to that of GCN on Cora # explains why the performance of GAT is close to that of GCN on Cora
# (according to `author's reported result # (according to `author's reported result
# <https://arxiv.org/pdf/1710.10903.pdf>`_, the accuracy difference averaged # <https://arxiv.org/pdf/1710.10903.pdf>`_, the accuracy difference averaged
# over 100 runs is less than 2%); attention does not matter # over 100 runs is less than 2 percent). Attention does not matter
# since it does not differentiate much any ways. # since it does not differentiate much.
# #
# *Does that mean the attention mechanism is not useful?* No! A different # *Does that mean the attention mechanism is not useful?* No! A different
# dataset exhibits an entirely different pattern, as we show next. # dataset exhibits an entirely different pattern, as you can see next.
# #
# Protein-Protein Interaction (PPI) networks # Protein-protein interaction (PPI) networks
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ # ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# #
# The PPI dataset used here consists of :math:`24` graphs corresponding to # The PPI dataset used here consists of :math:`24` graphs corresponding to
...@@ -410,13 +406,13 @@ for epoch in range(30): ...@@ -410,13 +406,13 @@ for epoch in range(30):
# the label of node is represented as a binary tensor of size :math:`121`. The # the label of node is represented as a binary tensor of size :math:`121`. The
# task is to predict node label. # task is to predict node label.
# #
# We use :math:`20` graphs for training, :math:`2` for validation and :math:`2` # Use :math:`20` graphs for training, :math:`2` for validation and :math:`2`
# for test. The average number of nodes per graph is :math:`2372`. Each node # for test. The average number of nodes per graph is :math:`2372`. Each node
# has :math:`50` features that are composed of positional gene sets, motif gene # has :math:`50` features that are composed of positional gene sets, motif gene
# sets and immunological signatures. Critically, test graphs remain completely # sets, and immunological signatures. Critically, test graphs remain completely
# unobserved during training, a setting called "inductive learning". # unobserved during training, a setting called "inductive learning".
# #
# We compare the performance of GAT and GCN for :math:`10` random runs on this # Compare the performance of GAT and GCN for :math:`10` random runs on this
# task and use hyperparameter search on the validation set to find the best # task and use hyperparameter search on the validation set to find the best
# model. # model.
# #
...@@ -432,7 +428,7 @@ for epoch in range(30): ...@@ -432,7 +428,7 @@ for epoch in range(30):
# * - Paper # * - Paper
# - :math:`0.973 \pm 0.002` # - :math:`0.973 \pm 0.002`
# #
# The table above is the result of this experiment, where we use micro `F1 # The table above is the result of this experiment, where you use micro `F1
# score <https://en.wikipedia.org/wiki/F1_score>`_ to evaluate the model # score <https://en.wikipedia.org/wiki/F1_score>`_ to evaluate the model
# performance. # performance.
# #
...@@ -453,7 +449,7 @@ for epoch in range(30): ...@@ -453,7 +449,7 @@ for epoch in range(30):
# * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others. # * :math:`FN_{t}` represents for number of output classes labeled as :math:`t` but predicted as others.
# * :math:`n` is the number of labels, i.e. :math:`121` in our case. # * :math:`n` is the number of labels, i.e. :math:`121` in our case.
# #
# During training, we use ``BCEWithLogitsLoss`` as the loss function. The # During training, use ``BCEWithLogitsLoss`` as the loss function. The
# learning curves of GAT and GCN are presented below; what is evident is the # learning curves of GAT and GCN are presented below; what is evident is the
# dramatic performance adavantage of GAT over GCN. # dramatic performance adavantage of GAT over GCN.
# #
...@@ -461,19 +457,19 @@ for epoch in range(30): ...@@ -461,19 +457,19 @@ for epoch in range(30):
# :width: 300px # :width: 300px
# :align: center # :align: center
# #
# As before, we can have a statistical understanding of the attentions learnt # As before, you can have a statistical understanding of the attentions learned
# by showing the histogram plot for the node-wise attention entropy. Below are # by showing the histogram plot for the node-wise attention entropy. Below are
# the attention histogram learnt by different attention layers. # the attention histograms learned by different attention layers.
# #
# *Attention learnt in layer 1:* # *Attention learned in layer 1:*
# #
# |image5| # |image5|
# #
# *Attention learnt in layer 2:* # *Attention learned in layer 2:*
# #
# |image6| # |image6|
# #
# *Attention learnt in final layer:* # *Attention learned in final layer:*
# #
# |image7| # |image7|
# #
...@@ -484,27 +480,27 @@ for epoch in range(30): ...@@ -484,27 +480,27 @@ for epoch in range(30):
# :align: center # :align: center
# #
# Clearly, **GAT does learn sharp attention weights**! There is a clear pattern # Clearly, **GAT does learn sharp attention weights**! There is a clear pattern
# over the layers as well: **the attention gets sharper with higher # over the layers as well: **the attention gets sharper with a higher
# layer**. # layer**.
# #
# Unlike the Cora dataset where GAT's gain is lukewarm at best, for PPI there # Unlike the Cora dataset where GAT's gain is minimal at best, for PPI there
# is a significant performance gap between GAT and other GNN variants compared # is a significant performance gap between GAT and other GNN variants compared
# in `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ (at least 20%), # in `the GAT paper <https://arxiv.org/pdf/1710.10903.pdf>`_ (at least 20 percent),
# and the attention distributions between the two clearly differ. While this # and the attention distributions between the two clearly differ. While this
# deserves further research, one immediate conclusion is that GAT's advantage # deserves further research, one immediate conclusion is that GAT's advantage
# lies perhaps more in its ability to handle a graph with more complex # lies perhaps more in its ability to handle a graph with more complex
# neighborhood structure. # neighborhood structure.
# #
# What's Next? # What's next?
# ------------ # ------------
# #
# So far, we demonstrated how to use DGL to implement GAT. There are some # So far, you have seen how to use DGL to implement GAT. There are some
# missing details such as dropout, skip connections and hyper-parameter tuning, # missing details such as dropout, skip connections, and hyper-parameter tuning,
# which are common practices and do not involve DGL-related concepts. We refer # which are practices that do not involve DGL-related concepts. For more information
# interested readers to the full example. # check out the full example.
# #
# * See the optimized full example `here <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_. # * See the optimized `full example <https://github.com/dmlc/dgl/blob/master/examples/pytorch/gat/gat.py>`_.
# * Stay tune for our next tutorial about how to speedup GAT models by parallelizing multiple attention heads and SPMV optimization. # * The next tutorial describes how to speedup GAT models by parallelizing multiple attention heads and SPMV optimization.
# #
# .. |image2| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/cora-attention-hist.png # .. |image2| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/cora-attention-hist.png
# .. |image5| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-first-layer-hist.png # .. |image5| image:: https://s3.us-east-2.amazonaws.com/dgl.ai/tutorial/gat/ppi-first-layer-hist.png
......
...@@ -8,7 +8,7 @@ NodeFlow and Sampling ...@@ -8,7 +8,7 @@ NodeFlow and Sampling
""" """
################################################################################################ ################################################################################################
# #
# GCN # Graph convolutional network
# ~~~ # ~~~
# #
# In an :math:`L`-layer graph convolution network (GCN), given a graph # In an :math:`L`-layer graph convolution network (GCN), given a graph
...@@ -29,7 +29,7 @@ NodeFlow and Sampling ...@@ -29,7 +29,7 @@ NodeFlow and Sampling
# function, and :math:`W^{(l)}` is a trainable parameter of the # function, and :math:`W^{(l)}` is a trainable parameter of the
# :math:`l`-th layer. # :math:`l`-th layer.
# #
# In the node classification task we minimize the following loss: # In the node classification task you minimize the following loss:
# #
# .. math:: # .. math::
# #
...@@ -43,11 +43,11 @@ NodeFlow and Sampling ...@@ -43,11 +43,11 @@ NodeFlow and Sampling
# features of its neighbors to compute its hidden feature in the next # features of its neighbors to compute its hidden feature in the next
# layer. # layer.
# #
# In this tutorial, we will run GCN on the Reddit dataset constructed by `Hamilton et # In this tutorial, you run GCN on the Reddit dataset constructed by `Hamilton et
# al. <https://arxiv.org/abs/1706.02216>`__, wherein the nodes are posts # al. <https://arxiv.org/abs/1706.02216>`__, wherein the nodes are posts
# and edges are established if two nodes are commented by a same user. The # and edges are established if two nodes are commented by a same user. The
# task is to predict the category that a post belongs to. This graph has # task is to predict the category that a post belongs to. This graph has
# 233K nodes, 114.6M edges and 41 categories. Let's first load the Reddit graph. # 233,000 nodes, 114.6 million edges and 41 categories. First load the Reddit graph.
# #
import numpy as np import numpy as np
import dgl import dgl
...@@ -73,7 +73,7 @@ g = DGLGraph(data.graph, readonly=True) ...@@ -73,7 +73,7 @@ g = DGLGraph(data.graph, readonly=True)
g.ndata['features'] = features g.ndata['features'] = features
################################################################################################ ################################################################################################
# Here we define the node UDF which has a fully-connected layer: # Here you define the node UDF, which has a fully-connected layer:
# #
class NodeUpdate(gluon.Block): class NodeUpdate(gluon.Block):
...@@ -90,7 +90,7 @@ class NodeUpdate(gluon.Block): ...@@ -90,7 +90,7 @@ class NodeUpdate(gluon.Block):
return {'activation': h} return {'activation': h}
################################################################################################ ################################################################################################
# In DGL, we implement GCN on the full graph with ``update_all`` in ``DGLGraph``. # In DGL, you implement GCN on the full graph with ``update_all`` in ``DGLGraph``.
# The following code performs two-layer GCN on the Reddit graph. # The following code performs two-layer GCN on the Reddit graph.
# #
...@@ -119,7 +119,7 @@ for i in range(L): ...@@ -119,7 +119,7 @@ for i in range(L):
# As the graph scales up to billions of nodes or edges, training on the # As the graph scales up to billions of nodes or edges, training on the
# full graph would no longer be efficient or even feasible. # full graph would no longer be efficient or even feasible.
# #
# Mini-batch training allows us to control the computation and memory # Mini-batch training allows you to control the computation and memory
# usage within some budget. The training loss for each iteration is # usage within some budget. The training loss for each iteration is
# #
# .. math:: # .. math::
...@@ -132,7 +132,7 @@ for i in range(L): ...@@ -132,7 +132,7 @@ for i in range(L):
# #
# Stemming from the labeled nodes :math:`\tilde{\mathcal{V}}_\mathcal{L}` # Stemming from the labeled nodes :math:`\tilde{\mathcal{V}}_\mathcal{L}`
# in a mini-batch and tracing back to the input forms a computational # in a mini-batch and tracing back to the input forms a computational
# dependency graph (a directed acyclic graph or DAG in short), which # dependency graph (a directed acyclic graph [DAG]), which
# captures the computation flow of :math:`Z^{(L)}`. # captures the computation flow of :math:`Z^{(L)}`.
# #
# In the example below, a mini-batch to compute the hidden features of # In the example below, a mini-batch to compute the hidden features of
...@@ -141,16 +141,16 @@ for i in range(L): ...@@ -141,16 +141,16 @@ for i in range(L):
# #
# |image0| # |image0|
# #
# For that purpose, we define ``NodeFlow`` to represent this computation # For that purpose, you define ``NodeFlow`` to represent this computation
# flow. # flow.
# #
# ``NodeFlow`` is a type of layered graph, where nodes are organized in # ``NodeFlow`` is a type of layered graph, where nodes are organized in
# :math:`L + 1` sequential *layers*, and edges only exist between adjacent # :math:`L + 1` sequential *layers*, and edges only exist between adjacent
# layers, forming *blocks*. We construct ``NodeFlow`` backwards, starting # layers, forming *blocks*. You construct ``NodeFlow`` backwards, starting
# from the last layer with all the nodes whose hidden features are # from the last layer with all the nodes whose hidden features are
# requested. The set of nodes the next layer depends on forms the previous # requested. The set of nodes the next layer depends on forms the previous
# layer. An edge connects a node in the previous layer to another in the # layer. An edge connects a node in the previous layer to another in the
# next layer iff the latter depends on the former. We repeat such process # next layer if the latter depends on the former. Repeat such process
# until all :math:`L + 1` layers are constructed. The feature of nodes in # until all :math:`L + 1` layers are constructed. The feature of nodes in
# each layer, and that of edges in each block, are stored as separate # each layer, and that of edges in each block, are stored as separate
# tensors. # tensors.
...@@ -163,11 +163,11 @@ for i in range(L): ...@@ -163,11 +163,11 @@ for i in range(L):
# #
############################################################################## ##############################################################################
# Neighbor Sampling # Neighbor sampling
# ~~~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~~~
# #
# Real-world graphs often have nodes with large degree, meaning that a # Real-world graphs often have nodes with large degree, meaning that a
# moderately deep (e.g. 3 layers) GCN would often depend on input features # moderately deep (e.g., three layers) GCN would often depend on input features
# of the entire graph, even if the computation only depends on outputs of # of the entire graph, even if the computation only depends on outputs of
# a few nodes, hence its cost-ineffectiveness. # a few nodes, hence its cost-ineffectiveness.
# #
...@@ -195,7 +195,7 @@ for i in range(L): ...@@ -195,7 +195,7 @@ for i in range(L):
# #
############################################################################## ##############################################################################
# We then implement *neighbor smapling* by ``NodeFlow``: # You then implement *neighbor sampling* by ``NodeFlow``:
# #
class GCNSampling(gluon.Block): class GCNSampling(gluon.Block):
...@@ -229,7 +229,7 @@ class GCNSampling(gluon.Block): ...@@ -229,7 +229,7 @@ class GCNSampling(gluon.Block):
nf.layers[i].data['h'] = h nf.layers[i].data['h'] = h
# block_compute() computes the feature of layer i given layer # block_compute() computes the feature of layer i given layer
# i-1, with the given message, reduce, and apply functions. # i-1, with the given message, reduce, and apply functions.
# Here, we essentially aggregate the neighbor node features in # Here, you essentially aggregate the neighbor node features in
# the previous layer, and update it with the `layer` function. # the previous layer, and update it with the `layer` function.
nf.block_compute(i, nf.block_compute(i,
fn.copy_src(src='h', out='m'), fn.copy_src(src='h', out='m'),
...@@ -244,8 +244,8 @@ class GCNSampling(gluon.Block): ...@@ -244,8 +244,8 @@ class GCNSampling(gluon.Block):
# ``NeighborSampler`` # ``NeighborSampler``
# returns an iterator that generates a ``NodeFlow`` each time. This function # returns an iterator that generates a ``NodeFlow`` each time. This function
# has many options to give users opportunities to customize the behavior # has many options to give users opportunities to customize the behavior
# of the neighbor sampler, including the number of neighbors to sample, # of the neighbor sampler, including the number of neighbors to sample or
# the number of hops to sample, etc. Please see `its API # the number of hops to sample, for example. Please see `its API
# document <https://doc.dgl.ai/api/python/sampler.html>`__ for more # document <https://doc.dgl.ai/api/python/sampler.html>`__ for more
# details. # details.
# #
...@@ -296,12 +296,12 @@ for epoch in range(num_epochs): ...@@ -296,12 +296,12 @@ for epoch in range(num_epochs):
trainer.step(batch_size=1) trainer.step(batch_size=1)
print("Epoch[{}]: loss {}".format(epoch, loss.asscalar())) print("Epoch[{}]: loss {}".format(epoch, loss.asscalar()))
i += 1 i += 1
# We only train the model with 32 mini-batches just for demonstration. # You only train the model with 32 mini-batches just for demonstration.
if i >= 32: if i >= 32:
break break
############################################################################## ##############################################################################
# Control Variate # Control variate
# ~~~~~~~~~~~~~~~ # ~~~~~~~~~~~~~~~
# #
# The unbiased estimator :math:`\hat{Z}^{(\cdot)}` used in *neighbor # The unbiased estimator :math:`\hat{Z}^{(\cdot)}` used in *neighbor
...@@ -312,8 +312,8 @@ for epoch in range(num_epochs): ...@@ -312,8 +312,8 @@ for epoch in range(num_epochs):
# standard variance reduction technique widely used in Monte Carlo # standard variance reduction technique widely used in Monte Carlo
# methods, 2 neighbors for a node seems sufficient. # methods, 2 neighbors for a node seems sufficient.
# #
# *Control variate* method works as follows: given a random variable # *Control variate* method works as follows: Given a random variable
# :math:`X` and we wish to estimate its expectation # :math:`X` and you wish to estimate its expectation
# :math:`\mathbb{E} [X] = \theta`, it finds another random variable # :math:`\mathbb{E} [X] = \theta`, it finds another random variable
# :math:`Y` which is highly correlated with :math:`X` and whose # :math:`Y` which is highly correlated with :math:`X` and whose
# expectation :math:`\mathbb{E} [Y]` can be easily computed. The *control # expectation :math:`\mathbb{E} [Y]` can be easily computed. The *control
...@@ -338,7 +338,7 @@ for epoch in range(num_epochs): ...@@ -338,7 +338,7 @@ for epoch in range(num_epochs):
# \hat{h}_v^{(l+1)} = \sigma ( \hat{z}_v^{(l+1)} W^{(l)} ) # \hat{h}_v^{(l+1)} = \sigma ( \hat{z}_v^{(l+1)} W^{(l)} )
# #
# This method can also be *conceptually* implemented in DGL as shown # This method can also be *conceptually* implemented in DGL as shown
# below, # here.
# #
have_large_memory = False have_large_memory = False
...@@ -348,7 +348,7 @@ if have_large_memory: ...@@ -348,7 +348,7 @@ if have_large_memory:
g.ndata['h_0'] = features g.ndata['h_0'] = features
for i in range(L): for i in range(L):
g.ndata['h_{}'.format(i+1)] = mx.nd.zeros((features.shape[0], n_hidden)) g.ndata['h_{}'.format(i+1)] = mx.nd.zeros((features.shape[0], n_hidden))
# With control-variate sampling, we only need to sample 2 neighbors to train GCN. # With control-variate sampling, you only need to sample two neighbors to train GCN.
for nf in dgl.contrib.sampling.NeighborSampler(g, batch_size, expand_factor=2, for nf in dgl.contrib.sampling.NeighborSampler(g, batch_size, expand_factor=2,
neighbor_type='in', num_hops=L, neighbor_type='in', num_hops=L,
seed_nodes=train_nid): seed_nodes=train_nid):
...@@ -387,7 +387,7 @@ if have_large_memory: ...@@ -387,7 +387,7 @@ if have_large_memory:
# Below shows the performance of graph convolution network and GraphSage # Below shows the performance of graph convolution network and GraphSage
# with neighbor sampling and control variate sampling on the Reddit # with neighbor sampling and control variate sampling on the Reddit
# dataset. Our GraphSage with control variate sampling, when sampling one # dataset. Our GraphSage with control variate sampling, when sampling one
# neighbor, can achieve over 96% test accuracy. |image1| # neighbor, can achieve over 96 percent test accuracy. |image1|
# #
# More APIs # More APIs
# ~~~~~~~~~ # ~~~~~~~~~
......
...@@ -8,11 +8,11 @@ Training on giant graphs ...@@ -8,11 +8,11 @@ Training on giant graphs
<5_giant_graph/1_sampling_mx.html>`__ `[MXNet code] <5_giant_graph/1_sampling_mx.html>`__ `[MXNet code]
<https://github.com/dmlc/dgl/tree/master/examples/mxnet/sampling>`__ `[Pytorch code] <https://github.com/dmlc/dgl/tree/master/examples/mxnet/sampling>`__ `[Pytorch code]
<https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling>`__: <https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling>`__:
we can perform neighbor sampling and control-variate sampling to train You can perform neighbor sampling and control-variate sampling to train a
graph convolution networks and its variants on a giant graph. graph convolution network and its variants on a giant graph.
* **Scale to giant graphs** `[tutorial] <5_giant_graph/2_giant.html>`__ * **Scale to giant graphs** `[tutorial] <5_giant_graph/2_giant.html>`__
`[MXNet code] <https://github.com/dmlc/dgl/tree/master/examples/mxnet/sampling>`__ `[MXNet code] <https://github.com/dmlc/dgl/tree/master/examples/mxnet/sampling>`__
`[Pytorch code] `[Pytorch code]
<https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling>`__: <https://github.com/dmlc/dgl/tree/master/examples/pytorch/sampling>`__:
We provide two components (graph store and distributed sampler) to scale to You can find two components (graph store and distributed sampler) to scale to
graphs with hundreds of millions of nodes. graphs with hundreds of millions of nodes.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment