minibatch-edge.rst 12.2 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-edge-classification-sampler:

6.2 Training GNN for Edge Classification with Neighborhood Sampling
----------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-edge-classification-sampler>`

8
9
10
11
12
13
14
15
16
17
18
Training for edge classification/regression is somewhat similar to that
of node classification/regression with several notable differences.

Define a neighborhood sampler and data loader
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You can use the
:ref:`same neighborhood samplers as node classification <guide-minibatch-node-classification-sampler>`.

.. code:: python

19
20
21
    datapipe = datapipe.sample_neighbor(g, [10, 10])
    # Or equivalently
    datapipe = dgl.graphbolt.NeighborSampler(datapipe, g, [10, 10])
22

23
24
25
The code for defining a data loader is also the same as that of node
classification. The only difference is that it iterates over the
edges(namely, node pairs) in the training set instead of the nodes.
26
27
28

.. code:: python

29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
    import dgl.graphbolt as gb

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = gb.SamplingGraph()
    node_paris = torch.arange(0, 1000).reshape(-1, 2)
    labels = torch.randint(0, 2, (5,))
    train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    # Or equivalently:
    # datapipe = gb.NeighborSampler(datapipe, g, [10, 10])
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)

Iterating over the DataLoader will yield :class:`~dgl.graphbolt.DGLMiniBatch`
which contains a list of specially created graphs representing the computation
dependencies on each layer. They are called *message flow graphs* (MFGs) in DGL.

.. code:: python
    mini_batch = next(iter(dataloader))
    print(mini_batch.blocks)
52

53
.. note::
54

55
   See the :doc:`Stochastic Training Tutorial
56
57
   <../notebooks/stochastic_training/neighbor_sampling_overview.nblink>`__
   for the concept of message flow graph.
58
59
60
61

   If you wish to develop your own neighborhood sampler or you want a more
   detailed explanation of the concept of MFGs, please refer to
   :ref:`guide-minibatch-customizing-neighborhood-sampler`.
62

63
64
.. _guide-minibatch-edge-classification-sampler-exclude:

65
66
67
68
69
70
71
72
73
74
Removing edges in the minibatch from the original graph for neighbor sampling
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^

When training edge classification models, sometimes you wish to remove
the edges appearing in the training data from the computation dependency
as if they never existed. Otherwise, the model will know the fact that
an edge exists between the two nodes, and potentially use it for
advantage.

Therefore in edge classification you sometimes would like to exclude the
75
76
77
seed edges as well as their reverse edges from the sampled minibatch.
You can use :func:`~dgl.graphbolt.exclude_seed_edges` alongside with
:class:`~dgl.graphbolt.MiniBatchTransformer` to achieve this.
78
79
80

.. code:: python

81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
    import dgl.graphbolt as gb
    from functools import partial

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = gb.SamplingGraph()
    node_paris = torch.arange(0, 1000).reshape(-1, 2)
    labels = torch.randint(0, 2, (5,))
    train_set = gb.ItemSet((node_pairs, labels), names=("node_pairs", "labels"))
    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    exclude_seed_edges = partial(gb.exclude_seed_edges, include_reverse_edges=True)
    datapipe = datapipe.transform(exclude_seed_edges)
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
    
98
99
100
101
102
103
104
105
106
107
108
109
110

Adapt your model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

The edge classification model usually consists of two parts:

-  One part that obtains the representation of incident nodes.
-  The other part that computes the edge score from the incident node
   representations.

The former part is exactly the same as
:ref:`that from node classification <guide-minibatch-node-classification-model>`
and we can simply reuse it. The input is still the list of
111
MFGs generated from a data loader provided by DGL, as well as the
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
input features.

.. code:: python

    class StochasticTwoLayerGCN(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.conv1 = dglnn.GraphConv(in_features, hidden_features)
            self.conv2 = dglnn.GraphConv(hidden_features, out_features)
    
        def forward(self, blocks, x):
            x = F.relu(self.conv1(blocks[0], x))
            x = F.relu(self.conv2(blocks[1], x))
            return x

The input to the latter part is usually the output from the
128
129
130
former part, as well as the subgraph(node pairs) of the original graph induced
by the edges in the minibatch. The subgraph is yielded from the same data
loader.
131
132

The following code shows an example of predicting scores on the edges by
133
concatenating the incident node features and projecting it with a dense layer.
134
135
136
137
138
139
140
141

.. code:: python

    class ScorePredictor(nn.Module):
        def __init__(self, num_classes, in_features):
            super().__init__()
            self.W = nn.Linear(2 * in_features, num_classes)
    
142
143
144
145
146
147
        def forward(self, node_pairs, x):
            src_x = x[node_pairs[0]]
            dst_x = x[node_pairs[1]]
            data = torch.cat([src_x, dst_x], 1)
            return self.W(data)

148

149
150
The entire model will take the list of MFGs and the edges generated by the data
loader, as well as the input node features as follows:
151
152
153
154
155
156
157
158
159

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, num_classes):
            super().__init__()
            self.gcn = StochasticTwoLayerGCN(
                in_features, hidden_features, out_features)
            self.predictor = ScorePredictor(num_classes, out_features)
160
161

        def forward(self, blocks, x, node_pairs):
162
            x = self.gcn(blocks, x)
163
            return self.predictor(node_pairs, x)
164
165

DGL ensures that that the nodes in the edge subgraph are the same as the
166
output nodes of the last MFG in the generated list of MFGs.
167
168
169
170
171
172

Training Loop
~~~~~~~~~~~~~

The training loop is very similar to node classification. You can
iterate over the dataloader and get a subgraph induced by the edges in
173
the minibatch, as well as the list of MFGs necessary for computing
174
175
176
177
their incident node representations.

.. code:: python

178
    import torch.nn.functional as F
179
    model = Model(in_features, hidden_features, out_features, num_classes)
180
    model = model.to(device)
181
    opt = torch.optim.Adam(model.parameters())
182
183
184
185
186
187

    for data in dataloader:
        blocks = data.blocks
        x = data.edge_features("feat")
        y_hat = model(data.blocks, x, data.positive_node_pairs)
        loss = F.cross_entropy(data.labels, y_hat)
188
189
190
191
        opt.zero_grad()
        loss.backward()
        opt.step()

192

193
194
195
196
197
198
199
200
201
202
For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~

The models computing the node representations on heterogeneous graphs
can also be used for computing incident node representations for edge
classification/regression.

.. code:: python

    class StochasticTwoLayerRGCN(nn.Module):
203
        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
            super().__init__()
            self.conv1 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                    for rel in rel_names
                })
            self.conv2 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                    for rel in rel_names
                })
    
        def forward(self, blocks, x):
            x = self.conv1(blocks[0], x)
            x = self.conv2(blocks[1], x)
            return x

For score prediction, the only implementation difference between the
homogeneous graph and the heterogeneous graph is that we are looping
221
over the edge types.
222
223
224
225
226
227
228
229

.. code:: python

    class ScorePredictor(nn.Module):
        def __init__(self, num_classes, in_features):
            super().__init__()
            self.W = nn.Linear(2 * in_features, num_classes)
    
230
231
232
233
234
235
236
        def forward(self, node_pairs, x):
            scores = {}
            for etype in node_pairs.keys():
                src, dst = node_pairs[etype]
                data = torch.cat([x[etype][src], x[etype][dst]], 1)
                scores[etype] = self.W(data)
            return scores
237

238
239
240
241
242
243
244
245
    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, num_classes,
                     etypes):
            super().__init__()
            self.rgcn = StochasticTwoLayerRGCN(
                in_features, hidden_features, out_features, etypes)
            self.pred = ScorePredictor(num_classes, out_features)

246
        def forward(self, node_pairs, blocks, x):
247
            x = self.rgcn(blocks, x)
248
            return self.pred(node_pairs, x)
249

250
251
252
Data loader definition is almost identical to that of homogeneous graph. The
only difference is that the train_set is now an instance of
:class:`~dgl.graphbolt.ItemSetDict` instead of :class:`~dgl.graphbolt.ItemSet`.
253
254
255

.. code:: python

256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
    import dgl.graphbolt as gb

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    g = gb.SamplingGraph()
    node_pairs = torch.arange(0, 1000).reshape(-1, 2)
    labels = torch.randint(0, 3, (1000,))
    node_pairs_labels = {
        "user:like:item": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
        "user:follow:user": gb.ItemSet(
            (node_pairs, labels), names=("node_pairs", "labels")
        ),
    }
    train_set = gb.ItemSetDict(node_pairs_labels)
    datapipe = gb.ItemSampler(train_set, batch_size=128, shuffle=True)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.fetch_feature(
        feature, node_feature_keys={"item": ["feat"], "user": ["feat"]}
    )
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)
279
280
281
282
283

Things become a little different if you wish to exclude the reverse
edges on heterogeneous graphs. On heterogeneous graphs, reverse edges
usually have a different edge type from the edges themselves, in order
to differentiate the forward and backward relationships (e.g.
284
285
``follow`` and ``followed_by`` are reverse relations of each other,
``like`` and ``liked_by`` are reverse relations of each other,
286
287
288
289
290
291
292
293
294
etc.).

If each edge in a type has a reverse edge with the same ID in another
type, you can specify the mapping between edge types and their reverse
types. The way to exclude the edges in the minibatch as well as their
reverse edges then goes as follows.

.. code:: python

295
296
297
298
299
300
301
302
303
304
305

    exclude_seed_edges = partial(
        gb.exclude_seed_edges,
        include_reverse_edges=True,
        reverse_etypes_mapping={
            "user:like:item": "item:liked_by:user",
            "user:follow:user": "user:followed_by:user",
        },
    )
    datapipe = datapipe.transform(exclude_seed_edges)

306
307
308
309
310
311
312

The training loop is again almost the same as that on homogeneous graph,
except for the implementation of ``compute_loss`` that will take in two
dictionaries of node types and predictions here.

.. code:: python

313
    model = Model(in_features, hidden_features, out_features, num_classes, etypes)
314
315
316
317
318
319
320
321
322
323
324
325
326
327
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())
    
    for input_nodes, edge_subgraph, blocks in dataloader:
        blocks = [b.to(torch.device('cuda')) for b in blocks]
        edge_subgraph = edge_subgraph.to(torch.device('cuda'))
        input_features = blocks[0].srcdata['features']
        edge_labels = edge_subgraph.edata['labels']
        edge_predictions = model(edge_subgraph, blocks, input_features)
        loss = compute_loss(edge_labels, edge_predictions)
        opt.zero_grad()
        loss.backward()
        opt.step()