"vscode:/vscode.git/clone" did not exist on "e6bae8d916d976d227489fe0daa57586472c6c4e"
tgn.py 2.99 KB
Newer Older
1
import copy
2

3
import torch.nn as nn
4
5
6
7
8
9
10
11
from modules import (
    MemoryModule,
    MemoryOperation,
    MsgLinkPredictor,
    TemporalTransformerConv,
    TimeEncode,
)

12
import dgl
13

14
15

class TGN(nn.Module):
16
17
18
19
20
21
22
23
24
25
26
27
    def __init__(
        self,
        edge_feat_dim,
        memory_dim,
        temporal_dim,
        embedding_dim,
        num_heads,
        num_nodes,
        n_neighbors=10,
        memory_updater_type="gru",
        layers=1,
    ):
28
29
30
31
32
33
34
35
36
        super(TGN, self).__init__()
        self.memory_dim = memory_dim
        self.edge_feat_dim = edge_feat_dim
        self.temporal_dim = temporal_dim
        self.embedding_dim = embedding_dim
        self.num_heads = num_heads
        self.n_neighbors = n_neighbors
        self.memory_updater_type = memory_updater_type
        self.num_nodes = num_nodes
37
38
39
        self.layers = layers

        self.temporal_encoder = TimeEncode(self.temporal_dim)
40

41
        self.memory = MemoryModule(self.num_nodes, self.memory_dim)
42

43
44
45
46
47
48
        self.memory_ops = MemoryOperation(
            self.memory_updater_type,
            self.memory,
            self.edge_feat_dim,
            self.temporal_encoder,
        )
49

50
51
52
53
54
55
56
57
58
        self.embedding_attn = TemporalTransformerConv(
            self.edge_feat_dim,
            self.memory_dim,
            self.temporal_encoder,
            self.embedding_dim,
            self.num_heads,
            layers=self.layers,
            allow_zero_in_degree=True,
        )
59
60
61
62
63
64

        self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)

    def embed(self, postive_graph, negative_graph, blocks):
        emb_graph = blocks[0]
        emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
65
        emb_t = emb_graph.ndata["timestamp"]
66
67
        embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
        emb2pred = dict(
68
69
            zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist())
        )
70
71
72
73
        # Since postive graph and negative graph has same is mapping
        feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
        feat = embedding[feat_id]
        pred_pos, pred_neg = self.msg_linkpredictor(
74
75
            feat, postive_graph, negative_graph
        )
76
77
78
79
        return pred_pos, pred_neg

    def update_memory(self, subg):
        new_g = self.memory_ops(subg)
80
        self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata["memory"])
81
        self.memory.set_last_update_t(
82
83
            new_g.ndata[dgl.NID], new_g.ndata["timestamp"]
        )
84
85
86
87
88
89
90
91
92
93

    # Some memory operation wrappers
    def detach_memory(self):
        self.memory.detach_memory()

    def reset_memory(self):
        self.memory.reset_memory()

    def store_memory(self):
        memory_checkpoint = {}
94
95
        memory_checkpoint["memory"] = copy.deepcopy(self.memory.memory)
        memory_checkpoint["last_t"] = copy.deepcopy(self.memory.last_update_t)
96
97
98
        return memory_checkpoint

    def restore_memory(self, memory_checkpoint):
99
100
        self.memory.memory = memory_checkpoint["memory"]
        self.memory.last_update_time = memory_checkpoint["last_t"]