model.py 16.7 KB
Newer Older
1
2
3
4
5
6
"""RGCN layer implementation"""
from collections import defaultdict

import torch as th
import torch.nn as nn
import torch.nn.functional as F
7
8
import tqdm

9
import dgl
10
import dgl.function as fn
11
import dgl.nn as dglnn
12

13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37

class RelGraphConvLayer(nn.Module):
    r"""Relational graph convolution layer.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    rel_names : list[str]
        Relation names.
    num_bases : int, optional
        Number of bases. If is none, use number of relations. Default: None.
    weight : bool, optional
        True if a linear layer is applied after message passing. Default: True
    bias : bool, optional
        True if bias is added. Default: True
    activation : callable, optional
        Activation function. Default: None
    self_loop : bool, optional
        True to include self loop message. Default: False
    dropout : float, optional
        Dropout rate. Default: 0.0
    """
38
39
40
41
42
43
44
45
46
47
48
49
50
51

    def __init__(
        self,
        in_feat,
        out_feat,
        rel_names,
        num_bases,
        *,
        weight=True,
        bias=True,
        activation=None,
        self_loop=False,
        dropout=0.0
    ):
52
53
54
55
56
57
58
59
60
        super(RelGraphConvLayer, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.rel_names = rel_names
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop

61
62
63
64
65
        self.conv = dglnn.HeteroGraphConv(
            {
                rel: dglnn.GraphConv(
                    in_feat, out_feat, norm="right", weight=False, bias=False
                )
66
                for rel in rel_names
67
68
            }
        )
69
70
71
72
73

        self.use_weight = weight
        self.use_basis = num_bases < len(self.rel_names) and weight
        if self.use_weight:
            if self.use_basis:
74
75
76
                self.basis = dglnn.WeightBasis(
                    (in_feat, out_feat), num_bases, len(self.rel_names)
                )
77
            else:
78
79
80
81
82
83
                self.weight = nn.Parameter(
                    th.Tensor(len(self.rel_names), in_feat, out_feat)
                )
                nn.init.xavier_uniform_(
                    self.weight, gain=nn.init.calculate_gain("relu")
                )
84
85
86
87
88
89
90
91
92

        # bias
        if bias:
            self.h_bias = nn.Parameter(th.Tensor(out_feat))
            nn.init.zeros_(self.h_bias)

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
93
94
95
            nn.init.xavier_uniform_(
                self.loop_weight, gain=nn.init.calculate_gain("relu")
            )
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116

        self.dropout = nn.Dropout(dropout)

    def forward(self, g, inputs):
        """Forward computation

        Parameters
        ----------
        g : DGLHeteroGraph
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
        if self.use_weight:
            weight = self.basis() if self.use_basis else self.weight
117
118
119
120
            wdict = {
                self.rel_names[i]: {"weight": w.squeeze(0)}
                for i, w in enumerate(th.split(weight, 1, dim=0))
            }
121
122
        else:
            wdict = {}
123
124
125

        if g.is_block:
            inputs_src = inputs
126
127
128
            inputs_dst = {
                k: v[: g.number_of_dst_nodes(k)] for k, v in inputs.items()
            }
129
        else:
130
            inputs_src = inputs_dst = inputs
131

132
        hs = self.conv(g, inputs, mod_kwargs=wdict)
133

134
135
        def _apply(ntype, h):
            if self.self_loop:
136
                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
137
138
139
140
141
            if self.bias:
                h = h + self.h_bias
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)
142
143
144

        return {ntype: _apply(ntype, h) for ntype, h in hs.items()}

145

146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
class RelGraphConvLayerHeteroAPI(nn.Module):
    r"""Relational graph convolution layer.

    Parameters
    ----------
    in_feat : int
        Input feature size.
    out_feat : int
        Output feature size.
    rel_names : list[str]
        Relation names.
    num_bases : int, optional
        Number of bases. If is none, use number of relations. Default: None.
    weight : bool, optional
        True if a linear layer is applied after message passing. Default: True
    bias : bool, optional
        True if bias is added. Default: True
    activation : callable, optional
        Activation function. Default: None
    self_loop : bool, optional
        True to include self loop message. Default: False
    dropout : float, optional
        Dropout rate. Default: 0.0
    """
170
171
172
173
174
175
176
177
178
179
180
181
182
183

    def __init__(
        self,
        in_feat,
        out_feat,
        rel_names,
        num_bases,
        *,
        weight=True,
        bias=True,
        activation=None,
        self_loop=False,
        dropout=0.0
    ):
184
185
186
187
188
189
190
191
192
193
194
195
196
        super(RelGraphConvLayerHeteroAPI, self).__init__()
        self.in_feat = in_feat
        self.out_feat = out_feat
        self.rel_names = rel_names
        self.num_bases = num_bases
        self.bias = bias
        self.activation = activation
        self.self_loop = self_loop

        self.use_weight = weight
        self.use_basis = num_bases < len(self.rel_names) and weight
        if self.use_weight:
            if self.use_basis:
197
198
199
                self.basis = dglnn.WeightBasis(
                    (in_feat, out_feat), num_bases, len(self.rel_names)
                )
200
            else:
201
202
203
204
205
206
                self.weight = nn.Parameter(
                    th.Tensor(len(self.rel_names), in_feat, out_feat)
                )
                nn.init.xavier_uniform_(
                    self.weight, gain=nn.init.calculate_gain("relu")
                )
207
208
209
210
211
212
213
214
215

        # bias
        if bias:
            self.h_bias = nn.Parameter(th.Tensor(out_feat))
            nn.init.zeros_(self.h_bias)

        # weight for self loop
        if self.self_loop:
            self.loop_weight = nn.Parameter(th.Tensor(in_feat, out_feat))
216
217
218
            nn.init.xavier_uniform_(
                self.loop_weight, gain=nn.init.calculate_gain("relu")
            )
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239

        self.dropout = nn.Dropout(dropout)

    def forward(self, g, inputs):
        """Forward computation

        Parameters
        ----------
        g : DGLHeteroGraph
            Input graph.
        inputs : dict[str, torch.Tensor]
            Node feature for each node type.

        Returns
        -------
        dict[str, torch.Tensor]
            New node features for each node type.
        """
        g = g.local_var()
        if self.use_weight:
            weight = self.basis() if self.use_basis else self.weight
240
241
242
243
            wdict = {
                self.rel_names[i]: {"weight": w.squeeze(0)}
                for i, w in enumerate(th.split(weight, 1, dim=0))
            }
244
245
246
247
248
        else:
            wdict = {}

        inputs_src = inputs_dst = inputs

249
250
        for srctype, _, _ in g.canonical_etypes:
            g.nodes[srctype].data["h"] = inputs[srctype]
251
252

        if self.use_weight:
253
254
            g.apply_edges(fn.copy_u("h", "m"))
            m = g.edata["m"]
255
256
            for rel in g.canonical_etypes:
                _, etype, _ = rel
257
258
259
                g.edges[rel].data["h*w_r"] = th.matmul(
                    m[rel], wdict[etype]["weight"]
                )
260
        else:
261
            g.apply_edges(fn.copy_u("h", "h*w_r"))
262

263
        g.update_all(fn.copy_e("h*w_r", "m"), fn.sum("m", "h"))
264
265

        def _apply(ntype):
266
            h = g.nodes[ntype].data["h"]
267
268
269
270
271
272
273
            if self.self_loop:
                h = h + th.matmul(inputs_dst[ntype], self.loop_weight)
            if self.bias:
                h = h + self.h_bias
            if self.activation:
                h = self.activation(h)
            return self.dropout(h)
274
275
276

        return {ntype: _apply(ntype) for ntype in g.dsttypes}

277

278
279
class RelGraphEmbed(nn.Module):
    r"""Embedding layer for featureless heterograph."""
280
281
282
283

    def __init__(
        self, g, embed_size, embed_name="embed", activation=None, dropout=0.0
    ):
284
285
286
287
288
289
290
291
292
293
        super(RelGraphEmbed, self).__init__()
        self.g = g
        self.embed_size = embed_size
        self.embed_name = embed_name
        self.activation = activation
        self.dropout = nn.Dropout(dropout)

        # create weight embeddings for each node for each relation
        self.embeds = nn.ParameterDict()
        for ntype in g.ntypes:
294
295
296
297
            embed = nn.Parameter(
                th.Tensor(g.number_of_nodes(ntype), self.embed_size)
            )
            nn.init.xavier_uniform_(embed, gain=nn.init.calculate_gain("relu"))
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
            self.embeds[ntype] = embed

    def forward(self, block=None):
        """Forward computation

        Parameters
        ----------
        block : DGLHeteroGraph, optional
            If not specified, directly return the full graph with embeddings stored in
            :attr:`embed_name`. Otherwise, extract and store the embeddings to the block
            graph and return.

        Returns
        -------
        DGLHeteroGraph
            The block graph fed with embeddings.
        """
        return self.embeds

317

318
class EntityClassify(nn.Module):
319
320
321
322
323
324
325
326
327
328
    def __init__(
        self,
        g,
        h_dim,
        out_dim,
        num_bases,
        num_hidden_layers=1,
        dropout=0,
        use_self_loop=False,
    ):
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
        super(EntityClassify, self).__init__()
        self.g = g
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
        if num_bases < 0 or num_bases > len(self.rel_names):
            self.num_bases = len(self.rel_names)
        else:
            self.num_bases = num_bases
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.use_self_loop = use_self_loop

        self.embed_layer = RelGraphEmbed(g, self.h_dim)
        self.layers = nn.ModuleList()
        # i2h
346
347
348
349
350
351
352
353
354
355
356
357
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.h_dim,
                self.rel_names,
                self.num_bases,
                activation=F.relu,
                self_loop=self.use_self_loop,
                dropout=self.dropout,
                weight=False,
            )
        )
358
359
        # h2h
        for i in range(self.num_hidden_layers):
360
361
362
363
364
365
366
367
368
369
370
            self.layers.append(
                RelGraphConvLayer(
                    self.h_dim,
                    self.h_dim,
                    self.rel_names,
                    self.num_bases,
                    activation=F.relu,
                    self_loop=self.use_self_loop,
                    dropout=self.dropout,
                )
            )
371
        # h2o
372
373
374
375
376
377
378
379
380
381
        self.layers.append(
            RelGraphConvLayer(
                self.h_dim,
                self.out_dim,
                self.rel_names,
                self.num_bases,
                activation=None,
                self_loop=self.use_self_loop,
            )
        )
382
383
384
385
386

    def forward(self, h=None, blocks=None):
        if h is None:
            # full graph training
            h = self.embed_layer()
387
388
389
390
391
392
393
        if blocks is None:
            # full graph training
            for layer in self.layers:
                h = layer(self.g, h)
        else:
            # minibatch training
            for layer, block in zip(self.layers, blocks):
394
                h = layer(block, h)
395
        return h
396
397
398
399
400
401
402
403
404
405
406

    def inference(self, g, batch_size, device, num_workers, x=None):
        """Minibatch inference of final representation over all node types.

        ***NOTE***
        For node classification, the model is trained to predict on only one node type's
        label.  Therefore, only that type's final representation is meaningful.
        """

        if x is None:
            x = self.embed_layer()
407
408
409
410
411

        for l, layer in enumerate(self.layers):
            y = {
                k: th.zeros(
                    g.number_of_nodes(k),
412
413
414
415
                    self.h_dim if l != len(self.layers) - 1 else self.out_dim,
                )
                for k in g.ntypes
            }
416
417

            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
418
            dataloader = dgl.dataloading.DataLoader(
419
420
421
422
423
424
                g,
                {k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
                sampler,
                batch_size=batch_size,
                shuffle=True,
                drop_last=False,
425
426
                num_workers=num_workers,
            )
427
428
429
430

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
                block = blocks[0].to(device)

431
432
433
434
                h = {
                    k: x[k][input_nodes[k]].to(device)
                    for k in input_nodes.keys()
                }
435
436
                h = layer(block, h)

437
                for k in output_nodes.keys():
438
439
440
441
442
                    y[k][output_nodes[k]] = h[k].cpu()

            x = y
        return y

443

444
class EntityClassify_HeteroAPI(nn.Module):
445
446
447
448
449
450
451
452
453
454
    def __init__(
        self,
        g,
        h_dim,
        out_dim,
        num_bases,
        num_hidden_layers=1,
        dropout=0,
        use_self_loop=False,
    ):
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
        super(EntityClassify_HeteroAPI, self).__init__()
        self.g = g
        self.h_dim = h_dim
        self.out_dim = out_dim
        self.rel_names = list(set(g.etypes))
        self.rel_names.sort()
        if num_bases < 0 or num_bases > len(self.rel_names):
            self.num_bases = len(self.rel_names)
        else:
            self.num_bases = num_bases
        self.num_hidden_layers = num_hidden_layers
        self.dropout = dropout
        self.use_self_loop = use_self_loop

        self.embed_layer = RelGraphEmbed(g, self.h_dim)
        self.layers = nn.ModuleList()
        # i2h
472
473
474
475
476
477
478
479
480
481
482
483
        self.layers.append(
            RelGraphConvLayerHeteroAPI(
                self.h_dim,
                self.h_dim,
                self.rel_names,
                self.num_bases,
                activation=F.relu,
                self_loop=self.use_self_loop,
                dropout=self.dropout,
                weight=False,
            )
        )
484
485
        # h2h
        for i in range(self.num_hidden_layers):
486
487
488
489
490
491
492
493
494
495
496
            self.layers.append(
                RelGraphConvLayerHeteroAPI(
                    self.h_dim,
                    self.h_dim,
                    self.rel_names,
                    self.num_bases,
                    activation=F.relu,
                    self_loop=self.use_self_loop,
                    dropout=self.dropout,
                )
            )
497
        # h2o
498
499
500
501
502
503
504
505
506
507
        self.layers.append(
            RelGraphConvLayerHeteroAPI(
                self.h_dim,
                self.out_dim,
                self.rel_names,
                self.num_bases,
                activation=None,
                self_loop=self.use_self_loop,
            )
        )
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532

    def forward(self, h=None, blocks=None):
        if h is None:
            # full graph training
            h = self.embed_layer()
        if blocks is None:
            # full graph training
            for layer in self.layers:
                h = layer(self.g, h)
        else:
            # minibatch training
            for layer, block in zip(self.layers, blocks):
                h = layer(block, h)
        return h

    def inference(self, g, batch_size, device, num_workers, x=None):
        """Minibatch inference of final representation over all node types.

        ***NOTE***
        For node classification, the model is trained to predict on only one node type's
        label.  Therefore, only that type's final representation is meaningful.
        """

        if x is None:
            x = self.embed_layer()
533
534
535
536
537

        for l, layer in enumerate(self.layers):
            y = {
                k: th.zeros(
                    g.number_of_nodes(k),
538
539
540
541
                    self.h_dim if l != len(self.layers) - 1 else self.out_dim,
                )
                for k in g.ntypes
            }
542

543
            sampler = dgl.dataloading.MultiLayerFullNeighborSampler(1)
544
            dataloader = dgl.dataloading.DataLoader(
545
546
547
548
549
550
                g,
                {k: th.arange(g.number_of_nodes(k)) for k in g.ntypes},
                sampler,
                batch_size=batch_size,
                shuffle=True,
                drop_last=False,
551
552
                num_workers=num_workers,
            )
553
554

            for input_nodes, output_nodes, blocks in tqdm.tqdm(dataloader):
555
                block = blocks[0].to(device)
556

557
558
559
560
                h = {
                    k: x[k][input_nodes[k]].to(device)
                    for k in input_nodes.keys()
                }
561
                h = layer(block, h)
562
563
564
565
566
567

                for k in h.keys():
                    y[k][output_nodes[k]] = h[k].cpu()

            x = y
        return y