"""
.. _tutorial-mp:
Message passing on graph
========================
**Author**: `Minjie Wang `_
Many of the graph-based deep neural networks are based on *"message passing"* --
nodes compute messages that are sent to others and the features are updated
using the messages. In this tutorial, we introduce the basic mechanism of message
passing in DGL.
"""
###############################################################################
# Let us start by import DGL and create an example graph used throughput this
# tutorial. The graph has 10 nodes, with node#0 be the source and node#9 be the
# sink. The source node (node#0) connects to all other nodes besides the sink
# node. Similarly, the sink node is connected by all other nodes besides the
# source node. We also initialize the feature vector of the source node to be
# all one, while the others have features of all zero.
# The code to create such graph is as follows (using pytorch syntax):
import dgl
import torch as th
g = dgl.DGLGraph()
g.add_nodes(10)
g.add_edges(0, list(range(1, 9)))
g.add_edges(list(range(1, 9)), 9)
# TODO(minjie): plot the graph here.
N = g.number_of_nodes()
M = g.number_of_edges()
print('#Nodes:', N)
print('#Edges:', M)
# initialize the node features
D = 1 # feature size
g.set_n_repr({'feat' : th.zeros((N, D))})
g.set_n_repr({'feat' : th.ones((1, D))}, 0)
print(g.get_n_repr()['feat'])
###############################################################################
# User-defined functions and high-level APIs
# ------------------------------------------
#
# There are two core components in DGL's message passing programming model:
#
# * **User-defined functions (UDFs)** on how the messages are computed and used.
# * **High-level APIs** on who are sending messages to whom and are being updated.
#
# For example, one simple user-defined message function can be as follows:
def send_source(src, edge):
return {'msg' : src['feat']}
###############################################################################
# The above function computes the messages over **a batch of edges**.
# It has two arguments: `src` for source node features and
# `edge` for the edge features, and it returns the messages computed. The argument
# and return type is dictionary from the feature/message name to tensor values.
# We can trigger this function using out ``send`` API:
g.send(0, 1, message_func=send_source)
###############################################################################
# Here, the message is computed using the feature of node#0. The result message
# (on 0->1) is not returned but directly saved in ``DGLGraph`` for the later
# receive phase.
#
# You can send multiple messages at once using the
# :ref:`multi-edge semantics `.
# In such case, the source node and edge features are batched on the first dimension.
# You can simply print out the shape of the feature tensor in your message
# function.
def send_source_print(src, edge):
print('src feat shape:', src['feat'].shape)
return {'msg' : src['feat']}
g.send(0, [4, 5, 6], message_func=send_source_print)
###############################################################################
# To receive and aggregate in-coming messages, user can define a reduce function
# that operators on **a batch of nodes**.
def simple_reduce(node, msgs):
return {'feat' : th.sum(msgs['msg'], dim=1)}
###############################################################################
# The reduce function has two arguments: ``node`` for the node features and
# ``msgs`` for the in-coming messages. It returns the updated node features.
# The function can be triggered using the ``recv`` API. Again, DGL support
# receive messages for multiple nodes at the same time. In such case, the
# node features are batched on the first dimension. Because each node can
# receive different number of in-coming messages, we divide the receiving
# nodes into buckets based on their numbers of receiving messages. As a result,
# the message tensor has at least three dimensions (B, n, D), where the second
# dimension concats all the messages for each node together. This also means
# the reduce UDF will be called for each bucket. You can simply print out
# the shape of the message tensor as follows:
def simple_reduce_print(node, msgs):
print('msg shape:', msgs['msg'].shape)
return {'feat' : th.sum(msgs['msg'], dim=1)}
g.recv([1, 4, 5, 6], reduce_func=simple_reduce_print)
print(g.get_n_repr()['feat'])
###############################################################################
# You can see that, after send and recv, the value of node#0 has been propagated
# to node 1, 4, 5 and 6.
###############################################################################
# DGL message passing APIs
# ------------------------
#
# TODO(minjie): enable backreference for all the mentioned APIs below.
#
# In DGL, we categorize the message passing APIs into three levels. All of them
# can be configured using UDFs such as the message and reduce functions.
#
# **Level-1 routines:** APIs that trigger computation on either a batch of nodes
# or a batch of edges. This includes:
#
# * ``send(u, v)`` and ``recv(v)``
# * ``update_edge(u, v)``: This updates the edge features using the current edge
# features and the source and destination nodes features.
# * ``apply_nodes(v)``: This transforms the node features using the current node
# features.
# * ``apply_edges(u, v)``: This transforms the edge features using the current edge
# features.
###############################################################################
# **Level-2 routines:** APIs that combines several level-1 routines.
#
# * ``send_and_recv(u, v)``: This first computes messages over u->v, then reduce
# them on v. An optional node apply function can be provided.
# * ``pull(v)``: This computes the messages over all the in-edges of v, then reduce
# them on v. An optional node apply function can be provided.
# * ``push(v)``: This computes the messages over all the out-edges of v, then
# reduce them on the successors. An optional node apply function can be provided.
# * ``update_all()``: Send out and reduce messages on every node. An optional node
# apply function can be provided.
#
# The following example uses ``send_and_recv`` to continue propagate signals to the
# sink node#9:
g.send_and_recv([1, 4, 5, 6], 9, message_func=send_source, reduce_func=simple_reduce)
print(g.get_n_repr()['feat'])
###############################################################################
# **Level-3 routines:** APIs that calls multiple level-2 routines.
#
# * ``propagate()``: TBD after Yu's traversal PR.
###############################################################################
# Builtin functions
# -----------------
#
# Since many message and reduce UDFs are very common (such as sending source
# node features as the message and aggregating messages using summation), DGL
# actually provides builtin functions that can be directly used:
import dgl.function as fn
g.send_and_recv(0, [2, 3], fn.copy_src(src='feat', out='msg'), fn.sum(msg='msg', out='feat'))
print(g.get_n_repr()['feat'])
###############################################################################
# TODO(minjie): document on multiple builtin function syntax after Lingfan
# finished his change.
###############################################################################
# Using builtin functions not only saves your time in writing codes, but also
# allows DGL to use more efficient implementation automatically. To see this,
# you can continue to our tutorial on Graph Convolutional Network.
# TODO(minjie): need a hyperref to the GCN tutorial here.