minibatch-link.rst 11.1 KB
Newer Older
1
2
3
4
5
.. _guide-minibatch-link-classification-sampler:

6.3 Training GNN for Link Prediction with Neighborhood Sampling
--------------------------------------------------------------------

6
7
:ref:`(中文版) <guide_cn-minibatch-link-classification-sampler>`

8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
Define a neighborhood sampler and data loader with negative sampling
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

You can still use the same neighborhood sampler as the one in node/edge
classification.

.. code:: python

    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)

:class:`~dgl.dataloading.pytorch.EdgeDataLoader` in DGL also
supports generating negative samples for link prediction. To do so, you
need to provide the negative sampling function.
:class:`~dgl.dataloading.negative_sampler.Uniform` is a
function that does uniform sampling. For each source node of an edge, it
samples ``k`` negative destination nodes.

The following data loader will pick 5 negative destination nodes
uniformly for each source node of an edge.

.. code:: python

    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
        num_workers=args.num_workers)

For the builtin negative samplers please see :ref:`api-dataloading-negative-sampling`.

You can also give your own negative sampler function, as long as it
takes in the original graph ``g`` and the minibatch edge ID array
``eid``, and returns a pair of source ID arrays and destination ID
arrays.

The following gives an example of custom negative sampler that samples
negative destination nodes according to a probability distribution
proportional to a power of degrees.

.. code:: python

    class NegativeSampler(object):
        def __init__(self, g, k):
            # caches the probability distribution
            self.weights = g.in_degrees().float() ** 0.75
            self.k = k
    
        def __call__(self, g, eids):
            src, _ = g.find_edges(eids)
            src = src.repeat_interleave(self.k)
            dst = self.weights.multinomial(len(src), replacement=True)
            return src, dst
    
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_seeds, sampler,
        negative_sampler=NegativeSampler(g, 5),
        batch_size=args.batch_size,
        shuffle=True,
        drop_last=False,
        pin_memory=True,
        num_workers=args.num_workers)

Adapt your model for minibatch training
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~

As explained in :ref:`guide-training-link-prediction`, link prediction is trained
via comparing the score of an edge (positive example) against a
non-existent edge (negative example). To compute the scores of edges you
can reuse the node representation computation model you have seen in
edge classification/regression.

.. code:: python

    class StochasticTwoLayerGCN(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.conv1 = dgl.nn.GraphConv(in_features, hidden_features)
            self.conv2 = dgl.nn.GraphConv(hidden_features, out_features)
    
        def forward(self, blocks, x):
            x = F.relu(self.conv1(blocks[0], x))
            x = F.relu(self.conv2(blocks[1], x))
            return x

For score prediction, since you only need to predict a scalar score for
each edge instead of a probability distribution, this example shows how
to compute a score with a dot product of incident node representations.

.. code:: python

    class ScorePredictor(nn.Module):
        def forward(self, edge_subgraph, x):
            with edge_subgraph.local_scope():
                edge_subgraph.ndata['x'] = x
                edge_subgraph.apply_edges(dgl.function.u_dot_v('x', 'x', 'score'))
                return edge_subgraph.edata['score']

When a negative sampler is provided, DGLs data loader will generate
three items per minibatch:

-  A positive graph containing all the edges sampled in the minibatch.
112

113
114
-  A negative graph containing all the non-existent edges generated by
   the negative sampler.
115
116

-  A list of *message flow graphs* (MFGs) generated by the neighborhood sampler.
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143

So one can define the link prediction model as follows that takes in the
three items as well as the input features.

.. code:: python

    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features):
            super().__init__()
            self.gcn = StochasticTwoLayerGCN(
                in_features, hidden_features, out_features)
    
        def forward(self, positive_graph, negative_graph, blocks, x):
            x = self.gcn(blocks, x)
            pos_score = self.predictor(positive_graph, x)
            neg_score = self.predictor(negative_graph, x)
            return pos_score, neg_score

Training loop
~~~~~~~~~~~~~

The training loop simply involves iterating over the data loader and
feeding in the graphs as well as the input features to the model defined
above.

.. code:: python

144
145
146
147
148
    def compute_loss(pos_score, neg_score):
        # an example hinge loss
        n = pos_score.shape[0]
        return (neg_score.view(n, -1) - pos_score.view(n, -1) + 1).clamp(min=0).mean()

149
150
151
152
153
154
155
156
157
    model = Model(in_features, hidden_features, out_features)
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())
    
    for input_nodes, positive_graph, negative_graph, blocks in dataloader:
        blocks = [b.to(torch.device('cuda')) for b in blocks]
        positive_graph = positive_graph.to(torch.device('cuda'))
        negative_graph = negative_graph.to(torch.device('cuda'))
        input_features = blocks[0].srcdata['features']
158
        pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)
159
160
161
162
163
164
165
166
167
168
169
        loss = compute_loss(pos_score, neg_score)
        opt.zero_grad()
        loss.backward()
        opt.step()

DGL provides the
`unsupervised learning GraphSAGE <https://github.com/dmlc/dgl/blob/master/examples/pytorch/graphsage/train_sampling_unsupervised.py>`__
that shows an example of link prediction on homogeneous graphs.

For heterogeneous graphs
~~~~~~~~~~~~~~~~~~~~~~~~
170
    
171
172
173
174
175
176
177
The models computing the node representations on heterogeneous graphs
can also be used for computing incident node representations for edge
classification/regression.

.. code:: python

    class StochasticTwoLayerRGCN(nn.Module):
178
        def __init__(self, in_feat, hidden_feat, out_feat, rel_names):
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
            super().__init__()
            self.conv1 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(in_feat, hidden_feat, norm='right')
                    for rel in rel_names
                })
            self.conv2 = dglnn.HeteroGraphConv({
                    rel : dglnn.GraphConv(hidden_feat, out_feat, norm='right')
                    for rel in rel_names
                })
    
        def forward(self, blocks, x):
            x = self.conv1(blocks[0], x)
            x = self.conv2(blocks[1], x)
            return x

For score prediction, the only implementation difference between the
homogeneous graph and the heterogeneous graph is that we are looping
over the edge types for :meth:`dgl.DGLHeteroGraph.apply_edges`.

.. code:: python

    class ScorePredictor(nn.Module):
        def forward(self, edge_subgraph, x):
            with edge_subgraph.local_scope():
                edge_subgraph.ndata['x'] = x
                for etype in edge_subgraph.canonical_etypes:
                    edge_subgraph.apply_edges(
                        dgl.function.u_dot_v('x', 'x', 'score'), etype=etype)
                return edge_subgraph.edata['score']

209
210
211
212
213
214
215
216
217
218
219
220
221
222
    class Model(nn.Module):
        def __init__(self, in_features, hidden_features, out_features, num_classes,
                     etypes):
            super().__init__()
            self.rgcn = StochasticTwoLayerRGCN(
                in_features, hidden_features, out_features, etypes)
            self.pred = ScorePredictor()

        def forward(self, positive_graph, negative_graph, blocks, x):
            x = self.rgcn(blocks, x)
            pos_score = self.pred(positive_graph, x)
            neg_score = self.pred(negative_graph, x)
            return pos_score, neg_score

223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
Data loader definition is also very similar to that of edge
classification/regression. The only difference is that you need to give
the negative sampler and you will be supplying a dictionary of edge
types and edge ID tensors instead of a dictionary of node types and node
ID tensors.

.. code:: python

    sampler = dgl.dataloading.MultiLayerFullNeighborSampler(2)
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_eid_dict, sampler,
        negative_sampler=dgl.dataloading.negative_sampler.Uniform(5),
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=4)

If you want to give your own negative sampling function, the function
should take in the original graph and the dictionary of edge types and
edge ID tensors. It should return a dictionary of edge types and
source-destination array pairs. An example is given as follows:

.. code:: python

247
248
249
250
251
   class NegativeSampler(object):
       def __init__(self, g, k):
           # caches the probability distribution
           self.weights = {
               etype: g.in_degrees(etype=etype).float() ** 0.75
252
               for etype in g.canonical_etypes}
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
           self.k = k

       def __call__(self, g, eids_dict):
           result_dict = {}
           for etype, eids in eids_dict.items():
               src, _ = g.find_edges(eids, etype=etype)
               src = src.repeat_interleave(self.k)
               dst = self.weights[etype].multinomial(len(src), replacement=True)
               result_dict[etype] = (src, dst)
           return result_dict

Then you can give the dataloader a dictionary of edge types and edge IDs as well as the negative
sampler.  For instance, the following iterates over all edges of the heterogeneous graph.

.. code:: python
268

269
    train_eid_dict = {
270
271
        etype: g.edges(etype=etype, form='eid')
        for etype in g.canonical_etypes}
272

273
274
    dataloader = dgl.dataloading.EdgeDataLoader(
        g, train_eid_dict, sampler,
275
        negative_sampler=NegativeSampler(g, 5),
276
277
278
279
280
281
282
283
284
285
286
        batch_size=1024,
        shuffle=True,
        drop_last=False,
        num_workers=4)

The training loop is again almost the same as that on homogeneous graph,
except for the implementation of ``compute_loss`` that will take in two
dictionaries of node types and predictions here.

.. code:: python

287
    model = Model(in_features, hidden_features, out_features, num_classes, etypes)
288
289
290
291
292
293
294
295
    model = model.cuda()
    opt = torch.optim.Adam(model.parameters())
    
    for input_nodes, positive_graph, negative_graph, blocks in dataloader:
        blocks = [b.to(torch.device('cuda')) for b in blocks]
        positive_graph = positive_graph.to(torch.device('cuda'))
        negative_graph = negative_graph.to(torch.device('cuda'))
        input_features = blocks[0].srcdata['features']
296
297
        pos_score, neg_score = model(positive_graph, negative_graph, blocks, input_features)
        loss = compute_loss(pos_score, neg_score)
298
299
300
301
302
303
        opt.zero_grad()
        loss.backward()
        opt.step()