minibatch-link.rst 9.72 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-link-classification-sampler:

6.3 Training GNN for Link Prediction with Neighborhood Sampling
--------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-link-classification-sampler>`

8
Define a data loader with neighbor and negative sampling
9
10
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

11
12
13
14
15
You can still use the same data loader as the one in node/edge classification.
The only difference is that you need to add an additional stage
`negative sampling` before neighbor sampling stage. The following data loader
will pick 5 negative destination nodes uniformly for each source node of an
edge.
16
17
18

.. code:: python

19
    datapipe = datapipe.sample_uniform_negative(graph, 5)
20

21
The whole data loader pipeline is as follows:
22
23
24

.. code:: python

25
26
27
28
29
30
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(feature, node_feature_keys=["feat"])
    datapipe = datapipe.copy_to(device)
31
    dataloader = gb.DataLoader(datapipe)
32

33

34
35
For the details about the builtin uniform negative sampler please see
:class:`~dgl.graphbolt.UniformNegativeSampler`.
36

37
38
39
40
You can also give your own negative sampler function, as long as it inherits
from :class:`~dgl.graphbolt.NegativeSampler` and overrides the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method which takes in
the node pairs in minibatch, and returns the negative node pairs back.
41
42
43
44
45
46
47

The following gives an example of custom negative sampler that samples
negative destination nodes according to a probability distribution
proportional to a power of degrees.

.. code:: python

48
49
50
51
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
52
            # caches the probability distribution
53
            self.weights = node_degrees ** 0.75
54
55
            self.k = k
    
56
57
        def _sample_with_etype(self, seeds, etype=None):
            src, _ = seeds.T
58
59
60
61
            src = src.repeat_interleave(self.k)
            dst = self.weights.multinomial(len(src), replacement=True)
            return src, dst

62
    datapipe = datapipe.customized_sample_negative(5, node_degrees)
63
64


65
66
Define a GraphSAGE model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
67
68
69

.. code:: python

70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
            super().__init__()
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.SAGEConv(in_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.layers.append(dglnn.SAGEConv(hidden_size, hidden_size, "mean"))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )
85

86
87
88
89
90
91
92
93
        def forward(self, blocks, x):
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x
94

95

96
97
When a negative sampler is provided, the data loader will generate positive and
negative node pairs for each minibatch besides the *Message Flow Graphs* (MFGs).
98
Use `compacted_seeds` and `labels` to get compact node pairs and corresponding
99
labels.
100

101
102
103
104
105
106
107
108
109
110

Training loop
~~~~~~~~~~~~~

The training loop simply involves iterating over the data loader and
feeding in the graphs as well as the input features to the model defined
above.

.. code:: python

111
112
113
114
115
116
117
118
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
            # Unpack MiniBatch.
119
120
            compacted_seeds = data.compacted_seeds.T
            labels = data.labels
121
122
123
124
125
126
127
            node_feature = data.node_features["feat"]
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks

            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
128
                y[compacted_seeds[0]] * y[compacted_seeds[1]]
129
130
131
132
133
134
135
136
137
138
            ).squeeze()

            # Compute loss.
            loss = F.binary_cross_entropy_with_logits(logits, labels)
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
139

140
141

DGL provides the
142
`unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/sampling/graphbolt/link_prediction.py>`__
143
144
145
146
that shows an example of link prediction on homogeneous graphs.

For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~
147
148
149
150

The previous model could be easily extended to heterogeneous graphs. The only
difference is that you need to use :class:`~dgl.nn.HeteroGraphConv` to wrap
:class:`~dgl.nn.SAGEConv` according to edge types.
151
152
153

.. code:: python

154
155
    class SAGE(nn.Module):
        def __init__(self, in_size, hidden_size):
156
            super().__init__()
157
158
159
            self.layers = nn.ModuleList()
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(in_size, hidden_size, "mean")
160
                    for rel in rel_names
161
162
163
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
164
                    for rel in rel_names
165
166
167
168
169
170
171
172
173
174
175
176
177
178
                }))
            self.layers.append(dglnn.HeteroGraphConv({
                    rel : dglnn.SAGEConv(hidden_size, hidden_size, "mean")
                    for rel in rel_names
                }))
            self.hidden_size = hidden_size
            self.predictor = nn.Sequential(
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, hidden_size),
                nn.ReLU(),
                nn.Linear(hidden_size, 1),
            )

179
        def forward(self, blocks, x):
180
181
182
183
184
185
186
187
            hidden_x = x
            for layer_idx, (layer, block) in enumerate(zip(self.layers, blocks)):
                hidden_x = layer(block, hidden_x)
                is_last_layer = layer_idx == len(self.layers) - 1
                if not is_last_layer:
                    hidden_x = F.relu(hidden_x)
            return hidden_x

188

189
190
Data loader definition is also very similar to that for homogeneous graph. The
only difference is that you need to give edge types for feature fetching.
191
192
193

.. code:: python

194
195
196
197
198
199
200
201
202
    datapipe = gb.ItemSampler(itemset, batch_size=1024, shuffle=True)
    datapipe = datapipe.sample_uniform_negative(graph, 5)
    datapipe = datapipe.sample_neighbor(g, [10, 10]) # 2 layers.
    datapipe = datapipe.transform(gb.exclude_seed_edges)
    datapipe = datapipe.fetch_feature(
        feature,
        node_feature_keys={"user": ["feat"], "item": ["feat"]}
    )
    datapipe = datapipe.copy_to(device)
203
    dataloader = gb.DataLoader(datapipe)
204
205
206
207

If you want to give your own negative sampling function, just inherit from the
:class:`~dgl.graphbolt.NegativeSampler` class and override the
:meth:`~dgl.graphbolt.NegativeSampler._sample_with_etype` method.
208
209
210

.. code:: python

211
212
213
214
215
216
217
218
219
220
    @functional_datapipe("customized_sample_negative")
    class CustomizedNegativeSampler(dgl.graphbolt.NegativeSampler):
        def __init__(self, datapipe, k, node_degrees):
            super().__init__(datapipe, k)
            # caches the probability distribution
            self.weights = {
                etype: node_degrees[etype] ** 0.75 for etype in node_degrees
            }
            self.k = k
    
221
222
        def _sample_with_etype(self, seeds, etype):
            src, _ = seeds.T
223
224
225
226
227
            src = src.repeat_interleave(self.k)
            dst = self.weights[etype].multinomial(len(src), replacement=True)
            return src, dst

    datapipe = datapipe.customized_sample_negative(5, node_degrees)
228
229


230
231
232
For heterogeneous graphs, node pairs are grouped by edge types. The training
loop is again almost the same as that on homogeneous graph, except for computing
loss on specific edge type.
233
234
235

.. code:: python

236
237
238
239
240
241
242
243
244
    optimizer = torch.optim.Adam(model.parameters(), lr=0.01)

    category = "user"
    for epoch in tqdm.trange(args.epochs):
        model.train()
        total_loss = 0
        start_epoch_time = time.time()
        for step, data in enumerate(dataloader):
            # Unpack MiniBatch.
245
246
            compacted_seeds = data.compacted_seeds
            labels = data.labels
247
248
249
250
251
252
253
254
255
            node_features = {
                ntype: data.node_features[(ntype, "feat")]
                for ntype in data.blocks[0].srctypes
            }
            # Convert sampled subgraphs to DGL blocks.
            blocks = data.blocks
            # Get the embeddings of the input nodes.
            y = model(blocks, node_feature)
            logits = model.predictor(
256
257
                y[category][compacted_pairs[category][:, 0]]
                * y[category][compacted_pairs[category][:, 1]]
258
259
260
            ).squeeze()

            # Compute loss.
261
            loss = F.binary_cross_entropy_with_logits(logits, labels[category])
262
263
264
265
266
267
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

            total_loss += loss.item()
        end_epoch_time = time.time()
268