minibatch-custom-sampler.rst 15.3 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-customizing-neighborhood-sampler:

6.4 Customizing Neighborhood Sampler
----------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-customizing-neighborhood-sampler>`

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
Although DGL provides some neighborhood sampling strategies, sometimes
users would want to write their own sampling strategy. This section
explains how to write your own strategy and plug it into your stochastic
GNN training framework.

Recall that in `How Powerful are Graph Neural
Networks <https://arxiv.org/pdf/1810.00826.pdf>`__, the definition of message
passing is:

.. math::


   \begin{gathered}
     \boldsymbol{a}_v^{(l)} = \rho^{(l)} \left(
       \left\lbrace
         \boldsymbol{h}_u^{(l-1)} : u \in \mathcal{N} \left( v \right)
       \right\rbrace
     \right)
   \\
     \boldsymbol{h}_v^{(l)} = \phi^{(l)} \left(
       \boldsymbol{h}_v^{(l-1)}, \boldsymbol{a}_v^{(l)}
     \right)
   \end{gathered}

where :math:`\rho^{(l)}` and :math:`\phi^{(l)}` are parameterized
functions, and :math:`\mathcal{N}(v)` is defined as the set of
predecessors (or *neighbors* if the graph is undirected) of :math:`v` on graph
:math:`\mathcal{G}`.

For instance, to perform a message passing for updating the red node in
the following graph:

Jinjing Zhou's avatar
Jinjing Zhou committed
40
.. figure:: https://data.dgl.ai/asset/image/guide_6_4_0.png
41
42
   :alt: Imgur

Jinjing Zhou's avatar
Jinjing Zhou committed
43

44
45
46
47

One needs to aggregate the node features of its neighbors, shown as
green nodes:

Jinjing Zhou's avatar
Jinjing Zhou committed
48
.. figure:: https://data.dgl.ai/asset/image/guide_6_4_1.png
49
50
   :alt: Imgur

Jinjing Zhou's avatar
Jinjing Zhou committed
51

52
53
54
55

Neighborhood sampling with pencil and paper
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

56
Let's first define a DGL graph according to the above image.
57
58
59
60
61

.. code:: python

    import torch
    import dgl
62

63
64
65
66
67
68
69
    src = torch.LongTensor(
        [0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10,
         1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11])
    dst = torch.LongTensor(
        [1, 2, 3, 3, 3, 4, 5, 5, 6, 5, 8, 6, 8, 9, 8, 11, 11, 10, 11,
         0, 0, 0, 1, 2, 2, 2, 3, 3, 4, 4, 5, 5, 6, 7, 7, 8, 9, 10])
    g = dgl.graph((src, dst))
70
71
72
73

We then consider how multi-layer message passing works for computing the
output of a single node. In the following text we refer to the nodes
whose GNN outputs are to be computed as *seed nodes*.
74
75
76
77
78
79
80

Finding the message passing dependency
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

Consider computing with a 2-layer GNN the output of the seed node 8,
colored red, in the following graph:

Jinjing Zhou's avatar
Jinjing Zhou committed
81
.. figure:: https://data.dgl.ai/asset/image/guide_6_4_2.png
82
83
   :alt: Imgur

Jinjing Zhou's avatar
Jinjing Zhou committed
84

85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111

By the formulation:

.. math::


   \begin{gathered}
     \boldsymbol{a}_8^{(2)} = \rho^{(2)} \left(
       \left\lbrace
         \boldsymbol{h}_u^{(1)} : u \in \mathcal{N} \left( 8 \right)
       \right\rbrace
     \right) = \rho^{(2)} \left(
       \left\lbrace
         \boldsymbol{h}_4^{(1)}, \boldsymbol{h}_5^{(1)},
         \boldsymbol{h}_7^{(1)}, \boldsymbol{h}_{11}^{(1)}
       \right\rbrace
     \right)
   \\
     \boldsymbol{h}_8^{(2)} = \phi^{(2)} \left(
       \boldsymbol{h}_8^{(1)}, \boldsymbol{a}_8^{(2)}
     \right)
   \end{gathered}

We can tell from the formulation that to compute
:math:`\boldsymbol{h}_8^{(2)}` we need messages from node 4, 5, 7 and 11
(colored green) along the edges visualized below.

Jinjing Zhou's avatar
Jinjing Zhou committed
112
.. figure:: https://data.dgl.ai/asset/image/guide_6_4_3.png
113
114
   :alt: Imgur

Jinjing Zhou's avatar
Jinjing Zhou committed
115

116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150

This graph contains all the nodes in the original graph but only the
edges necessary for message passing to the given output nodes. We call
that the *frontier* of the second GNN layer for the red node 8.

Several functions can be used for generating frontiers. For instance,
:func:`dgl.in_subgraph()` is a function that induces a
subgraph by including all the nodes in the original graph, but only all
the incoming edges of the given nodes. You can use that as a frontier
for message passing along all the incoming edges.

.. code:: python

    frontier = dgl.in_subgraph(g, [8])
    print(frontier.all_edges())

For a concrete list, please refer to :ref:`api-subgraph-extraction` and
:ref:`api-sampling`.

Technically, any graph that has the same set of nodes as the original
graph can serve as a frontier. This serves as the basis for
:ref:`guide-minibatch-customizing-neighborhood-sampler-impl`.

The Bipartite Structure for Multi-layer Minibatch Message Passing
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

However, to compute :math:`\boldsymbol{h}_8^{(2)}` from
:math:`\boldsymbol{h}_\cdot^{(1)}`, we cannot simply perform message
passing on the frontier directly, because it still contains all the
nodes from the original graph. Namely, we only need nodes 4, 5, 7, 8,
and 11 (green and red nodes) as input, as well as node 8 (red node) as output.
Since the number of nodes
for input and output is different, we need to perform message passing on
a small, bipartite-structured graph instead. We call such a
bipartite-structured graph that only contains the necessary input nodes
151
152
153
154
(referred as *source* nodes) and output nodes (referred as *destination* nodes)
of a *message flow graph* (MFG).

The following figure shows the MFG of the second GNN layer for node 8.
155

Jinjing Zhou's avatar
Jinjing Zhou committed
156
.. figure:: https://data.dgl.ai/asset/image/guide_6_4_4.png
157
158
   :alt: Imgur

159
160
161
162
163
.. note::

   See the :doc:`Stochastic Training Tutorial
   <tutorials/large/L0_neighbor_sampling_overview>` for the concept of
   message flow graph.
Jinjing Zhou's avatar
Jinjing Zhou committed
164

165
166
Note that the destination nodes also appear in the source nodes. The reason is
that representations of destination nodes from the previous layer are needed
167
168
169
for feature combination after message passing (i.e. :math:`\phi^{(2)}`).

DGL provides :func:`dgl.to_block` to convert any frontier
170
171
172
to a MFG where the first argument specifies the frontier and the
second argument specifies the destination nodes. For instance, the frontier
above can be converted to a MFG with destination node 8 with the code as
173
174
175
176
follows.

.. code:: python

177
178
    dst_nodes = torch.LongTensor([8])
    block = dgl.to_block(frontier, dst_nodes)
179

180
To find the number of source nodes and destination nodes of a given node type,
181
182
183
184
185
one can use :meth:`dgl.DGLHeteroGraph.number_of_src_nodes` and
:meth:`dgl.DGLHeteroGraph.number_of_dst_nodes` methods.

.. code:: python

186
187
    num_src_nodes, num_dst_nodes = block.number_of_src_nodes(), block.number_of_dst_nodes()
    print(num_src_nodes, num_dst_nodes)
188

189
The MFG’s source node features can be accessed via member
190
:attr:`dgl.DGLHeteroGraph.srcdata` and :attr:`dgl.DGLHeteroGraph.srcnodes`, and
191
its destination node features can be accessed via member
192
193
194
195
196
197
198
:attr:`dgl.DGLHeteroGraph.dstdata` and :attr:`dgl.DGLHeteroGraph.dstnodes`. The
syntax of ``srcdata``/``dstdata`` and ``srcnodes``/``dstnodes`` are
identical to :attr:`dgl.DGLHeteroGraph.ndata` and
:attr:`dgl.DGLHeteroGraph.nodes` in normal graphs.

.. code:: python

199
200
    block.srcdata['h'] = torch.randn(num_src_nodes, 5)
    block.dstdata['h'] = torch.randn(num_dst_nodes, 5)
201

202
203
204
If a MFG is converted from a frontier, which is in turn converted from
a graph, one can directly read the feature of the MFG’s source and
destination nodes via
205
206
207
208
209
210

.. code:: python

    print(block.srcdata['x'])
    print(block.dstdata['y'])

211
.. note::
212

213
214
215
216
   The original node IDs of the source nodes and destination nodes in the MFG
   can be found as the feature ``dgl.NID``, and the mapping from the
   MFG’s edge IDs to the input frontier’s edge IDs can be found as the
   feature ``dgl.EID``.
217

218
219
DGL ensures that the destination nodes of a MFG will always appear in the
source nodes. The destination nodes will always index firstly in the source
220
221
222
223
nodes.

.. code:: python

224
225
226
    src_nodes = block.srcdata[dgl.NID]
    dst_nodes = block.dstdata[dgl.NID]
    assert torch.equal(src_nodes[:len(dst_nodes)], dst_nodes)
227

228
As a result, the destination nodes must cover all nodes that are the
229
230
231
232
destination of an edge in the frontier.

For example, consider the following frontier

Jinjing Zhou's avatar
Jinjing Zhou committed
233
.. figure:: https://data.dgl.ai/asset/image/guide_6_4_5.png
234
235
   :alt: Imgur

Jinjing Zhou's avatar
Jinjing Zhou committed
236

237
238
239

where the red and green nodes (i.e. node 4, 5, 7, 8, and 11) are all
nodes that is a destination of an edge. Then the following code will
240
raise an error because the destination nodes did not cover all those nodes.
241
242
243
244
245

.. code:: python

    dgl.to_block(frontier2, torch.LongTensor([4, 5]))   # ERROR

246
However, the destination nodes can have more nodes than above. In this case,
247
we will have isolated nodes that do not have any edge connecting to it.
248
The isolated nodes will be included in both source nodes and destination
249
250
251
252
253
254
255
256
257
258
259
260
nodes.

.. code:: python

    # Node 3 is an isolated node that do not have any edge pointing to it.
    block3 = dgl.to_block(frontier2, torch.LongTensor([4, 5, 7, 8, 11, 3]))
    print(block3.srcdata[dgl.NID])
    print(block3.dstdata[dgl.NID])

Heterogeneous Graphs
^^^^^^^^^^^^^^^^^^^^

261
MFGs also work on heterogeneous graphs. Let’s say that we have the
262
263
264
265
266
267
268
269
270
271
following frontier:

.. code:: python

    hetero_frontier = dgl.heterograph({
        ('user', 'follow', 'user'): ([1, 3, 7], [3, 6, 8]),
        ('user', 'play', 'game'): ([5, 5, 4], [6, 6, 2]),
        ('game', 'played-by', 'user'): ([2], [6])
    }, num_nodes_dict={'user': 10, 'game': 10})

272
One can also create a MFG with destination nodes User #3, #6, and #8, as
273
274
275
276
well as Game #2 and #6.

.. code:: python

277
    hetero_block = dgl.to_block(hetero_frontier, {'user': [3, 6, 8], 'game': [2, 6]})
278

279
One can also get the source nodes and destination nodes by type:
280
281
282

.. code:: python

283
    # source users and games
284
    print(hetero_block.srcnodes['user'].data[dgl.NID], hetero_block.srcnodes['game'].data[dgl.NID])
285
    # destination users and games
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
    print(hetero_block.dstnodes['user'].data[dgl.NID], hetero_block.dstnodes['game'].data[dgl.NID])


.. _guide-minibatch-customizing-neighborhood-sampler-impl:

Implementing a Custom Neighbor Sampler
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

Recall that the following code performs neighbor sampling for node
classification.

.. code:: python

    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

To implement your own neighborhood sampling strategy, you basically
replace the ``sampler`` object with your own. To do that, let’s first
see what :class:`~dgl.dataloading.dataloader.BlockSampler`, the parent class of
:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler`, is.

:class:`~dgl.dataloading.dataloader.BlockSampler` is responsible for
307
generating the list of MFGs starting from the last layer, with method
308
309
:meth:`~dgl.dataloading.dataloader.BlockSampler.sample_blocks`. The default implementation of
``sample_blocks`` is to iterate backwards, generating the frontiers and
310
converting them to MFGs.
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363

Therefore, for neighborhood sampling, **you only need to implement
the**\ :meth:`~dgl.dataloading.dataloader.BlockSampler.sample_frontier`\ **method**. Given which
layer the sampler is generating frontier for, as well as the original
graph and the nodes to compute representations, this method is
responsible for generating a frontier for them.

Meanwhile, you also need to pass how many GNN layers you have to the
parent class.

For example, the implementation of
:class:`~dgl.dataloading.neighbor.MultiLayerFullNeighborSampler` can
go as follows.

.. code:: python

    class MultiLayerFullNeighborSampler(dgl.dataloading.BlockSampler):
        def __init__(self, n_layers):
            super().__init__(n_layers)
    
        def sample_frontier(self, block_id, g, seed_nodes):
            frontier = dgl.in_subgraph(g, seed_nodes)
            return frontier

:class:`dgl.dataloading.neighbor.MultiLayerNeighborSampler`, a more
complicated neighbor sampler class that allows you to sample a small
number of neighbors to gather message for each node, goes as follows.

.. code:: python

    class MultiLayerNeighborSampler(dgl.dataloading.BlockSampler):
        def __init__(self, fanouts):
            super().__init__(len(fanouts))
    
            self.fanouts = fanouts
    
        def sample_frontier(self, block_id, g, seed_nodes):
            fanout = self.fanouts[block_id]
            if fanout is None:
                frontier = dgl.in_subgraph(g, seed_nodes)
            else:
                frontier = dgl.sampling.sample_neighbors(g, seed_nodes, fanout)
            return frontier

Although the functions above can generate a frontier, any graph that has
the same nodes as the original graph can serve as a frontier.

For example, if one want to randomly drop inbound edges to the seed
nodes with a probability, one can simply define the sampler as follows:

.. code:: python

    class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
364
365
        def __init__(self, p, num_layers):
            super().__init__(num_layers)
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
    
            self.p = p
    
        def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
            # Get all inbound edges to `seed_nodes`
            src, dst = dgl.in_subgraph(g, seed_nodes).all_edges()
            # Randomly select edges with a probability of p
            mask = torch.zeros_like(src).bernoulli_(self.p)
            src = src[mask]
            dst = dst[mask]
            # Return a new graph with the same nodes as the original graph as a
            # frontier
            frontier = dgl.graph((src, dst), num_nodes=g.number_of_nodes())
            return frontier
    
        def __len__(self):
382
            return self.num_layers
383
384

After implementing your sampler, you can create a data loader that takes
385
in your sampler and it will keep generating lists of MFGs while
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
iterating over the seed nodes as usual.

.. code:: python

    sampler = MultiLayerDropoutSampler(0.5, 2)
    dataloader = dgl.dataloading.NodeDataLoader(
        g, train_nids, sampler,
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=4)
    
    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features)
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())
    
    for input_nodes, blocks in dataloader:
        blocks = [b.to(torch.device('cuda')) for b in blocks]
        input_features = blocks[0].srcdata     # returns a dict
        output_labels = blocks[-1].dstdata     # returns a dict
        output_predictions = model(blocks, input_features)
        loss = compute_loss(output_labels, output_predictions)
        opt.zero_grad()
        loss.backward()
        opt.step()

Heterogeneous Graphs
^^^^^^^^^^^^^^^^^^^^

Generating a frontier for a heterogeneous graph is nothing different
than that for a homogeneous graph. Just make the returned graph have the
same nodes as the original graph, and it should work fine. For example,
we can rewrite the ``MultiLayerDropoutSampler`` above to iterate over
all edge types, so that it can work on heterogeneous graphs as well.

.. code:: python

    class MultiLayerDropoutSampler(dgl.dataloading.BlockSampler):
424
425
        def __init__(self, p, num_layers):
            super().__init__(num_layers)
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
    
            self.p = p
    
        def sample_frontier(self, block_id, g, seed_nodes, *args, **kwargs):
            # Get all inbound edges to `seed_nodes`
            sg = dgl.in_subgraph(g, seed_nodes)
    
            new_edges_masks = {}
            # Iterate over all edge types
            for etype in sg.canonical_etypes:
                edge_mask = torch.zeros(sg.number_of_edges(etype))
                edge_mask.bernoulli_(self.p)
                new_edges_masks[etype] = edge_mask.bool()
    
            # Return a new graph with the same nodes as the original graph as a
            # frontier
442
            frontier = dgl.edge_subgraph(new_edges_masks, relabel_nodes=False)
443
444
445
            return frontier
    
        def __len__(self):
446
            return self.num_layers