modules.py 18.6 KB
Newer Older
1
2
3
4
5
6
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F

import dgl
7
import dgl.function as fn
8
9
10
11
from dgl.base import DGLError
from dgl.ops import edge_softmax


12
13
14
15
class Identity(nn.Module):
    """A placeholder identity operator that is argument-insensitive.
    (Identity has already been supported by PyTorch 1.2, we will directly
    import torch.nn.Identity in the future)
16
17
    """

18
19
    def __init__(self):
        super(Identity, self).__init__()
20

21
22
23
    def forward(self, x):
        """Return input"""
        return x
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62


class MsgLinkPredictor(nn.Module):
    """Predict Pair wise link from pos subg and neg subg
    use message passing.

    Use Two layer MLP on edge to predict the link probability

    Parameters
    ----------
    embed_dim : int
        dimension of each each feature's embedding

    Example
    ----------
    >>> linkpred = MsgLinkPredictor(10)
    >>> pos_g = dgl.graph(([0,1,2,3,4],[1,2,3,4,0]))
    >>> neg_g = dgl.graph(([0,1,2,3,4],[2,1,4,3,0]))
    >>> x = torch.ones(5,10)
    >>> linkpred(x,pos_g,neg_g)
    (tensor([[0.0902],
         [0.0902],
         [0.0902],
         [0.0902],
         [0.0902]], grad_fn=<AddmmBackward>),
    tensor([[0.0902],
         [0.0902],
         [0.0902],
         [0.0902],
         [0.0902]], grad_fn=<AddmmBackward>))
    """

    def __init__(self, emb_dim):
        super(MsgLinkPredictor, self).__init__()
        self.src_fc = nn.Linear(emb_dim, emb_dim)
        self.dst_fc = nn.Linear(emb_dim, emb_dim)
        self.out_fc = nn.Linear(emb_dim, 1)

    def link_pred(self, edges):
63
64
65
        src_hid = self.src_fc(edges.src["embedding"])
        dst_hid = self.dst_fc(edges.dst["embedding"])
        score = F.relu(src_hid + dst_hid)
66
        score = self.out_fc(score)
67
        return {"score": score}
68
69
70

    def forward(self, x, pos_g, neg_g):
        # Local Scope?
71
72
        pos_g.ndata["embedding"] = x
        neg_g.ndata["embedding"] = x
73
74
75
76

        pos_g.apply_edges(self.link_pred)
        neg_g.apply_edges(self.link_pred)

77
78
        pos_escore = pos_g.edata["score"]
        neg_escore = neg_g.edata["score"]
79
80
81
82
83
84
85
86
        return pos_escore, neg_escore


class TimeEncode(nn.Module):
    """Use finite fourier series with different phase and frequency to encode
    time different between two event

    ..math::
87
        \Phi(t) = [\cos(\omega_0t+\psi_0),\cos(\omega_1t+\psi_1),...,\cos(\omega_nt+\psi_n)]
88
89
90
91

    Parameter
    ----------
    dimension : int
92
        Length of the fourier series. The longer it is ,
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        the more timescale information it can capture

    Example
    ----------
    >>> tecd = TimeEncode(10)
    >>> t = torch.tensor([[1]])
    >>> tecd(t)
    tensor([[[0.5403, 0.9950, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000, 1.0000,
          1.0000, 1.0000]]], dtype=torch.float64, grad_fn=<CosBackward>)
    """

    def __init__(self, dimension):
        super(TimeEncode, self).__init__()

        self.dimension = dimension
        self.w = torch.nn.Linear(1, dimension)
109
110
111
112
113
        self.w.weight = torch.nn.Parameter(
            (torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
            .double()
            .reshape(dimension, -1)
        )
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
        self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())

    def forward(self, t):
        t = t.unsqueeze(dim=2)
        output = torch.cos(self.w(t))
        return output


class MemoryModule(nn.Module):
    """Memory module as well as update interface

    The memory module stores both historical representation in last_update_t

    Parameters
    ----------
    n_node : int
        number of node of the entire graph

    hidden_dim : int
        dimension of memory of each node

    Example
    ----------
    Please refers to examples/pytorch/tgn/tgn.py;
138
                     examples/pytorch/tgn/train.py
139
140
141
142
143
144
145
146
147
148

    """

    def __init__(self, n_node, hidden_dim):
        super(MemoryModule, self).__init__()
        self.n_node = n_node
        self.hidden_dim = hidden_dim
        self.reset_memory()

    def reset_memory(self):
149
150
151
152
153
154
155
        self.last_update_t = nn.Parameter(
            torch.zeros(self.n_node).float(), requires_grad=False
        )
        self.memory = nn.Parameter(
            torch.zeros((self.n_node, self.hidden_dim)).float(),
            requires_grad=False,
        )
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200

    def backup_memory(self):
        """
        Return a deep copy of memory state and last_update_t
        For test new node, since new node need to use memory upto validation set
        After validation, memory need to be backed up before run test set without new node
        so finally, we can use backup memory to update the new node test set
        """
        return self.memory.clone(), self.last_update_t.clone()

    def restore_memory(self, memory_backup):
        """Restore the memory from validation set

        Parameters
        ----------
        memory_backup : (memory,last_update_t)
            restore memory based on input tuple
        """
        self.memory = memory_backup[0].clone()
        self.last_update_t = memory_backup[1].clone()

    # Which is used for attach to subgraph
    def get_memory(self, node_idxs):
        return self.memory[node_idxs, :]

    # When the memory need to be updated
    def set_memory(self, node_idxs, values):
        self.memory[node_idxs, :] = values

    def set_last_update_t(self, node_idxs, values):
        self.last_update_t[node_idxs] = values

    # For safety check
    def get_last_update(self, node_idxs):
        return self.last_update_t[node_idxs]

    def detach_memory(self):
        """
        Disconnect the memory from computation graph to prevent gradient be propagated multiple
        times
        """
        self.memory.detach_()


class MemoryOperation(nn.Module):
201
    """Memory update using message passing manner, update memory based on positive
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
    pair graph of each batch with recurrent module GRU or RNN

    Message function
    ..math::
        m_i(t) = concat(memory_i(t^-),TimeEncode(t),v_i(t))

    v_i is node feature at current time stamp

    Aggregation function
    ..math::
        \bar{m}_i(t) = last(m_i(t_1),...,m_i(t_b))

    Update function
    ..math::
        memory_i(t) = GRU(\bar{m}_i(t),memory_i(t-1))

    Parameters
    ----------

    updater_type : str
        indicator string to specify updater

        'rnn' : use Vanilla RNN as updater

        'gru' : use GRU as updater

    memory : MemoryModule
        memory content for update

    e_feat_dim : int
        dimension of edge feature

    temporal_dim : int
        length of fourier series for time encoding

    Example
    ----------
    Please refers to examples/pytorch/tgn/tgn.py
    """

242
    def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
243
        super(MemoryOperation, self).__init__()
244
        updater_dict = {"gru": nn.GRUCell, "rnn": nn.RNNCell}
245
246
        self.memory = memory
        memory_dim = self.memory.hidden_dim
247
        self.temporal_encoder = temporal_encoder
248
249
250
251
252
253
254
255
256
        self.message_dim = (
            memory_dim
            + memory_dim
            + e_feat_dim
            + self.temporal_encoder.dimension
        )
        self.updater = updater_dict[updater_type](
            input_size=self.message_dim, hidden_size=memory_dim
        )
257
258
259
260
261
        self.memory = memory

    # Here assume g is a subgraph from each iteration
    def stick_feat_to_graph(self, g):
        # How can I ensure order of the node ID
262
263
        g.ndata["timestamp"] = self.memory.last_update_t[g.ndata[dgl.NID]]
        g.ndata["memory"] = self.memory.memory[g.ndata[dgl.NID]]
264
265

    def msg_fn_cat(self, edges):
266
267
268
269
270
271
272
273
274
275
276
277
278
279
        src_delta_time = edges.data["timestamp"] - edges.src["timestamp"]
        time_encode = self.temporal_encoder(
            src_delta_time.unsqueeze(dim=1)
        ).view(len(edges.data["timestamp"]), -1)
        ret = torch.cat(
            [
                edges.src["memory"],
                edges.dst["memory"],
                edges.data["feats"],
                time_encode,
            ],
            dim=1,
        )
        return {"message": ret, "timestamp": edges.data["timestamp"]}
280
281

    def agg_last(self, nodes):
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        timestamp, latest_idx = torch.max(nodes.mailbox["timestamp"], dim=1)
        ret = (
            nodes.mailbox["message"]
            .gather(
                1,
                latest_idx.repeat(self.message_dim).view(
                    -1, 1, self.message_dim
                ),
            )
            .view(-1, self.message_dim)
        )
        return {
            "message_bar": ret.reshape(-1, self.message_dim),
            "timestamp": timestamp,
        }
297
298
299
300

    def update_memory(self, nodes):
        # It should pass the feature through RNN
        ret = self.updater(
301
302
303
            nodes.data["message_bar"].float(), nodes.data["memory"].float()
        )
        return {"memory": ret}
304
305
306
307
308
309
310

    def forward(self, g):
        self.stick_feat_to_graph(g)
        g.update_all(self.msg_fn_cat, self.agg_last, self.update_memory)
        return g


311
class EdgeGATConv(nn.Module):
312
    """Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
313
    edge feature.
314

315
316
317
318
    Parameter
    ==========
    node_feats : int
        number of node features
319

320
321
    edge_feats : int
        number of edge features
322

323
324
    out_feats : int
        number of output features
325

326
327
    num_heads : int
        number of heads in multihead attention
328

329
330
    feat_drop : float, optional
        drop out rate on the feature
331

332
333
    attn_drop : float, optional
        drop out rate on the attention weight
334

335
336
    negative_slope : float, optional
        LeakyReLU angle of negative slope.
337

338
339
    residual : bool, optional
        whether use residual connection
340

341
342
343
344
345
346
    allow_zero_in_degree : bool, optional
        If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
        since no message will be passed to those nodes. This is harmful for some applications
        causing silent performance regression. This module will raise a DGLError if it detects
        0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
        and let the users handle it by themselves. Defaults: ``False``.
347

348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
    """

    def __init__(
        self,
        node_feats,
        edge_feats,
        out_feats,
        num_heads,
        feat_drop=0.0,
        attn_drop=0.0,
        negative_slope=0.2,
        residual=False,
        activation=None,
        allow_zero_in_degree=False,
    ):
363
364
365
        super(EdgeGATConv, self).__init__()
        self._num_heads = num_heads
        self._node_feats = node_feats
366
367
368
        self._edge_feats = edge_feats
        self._out_feats = out_feats
        self._allow_zero_in_degree = allow_zero_in_degree
369
        self.fc_node = nn.Linear(
370
371
            self._node_feats, self._out_feats * self._num_heads
        )
372
        self.fc_edge = nn.Linear(
373
374
375
376
377
378
379
380
381
382
383
            self._edge_feats, self._out_feats * self._num_heads
        )
        self.attn_l = nn.Parameter(
            torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
        )
        self.attn_r = nn.Parameter(
            torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
        )
        self.attn_e = nn.Parameter(
            torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
        )
384
385
386
387
388
389
390
        self.feat_drop = nn.Dropout(feat_drop)
        self.attn_drop = nn.Dropout(attn_drop)
        self.leaky_relu = nn.LeakyReLU(negative_slope)
        self.residual = residual
        if residual:
            if self._node_feats != self._out_feats:
                self.res_fc = nn.Linear(
391
392
393
394
                    self._node_feats,
                    self._out_feats * self._num_heads,
                    bias=False,
                )
395
396
397
398
399
400
            else:
                self.res_fc = Identity()
        self.reset_parameters()
        self.activation = activation

    def reset_parameters(self):
401
        gain = nn.init.calculate_gain("relu")
402
403
404
405
406
407
408
409
410
        nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
        nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
        nn.init.xavier_normal_(self.attn_l, gain=gain)
        nn.init.xavier_normal_(self.attn_r, gain=gain)
        nn.init.xavier_normal_(self.attn_e, gain=gain)
        if self.residual and isinstance(self.res_fc, nn.Linear):
            nn.init.xavier_normal_(self.res_fc.weight, gain=gain)

    def msg_fn(self, edges):
411
412
413
414
415
        ret = (
            edges.data["a"].view(-1, self._num_heads, 1)
            * edges.data["el_prime"]
        )
        return {"m": ret}
416
417
418
419
420

    def forward(self, graph, nfeat, efeat, get_attention=False):
        with graph.local_scope():
            if not self._allow_zero_in_degree:
                if (graph.in_degrees() == 0).any():
421
422
423
424
425
426
427
428
429
430
431
                    raise DGLError(
                        "There are 0-in-degree nodes in the graph, "
                        "output for those nodes will be invalid. "
                        "This is harmful for some applications, "
                        "causing silent performance regression. "
                        "Adding self-loop on the input graph by "
                        "calling `g = dgl.add_self_loop(g)` will resolve "
                        "the issue. Setting ``allow_zero_in_degree`` "
                        "to be `True` when constructing this module will "
                        "suppress the check and let the code run."
                    )
432
433
434
435

            nfeat = self.feat_drop(nfeat)
            efeat = self.feat_drop(efeat)

436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
            node_feat = self.fc_node(nfeat).view(
                -1, self._num_heads, self._out_feats
            )
            edge_feat = self.fc_edge(efeat).view(
                -1, self._num_heads, self._out_feats
            )

            el = (node_feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
            er = (node_feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
            ee = (edge_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
            graph.ndata["ft"] = node_feat
            graph.ndata["el"] = el
            graph.ndata["er"] = er
            graph.edata["ee"] = ee
            graph.apply_edges(fn.u_add_e("el", "ee", "el_prime"))
            graph.apply_edges(fn.e_add_v("el_prime", "er", "e"))
            e = self.leaky_relu(graph.edata["e"])
            graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
            graph.edata["efeat"] = edge_feat
            graph.update_all(self.msg_fn, fn.sum("m", "ft"))
            rst = graph.ndata["ft"]
457
458
            if self.residual:
                resval = self.res_fc(nfeat).view(
459
460
                    nfeat.shape[0], -1, self._out_feats
                )
461
462
463
464
465
466
                rst = rst + resval

            if self.activation:
                rst = self.activation(rst)

            if get_attention:
467
                return rst, graph.edata["a"]
468
469
470
471
472
            else:
                return rst


class TemporalEdgePreprocess(nn.Module):
473
    """Preprocess layer, which finish time encoding and concatenate
474
475
476
477
478
479
480
481
482
    the time encoding to edge feature.

    Parameter
    ==========
    edge_feats : int
        number of orginal edge feature

    temporal_encoder : torch.nn.Module
        time encoder model
483
    """
484

485
486
487
488
489
490
    def __init__(self, edge_feats, temporal_encoder):
        super(TemporalEdgePreprocess, self).__init__()
        self.edge_feats = edge_feats
        self.temporal_encoder = temporal_encoder

    def edge_fn(self, edges):
491
492
493
494
495
496
497
        t0 = torch.zeros_like(edges.dst["timestamp"])
        time_diff = edges.data["timestamp"] - edges.src["timestamp"]
        time_encode = self.temporal_encoder(time_diff.unsqueeze(dim=1)).view(
            t0.shape[0], -1
        )
        edge_feat = torch.cat([edges.data["feats"], time_encode], dim=1)
        return {"efeat": edge_feat}
498

499
500
    def forward(self, graph):
        graph.apply_edges(self.edge_fn)
501
        efeat = graph.edata["efeat"]
502
503
504
505
        return efeat


class TemporalTransformerConv(nn.Module):
506
507
508
509
510
511
512
513
514
515
516
    def __init__(
        self,
        edge_feats,
        memory_feats,
        temporal_encoder,
        out_feats,
        num_heads,
        allow_zero_in_degree=False,
        layers=1,
    ):
        """Temporal Transformer model for TGN and TGAT
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540

        Parameter
        ==========
        edge_feats : int
            number of edge features

        memory_feats : int
            dimension of memory vector

        temporal_encoder : torch.nn.Module
            compute fourier time encoding

        out_feats : int
            number of out features

        num_heads : int
            number of attention head

        allow_zero_in_degree : bool, optional
            If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
            since no message will be passed to those nodes. This is harmful for some applications
            causing silent performance regression. This module will raise a DGLError if it detects
            0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
            and let the users handle it by themselves. Defaults: ``False``.
541
        """
542
543
544
545
546
547
548
549
550
551
        super(TemporalTransformerConv, self).__init__()
        self._edge_feats = edge_feats
        self._memory_feats = memory_feats
        self.temporal_encoder = temporal_encoder
        self._out_feats = out_feats
        self._allow_zero_in_degree = allow_zero_in_degree
        self._num_heads = num_heads
        self.layers = layers

        self.preprocessor = TemporalEdgePreprocess(
552
553
            self._edge_feats, self.temporal_encoder
        )
554
        self.layer_list = nn.ModuleList()
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
        self.layer_list.append(
            EdgeGATConv(
                node_feats=self._memory_feats,
                edge_feats=self._edge_feats + self.temporal_encoder.dimension,
                out_feats=self._out_feats,
                num_heads=self._num_heads,
                feat_drop=0.6,
                attn_drop=0.6,
                residual=True,
                allow_zero_in_degree=allow_zero_in_degree,
            )
        )
        for i in range(self.layers - 1):
            self.layer_list.append(
                EdgeGATConv(
                    node_feats=self._out_feats * self._num_heads,
                    edge_feats=self._edge_feats
                    + self.temporal_encoder.dimension,
                    out_feats=self._out_feats,
                    num_heads=self._num_heads,
                    feat_drop=0.6,
                    attn_drop=0.6,
                    residual=True,
                    allow_zero_in_degree=allow_zero_in_degree,
                )
            )
581
582

    def forward(self, graph, memory, ts):
583
        graph = graph.local_var()
584
        graph.ndata["timestamp"] = ts
585
586
        efeat = self.preprocessor(graph).float()
        rst = memory
587
        for i in range(self.layers - 1):
588
589
            rst = self.layer_list[i](graph, rst, efeat).flatten(1)
        rst = self.layer_list[-1](graph, rst, efeat).mean(1)
590
        return rst