minibatch-link.rst 10.9 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-link-classification-sampler:

6.3 Training GNN for Link Prediction with Neighborhood Sampling
--------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-link-classification-sampler>`

8
Define a data loader with neighbor and negative sampling
9
10
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

11
12
13
14
15
You can still use the same data loader as the one in node/edge classification.
The only difference is that you need to add an additional stage
`negative sampling` before neighbor sampling stage. The following data loader
will pick 5 negative destination nodes uniformly for each source node of an
edge.
16
17
18

.. code:: python

19
    datapipe = datapipe.sample_uniform_negative(graph, 5)
20

21
The whole data loader pipeline is as follows:
22
23
24

.. code:: python

25
26
27
28
29
30
31
32
33
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)

34

35
36
For the details about the builtin uniform negative sampler please see
:class:`~dgl.graphbolt.UniformNegativeSampler`.
37

38
39
40
41
You can also give your own negative sampler function, as long as it inherits
from :class:`~dgl.graphbolt.NegativeSampler` and overrides the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method which takes in
the node pairs in minibatch, and returns the negative node pairs back.
42
43
44
45
46
47
48

The following gives an example of custom negative sampler that samples
negative destination nodes according to a probability distribution
proportional to a power of degrees.

.. code:: python

49
50
51
52
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
53
            # caches the probability distribution
54
            self.weights = node_degrees ** 0.75
55
56
            self.k = k
    
57
58
        def _sample_with_etype(node_pairs, etype=None):
            src, _ = node_pairs
59
60
61
62
            src = src.repeat_interleave(self.k)
            dst = self.weights.multinomial(len(src), replacement=True)
            return src, dst

63
    datapipe = datapipe.customized_sample_negative(5, node_degrees)
64
65


66
67
Define a GraphSAGE model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
68
69
70

.. code:: python

71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
            super().__init__()
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )
86

87
88
89
90
91
92
93
94
        def forward(self, blocks, x):
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x
95

96

97
98
99
When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
Let's define a utility function to compact node pairs as follows:
100
101
102

.. code:: python

103
104
105
106
107
108
109
110
111
112
113
114
115
    def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
        """Convert the minibatch to a training pair and a label tensor."""
        pos_src, pos_dst = data.positive_node_pairs
        neg_src, neg_dst = data.negative_node_pairs
        node_pairs = (
            torch.cat((pos_src, neg_src), dim=0),
            torch.cat((pos_dst, neg_dst), dim=0),
        )
        pos_label = torch.ones_like(pos_src)
        neg_label = torch.zeros_like(neg_src)
        labels = torch.cat([pos_label, neg_label], dim=0)
        return (node_pairs, labels.float())

116
117
118
119
120
121
122
123
124
125

Training loop
~~~~~~~~~~~~~

The training loop simply involves iterating over the data loader and
feeding in the graphs as well as the input features to the model defined
above.

.. code:: python

126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
            # Unpack MiniBatch.
            compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
            node_feature = data.node_features["feat"]
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks

            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
                y[compacted_pairs[0]] * y[compacted_pairs[1]]
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
153

154
155

DGL provides the
156
`unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/link_prediction.py>`__
157
158
159
160
that shows an example of link prediction on homogeneous graphs.

For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~
161
162
163
164

The previous model could be easily extended to heterogeneous graphs. The only
difference is that you need to use :class:`~dgl.nn.HeteroGraphConv` to wrap
:class:`~dgl.nn.SAGEConv` according to edge types.
165
166
167

.. code:: python

168
169
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
170
            super().__init__()
171
172
173
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
174
                    for rel in rel_names
175
176
177
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
178
                    for rel in rel_names
179
180
181
182
183
184
185
186
187
188
189
190
191
192
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                    for rel in rel_names
                }))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )

193
        def forward(self, blocks, x):
194
195
196
197
198
199
200
201
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x

202

203
204
Data loader definition is also very similar to that for homogeneous graph. The
only difference is that you need to give edge types for feature fetching.
205
206
207

.. code:: python

208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature,
        node_feature_keys={"user": ["feat"], "item": ["feat"]}
    )
    datapipe = datapipe.to_dgl()
    datapipe = datapipe.copy_to(device)
    dataloader = gb.MultiProcessDataLoader(datapipe, num_workers=0)

If you want to give your own negative sampling function, just inherit from the
:class:`~dgl.graphbolt.NegativeSampler` class and override the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method.
223
224
225

.. code:: python

226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
            # caches the probability distribution
            self.weights = {
                etype: node_degrees[etype] ** 0.75 for etype in node_degrees
            }
            self.k = k
    
        def _sample_with_etype(node_pairs, etype):
            src, _ = node_pairs
            src = src.repeat_interleave(self.k)
            dst = self.weights[etype].multinomial(len(src), replacement=True)
            return src, dst

    datapipe = datapipe.customized_sample_negative(5, node_degrees)
243
244


245
For heterogeneous graphs, node pairs are grouped by edge types.
246
247

.. code:: python
248

249
250
251
252
253
254
255
256
257
258
259
260
261
    def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch, etype):
        """Convert the minibatch to a training pair and a label tensor."""
        pos_src, pos_dst = data.positive_node_pairs[etype]
        neg_src, neg_dst = data.negative_node_pairs[etype]
        node_pairs = (
            torch.cat((pos_src, neg_src), dim=0),
            torch.cat((pos_dst, neg_dst), dim=0),
        )
        pos_label = torch.ones_like(pos_src)
        neg_label = torch.zeros_like(neg_src)
        labels = torch.cat([pos_label, neg_label], dim=0)
        return (node_pairs, labels.float())

262
263

The training loop is again almost the same as that on homogeneous graph,
264
except for computing loss on specific edge type.
265
266
267

.. code:: python

268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    category = "user"
    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
            # Unpack MiniBatch.
            compacted_pairs, labels = to_binary_link_dgl_computing_pack(data, category)
            node_features = {
                ntype: data.node_features[(ntype, "feat")]
                for ntype in data.blocks[0].srctypes
            }
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks
            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
                y[category][compacted_pairs[0]] * y[category][compacted_pairs[1]]
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
298