minibatch-link.rst 11 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-link-classification-sampler:

6.3 Training GNN for Link Prediction with Neighborhood Sampling
--------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-link-classification-sampler>`

8
Define a data loader with neighbor and negative sampling
9
10
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

11
12
13
14
15
You can still use the same data loader as the one in node/edge classification.
The only difference is that you need to add an additional stage
`negative sampling` before neighbor sampling stage. The following data loader
will pick 5 negative destination nodes uniformly for each source node of an
edge.
16
17
18

.. code:: python

19
    datapipe = datapipe.sample_uniform_negative(graph, 5)
20

21
The whole data loader pipeline is as follows:
22
23
24

.. code:: python

25
26
27
28
29
30
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
31
    dataloader = gb.DataLoader(datapipe, num_workers=0)
32

33

34
35
For the details about the builtin uniform negative sampler please see
:class:`~dgl.graphbolt.UniformNegativeSampler`.
36

37
38
39
40
You can also give your own negative sampler function, as long as it inherits
from :class:`~dgl.graphbolt.NegativeSampler` and overrides the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method which takes in
the node pairs in minibatch, and returns the negative node pairs back.
41
42
43
44
45
46
47

The following gives an example of custom negative sampler that samples
negative destination nodes according to a probability distribution
proportional to a power of degrees.

.. code:: python

48
49
50
51
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
52
            # caches the probability distribution
53
            self.weights = node_degrees ** 0.75
54
55
            self.k = k
    
56
57
        def _sample_with_etype(node_pairs, etype=None):
            src, _ = node_pairs
58
59
60
61
            src = src.repeat_interleave(self.k)
            dst = self.weights.multinomial(len(src), replacement=True)
            return src, dst

62
    datapipe = datapipe.customized_sample_negative(5, node_degrees)
63
64


65
66
Define a GraphSAGE model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67
68
69

.. code:: python

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
            super().__init__()
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )
85

86
87
88
89
90
91
92
93
        def forward(self, blocks, x):
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x
94

95

96
97
98
When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
Let's define a utility function to compact node pairs as follows:
99
100
101

.. code:: python

102
103
104
105
106
107
108
109
110
111
112
113
114
    def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
        """Convert the minibatch to a training pair and a label tensor."""
        pos_src, pos_dst = data.positive_node_pairs
        neg_src, neg_dst = data.negative_node_pairs
        node_pairs = (
            torch.cat((pos_src, neg_src), dim=0),
            torch.cat((pos_dst, neg_dst), dim=0),
        )
        pos_label = torch.ones_like(pos_src)
        neg_label = torch.zeros_like(neg_src)
        labels = torch.cat([pos_label, neg_label], dim=0)
        return (node_pairs, labels.float())

115
116
117
118
119
120
121
122
123
124

Training loop
~~~~~~~~~~~~~

The training loop simply involves iterating over the data loader and
feeding in the graphs as well as the input features to the model defined
above.

.. code:: python

125
126
127
128
129
130
131
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
132
133
            # Convert MiniBatch to DGLMiniBatch.
            data = data.to_dgl()
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
            # Unpack MiniBatch.
            compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
            node_feature = data.node_features["feat"]
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks

            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
                y[compacted_pairs[0]] * y[compacted_pairs[1]]
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
154

155
156

DGL provides the
157
`unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/link_prediction.py>`__
158
159
160
161
that shows an example of link prediction on homogeneous graphs.

For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~
162
163
164
165

The previous model could be easily extended to heterogeneous graphs. The only
difference is that you need to use :class:`~dgl.nn.HeteroGraphConv` to wrap
:class:`~dgl.nn.SAGEConv` according to edge types.
166
167
168

.. code:: python

169
170
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
171
            super().__init__()
172
173
174
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
175
                    for rel in rel_names
176
177
178
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
179
                    for rel in rel_names
180
181
182
183
184
185
186
187
188
189
190
191
192
193
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                    for rel in rel_names
                }))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )

194
        def forward(self, blocks, x):
195
196
197
198
199
200
201
202
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x

203

204
205
Data loader definition is also very similar to that for homogeneous graph. The
only difference is that you need to give edge types for feature fetching.
206
207
208

.. code:: python

209
210
211
212
213
214
215
216
217
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature,
        node_feature_keys={"user": ["feat"], "item": ["feat"]}
    )
    datapipe = datapipe.copy_to(device)
218
    dataloader = gb.DataLoader(datapipe, num_workers=0)
219
220
221
222

If you want to give your own negative sampling function, just inherit from the
:class:`~dgl.graphbolt.NegativeSampler` class and override the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method.
223
224
225

.. code:: python

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
            # caches the probability distribution
            self.weights = {
                etype: node_degrees[etype] ** 0.75 for etype in node_degrees
            }
            self.k = k
    
        def _sample_with_etype(node_pairs, etype):
            src, _ = node_pairs
            src = src.repeat_interleave(self.k)
            dst = self.weights[etype].multinomial(len(src), replacement=True)
            return src, dst

    datapipe = datapipe.customized_sample_negative(5, node_degrees)
243
244


245
For heterogeneous graphs, node pairs are grouped by edge types.
246
247

.. code:: python
248

249
250
251
252
253
254
255
256
257
258
259
260
261
    def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch, etype):
        """Convert the minibatch to a training pair and a label tensor."""
        pos_src, pos_dst = data.positive_node_pairs[etype]
        neg_src, neg_dst = data.negative_node_pairs[etype]
        node_pairs = (
            torch.cat((pos_src, neg_src), dim=0),
            torch.cat((pos_dst, neg_dst), dim=0),
        )
        pos_label = torch.ones_like(pos_src)
        neg_label = torch.zeros_like(neg_src)
        labels = torch.cat([pos_label, neg_label], dim=0)
        return (node_pairs, labels.float())

262
263

The training loop is again almost the same as that on homogeneous graph,
264
except for computing loss on specific edge type.
265
266
267

.. code:: python

268
269
270
271
272
273
274
275
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    category = "user"
    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
276
277
            # Convert MiniBatch to DGLMiniBatch.
            data = data.to_dgl()
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
            # Unpack MiniBatch.
            compacted_pairs, labels = to_binary_link_dgl_computing_pack(data, category)
            node_features = {
                ntype: data.node_features[(ntype, "feat")]
                for ntype in data.blocks[0].srctypes
            }
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks
            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
                y[category][compacted_pairs[0]] * y[category][compacted_pairs[1]]
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
300