"src/vscode:/vscode.git/clone" did not exist on "7d6f30e89ba3460dd26235c298c54d2ddb9d1590"
nn.rst 14.7 KB
Newer Older
1
2
.. _guide-nn:

3
4
Chapter 3: Building GNN Modules
=====================================
5
6

DGL NN module is the building block for your GNN model. It inherents
7
from `Pytorch’s NN Module <https://pytorch.org/docs/1.2.0/_modules/torch/nn/modules/module.html>`__, `MXNet Gluon’s NN Block  <http://mxnet.incubator.apache.org/versions/1.6/api/python/docs/api/gluon/nn/index.html>`__ and `TensorFlow’s Keras
8
Layer <https://www.tensorflow.org/api_docs/python/tf/keras/layers>`__, depending on the DNN framework backend in use. In DGL NN
9
10
11
12
13
14
15
module, the parameter registration in construction function and tensor
operation in forward function are the same with the backend framework.
In this way, DGL code can be seamlessly integrated into the backend
framework code. The major difference lies in the message passing
operations that are unique in DGL.

DGL has integrated many commonly used
16
:ref:`apinn-pytorch-conv`, :ref:`apinn-pytorch-dense-conv`, :ref:`apinn-pytorch-pooling`,
17
and
18
:ref:`apinn-pytorch-util`. We welcome your contribution!
19
20

In this section, we will use
21
:class:`~dgl.nn.pytorch.conv.SAGEConv`
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
with Pytorch backend as an example to introduce how to build your own
DGL NN Module.

DGL NN Module Construction Function
-----------------------------------

The construction function will do the following:

1. Set options.
2. Register learnable paramesters or submodules.
3. Reset parameters.

.. code::

    import torch as th
    from torch import nn
    from torch.nn import init

    from .... import function as fn
    from ....base import DGLError
    from ....utils import expand_as_pair, check_eq_shape

    class SAGEConv(nn.Module):
        def __init__(self,
                     in_feats,
                     out_feats,
                     aggregator_type,
                     bias=True,
                     norm=None,
51
                     activation=None):
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
            super(SAGEConv, self).__init__()

            self._in_src_feats, self._in_dst_feats = expand_as_pair(in_feats)
            self._out_feats = out_feats
            self._aggre_type = aggregator_type
            self.norm = norm
            self.activation = activation

In construction function, we first need to set the data dimensions. For
general Pytorch module, the dimensions are usually input dimension,
output dimension and hidden dimensions. For graph neural, the input
dimension can be split into source node dimension and destination node
dimension.

Besides data dimensions, a typical option for graph neural network is
aggregation type (``self._aggre_type``). Aggregation type determines how
messages on different edges are aggregated for a certain destination
node. Commonly used aggregation types include ``mean``, ``sum``,
``max``, ``min``. Some modules may apply more complicated aggregation
like a ``lstm``.

``norm`` here is a callable function for feature normalization. On the
SAGEConv paper, such normalization can be l2 norm:
:math:`h_v = h_v / \lVert h_v \rVert_2`.

.. code::

            # aggregator type: mean, max_pool, lstm, gcn
            if aggregator_type not in ['mean', 'max_pool', 'lstm', 'gcn']:
                raise KeyError('Aggregator type {} not supported.'.format(aggregator_type))
            if aggregator_type == 'max_pool':
                self.fc_pool = nn.Linear(self._in_src_feats, self._in_src_feats)
            if aggregator_type == 'lstm':
                self.lstm = nn.LSTM(self._in_src_feats, self._in_src_feats, batch_first=True)
            if aggregator_type in ['mean', 'max_pool', 'lstm']:
                self.fc_self = nn.Linear(self._in_dst_feats, out_feats, bias=bias)
            self.fc_neigh = nn.Linear(self._in_src_feats, out_feats, bias=bias)
            self.reset_parameters()

Register parameters and submodules. In SAGEConv, submodules vary
according to the aggregation type. Those modules are pure Pytorch nn
modules like ``nn.Linear``, ``nn.LSTM``, etc. At the end of construction
function, weight initialization is applied by calling
``reset_parameters()``.

.. code::

        def reset_parameters(self):
            """Reinitialize learnable parameters."""
            gain = nn.init.calculate_gain('relu')
            if self._aggre_type == 'max_pool':
                nn.init.xavier_uniform_(self.fc_pool.weight, gain=gain)
            if self._aggre_type == 'lstm':
                self.lstm.reset_parameters()
            if self._aggre_type != 'gcn':
                nn.init.xavier_uniform_(self.fc_self.weight, gain=gain)
            nn.init.xavier_uniform_(self.fc_neigh.weight, gain=gain)

DGL NN Module Forward Function
----------------------------------

In NN module, ``forward()`` function does the actual message passing and
computating. Compared with Pytorch’s NN module which usually takes
tensors as the parameters, DGL NN module takes an additional parameter
116
:class:`dgl.DGLGraph`. The
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
workload for ``forward()`` function can be splitted into three parts:

-  Graph checking and graph type specification.

-  Message passing and reducing.

-  Update feature after reducing for output.

Let’s dive deep into the ``forward()`` function in SAGEConv example.

Graph checking and graph type specification
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code::

        def forward(self, graph, feat):
            with graph.local_scope():
                # Specify graph type then expand input feature according to graph type
                feat_src, feat_dst = expand_as_pair(feat, graph)

``forward()`` needs to handle many corner cases on the input that can
138
lead to invalid values in computing and message passing. One typical check in conv modules like :class:`~dgl.nn.pytorch.conv.GraphConv` is to verify no 0-in-degree node in the input graph. When a node has 0-in-degree, the ``mailbox`` will be empty and the reduce function will produce all-zero values. This may cause silent regression in model performance. However, in :class:`~dgl.nn.pytorch.conv.SAGEConv` module, the aggregated representation will be concatenated with the original node feature, the output of ``forward()`` will not be all-zero. No such check is needed in this case.
139
140

DGL NN module should be reusable across different types of graph input
141
142
143
including: homogeneous graph, heterogeneous
graph (:ref:`guide-graph-heterogeneous`), subgraph
block (:ref:`guide-minibatch`).
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165

The math formulas for SAGEConv are:

.. math::


   h_{\mathcal{N}(dst)}^{(l+1)}  = \mathrm{aggregate}
           \left(\{h_{src}^{l}, \forall src \in \mathcal{N}(dst) \}\right)

.. math::

    h_{dst}^{(l+1)} = \sigma \left(W \cdot \mathrm{concat}
           (h_{dst}^{l}, h_{\mathcal{N}(dst)}^{l+1} + b) \right)

.. math::

    h_{dst}^{(l+1)} = \mathrm{norm}(h_{dst}^{l})

We need to specify the source node feature ``feat_src`` and destination
node feature ``feat_dst`` according to the graph type. The function to
specify the graph type and expand ``feat`` into ``feat_src`` and
``feat_dst`` is
166
``expand_as_pair()``.
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
The detail of this function is shown below.

.. code::

    def expand_as_pair(input_, g=None):
        if isinstance(input_, tuple):
            # Bipartite graph case
            return input_
        elif g is not None and g.is_block:
            # Subgraph block case
            if isinstance(input_, Mapping):
                input_dst = {
                    k: F.narrow_row(v, 0, g.number_of_dst_nodes(k))
                    for k, v in input_.items()}
            else:
                input_dst = F.narrow_row(input_, 0, g.number_of_dst_nodes())
            return input_, input_dst
        else:
            # Homograph case
            return input_, input_

For homogeneous whole graph training, source nodes and destination nodes
are the same. They are all the nodes in the graph.

For heterogeneous case, the graph can be splitted into several bipartite
graphs, one for each relation. The relations are represented as
``(src_type, edge_type, dst_dtype)``. When we identify the input feature
``feat`` is a tuple, we will treat the graph as bipartite. The first
element in the tuple will be the source node feature and the second
element will be the destination node feature.

In mini-batch training, the computing is applied on a subgraph sampled
by given a bunch of destination nodes. The subgraph is called as
``block`` in DGL. After message passing, only those destination nodes
will be updated since they have the same neighborhood as the one they
have in the original full graph. In the block creation phase,
``dst nodes`` are in the front of the node list. We can find the
``feat_dst`` by the index ``[0:g.number_of_dst_nodes()]``.

After determining ``feat_src`` and ``feat_dst``, the computing for the
above three graph types are the same.

Message passing and reducing
~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code::

                if self._aggre_type == 'mean':
                    graph.srcdata['h'] = feat_src
                    graph.update_all(fn.copy_u('h', 'm'), fn.mean('m', 'neigh'))
                    h_neigh = graph.dstdata['neigh']
                elif self._aggre_type == 'gcn':
                    check_eq_shape(feat)
                    graph.srcdata['h'] = feat_src
                    graph.dstdata['h'] = feat_dst     # same as above if homogeneous
                    graph.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'neigh'))
                    # divide in_degrees
                    degs = graph.in_degrees().to(feat_dst)
                    h_neigh = (graph.dstdata['neigh'] + graph.dstdata['h']) / (degs.unsqueeze(-1) + 1)
                elif self._aggre_type == 'max_pool':
                    graph.srcdata['h'] = F.relu(self.fc_pool(feat_src))
                    graph.update_all(fn.copy_u('h', 'm'), fn.max('m', 'neigh'))
                    h_neigh = graph.dstdata['neigh']
                else:
                    raise KeyError('Aggregator type {} not recognized.'.format(self._aggre_type))

                # GraphSAGE GCN does not require fc_self.
                if self._aggre_type == 'gcn':
                    rst = self.fc_neigh(h_neigh)
                else:
                    rst = self.fc_self(h_self) + self.fc_neigh(h_neigh)

The code actually does message passing and reducing computing. This part
of code varies module by module. Note that all the message passings in
the above code are implemented using ``update_all()`` API and
``built-in`` message/reduce functions to fully utilize DGL’s performance
243
optimization as described in :ref:`guide-message-passing`.
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265

Update feature after reducing for output
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code::

                # activation
                if self.activation is not None:
                    rst = self.activation(rst)
                # normalization
                if self.norm is not None:
                    rst = self.norm(rst)
                return rst

The last part of ``forward()`` function is to update the feature after
the ``reduce function``. Common update operations are applying
activation function and normalization according to the option set in the
object construction phase.

Heterogeneous GraphConv Module
------------------------------

266
:class:`dgl.nn.pytorch.HeteroGraphConv`
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
is a module-level encapsulation to run DGL NN module on heterogeneous
graph. The implementation logic is the same as message passing level API
``multi_update_all()``:

-  DGL NN module within each relation :math:`r`.
-  Reduction that merges the results on the same node type from multiple
   relationships.

This can be formulated as:

.. math::  h_{dst}^{(l+1)} = \underset{r\in\mathcal{R}, r_{dst}=dst}{AGG} (f_r(g_r, h_{r_{src}}^l, h_{r_{dst}}^l))

where :math:`f_r` is the NN module for each relation :math:`r`,
:math:`AGG` is the aggregation function.

HeteroGraphConv implementation logic:
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

.. code::

    class HeteroGraphConv(nn.Module):
        def __init__(self, mods, aggregate='sum'):
            super(HeteroGraphConv, self).__init__()
            self.mods = nn.ModuleDict(mods)
            if isinstance(aggregate, str):
                self.agg_fn = get_aggregate_fn(aggregate)
            else:
                self.agg_fn = aggregate

The heterograph convolution takes a dictonary ``mods`` that maps each
relation to a nn module. And set the function that aggregates results on
the same node type from multiple relations.

.. code::

    def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
        if mod_args is None:
            mod_args = {}
        if mod_kwargs is None:
            mod_kwargs = {}
        outputs = {nty : [] for nty in g.dsttypes}

Besides input graph and input tensors, the ``forward()`` function takes
two additional dictionary parameters ``mod_args`` and ``mod_kwargs``.
These two dictionaries have the same keys as ``self.mods``. They are
used as customized parameters when calling their corresponding NN
modules in ``self.mods``\ for different types of relations.

An output dictionary is created to hold output tensor for each
destination type\ ``nty`` . Note that the value for each ``nty`` is a
list, indicating a single node type may get multiple outputs if more
than one relations have ``nty`` as the destination type. We will hold
them in list for further aggregation.

.. code::

          if g.is_block:
              src_inputs = inputs
              dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
          else:
              src_inputs = dst_inputs = inputs

          for stype, etype, dtype in g.canonical_etypes:
              rel_graph = g[stype, etype, dtype]
              if rel_graph.number_of_edges() == 0:
                  continue
              if stype not in src_inputs or dtype not in dst_inputs:
                  continue
              dstdata = self.mods[etype](
                  rel_graph,
                  (src_inputs[stype], dst_inputs[dtype]),
                  *mod_args.get(etype, ()),
                  **mod_kwargs.get(etype, {}))
              outputs[dtype].append(dstdata)

The input ``g`` can be a heterogeneous graph or a subgraph block from a
heterogeneous graph. As in ordinary NN module, the ``forward()``
function need to handle different input graph types separately.

Each relation is represented as a ``canonical_etype``, which is
``(stype, etype, dtype)``. Using ``canonical_etype`` as the key, we can
extract out a bipartite graph ``rel_graph``. For bipartite graph, the
input feature will be organized as a tuple
``(src_inputs[stype], dst_inputs[dtype])``. The NN module for each
relation is called and the output is saved. To avoid unnecessary call,
relations with no edge or no node with the its src type will be skipped.

.. code::

        rsts = {}
        for nty, alist in outputs.items():
            if len(alist) != 0:
                rsts[nty] = self.agg_fn(alist, nty)

Finally, the results on the same destination node type from multiple
362
relationships are aggregated using ``self.agg_fn`` function. Examples can be found in the API Doc for :class:`dgl.nn.pytorch.HeteroGraphConv`.