Unverified Commit ebc6a85a authored by Chen Sirui's avatar Chen Sirui Committed by GitHub
Browse files

[Bugfix] TGN model bugfix and improvement (#2860)



* Add DCRNN Example

* Bugfix DCRNN

* Bugfix DCRNN

* Bugfix Train/eval have different flow

* Performance Matched

* modified tgn for research

* kind of fixed 2hop issue

* remove data

* bugfix and improve the performance

* Refractor Code
Co-authored-by: default avatarUbuntu <ubuntu@ip-172-31-45-47.ap-northeast-1.compute.internal>
Co-authored-by: default avatarTianjun Xiao <xiaotj1990327@gmail.com>
parent d76af4d4
......@@ -56,6 +56,12 @@ If you want to change memory updating module:
python train.py --dataset wikipedia --memory_updater [rnn/gru]
```
If you want to use TGAT:
```python
python train.py --dataset wikipedia --not_use_memory --k_hop 2
```
## Performance
#### Without New Node in test set
......@@ -96,5 +102,10 @@ Normally temporal encoding needs each node to use incoming time frame as current
**What is New Node test**
To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training
To test the model has the ability to predict link between unseen nodes based on neighboring information of seen nodes. This model deliberately select 10 % of node in test graph and mask them out during the training.
**Why the attention module is not exactly same as TGN original paper**
Attention module used in this model is adapted from DGL GATConv, considering edge feature and time encoding. It is more memory efficient and faster to compute then the attention module proposed in the paper, meanwhile, according to our test, the accuracy of our module compared with the one in paper is the same.
......@@ -9,50 +9,18 @@ from dgl.ops import edge_softmax
import dgl.function as fn
class MergeLayer(nn.Module):
"""Merge two tensor into one
Which is useful as skip connection in Merge GAT's input with output
Parameter
----------
dim1 : int
dimension of first input tensor
dim2 : int
dimension of second input tensor
dim3 : int
hidden dimension after first merging
dim4 : int
output dimension
Example
----------
>>> merger = MergeLayer(10,10,10,5)
>>> input1 = torch.ones(4,10)
>>> input2 = torch.ones(4,10)
>>> merger(input1,input2)
tensor([[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362],
[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362],
[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362],
[-0.1578, 0.1842, 0.2373, 1.2656, 1.0362]],
grad_fn=<AddmmBackward>)
class Identity(nn.Module):
"""A placeholder identity operator that is argument-insensitive.
(Identity has already been supported by PyTorch 1.2, we will directly
import torch.nn.Identity in the future)
"""
def __init__(self, dim1, dim2, dim3, dim4):
super(MergeLayer, self).__init__()
self.fc1 = torch.nn.Linear(dim1 + dim2, dim3)
self.fc2 = torch.nn.Linear(dim3, dim4)
self.act = torch.nn.ReLU()
def __init__(self):
super(Identity, self).__init__()
torch.nn.init.xavier_normal_(self.fc1.weight)
torch.nn.init.xavier_normal_(self.fc2.weight)
def forward(self, x1, x2):
x = torch.cat([x1, x2], dim=1)
h = self.act(self.fc1(x))
return self.fc2(h)
def forward(self, x):
"""Return input"""
return x
class MsgLinkPredictor(nn.Module):
......@@ -138,7 +106,6 @@ class TimeEncode(nn.Module):
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double().reshape(dimension, -1))
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())
......@@ -266,16 +233,17 @@ class MemoryOperation(nn.Module):
Please refers to examples/pytorch/tgn/tgn.py
"""
def __init__(self, updater_type, memory, e_feat_dim, temporal_dim):
def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
super(MemoryOperation, self).__init__()
updater_dict = {'gru': nn.GRUCell, 'rnn': nn.RNNCell}
self.memory = memory
memory_dim = self.memory.hidden_dim
self.message_dim = memory_dim+memory_dim+e_feat_dim+temporal_dim
self.temporal_encoder = temporal_encoder
self.message_dim = memory_dim+memory_dim + \
e_feat_dim+self.temporal_encoder.dimension
self.updater = updater_dict[updater_type](input_size=self.message_dim,
hidden_size=memory_dim)
self.memory = memory
self.temporal_encoder = TimeEncode(temporal_dim)
# Here assume g is a subgraph from each iteration
def stick_feat_to_graph(self, g):
......@@ -309,134 +277,253 @@ class MemoryOperation(nn.Module):
return g
class TemporalGATConv(nn.Module):
"""Dot Product based embedding with temporal encoding
it will each node will compute the attention weight with other
nodes using other node's memory as well as edge features
class EdgeGATConv(nn.Module):
'''Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
edge feature.
Aggregation:
..math::
h_i^{(l+1)} = ReLu(\sum_{j\in \mathcal{N}(i)} \alpha_{i, j} (h_j^{(l)}(t)||e_{jj}||TimeEncode(t-t_j)))
Parameter
==========
node_feats : int
number of node features
\alpha_{i, j} = \mathrm{softmax_i}(\frac{QK^T}{\sqrt{d_k}})V
edge_feats : int
number of edge features
K,Q,V computation:
..math::
K = W_k[memory_{src}(t),memory_{dst}(t),TimeEncode(t_{dst}-t_{src})]
Q = W_q[memory_{src}(t),memory_{dst}(t),TimeEncode(t_{dst}-t_{src})]
V = W_v[memory_{src}(t),memory_{dst}(t),TimeEncode(t_{dst}-t_{src})]
out_feats : int
number of output features
Parameters
----------
num_heads : int
number of heads in multihead attention
edge_feats : int
dimension of edge feats
feat_drop : float, optional
drop out rate on the feature
memory_feats : int
dimension of memory feats
attn_drop : float, optional
drop out rate on the attention weight
temporal_feats : int
length of fourier series of time encoding
negative_slope : float, optional
LeakyReLU angle of negative slope.
num_heads : int
number of head in multihead attention
residual : bool, optional
whether use residual connection
allow_zero_in_degree : bool
Whether allow some node have indegree == 0 to prevent silence evaluation
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
Example
----------
>>> attn = TemporalGATConv(2,2,5,2,2,False)
>>> star_graph = dgl.graph(([1,2,3,4,5],[0,0,0,0,0]))
>>> star_graph.edata['feats'] = torch.ones(5,2).double()
>>> star_graph.edata['timestamp'] = torch.zeros(5).double()
>>> memory = torch.ones(6,2)
>>> ts = torch.random.rand(6).double()
>>> star_graph = dgl.add_self_loop(star_graph)
>>> attn(graph,memory,ts)
tensor([[-0.0924, -0.3842],
[-0.0840, -0.3539],
[-0.0842, -0.3543],
[-0.0838, -0.3536],
[-0.0856, -0.3568],
[-0.0858, -0.3572]], grad_fn=<AddmmBackward>)
"""
'''
def __init__(self,
node_feats,
edge_feats,
memory_feats,
temporal_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False):
super(TemporalGATConv, self).__init__()
super(EdgeGATConv, self).__init__()
self._num_heads = num_heads
self._node_feats = node_feats
self._edge_feats = edge_feats
self._memory_feats = memory_feats
self._temporal_feats = temporal_feats
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
self.fc_node = nn.Linear(
self._node_feats, self._out_feats*self._num_heads)
self.fc_edge = nn.Linear(
self._edge_feats, self._out_feats*self._num_heads)
self.attn_l = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.attn_r = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.attn_e = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
self.residual = residual
if residual:
if self._node_feats != self._out_feats:
self.res_fc = nn.Linear(
self._node_feats, self._out_feats*self._num_heads, bias=False)
else:
self.res_fc = Identity()
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
nn.init.xavier_normal_(self.attn_r, gain=gain)
nn.init.xavier_normal_(self.attn_e, gain=gain)
if self.residual and isinstance(self.res_fc, nn.Linear):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def msg_fn(self, edges):
ret = edges.data['a'].view(-1, self._num_heads,
1)*edges.data['el_prime']
return {'m': ret}
def forward(self, graph, nfeat, efeat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
nfeat = self.feat_drop(nfeat)
efeat = self.feat_drop(efeat)
node_feat = self.fc_node(
nfeat).view(-1, self._num_heads, self._out_feats)
edge_feat = self.fc_edge(
efeat).view(-1, self._num_heads, self._out_feats)
el = (node_feat*self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (node_feat*self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat*self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata['ft'] = node_feat
graph.ndata['el'] = el
graph.ndata['er'] = er
graph.edata['ee'] = ee
graph.apply_edges(fn.u_add_e('el', 'ee', 'el_prime'))
graph.apply_edges(fn.e_add_v('el_prime', 'er', 'e'))
e = self.leaky_relu(graph.edata['e'])
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata['efeat'] = edge_feat
graph.update_all(self.msg_fn, fn.sum('m', 'ft'))
rst = graph.ndata['ft']
if self.residual:
resval = self.res_fc(nfeat).view(
nfeat.shape[0], -1, self._out_feats)
rst = rst + resval
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
else:
return rst
class TemporalEdgePreprocess(nn.Module):
'''Preprocess layer, which finish time encoding and concatenate
the time encoding to edge feature.
Parameter
==========
edge_feats : int
number of orginal edge feature
temporal_encoder : torch.nn.Module
time encoder model
'''
self.fc_Q = nn.Linear(self._memory_feats+self._temporal_feats,
self._out_feats*self._num_heads, bias=False)
self.fc_K = nn.Linear(self._memory_feats+self._edge_feats +
self._temporal_feats, self._out_feats*self._num_heads, bias=False)
self.fc_V = nn.Linear(self._memory_feats+self._edge_feats +
self._temporal_feats, self._out_feats*self._num_heads, bias=False)
self.merge = MergeLayer(
self._memory_feats, self._out_feats*self._num_heads, 512, self._out_feats)
self.temporal_encoder = TimeEncode(self._temporal_feats)
def weight_fn(self, edges):
def __init__(self, edge_feats, temporal_encoder):
super(TemporalEdgePreprocess, self).__init__()
self.edge_feats = edge_feats
self.temporal_encoder = temporal_encoder
def edge_fn(self, edges):
t0 = torch.zeros_like(edges.dst['timestamp'])
q = torch.cat([edges.dst['s'],
self.temporal_encoder(t0.unsqueeze(dim=1)).view(len(t0), -1)], dim=1)
time_diff = edges.data['timestamp']-edges.src['timestamp']
time_diff = edges.data['timestamp'] - edges.src['timestamp']
time_encode = self.temporal_encoder(
time_diff.unsqueeze(dim=1)).view(len(t0), -1)
k = torch.cat(
[edges.src['s'], edges.data['feats'], time_encode], dim=1)
squeezed_k = self.fc_K(
k.float()).view(-1, self._num_heads, self._out_feats)
squeezed_q = self.fc_Q(
q.float()).view(-1, self._num_heads, self._out_feats)
ret = torch.sum(squeezed_q*squeezed_k, dim=2)
return {'a': ret, 'efeat': squeezed_k}
time_diff.unsqueeze(dim=1)).view(t0.shape[0], -1)
edge_feat = torch.cat([edges.data['feats'], time_encode], dim=1)
return {'efeat': edge_feat}
def msg_fn(self, edges):
ret = edges.data['sa'].view(-1, self._num_heads, 1)*edges.data['efeat']
return {'attn': ret}
def forward(self, graph):
graph.apply_edges(self.edge_fn)
efeat = graph.edata['efeat']
return efeat
class TemporalTransformerConv(nn.Module):
def __init__(self,
edge_feats,
memory_feats,
temporal_encoder,
out_feats,
num_heads,
allow_zero_in_degree=False,
layers=1):
'''Temporal Transformer model for TGN and TGAT
Parameter
==========
edge_feats : int
number of edge features
memory_feats : int
dimension of memory vector
temporal_encoder : torch.nn.Module
compute fourier time encoding
out_feats : int
number of out features
num_heads : int
number of attention head
allow_zero_in_degree : bool, optional
If there are 0-in-degree nodes in the graph, output for those nodes will be invalid
since no message will be passed to those nodes. This is harmful for some applications
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
'''
super(TemporalTransformerConv, self).__init__()
self._edge_feats = edge_feats
self._memory_feats = memory_feats
self.temporal_encoder = temporal_encoder
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self._num_heads = num_heads
self.layers = layers
self.preprocessor = TemporalEdgePreprocess(
self._edge_feats, self.temporal_encoder)
self.layer_list = nn.ModuleList()
self.layer_list.append(EdgeGATConv(node_feats=self._memory_feats,
edge_feats=self._edge_feats+self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree))
for i in range(self.layers-1):
self.layer_list.append(EdgeGATConv(node_feats=self._out_feats*self._num_heads,
edge_feats=self._edge_feats+self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree))
def forward(self, graph, memory, ts):
graph = graph.local_var() # Using local scope for graph
if not self._allow_zero_in_degree:
if(graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
#print("Shape: ",memory.shape,ts.shape)
graph.srcdata.update({'s': memory, 'timestamp': ts})
graph.dstdata.update({'s': memory, 'timestamp': ts})
# Dot product Calculate the attentio weight
graph.apply_edges(self.weight_fn)
# Edge softmax
graph.edata['sa'] = edge_softmax(
graph, graph.edata['a'])/(self._out_feats**0.5)
# Update dst node Here msg_fn include edge feature
graph.update_all(self.msg_fn, fn.sum('attn', 'agg_u'))
rst = graph.dstdata['agg_u']
# Implement skip connection
rst = self.merge(rst.view(-1, self._num_heads *
self._out_feats), graph.dstdata['s'])
graph = graph.local_var()
graph.ndata['timestamp'] = ts
efeat = self.preprocessor(graph).float()
rst = memory
for i in range(self.layers-1):
rst = self.layer_list[i](graph, rst, efeat).flatten(1)
rst = self.layer_list[-1](graph, rst, efeat).mean(1)
return rst
import copy
import torch.nn as nn
import dgl
from modules import MemoryModule, MemoryOperation, TemporalGATConv, MsgLinkPredictor
from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode
class TGN(nn.Module):
def __init__(self,
......@@ -10,9 +11,10 @@ class TGN(nn.Module):
temporal_dim,
embedding_dim,
num_heads,
num_nodes, # entire graph
num_nodes,
n_neighbors=10,
memory_updater_type='gru'):
memory_updater_type='gru',
layers=1):
super(TGN, self).__init__()
self.memory_dim = memory_dim
self.edge_feat_dim = edge_feat_dim
......@@ -22,6 +24,9 @@ class TGN(nn.Module):
self.n_neighbors = n_neighbors
self.memory_updater_type = memory_updater_type
self.num_nodes = num_nodes
self.layers = layers
self.temporal_encoder = TimeEncode(self.temporal_dim)
self.memory = MemoryModule(self.num_nodes,
self.memory_dim)
......@@ -29,15 +34,15 @@ class TGN(nn.Module):
self.memory_ops = MemoryOperation(self.memory_updater_type,
self.memory,
self.edge_feat_dim,
self.temporal_dim)
self.temporal_encoder)
self.embedding_attn = TemporalGATConv(self.edge_feat_dim,
self.memory_dim,
self.temporal_dim,
self.embedding_dim,
self.num_heads,
allow_zero_in_degree=True)
self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim,
self.memory_dim,
self.temporal_encoder,
self.embedding_dim,
self.num_heads,
layers=self.layers,
allow_zero_in_degree=True)
self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)
......
......@@ -24,7 +24,7 @@ np.random.seed(2021)
torch.manual_seed(2021)
def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mode):
def train(model, dataloader, sampler, criterion, optimizer, args):
model.train()
total_loss = 0
batch_cnt = 0
......@@ -35,13 +35,14 @@ def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mod
positive_pair_g, negative_pair_g, blocks)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)*batch_size
retain_graph = True if batch_cnt == 0 and not fast_mode else False
total_loss += float(loss)*args.batch_size
retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
loss.backward(retain_graph=retain_graph)
optimizer.step()
model.detach_memory()
model.update_memory(positive_pair_g)
if fast_mode:
if not args.not_use_memory:
model.update_memory(positive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
print("Batch: ", batch_cnt, "Time: ", time.time()-last_t)
last_t = time.time()
......@@ -49,9 +50,9 @@ def train(model, dataloader, sampler, criterion, optimizer, batch_size, fast_mod
return total_loss
def test_val(model, dataloader, sampler, criterion, batch_size, fast_mode):
def test_val(model, dataloader, sampler, criterion, args):
model.eval()
batch_size = batch_size
batch_size = args.batch_size
total_loss = 0
aps, aucs = [], []
batch_cnt = 0
......@@ -65,8 +66,9 @@ def test_val(model, dataloader, sampler, criterion, batch_size, fast_mode):
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat(
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
model.update_memory(postive_pair_g)
if fast_mode:
if not args.not_use_memory:
model.update_memory(postive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
aps.append(average_precision_score(y_true, y_pred))
aucs.append(roc_auc_score(y_true, y_pred))
......@@ -106,11 +108,14 @@ if __name__ == "__main__":
parser.add_argument("--dataset", type=str, default="wikipedia",
help="dataset selection wikipedia/reddit")
parser.add_argument("--k_hop", type=int, default=1,
help="sampling k-hop neighborhood")
help="sampling k-hop neighborhood")
parser.add_argument("--not_use_memory", action="store_true", default=False,
help="Enable memory for TGN Model disable memory for TGN Model")
args = parser.parse_args()
assert not (args.fast_mode and args.simple_mode), "you can only choose one sampling mode"
assert not (
args.fast_mode and args.simple_mode), "you can only choose one sampling mode"
if args.k_hop != 1:
assert args.simple_mode, "this k-hop parameter only support simple mode"
......@@ -246,7 +251,8 @@ if __name__ == "__main__":
num_heads=args.num_heads,
num_nodes=num_node,
n_neighbors=args.n_neighbors,
memory_updater_type=args.memory_updater)
memory_updater_type=args.memory_updater,
layers=args.k_hop)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
......@@ -257,21 +263,22 @@ if __name__ == "__main__":
try:
for i in range(args.epochs):
train_loss = train(model, train_dataloader, sampler,
criterion, optimizer, args.batch_size, args.fast_mode)
criterion, optimizer, args)
val_ap, val_auc = test_val(
model, valid_dataloader, sampler, criterion, args.batch_size, args.fast_mode)
model, valid_dataloader, sampler, criterion, args)
memory_checkpoint = model.store_memory()
if args.fast_mode:
new_node_sampler.sync(sampler)
test_ap, test_auc = test_val(
model, test_dataloader, sampler, criterion, args.batch_size, args.fast_mode)
model, test_dataloader, sampler, criterion, args)
model.restore_memory(memory_checkpoint)
if args.fast_mode:
sample_nn = new_node_sampler
else:
sample_nn = sampler
nn_test_ap, nn_test_auc = test_val(
model, test_new_node_dataloader, sample_nn, criterion, args.batch_size, args.fast_mode)
model, test_new_node_dataloader, sample_nn, criterion, args)
log_content = []
log_content.append("Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc))
......@@ -290,5 +297,4 @@ if __name__ == "__main__":
error_content = "Training Interreputed!"
f.writelines(error_content)
f.close()
# exit(-1)
print("========Training is Done========")
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment