""" Write your own GNN module ========================= Sometimes, your model goes beyond simply stacking existing GNN modules. For example, you would like to invent a new way of aggregating neighbor information by considering node importance or edge weights. By the end of this tutorial you will be able to - Understand DGL’s message passing APIs. - Implement GraphSAGE convolution module by your own. This tutorial assumes that you already know :doc:`the basics of training a GNN for node classification <1_introduction>`. (Time estimate: 10 minutes) """ import os os.environ['DGLBACKEND'] = 'pytorch' import torch import torch.nn as nn import torch.nn.functional as F import dgl import dgl.function as fn ###################################################################### # Message passing and GNNs # ------------------------ # # DGL follows the *message passing paradigm* inspired by the Message # Passing Neural Network proposed by `Gilmer et # al. `__ Essentially, they found many # GNN models can fit into the following framework: # # .. math:: # # # m_{u\to v}^{(l)} = M^{(l)}\left(h_v^{(l-1)}, h_u^{(l-1)}, e_{u\to v}^{(l-1)}\right) # # .. math:: # # # m_{v}^{(l)} = \sum_{u\in\mathcal{N}(v)}m_{u\to v}^{(l)} # # .. math:: # # # h_v^{(l)} = U^{(l)}\left(h_v^{(l-1)}, m_v^{(l)}\right) # # where DGL calls :math:`M^{(l)}` the *message function*, :math:`\sum` the # *reduce function* and :math:`U^{(l)}` the *update function*. Note that # :math:`\sum` here can represent any function and is not necessarily a # summation. # ###################################################################### # For example, the `GraphSAGE convolution (Hamilton et al., # 2017) `__ # takes the following mathematical form: # # .. math:: # # # h_{\mathcal{N}(v)}^k\leftarrow \text{Average}\{h_u^{k-1},\forall u\in\mathcal{N}(v)\} # # .. math:: # # # h_v^k\leftarrow \text{ReLU}\left(W^k\cdot \text{CONCAT}(h_v^{k-1}, h_{\mathcal{N}(v)}^k) \right) # # You can see that message passing is directional: the message sent from # one node :math:`u` to other node :math:`v` is not necessarily the same # as the other message sent from node :math:`v` to node :math:`u` in the # opposite direction. # # Although DGL has builtin support of GraphSAGE via # :class:`dgl.nn.SAGEConv `, # here is how you can implement GraphSAGE convolution in DGL by your own. # class SAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(SAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. """ with g.local_scope(): g.ndata["h"] = h # update_all is a message passing API. g.update_all( message_func=fn.copy_u("h", "m"), reduce_func=fn.mean("m", "h_N"), ) h_N = g.ndata["h_N"] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) ###################################################################### # The central piece in this code is the # :func:`g.update_all ` # function, which gathers and averages the neighbor features. There are # three concepts here: # # * Message function ``fn.copy_u('h', 'm')`` that # copies the node feature under name ``'h'`` as *messages* with name # ``'m'`` sent to neighbors. # # * Reduce function ``fn.mean('m', 'h_N')`` that averages # all the received messages under name ``'m'`` and saves the result as a # new node feature ``'h_N'``. # # * ``update_all`` tells DGL to trigger the # message and reduce functions for all the nodes and edges. # ###################################################################### # Afterwards, you can stack your own GraphSAGE convolution layers to form # a multi-layer GraphSAGE network. # class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = SAGEConv(in_feats, h_feats) self.conv2 = SAGEConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat) h = F.relu(h) h = self.conv2(g, h) return h ###################################################################### # Training loop # ~~~~~~~~~~~~~ # The following code for data loading and training loop is directly copied # from the introduction tutorial. # import dgl.data dataset = dgl.data.CoraGraphDataset() g = dataset[0] def train(g, model): optimizer = torch.optim.Adam(model.parameters(), lr=0.01) all_logits = [] best_val_acc = 0 best_test_acc = 0 features = g.ndata["feat"] labels = g.ndata["label"] train_mask = g.ndata["train_mask"] val_mask = g.ndata["val_mask"] test_mask = g.ndata["test_mask"] for e in range(200): # Forward logits = model(g, features) # Compute prediction pred = logits.argmax(1) # Compute loss # Note that we should only compute the losses of the nodes in the training set, # i.e. with train_mask 1. loss = F.cross_entropy(logits[train_mask], labels[train_mask]) # Compute accuracy on training/validation/test train_acc = (pred[train_mask] == labels[train_mask]).float().mean() val_acc = (pred[val_mask] == labels[val_mask]).float().mean() test_acc = (pred[test_mask] == labels[test_mask]).float().mean() # Save the best validation accuracy and the corresponding test accuracy. if best_val_acc < val_acc: best_val_acc = val_acc best_test_acc = test_acc # Backward optimizer.zero_grad() loss.backward() optimizer.step() all_logits.append(logits.detach()) if e % 5 == 0: print( "In epoch {}, loss: {:.3f}, val acc: {:.3f} (best {:.3f}), test acc: {:.3f} (best {:.3f})".format( e, loss, val_acc, best_val_acc, test_acc, best_test_acc ) ) model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) ###################################################################### # More customization # ------------------ # # In DGL, we provide many built-in message and reduce functions under the # ``dgl.function`` package. You can find more details in :ref:`the API # doc `. # ###################################################################### # These APIs allow one to quickly implement new graph convolution modules. # For example, the following implements a new ``SAGEConv`` that aggregates # neighbor representations using a weighted average. Note that ``edata`` # member can hold edge features which can also take part in message # passing. # class WeightedSAGEConv(nn.Module): """Graph convolution module used by the GraphSAGE model with edge weights. Parameters ---------- in_feat : int Input feature size. out_feat : int Output feature size. """ def __init__(self, in_feat, out_feat): super(WeightedSAGEConv, self).__init__() # A linear submodule for projecting the input and neighbor feature to the output. self.linear = nn.Linear(in_feat * 2, out_feat) def forward(self, g, h, w): """Forward computation Parameters ---------- g : Graph The input graph. h : Tensor The input node feature. w : Tensor The edge weight. """ with g.local_scope(): g.ndata["h"] = h g.edata["w"] = w g.update_all( message_func=fn.u_mul_e("h", "w", "m"), reduce_func=fn.mean("m", "h_N"), ) h_N = g.ndata["h_N"] h_total = torch.cat([h, h_N], dim=1) return self.linear(h_total) ###################################################################### # Because the graph in this dataset does not have edge weights, we # manually assign all edge weights to one in the ``forward()`` function of # the model. You can replace it with your own edge weights. # class Model(nn.Module): def __init__(self, in_feats, h_feats, num_classes): super(Model, self).__init__() self.conv1 = WeightedSAGEConv(in_feats, h_feats) self.conv2 = WeightedSAGEConv(h_feats, num_classes) def forward(self, g, in_feat): h = self.conv1(g, in_feat, torch.ones(g.num_edges(), 1).to(g.device)) h = F.relu(h) h = self.conv2(g, h, torch.ones(g.num_edges(), 1).to(g.device)) return h model = Model(g.ndata["feat"].shape[1], 16, dataset.num_classes) train(g, model) ###################################################################### # Even more customization by user-defined function # ------------------------------------------------ # # DGL allows user-defined message and reduce function for the maximal # expressiveness. Here is a user-defined message function that is # equivalent to ``fn.u_mul_e('h', 'w', 'm')``. # def u_mul_e_udf(edges): return {"m": edges.src["h"] * edges.data["w"]} ###################################################################### # ``edges`` has three members: ``src``, ``data`` and ``dst``, representing # the source node feature, edge feature, and destination node feature for # all edges. # ###################################################################### # You can also write your own reduce function. For example, the following # is equivalent to the builtin ``fn.mean('m', 'h_N')`` function that averages # the incoming messages: # def mean_udf(nodes): return {"h_N": nodes.mailbox["m"].mean(1)} ###################################################################### # In short, DGL will group the nodes by their in-degrees, and for each # group DGL stacks the incoming messages along the second dimension. You # can then perform a reduction along the second dimension to aggregate # messages. # # For more details on customizing message and reduce function with # user-defined function, please refer to the :ref:`API # reference `. # ###################################################################### # Best practice of writing custom GNN modules # ------------------------------------------- # # DGL recommends the following practice ranked by preference: # # - Use ``dgl.nn`` modules. # - Use ``dgl.nn.functional`` functions which contain lower-level complex # operations such as computing a softmax for each node over incoming # edges. # - Use ``update_all`` with builtin message and reduce functions. # - Use user-defined message or reduce functions. # ###################################################################### # What’s next? # ------------ # # - :ref:`Writing Efficient Message Passing # Code `. # # Thumbnail credits: Representation Learning on Networks, Jure Leskovec, WWW 2018 # sphinx_gallery_thumbnail_path = '_static/blitz_3_message_passing.png'