import copy import torch.nn as nn from modules import ( MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode, ) import dgl class TGN(nn.Module): def __init__( self, edge_feat_dim, memory_dim, temporal_dim, embedding_dim, num_heads, num_nodes, n_neighbors=10, memory_updater_type="gru", layers=1, ): super(TGN, self).__init__() self.memory_dim = memory_dim self.edge_feat_dim = edge_feat_dim self.temporal_dim = temporal_dim self.embedding_dim = embedding_dim self.num_heads = num_heads self.n_neighbors = n_neighbors self.memory_updater_type = memory_updater_type self.num_nodes = num_nodes self.layers = layers self.temporal_encoder = TimeEncode(self.temporal_dim) self.memory = MemoryModule(self.num_nodes, self.memory_dim) self.memory_ops = MemoryOperation( self.memory_updater_type, self.memory, self.edge_feat_dim, self.temporal_encoder, ) self.embedding_attn = TemporalTransformerConv( self.edge_feat_dim, self.memory_dim, self.temporal_encoder, self.embedding_dim, self.num_heads, layers=self.layers, allow_zero_in_degree=True, ) self.msg_linkpredictor = MsgLinkPredictor(embedding_dim) def embed(self, postive_graph, negative_graph, blocks): emb_graph = blocks[0] emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :] emb_t = emb_graph.ndata["timestamp"] embedding = self.embedding_attn(emb_graph, emb_memory, emb_t) emb2pred = dict( zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist()) ) # Since postive graph and negative graph has same is mapping feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]] feat = embedding[feat_id] pred_pos, pred_neg = self.msg_linkpredictor( feat, postive_graph, negative_graph ) return pred_pos, pred_neg def update_memory(self, subg): new_g = self.memory_ops(subg) self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata["memory"]) self.memory.set_last_update_t( new_g.ndata[dgl.NID], new_g.ndata["timestamp"] ) # Some memory operation wrappers def detach_memory(self): self.memory.detach_memory() def reset_memory(self): self.memory.reset_memory() def store_memory(self): memory_checkpoint = {} memory_checkpoint["memory"] = copy.deepcopy(self.memory.memory) memory_checkpoint["last_t"] = copy.deepcopy(self.memory.last_update_t) return memory_checkpoint def restore_memory(self, memory_checkpoint): self.memory.memory = memory_checkpoint["memory"] self.memory.last_update_time = memory_checkpoint["last_t"]