""" Stochastic Training of GNN for Link Prediction ============================================== This tutorial will show how to train a multi-layer GraphSAGE for link prediction on Amazon Co-purchase Network provided by `Open Graph Benchmark (OGB) `__. The dataset contains 2.4 million nodes and 61 million edges. By the end of this tutorial, you will be able to - Train a GNN model for link prediction on a single GPU with DGL's neighbor sampling components. This tutorial assumes that you have read the :doc:`Introduction of Neighbor Sampling for GNN Training ` and :doc:`Neighbor Sampling for Node Classification `. """ ###################################################################### # Link Prediction Overview # ------------------------ # # Link prediction requires the model to predict the probability of # existence of an edge. This tutorial does so by computing a dot product # between the representations of both incident nodes. # # .. math:: # # # \hat{y}_{u\sim v} = \sigma(h_u^T h_v) # # It then minimizes the following binary cross entropy loss. # # .. math:: # # # \mathcal{L} = -\sum_{u\sim v\in \mathcal{D}}\left( y_{u\sim v}\log(\hat{y}_{u\sim v}) + (1-y_{u\sim v})\log(1-\hat{y}_{u\sim v})) \right) # # This is identical to the link prediction formulation in :doc:`the previous # tutorial on link prediction <4_link_predict>`. # ###################################################################### # Loading Dataset # --------------- # # This tutorial loads the dataset from the ``ogb`` package as in the # :doc:`previous tutorial `. # import dgl import torch import numpy as np from ogb.nodeproppred import DglNodePropPredDataset dataset = DglNodePropPredDataset('ogbn-products') graph, node_labels = dataset[0] print(graph) print(node_labels) node_features = graph.ndata['feat'] node_labels = node_labels[:, 0] num_features = node_features.shape[1] num_classes = (node_labels.max() + 1).item() print('Number of classes:', num_classes) idx_split = dataset.get_idx_split() train_nids = idx_split['train'] valid_nids = idx_split['valid'] test_nids = idx_split['test'] ###################################################################### # Defining Neighbor Sampler and Data Loader in DGL # ------------------------------------------------ # # Different from the :doc:`link prediction tutorial for full # graph <4_link_predict>`, a common practice to train GNN on large graphs is # to iterate over the edges # in minibatches, since computing the probability of all edges is usually # impossible. For each minibatch of edges, you compute the output # representation of their incident nodes using neighbor sampling and GNN, # in a similar fashion introduced in the :doc:`large-scale node classification # tutorial `. # # DGL provides ``dgl.dataloading.EdgeDataLoader`` to # iterate over edges for edge classification or link prediction tasks. # # To perform link prediction, you need to specify a negative sampler. DGL # provides builtin negative samplers such as # ``dgl.dataloading.negative_sampler.Uniform``. Here this tutorial uniformly # draws 5 negative examples per positive example. # negative_sampler = dgl.dataloading.negative_sampler.Uniform(5) ###################################################################### # After defining the negative sampler, one can then define the edge data # loader with neighbor sampling. To create an ``EdgeDataLoader`` for # link prediction, provide a neighbor sampler object as well as the negative # sampler object created above. # sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) train_dataloader = dgl.dataloading.EdgeDataLoader( # The following arguments are specific to NodeDataLoader. graph, # The graph torch.arange(graph.number_of_edges()), # The edges to iterate over sampler, # The neighbor sampler negative_sampler=negative_sampler, # The negative sampler device='cuda', # Put the bipartite graphs on GPU # The following arguments are inherited from PyTorch DataLoader. batch_size=1024, # Batch size shuffle=True, # Whether to shuffle the nodes for every epoch drop_last=False, # Whether to drop the last incomplete batch num_workers=0 # Number of sampler processes ) ###################################################################### # You can peek one minibatch from ``train_dataloader`` and see what it # will give you. # input_nodes, pos_graph, neg_graph, bipartites = next(iter(train_dataloader)) print('Number of input nodes:', len(input_nodes)) print('Positive graph # nodes:', pos_graph.number_of_nodes(), '# edges:', pos_graph.number_of_edges()) print('Negative graph # nodes:', neg_graph.number_of_nodes(), '# edges:', neg_graph.number_of_edges()) print(bipartites) ###################################################################### # The example minibatch consists of four elements. # # The first element is an ID tensor for the input nodes, i.e., nodes # whose input features are needed on the first GNN layer for this minibatch. # # The second element and the third element are the positive graph and the # negative graph for this minibatch. # The concept of positive and negative graphs have been introduced in the # :doc:`full-graph link prediction tutorial <4_link_predict>`. In minibatch # training, the positive graph and the negative graph only contain nodes # necessary for computing the pair-wise scores of positive and negative examples # in the current minibatch. # # The last element is a list of bipartite graphs storing the computation # dependencies for each GNN layer. # The bipartite graphs are used to compute the GNN outputs of the nodes # involved in positive/negative graph. # ###################################################################### # Defining Model for Node Representation # -------------------------------------- # # The model is almost identical to the one in the :doc:`node classification # tutorial `. The only difference is # that since you are doing link prediction, the output dimension will not # be the number of classes in the dataset. # import torch.nn as nn import torch.nn.functional as F from dgl.nn import SAGEConv class Model(nn.Module): def __init__(self, in_feats, h_feats): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats, aggregator_type='mean') self.conv2 = SAGEConv(h_feats, h_feats, aggregator_type='mean') self.h_feats = h_feats def forward(self, bipartites, x): h_dst = x[:bipartites[0].num_dst_nodes()] h = self.conv1(bipartites[0], (x, h_dst)) h = F.relu(h) h_dst = h[:bipartites[1].num_dst_nodes()] h = self.conv2(bipartites[1], (h, h_dst)) return h model = Model(num_features, 128).cuda() ###################################################################### # Defining the Score Predictor for Edges # -------------------------------------- # # After getting the node representation necessary for the minibatch, the # last thing to do is to predict the score of the edges and non-existent # edges in the sampled minibatch. # # The following score predictor, copied from the :doc:`link prediction # tutorial <4_link_predict>`, takes a dot product between the # incident nodes’ representations. # import dgl.function as fn class DotPredictor(nn.Module): def forward(self, g, h): with g.local_scope(): g.ndata['h'] = h # Compute a new edge feature named 'score' by a dot-product between the # source node feature 'h' and destination node feature 'h'. g.apply_edges(fn.u_dot_v('h', 'h', 'score')) # u_dot_v returns a 1-element vector for each edge so you need to squeeze it. return g.edata['score'][:, 0] ###################################################################### # Evaluating Performance (Optional) # --------------------------------- # # There are various ways to evaluate the performance of link prediction. # This tutorial follows the practice of `GraphSAGE # paper `__, # where it treats the node embeddings learned by link prediction via # training and evaluating a linear classifier on top of the learned node # embeddings. # ###################################################################### # To obtain the representations of all the nodes, this tutorial uses # neighbor sampling as introduced in the :doc:`node classification # tutorial `. # # .. note:: # # If you would like to obtain node representations without # neighbor sampling during inference, please refer to this :ref:`user # guide `. # def inference(model, graph, node_features): with torch.no_grad(): nodes = torch.arange(graph.number_of_nodes()) sampler = dgl.dataloading.MultiLayerNeighborSampler([4, 4]) train_dataloader = dgl.dataloading.NodeDataLoader( graph, torch.arange(graph.number_of_nodes()), sampler, batch_size=1024, shuffle=False, drop_last=False, num_workers=4, device='cuda') result = [] for input_nodes, output_nodes, bipartites in train_dataloader: # feature copy from CPU to GPU takes place here inputs = bipartites[0].srcdata['feat'] result.append(model(bipartites, inputs)) return torch.cat(result) import sklearn.metrics def evaluate(emb, label, train_nids, valid_nids, test_nids): classifier = nn.Linear(emb.shape[1], label.max().item()).cuda() opt = torch.optim.LBFGS(classifier.parameters()) def compute_loss(): pred = classifier(emb[train_nids].cuda()) loss = F.cross_entropy(pred, label[train_nids].cuda()) return loss def closure(): loss = compute_loss() opt.zero_grad() loss.backward() return loss prev_loss = float('inf') for i in range(1000): opt.step(closure) with torch.no_grad(): loss = compute_loss().item() if np.abs(loss - prev_loss) < 1e-4: print('Converges at iteration', i) break else: prev_loss = loss with torch.no_grad(): pred = classifier(emb.cuda()).cpu() label = label valid_acc = sklearn.metrics.accuracy_score(label[valid_nids].numpy(), pred[valid_nids].numpy().argmax(1)) test_acc = sklearn.metrics.accuracy_score(label[test_nids].numpy(), pred[test_nids].numpy().argmax(1)) return valid_acc, test_acc ###################################################################### # Defining Training Loop # ---------------------- # # The following initializes the model and defines the optimizer. # model = Model(node_features.shape[1], 128).cuda() predictor = DotPredictor().cuda() opt = torch.optim.Adam(list(model.parameters()) + list(predictor.parameters())) ###################################################################### # The following is the training loop for link prediction and # evaluation, and also saves the model that performs the best on the # validation set: # import tqdm import sklearn.metrics best_accuracy = 0 best_model_path = 'model.pt' for epoch in range(1): with tqdm.tqdm(train_dataloader) as tq: for step, (input_nodes, pos_graph, neg_graph, bipartites) in enumerate(tq): # feature copy from CPU to GPU takes place here inputs = bipartites[0].srcdata['feat'] outputs = model(bipartites, inputs) pos_score = predictor(pos_graph, outputs) neg_score = predictor(neg_graph, outputs) score = torch.cat([pos_score, neg_score]) label = torch.cat([torch.ones_like(pos_score), torch.zeros_like(neg_score)]) loss = F.binary_cross_entropy_with_logits(score, label) opt.zero_grad() loss.backward() opt.step() tq.set_postfix({'loss': '%.03f' % loss.item()}, refresh=False) if step % 1000 == 999: model.eval() emb = inference(model, graph, node_features) valid_acc, test_acc = evaluate(emb, node_labels, train_nids, valid_nids, test_nids) print('Epoch {} Validation Accuracy {} Test Accuracy {}'.format(epoch, valid_acc, test_acc)) if best_accuracy < valid_acc: best_accuracy = valid_acc torch.save(model.state_dict(), best_model_path) model.train() # Note that this tutorial do not train the whole model to the end. break ###################################################################### # Conclusion # ---------- # # In this tutorial, you have learned how to train a multi-layer GraphSAGE # for link prediction with neighbor sampling. #