minibatch-link.rst 10.8 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-link-classification-sampler:

6.3 Training GNN for Link Prediction with Neighborhood Sampling
--------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-link-classification-sampler>`

8
Define a data loader with neighbor and negative sampling
9
10
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

11
12
13
14
15
You can still use the same data loader as the one in node/edge classification.
The only difference is that you need to add an additional stage
`negative sampling` before neighbor sampling stage. The following data loader
will pick 5 negative destination nodes uniformly for each source node of an
edge.
16
17
18

.. code:: python

19
    datapipe = datapipe.sample_uniform_negative(graph, 5)
20

21
The whole data loader pipeline is as follows:
22
23
24

.. code:: python

25
26
27
28
29
30
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
31
    dataloader = gb.DataLoader(datapipe)
32

33

34
35
For the details about the builtin uniform negative sampler please see
:class:`~dgl.graphbolt.UniformNegativeSampler`.
36

37
38
39
40
You can also give your own negative sampler function, as long as it inherits
from :class:`~dgl.graphbolt.NegativeSampler` and overrides the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method which takes in
the node pairs in minibatch, and returns the negative node pairs back.
41
42
43
44
45
46
47

The following gives an example of custom negative sampler that samples
negative destination nodes according to a probability distribution
proportional to a power of degrees.

.. code:: python

48
49
50
51
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
52
            # caches the probability distribution
53
            self.weights = node_degrees ** 0.75
54
55
            self.k = k
    
56
57
        def _sample_with_etype(node_pairs, etype=None):
            src, _ = node_pairs
58
59
60
61
            src = src.repeat_interleave(self.k)
            dst = self.weights.multinomial(len(src), replacement=True)
            return src, dst

62
    datapipe = datapipe.customized_sample_negative(5, node_degrees)
63
64


65
66
Define a GraphSAGE model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67
68
69

.. code:: python

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
            super().__init__()
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )
85

86
87
88
89
90
91
92
93
        def forward(self, blocks, x):
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x
94

95

96
97
98
When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
Let's define a utility function to compact node pairs as follows:
99
100
101

.. code:: python

102
103
104
105
106
107
108
109
110
111
112
113
114
    def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch):
        """Convert the minibatch to a training pair and a label tensor."""
        pos_src, pos_dst = data.positive_node_pairs
        neg_src, neg_dst = data.negative_node_pairs
        node_pairs = (
            torch.cat((pos_src, neg_src), dim=0),
            torch.cat((pos_dst, neg_dst), dim=0),
        )
        pos_label = torch.ones_like(pos_src)
        neg_label = torch.zeros_like(neg_src)
        labels = torch.cat([pos_label, neg_label], dim=0)
        return (node_pairs, labels.float())

115
116
117
118
119
120
121
122
123
124

Training loop
~~~~~~~~~~~~~

The training loop simply involves iterating over the data loader and
feeding in the graphs as well as the input features to the model defined
above.

.. code:: python

125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
            # Unpack MiniBatch.
            compacted_pairs, labels = to_binary_link_dgl_computing_pack(data)
            node_feature = data.node_features["feat"]
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks

            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
                y[compacted_pairs[0]] * y[compacted_pairs[1]]
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
152

153
154

DGL provides the
155
`unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/link_prediction.py>`__
156
157
158
159
that shows an example of link prediction on homogeneous graphs.

For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~
160
161
162
163

The previous model could be easily extended to heterogeneous graphs. The only
difference is that you need to use :class:`~dgl.nn.HeteroGraphConv` to wrap
:class:`~dgl.nn.SAGEConv` according to edge types.
164
165
166

.. code:: python

167
168
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
169
            super().__init__()
170
171
172
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
173
                    for rel in rel_names
174
175
176
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
177
                    for rel in rel_names
178
179
180
181
182
183
184
185
186
187
188
189
190
191
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                    for rel in rel_names
                }))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )

192
        def forward(self, blocks, x):
193
194
195
196
197
198
199
200
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x

201

202
203
Data loader definition is also very similar to that for homogeneous graph. The
only difference is that you need to give edge types for feature fetching.
204
205
206

.. code:: python

207
208
209
210
211
212
213
214
215
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature,
        node_feature_keys={"user": ["feat"], "item": ["feat"]}
    )
    datapipe = datapipe.copy_to(device)
216
    dataloader = gb.DataLoader(datapipe)
217
218
219
220

If you want to give your own negative sampling function, just inherit from the
:class:`~dgl.graphbolt.NegativeSampler` class and override the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method.
221
222
223

.. code:: python

224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
            # caches the probability distribution
            self.weights = {
                etype: node_degrees[etype] ** 0.75 for etype in node_degrees
            }
            self.k = k
    
        def _sample_with_etype(node_pairs, etype):
            src, _ = node_pairs
            src = src.repeat_interleave(self.k)
            dst = self.weights[etype].multinomial(len(src), replacement=True)
            return src, dst

    datapipe = datapipe.customized_sample_negative(5, node_degrees)
241
242


243
For heterogeneous graphs, node pairs are grouped by edge types.
244
245

.. code:: python
246

247
248
249
250
251
252
253
254
255
256
257
258
259
    def to_binary_link_dgl_computing_pack(data: gb.DGLMiniBatch, etype):
        """Convert the minibatch to a training pair and a label tensor."""
        pos_src, pos_dst = data.positive_node_pairs[etype]
        neg_src, neg_dst = data.negative_node_pairs[etype]
        node_pairs = (
            torch.cat((pos_src, neg_src), dim=0),
            torch.cat((pos_dst, neg_dst), dim=0),
        )
        pos_label = torch.ones_like(pos_src)
        neg_label = torch.zeros_like(neg_src)
        labels = torch.cat([pos_label, neg_label], dim=0)
        return (node_pairs, labels.float())

260
261

The training loop is again almost the same as that on homogeneous graph,
262
except for computing loss on specific edge type.
263
264
265

.. code:: python

266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    category = "user"
    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
            # Unpack MiniBatch.
            compacted_pairs, labels = to_binary_link_dgl_computing_pack(data, category)
            node_features = {
                ntype: data.node_features[(ntype, "feat")]
                for ntype in data.blocks[0].srctypes
            }
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks
            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
                y[category][compacted_pairs[0]] * y[category][compacted_pairs[1]]
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
296