minibatch-node.rst 9.22 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-node-classification-sampler:

6.1 Training GNN for Node Classification with Neighborhood Sampling
-----------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-node-classification-sampler>`

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
To make your model been trained stochastically, you need to do the
followings:

-  Define a neighborhood sampler.
-  Adapt your model for minibatch training.
-  Modify your training loop.

The following sub-subsections address these steps one by one.

Define a neighborhood sampler and data loader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

DGL provides several neighborhood sampler classes that generates the
computation dependencies needed for each layer given the nodes we wish
to compute on.

24
25
26
The simplest neighborhood sampler is :class:`~dgl.graphbolt.NeighborSampler`
or the equivalent function-like interface :func:`~dgl.graphbolt.sample_neighbor`
which makes the node gather messages from its neighbors.
27
28

To use a sampler provided by DGL, one also need to combine it with
29
:class:`~dgl.graphbolt.DataLoader`, which iterates
30
over a set of indices (nodes in this case) in minibatches.
31

32
33
For example, the following code creates a DataLoader that
iterates over the training node ID set of ``ogbn-arxiv`` in batches,
34
putting the list of generated MFGs onto GPU.
35
36
37
38

.. code:: python

    import dgl
39
    import dgl.graphbolt as gb
40
41
42
43
    import dgl.nn as dglnn
    import torch
    import torch.nn as nn
    import torch.nn.functional as F
44
45
46
47
48
49
50
51
52
53

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = gb.BuiltinDataset("ogbn-arxiv").load()
    train_set = dataset.tasks[0].train_set
    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    # Or equivalently:
    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
54
    dataloader = gb.DataLoader(datapipe, num_workers=0)
55
56


57
Iterating over the DataLoader will yield :class:`~dgl.graphbolt.MiniBatch`
58
which contains a list of specially created graphs representing the computation
59
60
61
dependencies on each layer. In order to train with DGL, you need to convert them
to :class:`~dgl.graphbolt.DGLMiniBatch`. Then you could access the
*message flow graphs* (MFGs).
62
63
64

.. code:: python

65
    mini_batch = next(iter(dataloader))
66
    mini_batch = mini_batch.to_dgl()
67
    print(mini_batch.blocks)
68
69


70
71
.. note::

72
73
74
   See the `Stochastic Training Tutorial
   <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__
   for the concept of message flow graph.
75
76
77
78

   If you wish to develop your own neighborhood sampler or you want a more
   detailed explanation of the concept of MFGs, please refer to
   :ref:`guide-minibatch-customizing-neighborhood-sampler`.
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122


.. _guide-minibatch-node-classification-model:

Adapt your model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

If your message passing modules are all provided by DGL, the changes
required to adapt your model to minibatch training is minimal. Take a
multi-layer GCN as an example. If your model on full graph is
implemented as follows:

.. code:: python

    class TwoLayerGCN(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.conv1 = dglnn.GraphConv(in_features, hidden_features)
            self.conv2 = dglnn.GraphConv(hidden_features, out_features)
    
        def forward(self, g, x):
            x = F.relu(self.conv1(g, x))
            x = F.relu(self.conv2(g, x))
            return x

Then all you need is to replace ``g`` with ``blocks`` generated above.

.. code:: python

    class StochasticTwoLayerGCN(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
    
        def forward(self, blocks, x):
            x = F.relu(self.conv1(blocks[0], x))
            x = F.relu(self.conv2(blocks[1], x))
            return x

The DGL ``GraphConv`` modules above accepts an element in ``blocks``
generated by the data loader as an argument.

:ref:`The API reference of each NN module <apinn>` will tell you
123
whether it supports accepting a MFG as an argument.
124
125
126
127
128
129
130
131

If you wish to use your own message passing module, please refer to
:ref:`guide-minibatch-custom-gnn-module`.

Training Loop
~~~~~~~~~~~~~

The training loop simply consists of iterating over the dataset with the
132
customized batching iterator. During each iteration that yields
133
:class:`~dgl.graphbolt.MiniBatch`, we:
134

135
136
137
138
1. Convert the :class:`~dgl.graphbolt.MiniBatch` to
   :class:`~dgl.graphbolt.DGLMiniBatch`.

2. Access the node features corresponding to the input nodes via
139
140
141
   ``data.node_features["feat"]``. These features are already moved to the
   target device (CPU or GPU) by the data loader.

142
3. Access the node labels corresponding to the output nodes via
143
144
145
   ``data.labels``. These labels are already moved to the target device
   (CPU or GPU) by the data loader.

146
4. Feed the list of MFGs and the input node features to the multilayer
147
   GNN and get the outputs.
148
149
150
151
152
153

4. Compute the loss and backpropagate.

.. code:: python

    model = StochasticTwoLayerGCN(in_features, hidden_features, out_features)
154
    model = model.to(device)
155
    opt = torch.optim.Adam(model.parameters())
156
157

    for data in dataloader:
158
        data = data.to_dgl()
159
160
161
        input_features = data.node_features["feat"]
        output_labels = data.labels
        output_predictions = model(data.blocks, input_features)
162
163
164
165
166
        loss = compute_loss(output_labels, output_predictions)
        opt.zero_grad()
        loss.backward()
        opt.step()

167

168
DGL provides an end-to-end stochastic training example `GraphSAGE
169
implementation <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/node_classification.py>`__.
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185

For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~

Training a graph neural network for node classification on heterogeneous
graph is similar.

For instance, we have previously seen
:ref:`how to train a 2-layer RGCN on full graph <guide-training-rgcn-node-classification>`.
The code for RGCN implementation on minibatch training looks very
similar to that (with self-loops, non-linearity and basis decomposition
removed for simplicity):

.. code:: python

    class StochasticTwoLayerRGCN(nn.Module):
186
        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
            super().__init__()
            self.conv1 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                    for rel in rel_names
                })
            self.conv2 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                    for rel in rel_names
                })
    
        def forward(self, blocks, x):
            x = self.conv1(blocks[0], x)
            x = self.conv2(blocks[1], x)
            return x

202
The samplers provided by DGL also support heterogeneous graphs.
203
For example, one can still use the provided
204
:class:`~dgl.graphbolt.NeighborSampler` class and
205
:class:`~dgl.graphbolt.DataLoader` class for
206
207
208
stochastic training. The only difference is that the itemset is now an
instance of :class:`~dgl.graphbolt.ItemSetDict` which is a dictionary
of node types to node IDs.
209
210
211

.. code:: python

212
213
214
215
216
217
218
219
220
221
222
223
224
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    dataset = gb.BuiltinDataset("ogbn-mag").load()
    train_set = dataset.tasks[0].train_set
    datapipe = gb.ItemSampler(train_set, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    # Or equivalently:
    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
    # For heterogeneous graphs, we need to specify the node feature keys
    # for each node type.
    datapipe = datapipe.fetch_feature(
        feature, node_feature_keys={"author": ["feat"], "paper": ["feat"]}
    )
    datapipe = datapipe.copy_to(device)
225
    dataloader = gb.DataLoader(datapipe, num_workers=0)
226
227
228
229
230
231
232

The training loop is almost the same as that of homogeneous graphs,
except for the implementation of ``compute_loss`` that will take in two
dictionaries of node types and predictions here.

.. code:: python

233
    model = StochasticTwoLayerRGCN(in_features, hidden_features, out_features, etypes)
234
    model = model.to(device)
235
236
    opt = torch.optim.Adam(model.parameters())
    
237
    for data in dataloader:
238
        data = data.to_dgl()
239
240
241
242
243
244
245
246
        # For heterogeneous graphs, we need to specify the node types and
        # feature name when accessing the node features. So does the labels.
        input_features = {
            "author": data.node_features[("author", "feat")],
            "paper": data.node_features[("paper", "feat")]
        }
        output_labels = data.labels["paper"]
        output_predictions = model(data.blocks, input_features)
247
248
249
250
251
252
        loss = compute_loss(output_labels, output_predictions)
        opt.zero_grad()
        loss.backward()
        opt.step()

DGL provides an end-to-end stochastic training example `RGCN
253
implementation <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/rgcn/hetero_rgcn.py>`__.
254
255