Unverified Commit 4ca706e1 authored by Minjie Wang's avatar Minjie Wang Committed by GitHub
Browse files

[Doc] Update doc of dgl.function (#2585)

* update doc of dgl.function

* small fix
parent 6ce2dc34
...@@ -5,14 +5,13 @@ ...@@ -5,14 +5,13 @@
dgl.function dgl.function
================================== ==================================
In DGL, message passing is mainly expressed by ``update_all(message_func, reduce_func)``. This subpackage hosts all the **built-in functions** provided by DGL. Built-in functions
This API computes messages on all edges and sends to the destination nodes; the nodes are DGL's recommended way to express different types of ref:`guide-message-passing` computation
that receive messages perform aggregation and update their own node data. (i.e., via :func:`~dgl.DGLGraph.update_all`) or computing edge-wise features from
node-wise features (i.e., via :func:`~dgl.DGLGraph.apply_edges`). Built-in functions
Internally, DGL fuses the message generation and aggregation into one kernel so no describe the node-wise and edge-wise computation in a symbolic way without any
explicit messages are generated and stored. To achieve this, we recommend using our **built-in actual computation, so DGL can analyze and map them to efficient low-level kernels.
message and reduce functions** so that DGL can analyze and map them to fused dedicated kernels. Here Here are some examples:
are some examples (in PyTorch syntax).
.. code:: python .. code:: python
...@@ -20,8 +19,8 @@ are some examples (in PyTorch syntax). ...@@ -20,8 +19,8 @@ are some examples (in PyTorch syntax).
import dgl.function as fn import dgl.function as fn
import torch as th import torch as th
g = ... # create a DGLGraph g = ... # create a DGLGraph
g.ndata['h'] = th.randn((g.number_of_nodes(), 10)) # each node has feature size 10 g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10
g.edata['w'] = th.randn((g.number_of_edges(), 1)) # each edge has feature size 1 g.edata['w'] = th.randn((g.num_edges(), 1)) # each edge has feature size 1
# collect features from source nodes and aggregate them in destination nodes # collect features from source nodes and aggregate them in destination nodes
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum')) g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
# multiply source node features with edge weights and aggregate them in destination nodes # multiply source node features with edge weights and aggregate them in destination nodes
...@@ -30,12 +29,13 @@ are some examples (in PyTorch syntax). ...@@ -30,12 +29,13 @@ are some examples (in PyTorch syntax).
g.apply_edges(fn.u_mul_v('h', 'h', 'w_new')) g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))
``fn.copy_u``, ``fn.u_mul_e``, ``fn.u_mul_v`` are built-in message functions, while ``fn.sum`` ``fn.copy_u``, ``fn.u_mul_e``, ``fn.u_mul_v`` are built-in message functions, while ``fn.sum``
and ``fn.max`` are built-in reduce functions. We use ``u``, ``v`` and ``e`` to represent and ``fn.max`` are built-in reduce functions. DGL's convention is to use ``u``, ``v``
source nodes, destination nodes, and edges among them, respectively. Hence, ``copy_u`` copies the source and ``e`` to represent source nodes, destination nodes, and edges, respectively.
node data as the messages, ``u_mul_e`` multiplies source node features with edge features, for example. For example, ``copy_u`` tells DGL to copy the source node data as the messages;
``u_mul_e`` tells DGL to multiply source node features with edge features.
To define a unary message function (e.g. ``copy_u``) specify one input feature name and one output To define a unary message function (e.g. ``copy_u``), specify one input feature name and one output
message name. To define a binary message function (e.g. ``u_mul_e``) specify message name. To define a binary message function (e.g. ``u_mul_e``), specify
two input feature names and one output message name. During the computation, two input feature names and one output message name. During the computation,
the message function will read the data under the given names, perform computation, and return the message function will read the data under the given names, perform computation, and return
the output using the output name. For example, the above ``fn.u_mul_e('h', 'w', 'm')`` is the output using the output name. For example, the above ``fn.u_mul_e('h', 'w', 'm')`` is
...@@ -55,18 +55,54 @@ following user-defined function: ...@@ -55,18 +55,54 @@ following user-defined function:
def udf_max(nodes): def udf_max(nodes):
return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]} return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}
Broadcasting is supported for binary message function, which means the tensor arguments All binary message function supports **broadcasting**, a mechansim for extending element-wise
can be automatically expanded to be of equal sizes. The supported broadcasting semantic operations to tensor inputs with different shapes. DGL generally follows the standard
is standard and matches `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_ broadcasting semantic by `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
and `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. If you are not familiar and `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. Below are some
with broadcasting, see the linked topics to learn more. In the examples:
above example, ``fn.u_mul_e`` will perform broadcasted multiplication automatically because
the node feature ``'h'`` and the edge feature ``'w'`` are of different shapes, but they can be broadcast. .. code:: python
All DGL's built-in functions support both CPU and GPU and backward computation so they import dgl
can be used in any `autograd` system. Also, built-in functions can be used not only in ``update_all`` import dgl.function as fn
or ``apply_edges`` as shown in the example, but wherever message and reduce functions are import torch as th
required (e.g. ``pull``, ``push``, ``send_and_recv``). g = ... # create a DGLGraph
# case 1
g.ndata['h'] = th.randn((g.num_nodes(), 10))
g.edata['w'] = th.randn((g.num_edges(), 1))
# OK, valid broadcasting between feature shapes (10,) and (1,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new'] # shape: (g.num_nodes(), 10)
# case 2
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 10))
# OK, valid broadcasting between feature shapes (5, 10) and (10,)
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
g.ndata['h_new'] # shape: (g.num_nodes(), 5, 10)
# case 3
g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
g.edata['w'] = th.randn((g.num_edges(), 5))
# NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)
# shapes are aligned from right
g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
# case 3
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))
# OK, valid broadcasting between feature shapes (1, 10) and (5, 1)
g.apply_edges(fn.u_add_v('h1', 'h2', 'x')) # apply_edges also supports broadcasting
g.edata['x'] # shape: (g.num_edges(), 5, 10)
# case 4
g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))
g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))
# OK, u_dot_v supports broadcasting but requires the last dimension to match
g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))
g.edata['x'] # shape: (g.num_edges(), 5, 10, 1)
.. _api-built-in: .. _api-built-in:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment