dgl.function.rst 8.1 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
.. _apifunction:

3
.. currentmodule:: dgl.function
Minjie Wang's avatar
Minjie Wang committed
4

5
6
7
dgl.function
==================================

Da Zheng's avatar
Da Zheng committed
8
9
10
In DGL, message passing is mainly expressed by ``update_all(message_func, reduce_func)``.
This API computes messages on all edges and sends to the destination nodes; the nodes
that receive messages perform aggregation and update their own node data.
11

Da Zheng's avatar
Da Zheng committed
12
Internally, DGL fuses the message generation and aggregation into one kernel so no
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
explicit messages are generated and stored. To achieve this, we recommend using our **built-in
message and reduce functions** so that DGL can analyze and map them to fused dedicated kernels. Here
are some examples (in PyTorch syntax).

.. code:: python
   
   import dgl
   import dgl.function as fn
   import torch as th
   g = ... # create a DGLGraph
   g.ndata['h'] = th.randn((g.number_of_nodes(), 10)) # each node has feature size 10
   g.edata['w'] = th.randn((g.number_of_edges(), 1))  # each edge has feature size 1
   # collect features from source nodes and aggregate them in destination nodes
   g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
   # multiply source node features with edge weights and aggregate them in destination nodes
   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))
   # compute edge embedding by multiplying source and destination node embeddings
   g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))

``fn.copy_u``, ``fn.u_mul_e``, ``fn.u_mul_v`` are built-in message functions, while ``fn.sum``
and ``fn.max`` are built-in reduce functions. We use ``u``, ``v`` and ``e`` to represent
source nodes, destination nodes, and edges among them, respectively. Hence, ``copy_u`` copies the source
node data as the messages, ``u_mul_e`` multiplies source node features with edge features, for example.

To define a unary message function (e.g. ``copy_u``) specify one input feature name and one output
message name. To define a binary message function (e.g. ``u_mul_e``) specify
two input feature names and one output message name. During the computation,
the message function will read the data under the given names, perform computation, and return
the output using the output name. For example, the above ``fn.u_mul_e('h', 'w', 'm')`` is
the same as the following user-defined function:

.. code:: python

   def udf_u_mul_e(edges):
      return {'m' : edges.src['h'] * edges.data['w']}

To define a reduce function, one input message name and one output node feature name
need to be specified. For example, the above ``fn.max('m', 'h_max')`` is the same as the
following user-defined function:

.. code:: python

   def udf_max(nodes):
      return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}

Broadcasting is supported for binary message function, which means the tensor arguments
can be automatically expanded to be of equal sizes. The supported broadcasting semantic
is standard and matches `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
and `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. If you are not familiar
with broadcasting, see the linked topics to learn more. In the
above example, ``fn.u_mul_e`` will perform broadcasted multiplication automatically because
the node feature ``'h'`` and the edge feature ``'w'`` are of different shapes, but they can be broadcast.

All DGL's built-in functions support both CPU and GPU and backward computation so they
can be used in any `autograd` system. Also, built-in functions can be used not only in ``update_all``
or ``apply_edges`` as shown in the example, but wherever message and reduce functions are
required (e.g. ``pull``, ``push``, ``send_and_recv``).

71
72
73
74
75
.. _api-built-in:

DGL Built-in Function
-------------------------

76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
Here is a cheatsheet of all the DGL built-in functions.

+-------------------------+-----------------------------------------------------------------+-----------------------+
| Category                | Functions                                                       | Memo                  |
+=========================+=================================================================+=======================+
| Unary message function  | ``copy_u``                                                      |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``copy_e``                                                      |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``copy_src``                                                    |  alias of ``copy_u``  |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``copy_edge``                                                   |  alias of ``copy_e``  |
+-------------------------+-----------------------------------------------------------------+-----------------------+
| Binary message function | ``u_add_v``, ``u_sub_v``, ``u_mul_v``, ``u_div_v``, ``u_dot_v`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``u_add_e``, ``u_sub_e``, ``u_mul_e``, ``u_div_e``, ``u_dot_e`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``v_add_u``, ``v_sub_u``, ``v_mul_u``, ``v_div_u``, ``v_dot_u`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``v_add_e``, ``v_sub_e``, ``v_mul_e``, ``v_div_e``, ``v_dot_e`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``e_add_u``, ``e_sub_u``, ``e_mul_u``, ``e_div_u``, ``e_dot_u`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v``, ``e_dot_v`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``src_mul_edge``                                                |  alias of ``u_mul_e`` |
+-------------------------+-----------------------------------------------------------------+-----------------------+
| Reduce function         | ``max``                                                         |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``min``                                                         |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``sum``                                                         |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``mean``                                                        |                       |
+-------------------------+-----------------------------------------------------------------+-----------------------+
Minjie Wang's avatar
Minjie Wang committed
111
112
113
114
115
116
117
118
119
120

Message functions
-----------------

.. autosummary::
    :toctree: ../../generated/

    copy_src
    copy_edge
    src_mul_edge
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
    copy_u
    copy_e
    u_add_v
    u_sub_v
    u_mul_v
    u_div_v
    u_add_e
    u_sub_e
    u_mul_e
    u_div_e
    v_add_u
    v_sub_u
    v_mul_u
    v_div_u
    v_add_e
    v_sub_e
    v_mul_e
    v_div_e
    e_add_u
    e_sub_u
    e_mul_u
    e_div_u
    e_add_v
    e_sub_v
    e_mul_v
    e_div_v
147
148
149
150
151
152
    u_dot_v
    u_dot_e
    v_dot_e
    v_dot_u
    e_dot_u
    e_dot_v
Minjie Wang's avatar
Minjie Wang committed
153
154
155
156
157
158
159
160
161

Reduce functions
----------------

.. autosummary::
    :toctree: ../../generated/

    sum
    max
162
163
    min
    mean