Unverified Commit 98ac391e authored by Quan (Andy) Gan's avatar Quan (Andy) Gan Committed by GitHub
Browse files

[Doc] Large graph training tutorials (#2595)

* large graph training tutorials

* add full graph link prediction update

* fix

* fix

* fix

* addressed comments
parent 0346b0aa
...@@ -81,7 +81,30 @@ Getting Started ...@@ -81,7 +81,30 @@ Getting Started
install/index install/index
install/backend install/backend
.. toctree::
:maxdepth: 2
:caption: Basic Tutorials
:hidden:
:glob:
new-tutorial/1_introduction new-tutorial/1_introduction
new-tutorial/2_dglgraph
new-tutorial/3_message_passing
new-tutorial/4_link_predict
new-tutorial/5_graph_classification
new-tutorial/6_load_data
.. toctree::
:maxdepth: 2
:caption: Stochastic GNN Training Tutorials
:hidden:
:glob:
new-tutorial/L0_neighbor_sampling_overview
new-tutorial/L1_large_node_classification
new-tutorial/L2_large_link_prediction
new-tutorial/L4_message_passing
.. toctree:: .. toctree::
:maxdepth: 2 :maxdepth: 2
......
...@@ -2,18 +2,19 @@ ...@@ -2,18 +2,19 @@
Link Prediction using Graph Neural Networks Link Prediction using Graph Neural Networks
=========================================== ===========================================
In the :doc:`introduction <1_introduction>`, you have already learned the In the :doc:`introduction <1_introduction>`, you have already learned
basic workflow of using GNNs for node classification, i.e. predicting the basic workflow of using GNNs for node classification,
the category of a node in a graph. This tutorial will teach you how to i.e. predicting the category of a node in a graph. This tutorial will
train a GNN for link prediction, i.e. predicting the existence of an teach you how to train a GNN for link prediction, i.e. predicting the
edge between two arbitrary nodes in a graph. existence of an edge between two arbitrary nodes in a graph.
By the end of this tutorial you will be able to By the end of this tutorial you will be able to
- Build a GNN-based link prediction model. - Build a GNN-based link prediction model.
- Train and evaluate the model on a small DGL-provided dataset. - Train and evaluate the model on a small DGL-provided dataset.
(Time estimate: 20 minutes) (Time estimate: 28 minutes)
""" """
import dgl import dgl
...@@ -28,19 +29,19 @@ import scipy.sparse as sp ...@@ -28,19 +29,19 @@ import scipy.sparse as sp
###################################################################### ######################################################################
# Overview of Link Prediction with GNN # Overview of Link Prediction with GNN
# ------------------------------------ # ------------------------------------
# #
# Many applications such as social recommendation, item recommendation, # Many applications such as social recommendation, item recommendation,
# knowledge graph completion, etc., can be formulated as link prediction, # knowledge graph completion, etc., can be formulated as link prediction,
# which predicts whether an edge exists between two particular nodes. This # which predicts whether an edge exists between two particular nodes. This
# tutorial shows an example of predicting whether a citation relationship, # tutorial shows an example of predicting whether a citation relationship,
# either citing or being cited, between two papers exists in a citation # either citing or being cited, between two papers exists in a citation
# network. # network.
# #
# This tutorial follows a relatively simple practice from # This tutorial follows a relatively simple practice from
# `SEAL <https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf>`__. # `SEAL <https://papers.nips.cc/paper/2018/file/53f0d7c537d99b3824f0f99d62ea2428-Paper.pdf>`__.
# It formulates the link prediction problem as a binary classification # It formulates the link prediction problem as a binary classification
# problem as follows: # problem as follows:
# #
# - Treat the edges in the graph as *positive examples*. # - Treat the edges in the graph as *positive examples*.
# - Sample a number of non-existent edges (i.e. node pairs with no edges # - Sample a number of non-existent edges (i.e. node pairs with no edges
# between them) as *negative* examples. # between them) as *negative* examples.
...@@ -48,19 +49,19 @@ import scipy.sparse as sp ...@@ -48,19 +49,19 @@ import scipy.sparse as sp
# set and a test set. # set and a test set.
# - Evaluate the model with any binary classification metric such as Area # - Evaluate the model with any binary classification metric such as Area
# Under Curve (AUC). # Under Curve (AUC).
# #
# In some domains such as large-scale recommender systems or information # In some domains such as large-scale recommender systems or information
# retrieval, you may favor metrics that emphasize good performance of # retrieval, you may favor metrics that emphasize good performance of
# top-K predictions. In these cases you may want to consider other metrics # top-K predictions. In these cases you may want to consider other metrics
# such as mean average precision, and use other negative sampling methods, # such as mean average precision, and use other negative sampling methods,
# which are beyond the scope of this tutorial. # which are beyond the scope of this tutorial.
# #
# Loading graph and features # Loading graph and features
# -------------------------- # --------------------------
# #
# Following the :doc:`introduction <1_introduction>`, we first load the # Following the :doc:`introduction <1_introduction>`, this tutorial
# Cora dataset. # first loads the Cora dataset.
# #
import dgl.data import dgl.data
...@@ -69,13 +70,13 @@ g = dataset[0] ...@@ -69,13 +70,13 @@ g = dataset[0]
###################################################################### ######################################################################
# Preparing training and testing sets # Prepare training and testing sets
# ----------------------------------- # ---------------------------------
# #
# This tutorial randomly picks 10% of the edges for positive examples in # This tutorial randomly picks 10% of the edges for positive examples in
# the test set, and leave the rest for the training set. It then samples # the test set, and leave the rest for the training set. It then samples
# the same number of edges for negative examples in both sets. # the same number of edges for negative examples in both sets.
# #
# Split edge set for training and testing # Split edge set for training and testing
u, v = g.edges() u, v = g.edges()
...@@ -96,16 +97,6 @@ neg_eids = np.random.choice(len(neg_u), g.number_of_edges() // 2) ...@@ -96,16 +97,6 @@ neg_eids = np.random.choice(len(neg_u), g.number_of_edges() // 2)
test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]] test_neg_u, test_neg_v = neg_u[neg_eids[:test_size]], neg_v[neg_eids[:test_size]]
train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]] train_neg_u, train_neg_v = neg_u[neg_eids[test_size:]], neg_v[neg_eids[test_size:]]
# Create training set.
train_u = torch.cat([torch.as_tensor(train_pos_u), torch.as_tensor(train_neg_u)])
train_v = torch.cat([torch.as_tensor(train_pos_v), torch.as_tensor(train_neg_v)])
train_label = torch.cat([torch.zeros(len(train_pos_u)), torch.ones(len(train_neg_u))])
# Create testing set.
test_u = torch.cat([torch.as_tensor(test_pos_u), torch.as_tensor(test_neg_u)])
test_v = torch.cat([torch.as_tensor(test_pos_v), torch.as_tensor(test_neg_v)])
test_label = torch.cat([torch.zeros(len(test_pos_u)), torch.ones(len(test_neg_u))])
###################################################################### ######################################################################
# When training, you will need to remove the edges in the test set from # When training, you will need to remove the edges in the test set from
...@@ -113,24 +104,24 @@ test_label = torch.cat([torch.zeros(len(test_pos_u)), torch.ones(len(test_neg_u) ...@@ -113,24 +104,24 @@ test_label = torch.cat([torch.zeros(len(test_pos_u)), torch.ones(len(test_neg_u)
# #
# .. note:: # .. note::
# #
# ``dgl.remove_edges`` works by creating a subgraph from the original # ``dgl.remove_edges`` works by creating a subgraph from the
# graph, resulting in a copy and therefore could be slow for large # original graph, resulting in a copy and therefore could be slow for
# graphs. If so, you could save the training and test graph to # large graphs. If so, you could save the training and test graph to
# disk, as you would do for preprocessing. # disk, as you would do for preprocessing.
# #
train_g = dgl.remove_edges(g, eids[:test_size]) train_g = dgl.remove_edges(g, eids[:test_size])
###################################################################### ######################################################################
# Defining a GraphSAGE model # Define a GraphSAGE model
# -------------------------- # ------------------------
# #
# This tutorial builds a model consisting of two # This tutorial builds a model consisting of two
# `GraphSAGE <https://arxiv.org/abs/1706.02216>`__ layers, each computes # `GraphSAGE <https://arxiv.org/abs/1706.02216>`__ layers, each computes
# new node representations by averaging neighbor information. DGL provides # new node representations by averaging neighbor information. DGL provides
# ``dgl.nn.SAGEConv`` that conveniently creates a GraphSAGE layer. # ``dgl.nn.SAGEConv`` that conveniently creates a GraphSAGE layer.
# #
from dgl.nn import SAGEConv from dgl.nn import SAGEConv
...@@ -147,46 +138,187 @@ class GraphSAGE(nn.Module): ...@@ -147,46 +138,187 @@ class GraphSAGE(nn.Module):
h = F.relu(h) h = F.relu(h)
h = self.conv2(g, h) h = self.conv2(g, h)
return h return h
model = GraphSAGE(train_g.ndata['feat'].shape[1], 16)
###################################################################### ######################################################################
# The model then predicts the probability of existence of an edge by # The model then predicts the probability of existence of an edge by
# computing a dot product between the representations of both incident # computing a score between the representations of both incident nodes
# nodes. # with a function (e.g. an MLP or a dot product), which you will see in
# # the next section.
#
# .. math:: # .. math::
# #
# #
# \hat{y}_{u\sim v} = \sigma(h_u^T h_v) # \hat{y}_{u\sim v} = f(h_u, h_v)
# #
######################################################################
# Positive graph, negative graph, and ``apply_edges``
# ---------------------------------------------------
#
# In previous tutorials you have learned how to compute node
# representations with a GNN. However, link prediction requires you to
# compute representation of *pairs of nodes*.
#
# DGL recommends you to treat the pairs of nodes as another graph, since
# you can describe a pair of nodes with an edge. In link prediction, you
# will have a *positive graph* consisting of all the positive examples as
# edges, and a *negative graph* consisting of all the negative examples.
# The *positive graph* and the *negative graph* will contain the same set
# of nodes as the original graph. This makes it easier to pass node
# features among multiple graphs for computation. As you will see later,
# you can directly fed the node representations computed on the entire
# graph to the positive and the negative graphs for computing pair-wise
# scores.
#
# The following code constructs the positive graph and the negative graph
# for the training set and the test set respectively.
#
train_pos_g = dgl.graph((train_pos_u, train_pos_v), num_nodes=g.number_of_nodes())
train_neg_g = dgl.graph((train_neg_u, train_neg_v), num_nodes=g.number_of_nodes())
test_pos_g = dgl.graph((test_pos_u, test_pos_v), num_nodes=g.number_of_nodes())
test_neg_g = dgl.graph((test_neg_u, test_neg_v), num_nodes=g.number_of_nodes())
######################################################################
# The benefit of treating the pairs of nodes as a graph is that you can
# use the ``DGLGraph.apply_edges`` method, which conveniently computes new
# edge features based on the incident nodes’ features and the original
# edge features (if applicable).
#
# DGL provides a set of optimized builtin functions to compute new
# edge features based on the original node/edge features. For example,
# ``dgl.function.u_dot_v`` computes a dot product of the incident nodes’
# representations for each edge.
#
import dgl.function as fn
class DotPredictor(nn.Module):
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
# Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata['score'][:, 0]
######################################################################
# You can also write your own function if it is complex.
# For instance, the following module produces a scalar score on each edge
# by concatenating the incident nodes’ features and passing it to an MLP.
#
class MLPPredictor(nn.Module):
def __init__(self, h_feats):
super().__init__()
self.W1 = nn.Linear(h_feats * 2, h_feats)
self.W2 = nn.Linear(h_feats, 1)
def apply_edges(self, edges):
"""
Computes a scalar score for each edge of the given graph.
Parameters
----------
edges :
Has three members ``src``, ``dst`` and ``data``, each of
which is a dictionary representing the features of the
source nodes, the destination nodes, and the edges
themselves.
Returns
-------
dict
A dictionary of new edge features.
"""
h = torch.cat([edges.src['h'], edges.dst['h']], 1)
return {'score': self.W2(F.relu(self.W1(h))).squeeze(1)}
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
g.apply_edges(self.apply_edges)
return g.edata['score']
######################################################################
# .. note::
#
# The builtin functions are optimized for both speed and memory.
# We recommend using builtin functions whenever possible.
#
# .. note::
#
# If you have read the :doc:`message passing
# tutorial <3_message_passing>`, you will notice that the
# argument ``apply_edges`` takes has exactly the same form as a message
# function in ``update_all``.
#
######################################################################
# Training loop
# -------------
#
# After you defined the node representation computation and the edge score
# computation, you can go ahead and define the overall model, loss
# function, and evaluation metric.
#
# The loss function is simply binary cross entropy loss. # The loss function is simply binary cross entropy loss.
# #
# .. math:: # .. math::
# #
# #
# \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right) # \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
# #
# The evaluation metric in this tutorial is AUC.
#
model = GraphSAGE(train_g.ndata['feat'].shape[1], 16)
# You can replace DotPredictor with MLPPredictor.
#pred = MLPPredictor(16)
pred = DotPredictor()
def compute_loss(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score])
labels = torch.cat([torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])])
return F.binary_cross_entropy_with_logits(scores, labels)
def compute_auc(pos_score, neg_score):
scores = torch.cat([pos_score, neg_score]).numpy()
labels = torch.cat(
[torch.ones(pos_score.shape[0]), torch.zeros(neg_score.shape[0])]).numpy()
return roc_auc_score(labels, scores)
######################################################################
# The training loop goes as follows:
#
# .. note:: # .. note::
# #
# This tutorial does not include evaluation on a validation # This tutorial does not include evaluation on a validation
# set. In practice you should save and evaluate the best model based on # set. In practice you should save and evaluate the best model based on
# performance on the validation set. # performance on the validation set.
# #
# ----------- 3. set up loss and optimizer -------------- # # ----------- 3. set up loss and optimizer -------------- #
# in this case, loss will in training loop # in this case, loss will in training loop
optimizer = torch.optim.Adam(itertools.chain(model.parameters()), lr=0.01) optimizer = torch.optim.Adam(itertools.chain(model.parameters(), pred.parameters()), lr=0.01)
# ----------- 4. training -------------------------------- # # ----------- 4. training -------------------------------- #
all_logits = []
for e in range(100): for e in range(100):
# forward # forward
logits = model(train_g, train_g.ndata['feat']) h = model(train_g, train_g.ndata['feat'])
pred = torch.sigmoid((logits[train_u] * logits[train_v]).sum(dim=1)) pos_score = pred(train_pos_g, h)
neg_score = pred(train_neg_g, h)
# compute loss loss = compute_loss(pos_score, neg_score)
loss = F.binary_cross_entropy(pred, train_label)
# backward # backward
optimizer.zero_grad() optimizer.zero_grad()
...@@ -199,9 +331,7 @@ for e in range(100): ...@@ -199,9 +331,7 @@ for e in range(100):
# ----------- 5. check results ------------------------ # # ----------- 5. check results ------------------------ #
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
with torch.no_grad(): with torch.no_grad():
pred = torch.sigmoid((logits[test_u] * logits[test_v]).sum(dim=1)) pos_score = pred(test_pos_g, h)
pred = pred.numpy() neg_score = pred(test_neg_g, h)
label = test_label.numpy() print('AUC', compute_auc(pos_score, neg_score))
print('AUC', roc_auc_score(label, pred))
"""
Introduction of Neighbor Sampling for GNN Training
==================================================
In :doc:`previous tutorials <1_introduction>` you have learned how to
train GNNs by computing the representations of all nodes on a graph.
However, sometimes your graph is too large to fit the computation of all
nodes in a single GPU.
By the end of this tutorial, you will be able to
- Understand the pipeline of stochastic GNN training.
- Understand what is neighbor sampling and why it yields a bipartite
graph for each GNN layer.
"""
######################################################################
# Message Passing Review
# ----------------------
#
# Recall that in `Gilmer et al. <https://arxiv.org/abs/1704.01212>`__
# (also in :doc:`message passing tutorial <3_message_passing>`), the
# message passing formulation is as follows:
#
# .. math::
#
#
# m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right)
#
# .. math::
#
#
# m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)}
#
# .. math::
#
#
# h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right)
#
# where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the
# *reduce function* and :math:`U^{(l)}` the *update function*. Note that
# :math:`\sum` here can represent any function and is not necessarily a
# summation.
#
# Essentially, the :math:`l`-th layer representation of a single node
# depends on the :math:`(l-1)`-th layer representation of the same node,
# as well as the :math:`(l-1)`-th layer representation of the neighboring
# nodes. Those :math:`(l-1)`-th layer representations then depend on the
# :math:`(l-2)`-th layer representation of those nodes, as well as their
# neighbors.
#
# The following animation shows how a 2-layer GNN is supposed to compute
# the output of node 5:
#
# |image1|
#
# You can see that to compute node 5 from the second layer, you will need
# its direct neighbors’ first layer representations (colored in yellow),
# which in turn needs their direct neighbors’ (i.e. node 5’s second-hop
# neighbors’) representations (colored in green).
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/sampling.gif
#
######################################################################
# Neighbor Sampling Overview
# --------------------------
#
# You can also see from the previous example that computing representation
# for a small number of nodes often requires input features of a
# significantly larger number of nodes. Taking all neighbors for message
# aggregation is often too costly since the nodes needed for input
# features would easily cover a large portion of the graph, especially for
# real-world graphs which are often
# `scale-free <https://en.wikipedia.org/wiki/Scale-free_network>`__.
#
# Neighbor sampling addresses this issue by selecting a subset of the
# neighbors to perform aggregation. For instance, to compute
# :math:`\boldsymbol{h}_8^{(2)}`, you can choose two of the neighbors
# instead of all of them to aggregate, as in the following animation:
#
# |image2|
#
# You can see that this method uses much fewer nodes needed in message
# passing for a single minibatch.
#
# .. |image2| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#
######################################################################
# You can also notice in the animation above that the computation
# dependencies in the animation above can be described as a series of
# *bipartite graphs*.
# The output nodes are on one side and all the nodes necessary for inputs
# are on the other side. The arrows indicate how the sampled neighbors
# propagates messages to the nodes.
#
# Note that some GNN modules, such as `SAGEConv`, need to use the output
# nodes' features on the previous layer to compute the outputs. Without
# loss of generality, DGL always includes the output nodes themselves
# in the input nodes.
#
######################################################################
# What’s next?
# ------------
#
# :doc:`Stochastic GNN Training for Node Classification in
# DGL <L1_large_node_classification>`
#
"""
Training GNN with Neighbor Sampling for Node Classification
===========================================================
This tutorial shows how to train a multi-layer GraphSAGE for node
classification on Amazon Co-purchase Network provided by `Open Graph
Benchmark (OGB) <https://ogb.stanford.edu/>`__. The dataset contains 2.4
million nodes and 61 million edges.
By the end of this tutorial, you will be able to
- Train a GNN model for node classification on a single GPU with DGL's
neighbor sampling components.
This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>`.
"""
######################################################################
# Loading Dataset
# ---------------
#
# OGB already prepared the data as DGL graph.
#
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-products')
######################################################################
# OGB dataset is a collection of graphs and their labels. The Amazon
# Co-purchase Network dataset only contains a single graph. So you can
# simply get the graph and its node labels like this:
#
graph, node_labels = dataset[0]
graph.ndata['label'] = node_labels[:, 0]
print(graph)
print(node_labels)
node_features = graph.ndata['feat']
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)
######################################################################
# You can get the training-validation-test split of the nodes with
# ``get_split_idx`` method.
#
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']
######################################################################
# How DGL Handles Computation Dependency
# --------------------------------------
#
# In the :doc:`previous tutorial <L0_neighbor_sampling_overview>`, you
# have seen that the computation dependency for message passing of a
# single node can be described as a series of bipartite graphs.
#
# |image1|
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#
######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# DGL provides tools to iterate over the dataset in minibatches
# while generating the computation dependencies to compute their outputs
# with the bipartite graphs above. For node classification, you can use
# ``dgl.dataloading.NodeDataLoader`` for iterating over the dataset.
# It accepts a sampler object to control how to generate the computation
# dependencies in the form of bipartite graphs. DGL provides
# implementations of common sampling algorithms such as
# ``dgl.dataloading.MultiLayerNeighborSampler`` which randomly picks
# a fixed number of neighbors for each node.
#
# .. note::
#
# To write your own neighbor sampler, please refer to :ref:`this user
# guide section <guide-minibatch-customizing-neighborhood-sampler>`.
#
# The syntax of ``dgl.dataloading.NodeDataLoader`` is mostly similar to a
# PyTorch ``DataLoader``, with the addition that it needs a graph to
# generate computation dependency from, a set of node IDs to iterate on,
# and the neighbor sampler you defined.
#
# Let’s say that each node will gather messages from 4 neighbors on each
# layer. The code defining the data loader and neighbor sampler will look
# like the following.
#
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
# The following arguments are specific to NodeDataLoader.
graph, # The graph
train_nids, # The node IDs to iterate over in minibatches
sampler, # The neighbor sampler
device='cuda', # Put the sampled bipartite graphs to GPU
# The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes
)
######################################################################
# You can iterate over the data loader and see what it yields.
#
input_nodes, output_nodes, bipartites = example_minibatch = next(iter(train_dataloader))
print(example_minibatch)
print("To compute {} nodes' outputs, we need {} nodes' input features".format(len(output_nodes), len(input_nodes)))
######################################################################
# ``NodeDataLoader`` gives us three items per iteration.
#
# - An ID tensor for the input nodes, i.e., nodes whose input features
# are needed on the first GNN layer for this minibatch.
# - An ID tensor for the output nodes, i.e. nodes whose representations
# are to be computed.
# - A list of bipartite graphs storing the computation dependencies
# for each GNN layer.
#
######################################################################
# You can get the input and output node IDs of the bipartite graphs
# and verify that the first few input nodes are always the same as the output
# nodes. As we described in the :doc:`overview <L0_neighbor_sampling_overview>`,
# output nodes' own features from the previous layer may also be necessary in
# the computation of the new features.
#
bipartite_0_src = bipartites[0].srcdata[dgl.NID]
bipartite_0_dst = bipartites[0].dstdata[dgl.NID]
print(bipartite_0_src)
print(bipartite_0_dst)
print(torch.equal(bipartite_0_src[:bipartites[0].num_dst_nodes()], bipartite_0_dst))
######################################################################
# Defining Model
# --------------
#
# Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The
# model can be written as follows:
#
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean')
self.h_feats = h_feats
def forward(self, bipartites, x):
# Lines that are changed are marked with an arrow: "<---"
h_dst = x[:bipartites[0].num_dst_nodes()] # <---
h = self.conv1(bipartites[0], (x, h_dst)) # <---
h = F.relu(h)
h_dst = h[:bipartites[1].num_dst_nodes()] # <---
h = self.conv2(bipartites[1], (h, h_dst)) # <---
return h
model = Model(num_features, 128, num_classes).cuda()
######################################################################
# If you compare against the code in the
# :doc:`introduction <1_introduction>`, you will notice several
# differences:
#
# - **DGL GNN layers on bipartite graphs**. Instead of computing on the
# full graph:
#
# .. code:: python
#
# h = self.conv1(g, x)
#
# you only compute on the sampled bipartite graph:
#
# .. code:: python
#
# h = self.conv1(bipartites[0], (x, h_dst))
#
# All DGL’s GNN modules support message passing on bipartite graphs,
# where you supply a pair of features, one for input nodes and another
# for output nodes.
#
# - **Feature slicing for self-dependency**. There are statements that
# perform slicing to obtain the previous-layer representation of the
# output nodes:
#
# .. code:: python
#
# h_dst = x[:bipartites[0].num_dst_nodes()]
#
# ``num_dst_nodes`` method works with bipartite graphs, where it will
# return the number of output nodes.
#
# Since the first few input nodes of the yielded bipartite graph are
# always the same as the output nodes, these statements obtain the
# representations of the output nodes on the previous layer. They are
# then combined with neighbor aggregation in ``dgl.nn.SAGEConv`` layer.
#
# .. note::
#
# See the :doc:`custom message passing
# tutorial <L4_message_passing>` for more details on how to
# manipulate bipartite graphs produced in this way, such as the usage
# of ``num_dst_nodes``.
#
######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#
opt = torch.optim.Adam(model.parameters())
######################################################################
# When computing the validation score for model selection, usually you can
# also do neighbor sampling. To do that, you need to define another data
# loader.
#
valid_dataloader = dgl.dataloading.NodeDataLoader(
graph, valid_nids, sampler,
batch_size=1024,
shuffle=False,
drop_last=False,
num_workers=0
)
######################################################################
# The following is a training loop that performs validation every epoch.
# It also saves the model with the best validation accuracy into a file.
#
import tqdm
import sklearn.metrics
best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(10):
model.train()
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
# feature copy from CPU to GPU takes place here
inputs = bipartites[0].srcdata['feat']
labels = bipartites[-1].dstdata['label']
predictions = model(bipartites, inputs)
loss = F.cross_entropy(predictions, labels)
opt.zero_grad()
loss.backward()
opt.step()
accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy())
tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False)
model.eval()
predictions = []
labels = []
with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad():
for input_nodes, output_nodes, bipartites in tq:
bipartites = [b.to(torch.device('cuda')) for b in bipartites]
inputs = node_features[input_nodes].cuda()
labels.append(node_labels[output_nodes].numpy())
predictions.append(model(bipartites, inputs).argmax(1).cpu().numpy())
predictions = np.concatenate(predictions)
labels = np.concatenate(labels)
accuracy = sklearn.metrics.accuracy_score(labels, predictions)
print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy))
if best_accuracy < accuracy:
best_accuracy = accuracy
torch.save(model.state_dict(), best_model_path)
# Note that this tutorial do not train the whole model to the end.
break
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# with neighbor sampling.
#
# What’s next?
# ------------
#
# - :doc:`Stochastic training of GNN for link
# prediction <L2_large_link_prediction>`.
# - :doc:`Adapting your custom GNN module for stochastic
# training <L4_message_passing>`.
# - During inference you may wish to disable neighbor sampling. If so,
# please refer to the :ref:`user guide on exact offline
# inference <guide-minibatch-inference>`.
#
"""
Stochastic Training of GNN for Link Prediction
==============================================
This tutorial will show how to train a multi-layer GraphSAGE for link
prediction on Amazon Co-purchase Network provided by `Open Graph Benchmark
(OGB) <https://ogb.stanford.edu/>`__. The dataset
contains 2.4 million nodes and 61 million edges.
By the end of this tutorial, you will be able to
- Train a GNN model for link prediction on a single GPU with DGL's
neighbor sampling components.
This tutorial assumes that you have read the :doc:`Introduction of Neighbor
Sampling for GNN Training <L0_neighbor_sampling_overview>` and :doc:`Neighbor
Sampling for Node Classification <L1_large_node_classification>`.
"""
######################################################################
# Link Prediction Overview
# ------------------------
#
# Link prediction requires the model to predict the probability of
# existence of an edge. This tutorial does so by computing a dot product
# between the representations of both incident nodes.
#
# .. math::
#
#
# \hat{y}_{u\sim v} = \sigma(h_u^T h_v)
#
# It then minimizes the following binary cross entropy loss.
#
# .. math::
#
#
# \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right)
#
# This is identical to the link prediction formulation in :doc:`the previous
# tutorial on link prediction <4_link_predict>`.
#
######################################################################
# Loading Dataset
# ---------------
#
# This tutorial loads the dataset from the ``ogb`` package as in the
# :doc:`previous tutorial <L1_large_node_classification>`.
#
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-products')
graph, node_labels = dataset[0]
print(graph)
print(node_labels)
node_features = graph.ndata['feat']
node_labels = node_labels[:, 0]
num_features = node_features.shape[1]
num_classes = (node_labels.max() + 1).item()
print('Number of classes:', num_classes)
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
valid_nids = idx_split['valid']
test_nids = idx_split['test']
######################################################################
# Defining Neighbor Sampler and Data Loader in DGL
# ------------------------------------------------
#
# Different from the :doc:`link prediction tutorial for full
# graph <4_link_predict>`, a common practice to train GNN on large graphs is
# to iterate over the edges
# in minibatches, since computing the probability of all edges is usually
# impossible. For each minibatch of edges, you compute the output
# representation of their incident nodes using neighbor sampling and GNN,
# in a similar fashion introduced in the :doc:`large-scale node classification
# tutorial <L1_large_node_classification>`.
#
# DGL provides ``dgl.dataloading.EdgeDataLoader`` to
# iterate over edges for edge classification or link prediction tasks.
#
# To perform link prediction, you need to specify a negative sampler. DGL
# provides builtin negative samplers such as
# ``dgl.dataloading.negative_sampler.Uniform``. Here this tutorial uniformly
# draws 5 negative examples per positive example.
#
negative_sampler = dgl.dataloading.negative_sampler.Uniform(5)
######################################################################
# After defining the negative sampler, one can then define the edge data
# loader with neighbor sampling. To create an ``EdgeDataLoader`` for
# link prediction, provide a neighbor sampler object as well as the negative
# sampler object created above.
#
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.EdgeDataLoader(
# The following arguments are specific to NodeDataLoader.
graph, # The graph
torch.arange(graph.number_of_edges()), # The edges to iterate over
sampler, # The neighbor sampler
negative_sampler=negative_sampler, # The negative sampler
device='cuda', # Put the bipartite graphs on GPU
# The following arguments are inherited from PyTorch DataLoader.
batch_size=1024, # Batch size
shuffle=True, # Whether to shuffle the nodes for every epoch
drop_last=False, # Whether to drop the last incomplete batch
num_workers=0 # Number of sampler processes
)
######################################################################
# You can peek one minibatch from ``train_dataloader`` and see what it
# will give you.
#
input_nodes, pos_graph, neg_graph, bipartites = next(iter(train_dataloader))
print('Number of input nodes:', len(input_nodes))
print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges())
print('Negative graph # nodes:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges())
print(bipartites)
######################################################################
# The example minibatch consists of four elements.
#
# The first element is an ID tensor for the input nodes, i.e., nodes
# whose input features are needed on the first GNN layer for this minibatch.
#
# The second element and the third element are the positive graph and the
# negative graph for this minibatch.
# The concept of positive and negative graphs have been introduced in the
# :doc:`full-graph link prediction tutorial <4_link_predict>`. In minibatch
# training, the positive graph and the negative graph only contain nodes
# necessary for computing the pair-wise scores of positive and negative examples
# in the current minibatch.
#
# The last element is a list of bipartite graphs storing the computation
# dependencies for each GNN layer.
# The bipartite graphs are used to compute the GNN outputs of the nodes
# involved in positive/negative graph.
#
######################################################################
# Defining Model for Node Representation
# --------------------------------------
#
# The model is almost identical to the one in the :doc:`node classification
# tutorial <L1_large_node_classification>`. The only difference is
# that since you are doing link prediction, the output dimension will not
# be the number of classes in the dataset.
#
import torch.nn as nn
import torch.nn.functional as F
from dgl.nn import SAGEConv
class Model(nn.Module):
def __init__(self, in_feats, h_feats):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean')
self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean')
self.h_feats = h_feats
def forward(self, bipartites, x):
h_dst = x[:bipartites[0].num_dst_nodes()]
h = self.conv1(bipartites[0], (x, h_dst))
h = F.relu(h)
h_dst = h[:bipartites[1].num_dst_nodes()]
h = self.conv2(bipartites[1], (h, h_dst))
return h
model = Model(num_features, 128).cuda()
######################################################################
# Defining the Score Predictor for Edges
# --------------------------------------
#
# After getting the node representation necessary for the minibatch, the
# last thing to do is to predict the score of the edges and non-existent
# edges in the sampled minibatch.
#
# The following score predictor, copied from the :doc:`link prediction
# tutorial <4_link_predict>`, takes a dot product between the
# incident nodes’ representations.
#
import dgl.function as fn
class DotPredictor(nn.Module):
def forward(self, g, h):
with g.local_scope():
g.ndata['h'] = h
# Compute a new edge feature named 'score' by a dot-product between the
# source node feature 'h' and destination node feature 'h'.
g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
# u_dot_v returns a 1-element vector for each edge so you need to squeeze it.
return g.edata['score'][:, 0]
######################################################################
# Evaluating Performance (Optional)
# ---------------------------------
#
# There are various ways to evaluate the performance of link prediction.
# This tutorial follows the practice of `GraphSAGE
# paper <https://cs.stanford.edu/people/jure/pubs/graphsage-nips17.pdf>`__,
# where it treats the node embeddings learned by link prediction via
# training and evaluating a linear classifier on top of the learned node
# embeddings.
#
######################################################################
# To obtain the representations of all the nodes, this tutorial uses
# neighbor sampling as introduced in the :doc:`node classification
# tutorial <L1_large_node_classification>`.
#
# .. note::
#
# If you would like to obtain node representations without
# neighbor sampling during inference, please refer to this :ref:`user
# guide <guide-minibatch-inference>`.
#
def inference(model, graph, node_features):
with torch.no_grad():
nodes = torch.arange(graph.number_of_nodes())
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
graph, torch.arange(graph.number_of_nodes()), sampler,
batch_size=1024,
shuffle=False,
drop_last=False,
num_workers=4,
device='cuda')
result = []
for input_nodes, output_nodes, bipartites in train_dataloader:
# feature copy from CPU to GPU takes place here
inputs = bipartites[0].srcdata['feat']
result.append(model(bipartites, inputs))
return torch.cat(result)
import sklearn.metrics
def evaluate(emb, label, train_nids, valid_nids, test_nids):
classifier = nn.Linear(emb.shape[1], label.max().item()).cuda()
opt = torch.optim.LBFGS(classifier.parameters())
def compute_loss():
pred = classifier(emb[train_nids].cuda())
loss = F.cross_entropy(pred, label[train_nids].cuda())
return loss
def closure():
loss = compute_loss()
opt.zero_grad()
loss.backward()
return loss
prev_loss = float('inf')
for i in range(1000):
opt.step(closure)
with torch.no_grad():
loss = compute_loss().item()
if np.abs(loss - prev_loss) < 1e-4:
print('Converges at iteration', i)
break
else:
prev_loss = loss
with torch.no_grad():
pred = classifier(emb.cuda()).cpu()
label = label
valid_acc = sklearn.metrics.accuracy_score(label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1))
test_acc = sklearn.metrics.accuracy_score(label[test_nids].numpy(), pred[test_nids].numpy().argmax(1))
return valid_acc, test_acc
######################################################################
# Defining Training Loop
# ----------------------
#
# The following initializes the model and defines the optimizer.
#
model = Model(node_features.shape[1], 128).cuda()
predictor = DotPredictor().cuda()
opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters()))
######################################################################
# The following is the training loop for link prediction and
# evaluation, and also saves the model that performs the best on the
# validation set:
#
import tqdm
import sklearn.metrics
best_accuracy = 0
best_model_path = 'model.pt'
for epoch in range(1):
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, pos_graph, neg_graph, bipartites) in enumerate(tq):
# feature copy from CPU to GPU takes place here
inputs = bipartites[0].srcdata['feat']
outputs = model(bipartites, inputs)
pos_score = predictor(pos_graph, outputs)
neg_score = predictor(neg_graph, outputs)
score = torch.cat([pos_score, neg_score])
label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)])
loss = F.binary_cross_entropy_with_logits(score, label)
opt.zero_grad()
loss.backward()
opt.step()
tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False)
if step % 1000 == 999:
model.eval()
emb = inference(model, graph, node_features)
valid_acc, test_acc = evaluate(emb, node_labels, train_nids, valid_nids, test_nids)
print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc))
if best_accuracy < valid_acc:
best_accuracy = valid_acc
torch.save(model.state_dict(), best_model_path)
model.train()
# Note that this tutorial do not train the whole model to the end.
break
######################################################################
# Conclusion
# ----------
#
# In this tutorial, you have learned how to train a multi-layer GraphSAGE
# for link prediction with neighbor sampling.
#
"""
Writing GNN Modules for Stochastic GNN Training
===============================================
All GNN modules DGL provides support stochastic GNN training. This
tutorial teaches you how to write your own graph neural network module
for stochastic GNN training. It assumes that
1. You know :doc:`how to write GNN modules for full graph
training <3_message_passing>`.
2. You know :doc:`how stochastic GNN training pipeline
works <L1_large_node_classification>`.
"""
import dgl
import torch
import numpy as np
from ogb.nodeproppred import DglNodePropPredDataset
dataset = DglNodePropPredDataset('ogbn-products')
graph, node_labels = dataset[0]
idx_split = dataset.get_idx_split()
train_nids = idx_split['train']
node_features = graph.ndata['feat']
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
graph, train_nids, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0
)
input_nodes, output_nodes, bipartites = next(iter(train_dataloader))
######################################################################
# DGL Bipartite Graph Introduction
# --------------------------------
#
# In the previous tutorials, you have seen the concept *bipartite graph*,
# where nodes are divided into two parts.
# This section introduces how you can manipulate (directional) bipartite
# graphs.
#
# You can access the input node features and output node features via
# ``srcdata`` and ``dstdata`` attributes:
#
bipartite = bipartites[0]
print(bipartite.srcdata)
print(bipartite.dstdata)
######################################################################
# It also has ``num_src_nodes`` and ``num_dst_nodes`` functions to query
# how many input nodes and output nodes exist in the bipartite graph:
#
print(bipartite.num_src_nodes(), bipartite.num_dst_nodes())
######################################################################
# You can assign features to ``srcdata`` and ``dstdata`` just as what you
# will do with ``ndata`` on the graphs you have seen earlier:
#
bipartite.srcdata['x'] = torch.zeros(bipartite.num_src_nodes(), bipartite.num_dst_nodes())
dst_feat = bipartite.dstdata['feat']
######################################################################
# Also, since the bipartite graphs are constructed by DGL, you can
# retrieve the input node IDs (i.e. those that are required to compute the
# output) and output node IDs (i.e. those whose representations the
# current GNN layer should compute) as follows.
#
bipartite.srcdata[dgl.NID], bipartite.dstdata[dgl.NID]
######################################################################
# Writing GNN Modules for Bipartite Graphs for Stochastic Training
# ----------------------------------------------------------------
#
######################################################################
# Recall that the bipartite graphs yielded by the ``NodeDataLoader`` and
# ``EdgeDataLoader`` have the property that the first few input nodes are
# always identical to the output nodes:
#
# |image1|
#
# .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif
#
print(torch.equal(bipartite.srcdata[dgl.NID][:bipartite.num_dst_nodes()], bipartite.dstdata[dgl.NID]))
######################################################################
# Suppose you have obtained the input node representations
# :math:`h_u^{(l-1)}`:
#
bipartite.srcdata['h'] = torch.randn(bipartite.num_src_nodes(), 10)
######################################################################
# Recall that DGL provides the `update_all` interface for expressing how
# to compute messages and how to aggregate them on the nodes that receive
# them. This concept naturally applies to bipartite graphs -- message
# computation happens on the edges between source and destination nodes of
# the edges, and message aggregation happens on the destination nodes.
#
# For example, suppose the message function copies the source feature
# (i.e. :math:`M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) = h_v^{(l-1)}`),
# and the reduce function averages the received messages. Performing
# such message passing computation on a bipartite graph is no different than
# on a full graph:
#
import dgl.function as fn
bipartite.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h'))
m_v = bipartite.dstdata['h']
m_v
######################################################################
# Putting them together, you can implement a GraphSAGE convolution for
# training with neighbor sampling as follows (the differences to the :doc:`full graph
# counterpart <3_message_passing>` are highlighted with arrows ``<---``)
#
import torch.nn as nn
import torch.nn.functional as F
import tqdm
class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super(SAGEConv, self).__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h):
"""Forward computation
Parameters
----------
g : Graph
The input bipartite graph.
h : (Tensor, Tensor)
The feature of input nodes and output nodes as a pair of Tensors.
"""
with g.local_scope():
h_src, h_dst = h
g.srcdata['h'] = h_src # <---
g.dstdata['h'] = h_dst # <---
# update_all is a message passing API.
g.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'h_neigh'))
h_N = g.dstdata['h_N']
h_total = torch.cat([h_dst, h_N], dim=1) # <---
return self.linear(h_total)
class Model(nn.Module):
def __init__(self, in_feats, h_feats, num_classes):
super(Model, self).__init__()
self.conv1 = SAGEConv(in_feats, h_feats)
self.conv2 = SAGEConv(h_feats, num_classes)
def forward(self, bipartites, x):
h_dst = x[:bipartites[0].num_dst_nodes()]
h = self.conv1(bipartites[0], (x, h_dst))
h = F.relu(h)
h_dst = h[:bipartites[1].num_dst_nodes()]
h = self.conv2(bipartites[1], (h, h_dst))
return h
sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4])
train_dataloader = dgl.dataloading.NodeDataLoader(
graph, train_nids, sampler,
batch_size=1024,
shuffle=True,
drop_last=False,
num_workers=0
)
model = Model(graph.ndata['feat'].shape[1], 128, dataset.num_classes).cuda()
with tqdm.tqdm(train_dataloader) as tq:
for step, (input_nodes, output_nodes, bipartites) in enumerate(tq):
bipartites = [b.to(torch.device('cuda')) for b in bipartites]
inputs = node_features[input_nodes].cuda()
labels = node_labels[output_nodes].cuda()
predictions = model(bipartites, inputs)
######################################################################
# Both ``update_all`` and the functions in ``nn.functional`` namespace
# support bipartite graphs, so you can migrate the code working for small
# graphs to large graph training with minimal changes introduced above.
#
######################################################################
# Writing GNN Modules for Both Full-graph Training and Stochastic Training
# ------------------------------------------------------------------------
#
# Here is a step-by-step tutorial for writing a GNN module for both
# :doc:`full-graph training <1_introduction>` *and* :doc:`stochastic
# training <L1_node_classification>`.
#
# Say you start with a GNN module that works for full-graph training only:
#
class SAGEConv(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h):
"""Forward computation
Parameters
----------
g : Graph
The input graph.
h : Tensor
The input node feature.
"""
with g.local_scope():
g.ndata['h'] = h
# update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
h_N = g.ndata['h_N']
h_total = torch.cat([h, h_N], dim=1)
return self.linear(h_total)
######################################################################
# **First step**: Check whether the input feature is a single tensor or a
# pair of tensors:
#
# .. code:: python
#
# if isinstance(h, tuple):
# h_src, h_dst = h
# else:
# h_src = h_dst = h
#
# **Second step**: Replace node features ``h`` with ``h_src`` or
# ``h_dst``, and assign the node features to ``srcdata`` or ``dstdata``,
# instead of ``ndata``.
#
# Whether to assign to ``srcdata`` or ``dstdata`` depends on whether the
# said feature acts as the features on source nodes or destination nodes
# of the edges in the message functions (in ``update_all`` or
# ``apply_edges``).
#
# *Example 1*: For the following ``update_all`` statement:
#
# .. code:: python
#
# g.ndata['h'] = h
# g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
#
# The node feature ``h`` acts as source node feature because ``'h'``
# appeared as source node feature. So you will need to replace ``h`` with
# source feature ``h_src`` and assign to ``srcdata`` for the version that
# works with both cases:
#
# .. code:: python
#
# g.srcdata['h'] = h_src
# g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
#
# *Example 2*: For the following ``apply_edges`` statement:
#
# .. code:: python
#
# g.ndata['h'] = h
# g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
#
# The node feature ``h`` acts as both source node feature and destination
# node feature. So you will assign ``h_src`` to ``srcdata`` and ``h_dst``
# to ``dstdata``:
#
# .. code:: python
#
# g.srcdata['h'] = h_src
# g.dstdata['h'] = h_dst
# # The first 'h' corresponds to source feature (u) while the second 'h' corresponds to destination feature (v).
# g.apply_edges(fn.u_dot_v('h', 'h', 'score'))
#
# .. note::
#
# For homogeneous graphs (i.e. graphs with only one node type
# and one edge type), ``srcdata`` and ``dstdata`` are aliases of
# ``ndata``. So you can safely replace ``ndata`` with ``srcdata`` and
# ``dstdata`` even for full-graph training.
#
# **Third step**: Replace the ``ndata`` for outputs with ``dstdata``.
#
# For example, the following code
#
# .. code:: python
#
# # Assume that update_all() function has been called with output node features in `h_N`.
# h_N = g.ndata['h_N']
# h_total = torch.cat([h, h_N], dim=1)
#
# will change to
#
# .. code:: python
#
# h_N = g.dstdata['h_N']
# h_total = torch.cat([h_dst, h_N], dim=1)
#
######################################################################
# Putting together, you will change the ``SAGEConvForBoth`` module above
# to something like the following:
#
class SAGEConvForBoth(nn.Module):
"""Graph convolution module used by the GraphSAGE model.
Parameters
----------
in_feat : int
Input feature size.
out_feat : int
Output feature size.
"""
def __init__(self, in_feat, out_feat):
super().__init__()
# A linear submodule for projecting the input and neighbor feature to the output.
self.linear = nn.Linear(in_feat * 2, out_feat)
def forward(self, g, h):
"""Forward computation
Parameters
----------
g : Graph
The input graph.
h : Tensor or tuple[Tensor, Tensor]
The input node feature.
"""
with g.local_scope():
if isinstance(h, tuple):
h_src, h_dst = h
else:
h_src = h_dst = h
g.srcdata['h'] = h_src
# update_all is a message passing API.
g.update_all(message_func=fn.copy_u('h', 'm'), reduce_func=fn.mean('m', 'h_N'))
h_N = g.ndata['h_N']
h_total = torch.cat([h_dst, h_N], dim=1)
return self.linear(h_total)
...@@ -244,6 +244,11 @@ class NodeDataLoader: ...@@ -244,6 +244,11 @@ class NodeDataLoader:
collate_fn=self.collator.collate, collate_fn=self.collator.collate,
**dataloader_kwargs) **dataloader_kwargs)
self.is_distributed = False self.is_distributed = False
# Precompute the CSR and CSC representations so each subprocess does not
# duplicate.
if dataloader_kwargs.get('num_workers', 0) > 0:
g.create_formats_()
self.device = device self.device = device
def __iter__(self): def __iter__(self):
...@@ -438,6 +443,11 @@ class EdgeDataLoader: ...@@ -438,6 +443,11 @@ class EdgeDataLoader:
self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs) self.collator.dataset, collate_fn=self.collator.collate, **dataloader_kwargs)
self.device = device self.device = device
# Precompute the CSR and CSC representations so each subprocess does not
# duplicate.
if dataloader_kwargs.get('num_workers', 0) > 0:
g.create_formats_()
def __iter__(self): def __iter__(self):
"""Return the iterator of the data loader.""" """Return the iterator of the data loader."""
return _EdgeDataLoaderIter(self) return _EdgeDataLoaderIter(self)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment