minibatch-edge.rst 12.1 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-edge-classification-sampler:

6.2 Training GNN for Edge Classification with Neighborhood Sampling
----------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-edge-classification-sampler>`

8
9
10
11
12
13
14
15
16
17
18
Training for edge classification/regression is somewhat similar to that
of node classification/regression with several notable differences.

Define a neighborhood sampler and data loader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You can use the
:ref:`same neighborhood samplers as node classification <guide-minibatch-node-classification-sampler>`.

.. code:: python

19
20
21
    datapipe = datapipe.sample_neighbor(g, [10, 10])
    # Or equivalently
    datapipe = dgl.graphbolt.NeighborSampler(datapipe, g, [10, 10])
22

23
24
25
The code for defining a data loader is also the same as that of node
classification. The only difference is that it iterates over the
edges(namely, node pairs) in the training set instead of the nodes.
26
27
28

.. code:: python

29
30
31
32
33
34
35
36
37
38
39
40
41
    import dgl.graphbolt as gb

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = gb.SamplingGraph()
    node_paris = torch.arange(0, 1000).reshape(-1, 2)
    labels = torch.randint(0, 2, (5,))
    train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    # Or equivalently:
    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
42
    dataloader = gb.DataLoader(datapipe, num_workers=0)
43

44
Iterating over the DataLoader will yield :class:`~dgl.graphbolt.MiniBatch`
45
which contains a list of specially created graphs representing the computation
46
47
48
dependencies on each layer. In order to train with DGL, you need to convert them
to :class:`~dgl.graphbolt.DGLMiniBatch`. Then you can access the
*message flow graphs* (MFGs).
49
50
51

.. code:: python
    mini_batch = next(iter(dataloader))
52
    mini_batch = mini_batch.to_dgl()
53
    print(mini_batch.blocks)
54

55
.. note::
56

57
   See the :doc:`Stochastic Training Tutorial
58
59
   <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__
   for the concept of message flow graph.
60
61
62
63

   If you wish to develop your own neighborhood sampler or you want a more
   detailed explanation of the concept of MFGs, please refer to
   :ref:`guide-minibatch-customizing-neighborhood-sampler`.
64

65
66
.. _guide-minibatch-edge-classification-sampler-exclude:

67
68
69
70
71
72
73
74
75
76
Removing edges in the minibatch from the original graph for neighbor sampling
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When training edge classification models, sometimes you wish to remove
the edges appearing in the training data from the computation dependency
as if they never existed. Otherwise, the model will know the fact that
an edge exists between the two nodes, and potentially use it for
advantage.

Therefore in edge classification you sometimes would like to exclude the
77
78
79
seed edges as well as their reverse edges from the sampled minibatch.
You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with
:class:`~dgl.graphbolt.MiniBatchTransformer` to achieve this.
80
81
82

.. code:: python

83
84
85
86
87
88
89
90
91
92
93
94
95
96
    import dgl.graphbolt as gb
    from functools import partial

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = gb.SamplingGraph()
    node_paris = torch.arange(0, 1000).reshape(-1, 2)
    labels = torch.randint(0, 2, (5,))
    train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
    datapipe = datapipe.transform(exclude_seed_edges)
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
97
    dataloader = gb.DataLoader(datapipe, num_workers=0)
98
    
99
100
101
102
103
104
105
106
107
108
109
110
111

Adapt your model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The edge classification model usually consists of two parts:

-  One part that obtains the representation of incident nodes.
-  The other part that computes the edge score from the incident node
   representations.

The former part is exactly the same as
:ref:`that from node classification <guide-minibatch-node-classification-model>`
and we can simply reuse it. The input is still the list of
112
MFGs generated from a data loader provided by DGL, as well as the
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
input features.

.. code:: python

    class StochasticTwoLayerGCN(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.conv1 = dglnn.GraphConv(in_features, hidden_features)
            self.conv2 = dglnn.GraphConv(hidden_features, out_features)
    
        def forward(self, blocks, x):
            x = F.relu(self.conv1(blocks[0], x))
            x = F.relu(self.conv2(blocks[1], x))
            return x

The input to the latter part is usually the output from the
129
130
131
former part, as well as the subgraph(node pairs) of the original graph induced
by the edges in the minibatch. The subgraph is yielded from the same data
loader.
132
133

The following code shows an example of predicting scores on the edges by
134
concatenating the incident node features and projecting it with a dense layer.
135
136
137
138
139
140
141
142

.. code:: python

    class ScorePredictor(nn.Module):
        def __init__(self, num_classes, in_features):
            super().__init__()
            self.W = nn.Linear(2 * in_features, num_classes)
    
143
144
145
146
147
148
        def forward(self, node_pairs, x):
            src_x = x[node_pairs[0]]
            dst_x = x[node_pairs[1]]
            data = torch.cat([src_x, dst_x], 1)
            return self.W(data)

149

150
151
The entire model will take the list of MFGs and the edges generated by the data
loader, as well as the input node features as follows:
152
153
154
155
156
157
158
159
160

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, num_classes):
            super().__init__()
            self.gcn = StochasticTwoLayerGCN(
                in_features, hidden_features, out_features)
            self.predictor = ScorePredictor(num_classes, out_features)
161
162

        def forward(self, blocks, x, node_pairs):
163
            x = self.gcn(blocks, x)
164
            return self.predictor(node_pairs, x)
165
166

DGL ensures that that the nodes in the edge subgraph are the same as the
167
output nodes of the last MFG in the generated list of MFGs.
168
169
170
171
172
173

Training Loop
~~~~~~~~~~~~~

The training loop is very similar to node classification. You can
iterate over the dataloader and get a subgraph induced by the edges in
174
the minibatch, as well as the list of MFGs necessary for computing
175
176
177
178
their incident node representations.

.. code:: python

179
    import torch.nn.functional as F
180
    model = Model(in_features, hidden_features, out_features, num_classes)
181
    model = model.to(device)
182
    opt = torch.optim.Adam(model.parameters())
183
184

    for data in dataloader:
185
        data = data.to_dgl()
186
187
188
189
        blocks = data.blocks
        x = data.edge_features("feat")
        y_hat = model(data.blocks, x, data.positive_node_pairs)
        loss = F.cross_entropy(data.labels, y_hat)
190
191
192
193
        opt.zero_grad()
        loss.backward()
        opt.step()

194

195
196
197
198
199
200
201
202
203
204
For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~

The models computing the node representations on heterogeneous graphs
can also be used for computing incident node representations for edge
classification/regression.

.. code:: python

    class StochasticTwoLayerRGCN(nn.Module):
205
        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            super().__init__()
            self.conv1 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                    for rel in rel_names
                })
            self.conv2 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                    for rel in rel_names
                })
    
        def forward(self, blocks, x):
            x = self.conv1(blocks[0], x)
            x = self.conv2(blocks[1], x)
            return x

For score prediction, the only implementation difference between the
homogeneous graph and the heterogeneous graph is that we are looping
223
over the edge types.
224
225
226
227
228
229
230
231

.. code:: python

    class ScorePredictor(nn.Module):
        def __init__(self, num_classes, in_features):
            super().__init__()
            self.W = nn.Linear(2 * in_features, num_classes)
    
232
233
234
235
236
237
238
        def forward(self, node_pairs, x):
            scores = {}
            for etype in node_pairs.keys():
                src, dst = node_pairs[etype]
                data = torch.cat([x[etype][src], x[etype][dst]], 1)
                scores[etype] = self.W(data)
            return scores
239

240
241
242
243
244
245
246
247
    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, num_classes,
                     etypes):
            super().__init__()
            self.rgcn = StochasticTwoLayerRGCN(
                in_features, hidden_features, out_features, etypes)
            self.pred = ScorePredictor(num_classes, out_features)

248
        def forward(self, node_pairs, blocks, x):
249
            x = self.rgcn(blocks, x)
250
            return self.pred(node_pairs, x)
251

252
253
254
Data loader definition is almost identical to that of homogeneous graph. The
only difference is that the train_set is now an instance of
:class:`~dgl.graphbolt.ItemSetDict` instead of :class:`~dgl.graphbolt.ItemSet`.
255
256
257

.. code:: python

258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    import dgl.graphbolt as gb

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = gb.SamplingGraph()
    node_pairs = torch.arange(0, 1000).reshape(-1, 2)
    labels = torch.randint(0, 3, (1000,))
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
    }
    train_set = gb.ItemSetDict(node_pairs_labels)
    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.fetch_feature(
        feature, node_feature_keys={"item": ["feat"], "user": ["feat"]}
    )
    datapipe = datapipe.copy_to(device)
279
    dataloader = gb.DataLoader(datapipe, num_workers=0)
280
281
282
283
284

Things become a little different if you wish to exclude the reverse
edges on heterogeneous graphs. On heterogeneous graphs, reverse edges
usually have a different edge type from the edges themselves, in order
to differentiate the forward and backward relationships (e.g.
285
286
``follow`` and ``followed_by`` are reverse relations of each other,
``like`` and ``liked_by`` are reverse relations of each other,
287
288
289
290
291
292
293
294
295
etc.).

If each edge in a type has a reverse edge with the same ID in another
type, you can specify the mapping between edge types and their reverse
types. The way to exclude the edges in the minibatch as well as their
reverse edges then goes as follows.

.. code:: python

296
297
298
299
300
301
302
303
304
305
306

    exclude_seed_edges = partial(
        gb.exclude_seed_edges,
        include_reverse_edges=True,
        reverse_etypes_mapping={
            "user:like:item": "item:liked_by:user",
            "user:follow:user": "user:followed_by:user",
        },
    )
    datapipe = datapipe.transform(exclude_seed_edges)

307
308
309
310
311
312
313

The training loop is again almost the same as that on homogeneous graph,
except for the implementation of ``compute_loss`` that will take in two
dictionaries of node types and predictions here.

.. code:: python

314
    import torch.nn.functional as F
315
    model = Model(in_features, hidden_features, out_features, num_classes, etypes)
316
    model = model.to(device)
317
    opt = torch.optim.Adam(model.parameters())
318
319
320
321
322
323
324

    for data in dataloader:
        data = data.to_dgl()
        blocks = data.blocks
        x = data.edge_features(("user:like:item", "feat"))
        y_hat = model(data.blocks, x, data.positive_node_pairs)
        loss = F.cross_entropy(data.labels, y_hat)
325
326
327
328
        opt.zero_grad()
        loss.backward()
        opt.step()