""" Training GNN with Neighbor Sampling for Node Classification =========================================================== This tutorial shows how to train a multi-layer GraphSAGE for node classification on Amazon Co-purchase Network provided by `Open Graph Benchmark (OGB) `__. The dataset contains 2.4 million nodes and 61 million edges. By the end of this tutorial, you will be able to - Train a GNN model for node classification on a single GPU with DGL's neighbor sampling components. This tutorial assumes that you have read the :doc:`Introduction of Neighbor Sampling for GNN Training `. """ ###################################################################### # Loading Dataset # --------------- # # OGB already prepared the data as DGL graph. # import dgl import torch import numpy as np from ogb.nodeproppred import DglNodePropPredDataset dataset = DglNodePropPredDataset('ogbn-products') ###################################################################### # OGB dataset is a collection of graphs and their labels. The Amazon # Co-purchase Network dataset only contains a single graph. So you can # simply get the graph and its node labels like this: # graph, node_labels = dataset[0] graph.ndata['label'] = node_labels[:, 0] print(graph) print(node_labels) node_features = graph.ndata['feat'] num_features = node_features.shape[1] num_classes = (node_labels.max() + 1).item() print('Number of classes:', num_classes) ###################################################################### # You can get the training-validation-test split of the nodes with # ``get_split_idx`` method. # idx_split = dataset.get_idx_split() train_nids = idx_split['train'] valid_nids = idx_split['valid'] test_nids = idx_split['test'] ###################################################################### # How DGL Handles Computation Dependency # -------------------------------------- # # In the :doc:`previous tutorial `, you # have seen that the computation dependency for message passing of a # single node can be described as a series of bipartite graphs. # # |image1| # # .. |image1| image:: https://data.dgl.ai/tutorial/img/bipartite.gif # ###################################################################### # Defining Neighbor Sampler and Data Loader in DGL # ------------------------------------------------ # # DGL provides tools to iterate over the dataset in minibatches # while generating the computation dependencies to compute their outputs # with the bipartite graphs above. For node classification, you can use # ``dgl.dataloading.NodeDataLoader`` for iterating over the dataset. # It accepts a sampler object to control how to generate the computation # dependencies in the form of bipartite graphs. DGL provides # implementations of common sampling algorithms such as # ``dgl.dataloading.MultiLayerNeighborSampler`` which randomly picks # a fixed number of neighbors for each node. # # .. note:: # # To write your own neighbor sampler, please refer to :ref:`this user # guide section `. # # The syntax of ``dgl.dataloading.NodeDataLoader`` is mostly similar to a # PyTorch ``DataLoader``, with the addition that it needs a graph to # generate computation dependency from, a set of node IDs to iterate on, # and the neighbor sampler you defined. # # Let’s say that each node will gather messages from 4 neighbors on each # layer. The code defining the data loader and neighbor sampler will look # like the following. # sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) train_dataloader = dgl.dataloading.NodeDataLoader( # The following arguments are specific to NodeDataLoader. graph, # The graph train_nids, # The node IDs to iterate over in minibatches sampler, # The neighbor sampler device='cuda', # Put the sampled bipartite graphs to GPU # The following arguments are inherited from PyTorch DataLoader. batch_size=1024, # Batch size shuffle=True, # Whether to shuffle the nodes for every epoch drop_last=False, # Whether to drop the last incomplete batch num_workers=0 # Number of sampler processes ) ###################################################################### # You can iterate over the data loader and see what it yields. # input_nodes, output_nodes, bipartites = example_minibatch = next(iter(train_dataloader)) print(example_minibatch) print("To compute {} nodes' outputs, we need {} nodes' input features".format(len(output_nodes), len(input_nodes))) ###################################################################### # ``NodeDataLoader`` gives us three items per iteration. # # - An ID tensor for the input nodes, i.e., nodes whose input features # are needed on the first GNN layer for this minibatch. # - An ID tensor for the output nodes, i.e. nodes whose representations # are to be computed. # - A list of bipartite graphs storing the computation dependencies # for each GNN layer. # ###################################################################### # You can get the input and output node IDs of the bipartite graphs # and verify that the first few input nodes are always the same as the output # nodes. As we described in the :doc:`overview `, # output nodes' own features from the previous layer may also be necessary in # the computation of the new features. # bipartite_0_src = bipartites[0].srcdata[dgl.NID] bipartite_0_dst = bipartites[0].dstdata[dgl.NID] print(bipartite_0_src) print(bipartite_0_dst) print(torch.equal(bipartite_0_src[:bipartites[0].num_dst_nodes()], bipartite_0_dst)) ###################################################################### # Defining Model # -------------- # # Let’s consider training a 2-layer GraphSAGE with neighbor sampling. The # model can be written as follows: # import torch.nn as nn import torch.nn.functional as F from dgl.nn import SAGEConv class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean') self.conv2 = SAGEConv(h_feats, num_classes, aggregator_type='mean') self.h_feats = h_feats def forward(self, bipartites, x): # Lines that are changed are marked with an arrow: "<---" h_dst = x[:bipartites[0].num_dst_nodes()] # <--- h = self.conv1(bipartites[0], (x, h_dst)) # <--- h = F.relu(h) h_dst = h[:bipartites[1].num_dst_nodes()] # <--- h = self.conv2(bipartites[1], (h, h_dst)) # <--- return h model = Model(num_features, 128, num_classes).cuda() ###################################################################### # If you compare against the code in the # :doc:`introduction <1_introduction>`, you will notice several # differences: # # - **DGL GNN layers on bipartite graphs**. Instead of computing on the # full graph: # # .. code:: python # # h = self.conv1(g, x) # # you only compute on the sampled bipartite graph: # # .. code:: python # # h = self.conv1(bipartites[0], (x, h_dst)) # # All DGL’s GNN modules support message passing on bipartite graphs, # where you supply a pair of features, one for input nodes and another # for output nodes. # # - **Feature slicing for self-dependency**. There are statements that # perform slicing to obtain the previous-layer representation of the # output nodes: # # .. code:: python # # h_dst = x[:bipartites[0].num_dst_nodes()] # # ``num_dst_nodes`` method works with bipartite graphs, where it will # return the number of output nodes. # # Since the first few input nodes of the yielded bipartite graph are # always the same as the output nodes, these statements obtain the # representations of the output nodes on the previous layer. They are # then combined with neighbor aggregation in ``dgl.nn.SAGEConv`` layer. # # .. note:: # # See the :doc:`custom message passing # tutorial ` for more details on how to # manipulate bipartite graphs produced in this way, such as the usage # of ``num_dst_nodes``. # ###################################################################### # Defining Training Loop # ---------------------- # # The following initializes the model and defines the optimizer. # opt = torch.optim.Adam(model.parameters()) ###################################################################### # When computing the validation score for model selection, usually you can # also do neighbor sampling. To do that, you need to define another data # loader. # valid_dataloader = dgl.dataloading.NodeDataLoader( graph, valid_nids, sampler, batch_size=1024, shuffle=False, drop_last=False, num_workers=0 ) ###################################################################### # The following is a training loop that performs validation every epoch. # It also saves the model with the best validation accuracy into a file. # import tqdm import sklearn.metrics best_accuracy = 0 best_model_path = 'model.pt' for epoch in range(10): model.train() with tqdm.tqdm(train_dataloader) as tq: for step, (input_nodes, output_nodes, bipartites) in enumerate(tq): # feature copy from CPU to GPU takes place here inputs = bipartites[0].srcdata['feat'] labels = bipartites[-1].dstdata['label'] predictions = model(bipartites, inputs) loss = F.cross_entropy(predictions, labels) opt.zero_grad() loss.backward() opt.step() accuracy = sklearn.metrics.accuracy_score(labels.cpu().numpy(), predictions.argmax(1).detach().cpu().numpy()) tq.set_postfix({'loss': '%.03f' % loss.item(), 'acc': '%.03f' % accuracy}, refresh=False) model.eval() predictions = [] labels = [] with tqdm.tqdm(valid_dataloader) as tq, torch.no_grad(): for input_nodes, output_nodes, bipartites in tq: bipartites = [b.to(torch.device('cuda')) for b in bipartites] inputs = node_features[input_nodes].cuda() labels.append(node_labels[output_nodes].numpy()) predictions.append(model(bipartites, inputs).argmax(1).cpu().numpy()) predictions = np.concatenate(predictions) labels = np.concatenate(labels) accuracy = sklearn.metrics.accuracy_score(labels, predictions) print('Epoch {} Validation Accuracy {}'.format(epoch, accuracy)) if best_accuracy < accuracy: best_accuracy = accuracy torch.save(model.state_dict(), best_model_path) # Note that this tutorial do not train the whole model to the end. break ###################################################################### # Conclusion # ---------- # # In this tutorial, you have learned how to train a multi-layer GraphSAGE # with neighbor sampling. # # What’s next? # ------------ # # - :doc:`Stochastic training of GNN for link # prediction `. # - :doc:`Adapting your custom GNN module for stochastic # training `. # - During inference you may wish to disable neighbor sampling. If so, # please refer to the :ref:`user guide on exact offline # inference `. #