Unverified Commit 05a43379 authored by Zihao Ye's avatar Zihao Ye Committed by GitHub
Browse files

[Doc] Introduce the relation of message passing APIs and operators in docstring. (#1878)

* upd

* fix typo

* upd
parent 2fa2b453
......@@ -229,3 +229,46 @@ The following is an example showing how GSDDMM works:
copy_v
Like GSpMM, GSDDMM operators support both homograph and bipartite graph.
Relation with Message Passing APIs
----------------------------------
``dgl.update_all`` and ``dgl.apply_edges`` calls with built-in message/reduce functions
would be dispatched into function calls of operators defined in ``dgl.ops``:
>>> import dgl
>>> import torch as th
>>> import dgl.ops as F
>>> import dgl.function as fn
>>> g = dgl.rand_graph(100, 1000) # create a DGLGraph with 100 nodes and 1000 edges.
>>> x = th.rand(100, 20) # node features.
>>> e = th.rand(1000, 20)
>>>
>>> # dgl.update_all + builtin functions
>>> g.srcdata['x'] = x # srcdata is the same as ndata for graphs with one node type.
>>> g.edata['e'] = e
>>> g.update_all(fn.u_mul_e('x', 'e', 'm'), fn.sum('m', 'y'))
>>> y = g.dstdata['y'] # dstdata is the same as ndata for graphs with one node type.
>>>
>>> # use GSpMM operators defined in dgl.ops directly
>>> y = F.u_mul_e_sum(g, x, e)
>>>
>>> # dgl.apply_edges + builtin functions
>>> g.srcdata['x'] = x
>>> g.dstdata['y'] = y
>>> g.apply_edges(fn.u_dot_v('x', 'y', 'z'))
>>> z = g.edata['z']
>>>
>>> # use GSDDMM operators defined in dgl.ops directly
>>> z = F.u_dot_v(g, x, y)
It up to user to decide whether to use message-passing APIs or GSpMM/GSDDMM operators, and both
of them have the same efficiency. Programs written in message-passing APIs look more like DGL-style
but in some cases calling GSpMM/GSDDMM operators is more concise (e.g. `edge_softmax
<https://github.com/dmlc/dgl/blob/master/python/dgl/nn/pytorch/softmax.py/>`_ function
provided by dgl).
Note that on PyTorch all operators defined in ``dgl.ops`` support higher-order gradients, so as
message passing APIs because they entirely depend on these operators.
......@@ -16,7 +16,8 @@ def gsddmm(g, op, lhs_data, rhs_data, lhs_target='u', rhs_target='v'):
where :math:`x_{e}` is the returned feature on edges and :math:`x_u`,
:math:`x_v` refers to :attr:`u`, :attr:`v` respectively. :math:`\phi`
is the binary operator :attr:`op`, and :math:`\mathcal{G}` is the graph
we apply gsddmm on: :attr:`g`. $lhs$ and $rhs$ are one of $u,v,e$'s.
we apply gsddmm on: :attr:`g`. :math:`lhs` and :math:`rhs` are one of
:math:`u,v,e`'s.
Parameters
----------
......
......@@ -8,13 +8,14 @@ __all__ = ['gspmm']
def gspmm(g, op, reduce_op, lhs_data, rhs_data):
r""" Generalized Sparse Matrix Multiplication interface.
It fuses two steps into one kernel.
(1) Computes messages by :attr:`op` source node and edge features.
(2) Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
1. Computes messages by :attr:`op` source node and edge features.
2. Aggregate the messages by :attr:`reduce_op` as the features on destination nodes.
.. math::
x_v = \psi_{(u, v, e)\in \mathcal{G}}(\rho(x_u, x_e))
where :math:`x_v` is the returned feature on destination nodes, and :math`x_u`,
where :math:`x_v` is the returned feature on destination nodes, and :math:`x_u`,
:math:`x_e` refers to :attr:`u`, :attr:`e` respectively. :math:`\rho` means binary
operator :attr:`op` and :math:`\psi` means reduce operator :attr:`reduce_op`,
:math:`\mathcal{G}` is the graph we apply gspmm on: :attr:`g`.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment