dgl.function.rst 9.08 KB
Newer Older
Minjie Wang's avatar
Minjie Wang committed
1
2
.. _apifunction:

3
.. currentmodule:: dgl.function
Minjie Wang's avatar
Minjie Wang committed
4

5
6
7
dgl.function
==================================

8
This subpackage hosts all the **built-in functions** provided by DGL. Built-in functions
9
are DGL's recommended way to express different types of :ref:`guide-message-passing` computation
10
11
12
13
14
(i.e., via :func:`~dgl.DGLGraph.update_all`) or computing edge-wise features from
node-wise features (i.e., via :func:`~dgl.DGLGraph.apply_edges`). Built-in functions
describe the node-wise and edge-wise computation in a symbolic way without any
actual computation, so DGL can analyze and map them to efficient low-level kernels.
Here are some examples:
15
16

.. code:: python
17

18
19
20
21
   import dgl
   import dgl.function as fn
   import torch as th
   g = ... # create a DGLGraph
22
23
   g.ndata['h'] = th.randn((g.num_nodes(), 10)) # each node has feature size 10
   g.edata['w'] = th.randn((g.num_edges(), 1))  # each edge has feature size 1
24
25
26
27
28
29
30
31
   # collect features from source nodes and aggregate them in destination nodes
   g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h_sum'))
   # multiply source node features with edge weights and aggregate them in destination nodes
   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.max('m', 'h_max'))
   # compute edge embedding by multiplying source and destination node embeddings
   g.apply_edges(fn.u_mul_v('h', 'h', 'w_new'))

``fn.copy_u``, ``fn.u_mul_e``, ``fn.u_mul_v`` are built-in message functions, while ``fn.sum``
32
33
34
35
and ``fn.max`` are built-in reduce functions. DGL's convention is to use ``u``, ``v``
and ``e`` to represent source nodes, destination nodes, and edges, respectively.
For example, ``copy_u`` tells DGL to copy the source node data as the messages;
``u_mul_e`` tells DGL to multiply source node features with edge features.
36

37
38
To define a unary message function (e.g. ``copy_u``), specify one input feature name and one output
message name. To define a binary message function (e.g. ``u_mul_e``), specify
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
two input feature names and one output message name. During the computation,
the message function will read the data under the given names, perform computation, and return
the output using the output name. For example, the above ``fn.u_mul_e('h', 'w', 'm')`` is
the same as the following user-defined function:

.. code:: python

   def udf_u_mul_e(edges):
      return {'m' : edges.src['h'] * edges.data['w']}

To define a reduce function, one input message name and one output node feature name
need to be specified. For example, the above ``fn.max('m', 'h_max')`` is the same as the
following user-defined function:

.. code:: python

   def udf_max(nodes):
      return {'h_max' : th.max(nodes.mailbox['m'], 1)[0]}

58
All binary message function supports **broadcasting**, a mechanism for extending element-wise
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
operations to tensor inputs with different shapes. DGL generally follows the standard
broadcasting semantic by `NumPy <https://docs.scipy.org/doc/numpy/user/basics.broadcasting.html>`_
and `PyTorch <https://pytorch.org/docs/stable/notes/broadcasting.html>`_. Below are some
examples:

.. code:: python

   import dgl
   import dgl.function as fn
   import torch as th
   g = ... # create a DGLGraph

   # case 1
   g.ndata['h'] = th.randn((g.num_nodes(), 10))
   g.edata['w'] = th.randn((g.num_edges(), 1))
   # OK, valid broadcasting between feature shapes (10,) and (1,)
   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
   g.ndata['h_new']  # shape: (g.num_nodes(), 10)

   # case 2
   g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
   g.edata['w'] = th.randn((g.num_edges(), 10))
   # OK, valid broadcasting between feature shapes (5, 10) and (10,)
   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))
   g.ndata['h_new']  # shape: (g.num_nodes(), 5, 10)

   # case 3
   g.ndata['h'] = th.randn((g.num_nodes(), 5, 10))
   g.edata['w'] = th.randn((g.num_edges(), 5))
   # NOT OK, invalid broadcasting between feature shapes (5, 10) and (5,)
   # shapes are aligned from right
   g.update_all(fn.u_mul_e('h', 'w', 'm'), fn.sum('m', 'h_new'))

   # case 3
   g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10))
   g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1))
   # OK, valid broadcasting between feature shapes (1, 10) and (5, 1)
   g.apply_edges(fn.u_add_v('h1', 'h2', 'x'))  # apply_edges also supports broadcasting
   g.edata['x']  # shape: (g.num_edges(), 5, 10)

   # case 4
   g.ndata['h1'] = th.randn((g.num_nodes(), 1, 10, 128))
   g.ndata['h2'] = th.randn((g.num_nodes(), 5, 1, 128))
   # OK, u_dot_v supports broadcasting but requires the last dimension to match
   g.apply_edges(fn.u_dot_v('h1', 'h2', 'x'))
   g.edata['x']  # shape: (g.num_edges(), 5, 10, 1)

106

107
108
109
110
111
.. _api-built-in:

DGL Built-in Function
-------------------------

112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
Here is a cheatsheet of all the DGL built-in functions.

+-------------------------+-----------------------------------------------------------------+-----------------------+
| Category                | Functions                                                       | Memo                  |
+=========================+=================================================================+=======================+
| Unary message function  | ``copy_u``                                                      |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``copy_e``                                                      |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``copy_src``                                                    |  alias of ``copy_u``  |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``copy_edge``                                                   |  alias of ``copy_e``  |
+-------------------------+-----------------------------------------------------------------+-----------------------+
| Binary message function | ``u_add_v``, ``u_sub_v``, ``u_mul_v``, ``u_div_v``, ``u_dot_v`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``u_add_e``, ``u_sub_e``, ``u_mul_e``, ``u_div_e``, ``u_dot_e`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``v_add_u``, ``v_sub_u``, ``v_mul_u``, ``v_div_u``, ``v_dot_u`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``v_add_e``, ``v_sub_e``, ``v_mul_e``, ``v_div_e``, ``v_dot_e`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``e_add_u``, ``e_sub_u``, ``e_mul_u``, ``e_div_u``, ``e_dot_u`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``e_add_v``, ``e_sub_v``, ``e_mul_v``, ``e_div_v``, ``e_dot_v`` |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``src_mul_edge``                                                |  alias of ``u_mul_e`` |
+-------------------------+-----------------------------------------------------------------+-----------------------+
| Reduce function         | ``max``                                                         |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``min``                                                         |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``sum``                                                         |                       |
|                         +-----------------------------------------------------------------+-----------------------+
|                         | ``mean``                                                        |                       |
+-------------------------+-----------------------------------------------------------------+-----------------------+
Minjie Wang's avatar
Minjie Wang committed
147
148
149
150
151
152
153
154
155
156

Message functions
-----------------

.. autosummary::
    :toctree: ../../generated/

    copy_src
    copy_edge
    src_mul_edge
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
    copy_u
    copy_e
    u_add_v
    u_sub_v
    u_mul_v
    u_div_v
    u_add_e
    u_sub_e
    u_mul_e
    u_div_e
    v_add_u
    v_sub_u
    v_mul_u
    v_div_u
    v_add_e
    v_sub_e
    v_mul_e
    v_div_e
    e_add_u
    e_sub_u
    e_mul_u
    e_div_u
    e_add_v
    e_sub_v
    e_mul_v
    e_div_v
183
184
185
186
187
188
    u_dot_v
    u_dot_e
    v_dot_e
    v_dot_u
    e_dot_u
    e_dot_v
Minjie Wang's avatar
Minjie Wang committed
189
190
191
192
193
194
195
196
197

Reduce functions
----------------

.. autosummary::
    :toctree: ../../generated/

    sum
    max
198
199
    min
    mean