Unverified Commit 4ca706e1 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Doc] Update doc of dgl.function (#2585)

* update doc of dgl.function

* small fix
parent 6ce2dc34
......@@ -5,14 +5,13 @@
dgl.function
==================================
In DGL, message passing is mainly expressed by ``update_all(message_func, reduce_func)``.
This API computes messages on all edges and sends to the destination nodes; the nodes
that receive messages perform aggregation and update their own node data.
Internally, DGL fuses the message generation and aggregation into one kernel so no
explicit messages are generated and stored. To achieve this, we recommend using our **built-in
message and reduce functions** so that DGL can analyze and map them to fused dedicated kernels. Here
are some examples (in PyTorch syntax).
This subpackage hosts all the **built-in functions** provided by DGL. Built-in functions
are DGL's recommended way to express different types of ref:`guide-message-passing` computation
(i.e., via :func:`~dgl.DGLGraph.update_all`) or computing edge-wise features from
node-wise features (i.e., via :func:`~dgl.DGLGraph.apply_edges`). Built-in functions
describe the node-wise and edge-wise computation in a symbolic way without any
actual computation, so DGL can analyze and map them to efficient low-level kernels.
Here are some examples:
.. code:: python
......@@ -20,8 +19,8 @@ are some examples (in PyTorch syntax).
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph
g.ndata['h'] = th.randn((g.number_of_nodes(), 10)) # each node has feature size 10
g.edata['w'] = th.randn((g.number_of_edges(), 1)) # each edge has feature size 1
g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10
g.edata['w'] = th.randn((g.num_edges(), 1)) # each edge has feature size 1
# collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
# multiply source node features with edge weights and aggregate them in destination nodes
......@@ -30,12 +29,13 @@ are some examples (in PyTorch syntax).
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))
``fn.copy_u``, ``fn.u_mul_e``, ``fn.u_mul_v`` are built-in message functions, while ``fn.sum``
and ``fn.max`` are built-in reduce functions. We use ``u``, ``v`` and ``e`` to represent
source nodes, destination nodes, and edges among them, respectively. Hence, ``copy_u`` copies the source
node data as the messages, ``u_mul_e`` multiplies source node features with edge features, for example.
and ``fn.max`` are built-in reduce functions. DGL's convention is to use ``u``, ``v``
and ``e`` to represent source nodes, destination nodes, and edges, respectively.
For example, ``copy_u`` tells DGL to copy the source node data as the messages;
``u_mul_e`` tells DGL to multiply source node features with edge features.
To define a unary message function (e.g. ``copy_u``) specify one input feature name and one output
message name. To define a binary message function (e.g. ``u_mul_e``) specify
To define a unary message function (e.g. ``copy_u``), specify one input feature name and one output
message name. To define a binary message function (e.g. ``u_mul_e``), specify
two input feature names and one output message name. During the computation,
the message function will read the data under the given names, perform computation, and return
the output using the output name. For example, the above ``fn.u_mul_e('h', 'w', 'm')`` is
......@@ -55,18 +55,54 @@ following user-defined function:
def udf_max(nodes):
return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}
Broadcasting is supported for binary message function, which means the tensor arguments
can be automatically expanded to be of equal sizes. The supported broadcasting semantic
is standard and matches `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
and `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. If you are not familiar
with broadcasting, see the linked topics to learn more. In the
above example, ``fn.u_mul_e`` will perform broadcasted multiplication automatically because
the node feature ``'h'`` and the edge feature ``'w'`` are of different shapes, but they can be broadcast.
All DGL's built-in functions support both CPU and GPU and backward computation so they
can be used in any `autograd` system. Also, built-in functions can be used not only in ``update_all``
or ``apply_edges`` as shown in the example, but wherever message and reduce functions are
required (e.g. ``pull``, ``push``, ``send_and_recv``).
All binary message function supports **broadcasting**, a mechansim for extending element-wise
operations to tensor inputs with different shapes. DGL generally follows the standard
broadcasting semantic by `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
and `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. Below are some
examples:
.. code:: python
import dgl
import dgl.function as fn
import torch as th
g = ... # create a DGLGraph
# case 1
g.ndata['h'] = th.randn((g.num_nodes(), 10))
g.edata['w'] = th.randn((g.num_edges(), 1))
# OK, valid broadcasting between feature shapes (10,) and (1,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new'] # shape: (g.num_nodes(), 10)
# case 2
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 10))
# OK, valid broadcasting between feature shapes (5, 10) and (10,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new'] # shape: (g.num_nodes(), 5, 10)
# case 3
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 5))
# NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)
# shapes are aligned from right
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
# case 3
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))
# OK, valid broadcasting between feature shapes (1, 10) and (5, 1)
g.apply_edges(fn.u_add_v('h1', 'h2', 'x')) # apply_edges also supports broadcasting
g.edata['x'] # shape: (g.num_edges(), 5, 10)
# case 4
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))
# OK, u_dot_v supports broadcasting but requires the last dimension to match
g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))
g.edata['x'] # shape: (g.num_edges(), 5, 10, 1)
.. _api-built-in:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment