Unverified Commit ebc6a85a authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

[Bugfix] TGN model bugfix and improvement (#2860)



* Add DCRNN Example

* Bugfix DCRNN

* Bugfix DCRNN

* Bugfix Train/eval have different flow

* Performance Matched

* modified tgn for research

* kind of fixed 2hop issue

* remove data

* bugfix and improve the performance

* Refractor Code
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-45-47.ap-northeast-1.compute.internal>
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
parent d76af4d4
...@@ -56,6 +56,12 @@ If you want to change memory updating module: ...@@ -56,6 +56,12 @@ If you want to change memory updating module:
python train.py --dataset wikipedia --memory_updater [rnn/gru] python train.py --dataset wikipedia --memory_updater [rnn/gru]
``` ```
If you want to use TGAT:
```python
python train.py --dataset wikipedia --not_use_memory --k_hop 2
```
## Performance ## Performance
#### Without New Node in test set #### Without New Node in test set
...@@ -96,5 +102,10 @@ Normally temporal encoding needs each node to use incoming time frame as current ...@@ -96,5 +102,10 @@ Normally temporal encoding needs each node to use incoming time frame as current
**What is New Node test** **What is New Node test**
To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training.
**Why the attention module is not exactly same as TGN original paper**
Attention module used in this model is adapted from DGL GATConv, considering edge feature and time encoding. It is more memory efficient and faster to compute then the attention module proposed in the paper, meanwhile, according to our test, the accuracy of our module compared with the one in paper is the same.
...@@ -9,50 +9,18 @@ from dgl.ops import edge_softmax ...@@ -9,50 +9,18 @@ from dgl.ops import edge_softmax
import dgl.function as fn import dgl.function as fn
class MergeLayer(nn.Module): class Identity(nn.Module):
"""Merge two tensor into one """A placeholder identity operator that is argument-insensitive.
Which is useful as skip connection in Merge GAT's input with output (Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
Parameter
----------
dim1 : int
dimension of first input tensor
dim2 : int
dimension of second input tensor
dim3 : int
hidden dimension after first merging
dim4 : int
output dimension
Example
----------
>>> merger = MergeLayer(10,10,10,5)
>>> input1 = torch.ones(4,10)
>>> input2 = torch.ones(4,10)
>>> merger(input1,input2)
tensor([[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362],
[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362],
[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362],
[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362]],
grad_fn=<AddmmBackward>)
""" """
def __init__(self, dim1, dim2, dim3, dim4): def __init__(self):
super(MergeLayer, self).__init__() super(Identity, self).__init__()
self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
self.fc2 = torch.nn.Linear(dim3, dim4)
self.act = torch.nn.ReLU()
torch.nn.init.xavier_normal_(self.fc1.weight) def forward(self, x):
torch.nn.init.xavier_normal_(self.fc2.weight) """Return input"""
return x
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
h = self.act(self.fc1(x))
return self.fc2(h)
class MsgLinkPredictor(nn.Module): class MsgLinkPredictor(nn.Module):
...@@ -138,7 +106,6 @@ class TimeEncode(nn.Module): ...@@ -138,7 +106,6 @@ class TimeEncode(nn.Module):
self.dimension = dimension self.dimension = dimension
self.w = torch.nn.Linear(1, dimension) self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))) self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double().reshape(dimension, -1)) .double().reshape(dimension, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double()) self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())
...@@ -266,16 +233,17 @@ class MemoryOperation(nn.Module): ...@@ -266,16 +233,17 @@ class MemoryOperation(nn.Module):
Please refers to examples/pytorch/tgn/tgn.py Please refers to examples/pytorch/tgn/tgn.py
""" """
def __init__(self, updater_type, memory, e_feat_dim, temporal_dim): def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
super(MemoryOperation, self).__init__() super(MemoryOperation, self).__init__()
updater_dict = {'gru': nn.GRUCell, 'rnn': nn.RNNCell} updater_dict = {'gru': nn.GRUCell, 'rnn': nn.RNNCell}
self.memory = memory self.memory = memory
memory_dim = self.memory.hidden_dim memory_dim = self.memory.hidden_dim
self.message_dim = memory_dim+memory_dim+e_feat_dim+temporal_dim self.temporal_encoder = temporal_encoder
self.message_dim = memory_dim+memory_dim + \
e_feat_dim+self.temporal_encoder.dimension
self.updater = updater_dict[updater_type](input_size=self.message_dim, self.updater = updater_dict[updater_type](input_size=self.message_dim,
hidden_size=memory_dim) hidden_size=memory_dim)
self.memory = memory self.memory = memory
self.temporal_encoder = TimeEncode(temporal_dim)
# Here assume g is a subgraph from each iteration # Here assume g is a subgraph from each iteration
def stick_feat_to_graph(self, g): def stick_feat_to_graph(self, g):
...@@ -309,134 +277,253 @@ class MemoryOperation(nn.Module): ...@@ -309,134 +277,253 @@ class MemoryOperation(nn.Module):
return g return g
class TemporalGATConv(nn.Module): class EdgeGATConv(nn.Module):
"""Dot Product based embedding with temporal encoding '''Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
it will each node will compute the attention weight with other edge feature.
nodes using other node's memory as well as edge features
Aggregation: Parameter
..math:: ==========
h_i^{(l+1)} = ReLu(\sum_{j\in \mathcal{N}(i)} \alpha_{i, j} (h_j^{(l)}(t)||e_{jj}||TimeEncode(t-t_j))) node_feats : int
number of node features
\alpha_{i, j} = \mathrm{softmax_i}(\frac{QK^T}{\sqrt{d_k}})V edge_feats : int
number of edge features
K,Q,V computation: out_feats : int
..math:: number of output features
K = W_k[memory_{src}(t),memory_{dst}(t),TimeEncode(t_{dst}-t_{src})]
Q = W_q[memory_{src}(t),memory_{dst}(t),TimeEncode(t_{dst}-t_{src})]
V = W_v[memory_{src}(t),memory_{dst}(t),TimeEncode(t_{dst}-t_{src})]
Parameters num_heads : int
---------- number of heads in multihead attention
edge_feats : int feat_drop : float, optional
dimension of edge feats drop out rate on the feature
memory_feats : int attn_drop : float, optional
dimension of memory feats drop out rate on the attention weight
temporal_feats : int negative_slope : float, optional
length of fourier series of time encoding LeakyReLU angle of negative slope.
num_heads : int residual : bool, optional
number of head in multihead attention whether use residual connection
allow_zero_in_degree : bool allow_zero_in_degree : bool, optional
Whether allow some node have indegree == 0 to prevent silence evaluation If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
Example '''
----------
>>> attn = TemporalGATConv(2,2,5,2,2,False)
>>> star_graph = dgl.graph(([1,2,3,4,5],[0,0,0,0,0]))
>>> star_graph.edata['feats'] = torch.ones(5,2).double()
>>> star_graph.edata['timestamp'] = torch.zeros(5).double()
>>> memory = torch.ones(6,2)
>>> ts = torch.random.rand(6).double()
>>> star_graph = dgl.add_self_loop(star_graph)
>>> attn(graph,memory,ts)
tensor([[-0.0924, -0.3842],
[-0.0840, -0.3539],
[-0.0842, -0.3543],
[-0.0838, -0.3536],
[-0.0856, -0.3568],
[-0.0858, -0.3572]], grad_fn=<AddmmBackward>)
"""
def __init__(self, def __init__(self,
node_feats,
edge_feats, edge_feats,
memory_feats,
temporal_feats,
out_feats, out_feats,
num_heads, num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False): allow_zero_in_degree=False):
super(TemporalGATConv, self).__init__() super(EdgeGATConv, self).__init__()
self._num_heads = num_heads
self._node_feats = node_feats
self._edge_feats = edge_feats self._edge_feats = edge_feats
self._memory_feats = memory_feats
self._temporal_feats = temporal_feats
self._out_feats = out_feats self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads self.fc_node = nn.Linear(
self._node_feats, self._out_feats*self._num_heads)
self.fc_edge = nn.Linear(
self._edge_feats, self._out_feats*self._num_heads)
self.attn_l = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.attn_r = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.attn_e = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
self.residual = residual
if residual:
if self._node_feats != self._out_feats:
self.res_fc = nn.Linear(
self._node_feats, self._out_feats*self._num_heads, bias=False)
else:
self.res_fc = Identity()
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
nn.init.xavier_normal_(self.attn_e, gain=gain)
if self.residual and isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def msg_fn(self, edges):
ret = edges.data['a'].view(-1, self._num_heads,
1)*edges.data['el_prime']
return {'m': ret}
def forward(self, graph, nfeat, efeat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
nfeat = self.feat_drop(nfeat)
efeat = self.feat_drop(efeat)
node_feat = self.fc_node(
nfeat).view(-1, self._num_heads, self._out_feats)
edge_feat = self.fc_edge(
efeat).view(-1, self._num_heads, self._out_feats)
el = (node_feat*self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (node_feat*self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat*self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata['ft'] = node_feat
graph.ndata['el'] = el
graph.ndata['er'] = er
graph.edata['ee'] = ee
graph.apply_edges(fn.u_add_e('el', 'ee', 'el_prime'))
graph.apply_edges(fn.e_add_v('el_prime', 'er', 'e'))
e = self.leaky_relu(graph.edata['e'])
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata['efeat'] = edge_feat
graph.update_all(self.msg_fn, fn.sum('m', 'ft'))
rst = graph.ndata['ft']
if self.residual:
resval = self.res_fc(nfeat).view(
nfeat.shape[0], -1, self._out_feats)
rst = rst + resval
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
class TemporalEdgePreprocess(nn.Module):
'''Preprocess layer, which finish time encoding and concatenate
the time encoding to edge feature.
Parameter
==========
edge_feats : int
number of orginal edge feature
temporal_encoder : torch.nn.Module
time encoder model
'''
self.fc_Q = nn.Linear(self._memory_feats+self._temporal_feats, def __init__(self, edge_feats, temporal_encoder):
self._out_feats*self._num_heads, bias=False) super(TemporalEdgePreprocess, self).__init__()
self.fc_K = nn.Linear(self._memory_feats+self._edge_feats + self.edge_feats = edge_feats
self._temporal_feats, self._out_feats*self._num_heads, bias=False) self.temporal_encoder = temporal_encoder
self.fc_V = nn.Linear(self._memory_feats+self._edge_feats +
self._temporal_feats, self._out_feats*self._num_heads, bias=False) def edge_fn(self, edges):
self.merge = MergeLayer(
self._memory_feats, self._out_feats*self._num_heads, 512, self._out_feats)
self.temporal_encoder = TimeEncode(self._temporal_feats)
def weight_fn(self, edges):
t0 = torch.zeros_like(edges.dst['timestamp']) t0 = torch.zeros_like(edges.dst['timestamp'])
q = torch.cat([edges.dst['s'], time_diff = edges.data['timestamp'] - edges.src['timestamp']
self.temporal_encoder(t0.unsqueeze(dim=1)).view(len(t0), -1)], dim=1)
time_diff = edges.data['timestamp']-edges.src['timestamp']
time_encode = self.temporal_encoder( time_encode = self.temporal_encoder(
time_diff.unsqueeze(dim=1)).view(len(t0), -1) time_diff.unsqueeze(dim=1)).view(t0.shape[0], -1)
k = torch.cat( edge_feat = torch.cat([edges.data['feats'], time_encode], dim=1)
[edges.src['s'], edges.data['feats'], time_encode], dim=1) return {'efeat': edge_feat}
squeezed_k = self.fc_K(
k.float()).view(-1, self._num_heads, self._out_feats)
squeezed_q = self.fc_Q(
q.float()).view(-1, self._num_heads, self._out_feats)
ret = torch.sum(squeezed_q*squeezed_k, dim=2)
return {'a': ret, 'efeat': squeezed_k}
def msg_fn(self, edges): def forward(self, graph):
ret = edges.data['sa'].view(-1, self._num_heads, 1)*edges.data['efeat'] graph.apply_edges(self.edge_fn)
return {'attn': ret} efeat = graph.edata['efeat']
return efeat
class TemporalTransformerConv(nn.Module):
def __init__(self,
edge_feats,
memory_feats,
temporal_encoder,
out_feats,
num_heads,
allow_zero_in_degree=False,
layers=1):
'''Temporal Transformer model for TGN and TGAT
Parameter
==========
edge_feats : int
number of edge features
memory_feats : int
dimension of memory vector
temporal_encoder : torch.nn.Module
compute fourier time encoding
out_feats : int
number of out features
num_heads : int
number of attention head
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
'''
super(TemporalTransformerConv, self).__init__()
self._edge_feats = edge_feats
self._memory_feats = memory_feats
self.temporal_encoder = temporal_encoder
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
self.layers = layers
self.preprocessor = TemporalEdgePreprocess(
self._edge_feats, self.temporal_encoder)
self.layer_list = nn.ModuleList()
self.layer_list.append(EdgeGATConv(node_feats=self._memory_feats,
edge_feats=self._edge_feats+self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree))
for i in range(self.layers-1):
self.layer_list.append(EdgeGATConv(node_feats=self._out_feats*self._num_heads,
edge_feats=self._edge_feats+self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree))
def forward(self, graph, memory, ts): def forward(self, graph, memory, ts):
graph = graph.local_var() # Using local scope for graph graph = graph.local_var()
if not self._allow_zero_in_degree: graph.ndata['timestamp'] = ts
if(graph.in_degrees() == 0).any(): efeat = self.preprocessor(graph).float()
raise DGLError('There are 0-in-degree nodes in the graph, ' rst = memory
'output for those nodes will be invalid. ' for i in range(self.layers-1):
'This is harmful for some applications, ' rst = self.layer_list[i](graph, rst, efeat).flatten(1)
'causing silent performance regression. ' rst = self.layer_list[-1](graph, rst, efeat).mean(1)
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
#print("Shape: ",memory.shape,ts.shape)
graph.srcdata.update({'s': memory, 'timestamp': ts})
graph.dstdata.update({'s': memory, 'timestamp': ts})
# Dot product Calculate the attentio weight
graph.apply_edges(self.weight_fn)
# Edge softmax
graph.edata['sa'] = edge_softmax(
graph, graph.edata['a'])/(self._out_feats**0.5)
# Update dst node Here msg_fn include edge feature
graph.update_all(self.msg_fn, fn.sum('attn', 'agg_u'))
rst = graph.dstdata['agg_u']
# Implement skip connection
rst = self.merge(rst.view(-1, self._num_heads *
self._out_feats), graph.dstdata['s'])
return rst return rst
import copy import copy
import torch.nn as nn import torch.nn as nn
import dgl import dgl
from modules import MemoryModule, MemoryOperation, TemporalGATConv, MsgLinkPredictor from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode
class TGN(nn.Module): class TGN(nn.Module):
def __init__(self, def __init__(self,
...@@ -10,9 +11,10 @@ class TGN(nn.Module): ...@@ -10,9 +11,10 @@ class TGN(nn.Module):
temporal_dim, temporal_dim,
embedding_dim, embedding_dim,
num_heads, num_heads,
num_nodes, # entire graph num_nodes,
n_neighbors=10, n_neighbors=10,
memory_updater_type='gru'): memory_updater_type='gru',
layers=1):
super(TGN, self).__init__() super(TGN, self).__init__()
self.memory_dim = memory_dim self.memory_dim = memory_dim
self.edge_feat_dim = edge_feat_dim self.edge_feat_dim = edge_feat_dim
...@@ -22,6 +24,9 @@ class TGN(nn.Module): ...@@ -22,6 +24,9 @@ class TGN(nn.Module):
self.n_neighbors = n_neighbors self.n_neighbors = n_neighbors
self.memory_updater_type = memory_updater_type self.memory_updater_type = memory_updater_type
self.num_nodes = num_nodes self.num_nodes = num_nodes
self.layers = layers
self.temporal_encoder = TimeEncode(self.temporal_dim)
self.memory = MemoryModule(self.num_nodes, self.memory = MemoryModule(self.num_nodes,
self.memory_dim) self.memory_dim)
...@@ -29,15 +34,15 @@ class TGN(nn.Module): ...@@ -29,15 +34,15 @@ class TGN(nn.Module):
self.memory_ops = MemoryOperation(self.memory_updater_type, self.memory_ops = MemoryOperation(self.memory_updater_type,
self.memory, self.memory,
self.edge_feat_dim, self.edge_feat_dim,
self.temporal_dim) self.temporal_encoder)
self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim,
self.embedding_attn = TemporalGATConv(self.edge_feat_dim, self.memory_dim,
self.memory_dim, self.temporal_encoder,
self.temporal_dim, self.embedding_dim,
self.embedding_dim, self.num_heads,
self.num_heads, layers=self.layers,
allow_zero_in_degree=True) allow_zero_in_degree=True)
self.msg_linkpredictor = MsgLinkPredictor(embedding_dim) self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)
......
...@@ -24,7 +24,7 @@ np.random.seed(2021) ...@@ -24,7 +24,7 @@ np.random.seed(2021)
torch.manual_seed(2021) torch.manual_seed(2021)
def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mode): def train(model, dataloader, sampler, criterion, optimizer, args):
model.train() model.train()
total_loss = 0 total_loss = 0
batch_cnt = 0 batch_cnt = 0
...@@ -35,13 +35,14 @@ def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mod ...@@ -35,13 +35,14 @@ def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mod
positive_pair_g, negative_pair_g, blocks) positive_pair_g, negative_pair_g, blocks)
loss = criterion(pred_pos, torch.ones_like(pred_pos)) loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg)) loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)*batch_size total_loss += float(loss)*args.batch_size
retain_graph = True if batch_cnt == 0 and not fast_mode else False retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
loss.backward(retain_graph=retain_graph) loss.backward(retain_graph=retain_graph)
optimizer.step() optimizer.step()
model.detach_memory() model.detach_memory()
model.update_memory(positive_pair_g) if not args.not_use_memory:
if fast_mode: model.update_memory(positive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t) sampler.attach_last_update(model.memory.last_update_t)
print("Batch: ", batch_cnt, "Time: ", time.time()-last_t) print("Batch: ", batch_cnt, "Time: ", time.time()-last_t)
last_t = time.time() last_t = time.time()
...@@ -49,9 +50,9 @@ def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mod ...@@ -49,9 +50,9 @@ def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mod
return total_loss return total_loss
def test_val(model, dataloader, sampler, criterion, batch_size, fast_mode): def test_val(model, dataloader, sampler, criterion, args):
model.eval() model.eval()
batch_size = batch_size batch_size = args.batch_size
total_loss = 0 total_loss = 0
aps, aucs = [], [] aps, aucs = [], []
batch_cnt = 0 batch_cnt = 0
...@@ -65,8 +66,9 @@ def test_val(model, dataloader, sampler, criterion, batch_size, fast_mode): ...@@ -65,8 +66,9 @@ def test_val(model, dataloader, sampler, criterion, batch_size, fast_mode):
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat( y_true = torch.cat(
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) [torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
model.update_memory(postive_pair_g) if not args.not_use_memory:
if fast_mode: model.update_memory(postive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t) sampler.attach_last_update(model.memory.last_update_t)
aps.append(average_precision_score(y_true, y_pred)) aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred)) aucs.append(roc_auc_score(y_true, y_pred))
...@@ -106,11 +108,14 @@ if __name__ == "__main__": ...@@ -106,11 +108,14 @@ if __name__ == "__main__":
parser.add_argument("--dataset", type=str, default="wikipedia", parser.add_argument("--dataset", type=str, default="wikipedia",
help="dataset selection wikipedia/reddit") help="dataset selection wikipedia/reddit")
parser.add_argument("--k_hop", type=int, default=1, parser.add_argument("--k_hop", type=int, default=1,
help="sampling k-hop neighborhood") help="sampling k-hop neighborhood")
parser.add_argument("--not_use_memory", action="store_true", default=False,
help="Enable memory for TGN Model disable memory for TGN Model")
args = parser.parse_args() args = parser.parse_args()
assert not (args.fast_mode and args.simple_mode), "you can only choose one sampling mode" assert not (
args.fast_mode and args.simple_mode), "you can only choose one sampling mode"
if args.k_hop != 1: if args.k_hop != 1:
assert args.simple_mode, "this k-hop parameter only support simple mode" assert args.simple_mode, "this k-hop parameter only support simple mode"
...@@ -246,7 +251,8 @@ if __name__ == "__main__": ...@@ -246,7 +251,8 @@ if __name__ == "__main__":
num_heads=args.num_heads, num_heads=args.num_heads,
num_nodes=num_node, num_nodes=num_node,
n_neighbors=args.n_neighbors, n_neighbors=args.n_neighbors,
memory_updater_type=args.memory_updater) memory_updater_type=args.memory_updater,
layers=args.k_hop)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
...@@ -257,21 +263,22 @@ if __name__ == "__main__": ...@@ -257,21 +263,22 @@ if __name__ == "__main__":
try: try:
for i in range(args.epochs): for i in range(args.epochs):
train_loss = train(model, train_dataloader, sampler, train_loss = train(model, train_dataloader, sampler,
criterion, optimizer, args.batch_size, args.fast_mode) criterion, optimizer, args)
val_ap, val_auc = test_val( val_ap, val_auc = test_val(
model, valid_dataloader, sampler, criterion, args.batch_size, args.fast_mode) model, valid_dataloader, sampler, criterion, args)
memory_checkpoint = model.store_memory() memory_checkpoint = model.store_memory()
if args.fast_mode: if args.fast_mode:
new_node_sampler.sync(sampler) new_node_sampler.sync(sampler)
test_ap, test_auc = test_val( test_ap, test_auc = test_val(
model, test_dataloader, sampler, criterion, args.batch_size, args.fast_mode) model, test_dataloader, sampler, criterion, args)
model.restore_memory(memory_checkpoint) model.restore_memory(memory_checkpoint)
if args.fast_mode: if args.fast_mode:
sample_nn = new_node_sampler sample_nn = new_node_sampler
else: else:
sample_nn = sampler sample_nn = sampler
nn_test_ap, nn_test_auc = test_val( nn_test_ap, nn_test_auc = test_val(
model, test_new_node_dataloader, sample_nn, criterion, args.batch_size, args.fast_mode) model, test_new_node_dataloader, sample_nn, criterion, args)
log_content = [] log_content = []
log_content.append("Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format( log_content.append("Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc)) i, train_loss, val_ap, val_auc))
...@@ -290,5 +297,4 @@ if __name__ == "__main__": ...@@ -290,5 +297,4 @@ if __name__ == "__main__":
error_content = "Training Interreputed!" error_content = "Training Interreputed!"
f.writelines(error_content) f.writelines(error_content)
f.close() f.close()
# exit(-1)
print("========Training is Done========") print("========Training is Done========")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment