Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
...@@ -5,28 +5,35 @@ Paper: https://arxiv.org/abs/1902.07153 ...@@ -5,28 +5,35 @@ Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL. SGC implementation in DGL.
""" """
import argparse, time, math import argparse
import math
import time
import numpy as np import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import load_data, register_data_args
from dgl.nn.pytorch.conv import SGConv from dgl.nn.pytorch.conv import SGConv
def normalize(h): def normalize(h):
return (h-h.mean(0))/h.std(0) return (h - h.mean(0)) / h.std(0)
def evaluate(model, features, graph, labels, mask): def evaluate(model, features, graph, labels, mask):
model.eval() model.eval()
with torch.no_grad(): with torch.no_grad():
logits = model(graph, features)[mask] # only compute the evaluation set logits = model(graph, features)[mask] # only compute the evaluation set
labels = labels[mask] labels = labels[mask]
_, indices = torch.max(logits, dim=1) _, indices = torch.max(logits, dim=1)
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
args.dataset = "reddit-self-loop" args.dataset = "reddit-self-loop"
...@@ -38,24 +45,29 @@ def main(args): ...@@ -38,24 +45,29 @@ def main(args):
cuda = True cuda = True
g = g.int().to(args.gpu) g = g.int().to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
g.ndata['train_mask'].int().sum().item(), n_edges,
g.ndata['val_mask'].int().sum().item(), n_classes,
g.ndata['test_mask'].int().sum().item())) g.ndata["train_mask"].int().sum().item(),
g.ndata["val_mask"].int().sum().item(),
g.ndata["test_mask"].int().sum().item(),
)
)
# graph preprocess and calculate normalization factor # graph preprocess and calculate normalization factor
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
...@@ -63,10 +75,12 @@ def main(args): ...@@ -63,10 +75,12 @@ def main(args):
degs = g.in_degrees().float() degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5) norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0 norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1) g.ndata["norm"] = norm.unsqueeze(1)
# create SGC model # create SGC model
model = SGConv(in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize) model = SGConv(
in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize
)
if args.gpu >= 0: if args.gpu >= 0:
model = model.cuda() model = model.cuda()
...@@ -90,15 +104,16 @@ def main(args): ...@@ -90,15 +104,16 @@ def main(args):
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='SGC') parser = argparse.ArgumentParser(description="SGC")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1, parser.add_argument("--gpu", type=int, default=-1, help="gpu")
help="gpu") parser.add_argument(
parser.add_argument("--bias", action='store_true', default=False, "--bias", action="store_true", default=False, help="flag to use bias"
help="flag to use bias") )
parser.add_argument("--n-epochs", type=int, default=2, parser.add_argument(
help="number of training epochs") "--n-epochs", type=int, default=2, help="number of training epochs"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import torch
import numpy as np import numpy as np
import torch
import dgl import dgl
...@@ -7,6 +8,7 @@ def load_dataset(name): ...@@ -7,6 +8,7 @@ def load_dataset(name):
dataset = name.lower() dataset = name.lower()
if dataset == "amazon": if dataset == "amazon":
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
dataset = DglNodePropPredDataset(name="ogbn-products") dataset = DglNodePropPredDataset(name="ogbn-products")
splitted_idx = dataset.get_idx_split() splitted_idx = dataset.get_idx_split()
train_nid = splitted_idx["train"] train_nid = splitted_idx["train"]
...@@ -14,26 +16,28 @@ def load_dataset(name): ...@@ -14,26 +16,28 @@ def load_dataset(name):
test_nid = splitted_idx["test"] test_nid = splitted_idx["test"]
g, labels = dataset[0] g, labels = dataset[0]
n_classes = int(labels.max() - labels.min() + 1) n_classes = int(labels.max() - labels.min() + 1)
g.ndata['label'] = labels.squeeze() g.ndata["label"] = labels.squeeze()
g.ndata['feat'] = g.ndata['feat'].float() g.ndata["feat"] = g.ndata["feat"].float()
elif dataset in ["reddit", "cora"]: elif dataset in ["reddit", "cora"]:
if dataset == "reddit": if dataset == "reddit":
from dgl.data import RedditDataset from dgl.data import RedditDataset
data = RedditDataset(self_loop=True) data = RedditDataset(self_loop=True)
g = data[0] g = data[0]
else: else:
from dgl.data import CitationGraphDataset from dgl.data import CitationGraphDataset
data = CitationGraphDataset('cora')
data = CitationGraphDataset("cora")
g = data[0] g = data[0]
n_classes = data.num_labels n_classes = data.num_labels
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
train_nid = torch.LongTensor(train_mask.nonzero().squeeze()) train_nid = torch.LongTensor(train_mask.nonzero().squeeze())
val_nid = torch.LongTensor(val_mask.nonzero().squeeze()) val_nid = torch.LongTensor(val_mask.nonzero().squeeze())
test_nid = torch.LongTensor(test_mask.nonzero().squeeze()) test_nid = torch.LongTensor(test_mask.nonzero().squeeze())
else: else:
print("Dataset {} is not supported".format(name)) print("Dataset {} is not supported".format(name))
assert(0) assert 0
return g, n_classes, train_nid, val_nid, test_nid return g, n_classes, train_nid, val_nid, test_nid
import argparse import argparse
import os import os
import time import time
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from dataset import load_dataset
import dgl import dgl
import dgl.function as fn import dgl.function as fn
from dataset import load_dataset
class FeedForwardNet(nn.Module): class FeedForwardNet(nn.Module):
...@@ -48,10 +50,12 @@ class Model(nn.Module): ...@@ -48,10 +50,12 @@ class Model(nn.Module):
self.inception_ffs = nn.ModuleList() self.inception_ffs = nn.ModuleList()
for hop in range(R + 1): for hop in range(R + 1):
self.inception_ffs.append( self.inception_ffs.append(
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)) FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)
)
# self.linear = nn.Linear(hidden * (R + 1), out_feats) # self.linear = nn.Linear(hidden * (R + 1), out_feats)
self.project = FeedForwardNet((R + 1) * hidden, hidden, out_feats, self.project = FeedForwardNet(
n_layers, dropout) (R + 1) * hidden, hidden, out_feats, n_layers, dropout
)
def forward(self, feats): def forward(self, feats):
hidden = [] hidden = []
...@@ -84,8 +88,10 @@ def preprocess(g, features, args): ...@@ -84,8 +88,10 @@ def preprocess(g, features, args):
g.edata["weight"] = calc_weight(g) g.edata["weight"] = calc_weight(g)
g.ndata["feat_0"] = features g.ndata["feat_0"] = features
for hop in range(1, args.R + 1): for hop in range(1, args.R + 1):
g.update_all(fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"), g.update_all(
fn.sum("msg", f"feat_{hop}")) fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"),
fn.sum("msg", f"feat_{hop}"),
)
res = [] res = []
for hop in range(args.R + 1): for hop in range(args.R + 1):
res.append(g.ndata.pop(f"feat_{hop}")) res.append(g.ndata.pop(f"feat_{hop}"))
...@@ -96,17 +102,26 @@ def prepare_data(device, args): ...@@ -96,17 +102,26 @@ def prepare_data(device, args):
data = load_dataset(args.dataset) data = load_dataset(args.dataset)
g, n_classes, train_nid, val_nid, test_nid = data g, n_classes, train_nid, val_nid, test_nid = data
g = g.to(device) g = g.to(device)
in_feats = g.ndata['feat'].shape[1] in_feats = g.ndata["feat"].shape[1]
feats = preprocess(g, g.ndata['feat'], args) feats = preprocess(g, g.ndata["feat"], args)
labels = g.ndata['label'] labels = g.ndata["label"]
# move to device # move to device
train_nid = train_nid.to(device) train_nid = train_nid.to(device)
val_nid = val_nid.to(device) val_nid = val_nid.to(device)
test_nid = test_nid.to(device) test_nid = test_nid.to(device)
train_feats = [x[train_nid] for x in feats] train_feats = [x[train_nid] for x in feats]
train_labels = labels[train_nid] train_labels = labels[train_nid]
return feats, labels, train_feats, train_labels, in_feats, \ return (
n_classes, train_nid, val_nid, test_nid feats,
labels,
train_feats,
train_labels,
in_feats,
n_classes,
train_nid,
val_nid,
test_nid,
)
def evaluate(epoch, args, model, feats, labels, train, val, test): def evaluate(epoch, args, model, feats, labels, train, val, test):
...@@ -121,7 +136,7 @@ def evaluate(epoch, args, model, feats, labels, train, val, test): ...@@ -121,7 +136,7 @@ def evaluate(epoch, args, model, feats, labels, train, val, test):
for i in range(n_batch): for i in range(n_batch):
batch_start = i * batch_size batch_start = i * batch_size
batch_end = min((i + 1) * batch_size, num_nodes) batch_end = min((i + 1) * batch_size, num_nodes)
batch_feats = [feat[batch_start: batch_end] for feat in feats] batch_feats = [feat[batch_start:batch_end] for feat in feats]
pred.append(model(batch_feats)) pred.append(model(batch_feats))
pred = torch.cat(pred) pred = torch.cat(pred)
...@@ -140,15 +155,31 @@ def main(args): ...@@ -140,15 +155,31 @@ def main(args):
device = "cuda:{}".format(args.gpu) device = "cuda:{}".format(args.gpu)
data = prepare_data(device, args) data = prepare_data(device, args)
feats, labels, train_feats, train_labels, in_size, num_classes, \ (
train_nid, val_nid, test_nid = data feats,
labels,
model = Model(in_size, args.num_hidden, num_classes, args.R, args.ff_layer, train_feats,
args.dropout) train_labels,
in_size,
num_classes,
train_nid,
val_nid,
test_nid,
) = data
model = Model(
in_size,
args.num_hidden,
num_classes,
args.R,
args.ff_layer,
args.dropout,
)
model = model.to(device) model = model.to(device)
loss_fcn = nn.CrossEntropyLoss() loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, optimizer = torch.optim.Adam(
weight_decay=args.weight_decay) model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
best_epoch = 0 best_epoch = 0
best_val = 0 best_val = 0
...@@ -164,38 +195,47 @@ def main(args): ...@@ -164,38 +195,47 @@ def main(args):
if epoch % args.eval_every == 0: if epoch % args.eval_every == 0:
model.eval() model.eval()
acc = evaluate(epoch, args, model, feats, labels, acc = evaluate(
train_nid, val_nid, test_nid) epoch, args, model, feats, labels, train_nid, val_nid, test_nid
)
end = time.time() end = time.time()
log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start) log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start)
log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}" \ log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format(
.format(*acc) *acc
)
print(log) print(log)
if acc[1] > best_val: if acc[1] > best_val:
best_val = acc[1] best_val = acc[1]
best_epoch = epoch best_epoch = epoch
best_test = acc[2] best_test = acc[2]
print("Best Epoch {}, Val {:.4f}, Test {:.4f}".format( print(
best_epoch, best_val, best_test)) "Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
best_epoch, best_val, best_test
)
)
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SIGN") parser = argparse.ArgumentParser(description="SIGN")
parser.add_argument("--num-epochs", type=int, default=1000) parser.add_argument("--num-epochs", type=int, default=1000)
parser.add_argument("--num-hidden", type=int, default=256) parser.add_argument("--num-hidden", type=int, default=256)
parser.add_argument("--R", type=int, default=3, parser.add_argument("--R", type=int, default=3, help="number of hops")
help="number of hops")
parser.add_argument("--lr", type=float, default=0.003) parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dataset", type=str, default="amazon") parser.add_argument("--dataset", type=str, default="amazon")
parser.add_argument("--dropout", type=float, default=0.5) parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--gpu", type=int, default=0) parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--weight-decay", type=float, default=0) parser.add_argument("--weight-decay", type=float, default=0)
parser.add_argument("--eval-every", type=int, default=50) parser.add_argument("--eval-every", type=int, default=50)
parser.add_argument("--eval-batch-size", type=int, default=250000, parser.add_argument(
help="evaluation batch size, -1 for full batch") "--eval-batch-size",
parser.add_argument("--ff-layer", type=int, default=2, type=int,
help="number of feed-forward layers") default=250000,
help="evaluation batch size, -1 for full batch",
)
parser.add_argument(
"--ff-layer", type=int, default=2, help="number of feed-forward layers"
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import torch
import numpy as np import numpy as np
import pandas as pd import pandas as pd
import torch
def load_data(file_path, len_train, len_val): def load_data(file_path, len_train, len_val):
df = pd.read_csv(file_path, header=None).values.astype(float) df = pd.read_csv(file_path, header=None).values.astype(float)
train = df[: len_train] train = df[:len_train]
val = df[len_train: len_train + len_val] val = df[len_train : len_train + len_val]
test = df[len_train + len_val:] test = df[len_train + len_val :]
return train, val, test return train, val, test
...@@ -16,15 +15,15 @@ def data_transform(data, n_his, n_pred, device): ...@@ -16,15 +15,15 @@ def data_transform(data, n_his, n_pred, device):
# produce data slices for training and testing # produce data slices for training and testing
n_route = data.shape[1] n_route = data.shape[1]
l = len(data) l = len(data)
num = l-n_his-n_pred num = l - n_his - n_pred
x = np.zeros([num, 1, n_his, n_route]) x = np.zeros([num, 1, n_his, n_route])
y = np.zeros([num, n_route]) y = np.zeros([num, n_route])
cnt = 0 cnt = 0
for i in range(l-n_his-n_pred): for i in range(l - n_his - n_pred):
head = i head = i
tail = i + n_his tail = i + n_his
x[cnt, :, :, :] = data[head: tail].reshape(1, n_his, n_route) x[cnt, :, :, :] = data[head:tail].reshape(1, n_his, n_route)
y[cnt] = data[tail + n_pred - 1] y[cnt] = data[tail + n_pred - 1]
cnt += 1 cnt += 1
return torch.Tensor(x).to(device), torch.Tensor(y).to(device) return torch.Tensor(x).to(device), torch.Tensor(y).to(device)
import dgl import argparse
import random import random
import torch
import numpy as np import numpy as np
import pandas as pd import pandas as pd
from sklearn.preprocessing import StandardScaler import scipy.sparse as sp
import torch
import torch.nn as nn
from load_data import * from load_data import *
from utils import *
from model import * from model import *
from sensors2graph import * from sensors2graph import *
import torch.nn as nn from sklearn.preprocessing import StandardScaler
import argparse from utils import *
import scipy.sparse as sp
parser = argparse.ArgumentParser(description='STGCN_WAVE') import dgl
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--disablecuda', action='store_true', help='Disable CUDA') parser = argparse.ArgumentParser(description="STGCN_WAVE")
parser.add_argument('--batch_size', type=int, default=50, help='batch size for training and validation (default: 50)') parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
parser.add_argument('--epochs', type=int, default=50, help='epochs for training (default: 50)') parser.add_argument("--disablecuda", action="store_true", help="Disable CUDA")
parser.add_argument('--num_layers', type=int, default=9, help='number of layers') parser.add_argument(
parser.add_argument('--window', type=int, default=144, help='window length') "--batch_size",
parser.add_argument('--sensorsfilepath', type=str, default='./data/sensor_graph/graph_sensor_ids.txt', help='sensors file path') type=int,
parser.add_argument('--disfilepath', type=str, default='./data/sensor_graph/distances_la_2012.csv', help='distance file path') default=50,
parser.add_argument('--tsfilepath', type=str, default='./data/metr-la.h5', help='ts file path') help="batch size for training and validation (default: 50)",
parser.add_argument('--savemodelpath', type=str, default='stgcnwavemodel.pt', help='save model path') )
parser.add_argument('--pred_len', type=int, default=5, help='how many steps away we want to predict') parser.add_argument(
parser.add_argument('--control_str', type=str, default='TNTSTNTST', help='model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer') "--epochs", type=int, default=50, help="epochs for training (default: 50)"
parser.add_argument('--channels', type=int, nargs='+', default=[1, 16, 32, 64, 32, 128], help='model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer') )
parser.add_argument(
"--num_layers", type=int, default=9, help="number of layers"
)
parser.add_argument("--window", type=int, default=144, help="window length")
parser.add_argument(
"--sensorsfilepath",
type=str,
default="./data/sensor_graph/graph_sensor_ids.txt",
help="sensors file path",
)
parser.add_argument(
"--disfilepath",
type=str,
default="./data/sensor_graph/distances_la_2012.csv",
help="distance file path",
)
parser.add_argument(
"--tsfilepath", type=str, default="./data/metr-la.h5", help="ts file path"
)
parser.add_argument(
"--savemodelpath",
type=str,
default="stgcnwavemodel.pt",
help="save model path",
)
parser.add_argument(
"--pred_len",
type=int,
default=5,
help="how many steps away we want to predict",
)
parser.add_argument(
"--control_str",
type=str,
default="TNTSTNTST",
help="model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer",
)
parser.add_argument(
"--channels",
type=int,
nargs="+",
default=[1, 16, 32, 64, 32, 128],
help="model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer",
)
args = parser.parse_args() args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() and not args.disablecuda else torch.device("cpu") device = (
torch.device("cuda")
if torch.cuda.is_available() and not args.disablecuda
else torch.device("cpu")
)
with open(args.sensorsfilepath) as f: with open(args.sensorsfilepath) as f:
sensor_ids = f.read().strip().split(',') sensor_ids = f.read().strip().split(",")
distance_df = pd.read_csv(args.disfilepath, dtype={'from': 'str', 'to': 'str'}) distance_df = pd.read_csv(args.disfilepath, dtype={"from": "str", "to": "str"})
adj_mx = get_adjacency_matrix(distance_df, sensor_ids) adj_mx = get_adjacency_matrix(distance_df, sensor_ids)
sp_mx = sp.coo_matrix(adj_mx) sp_mx = sp.coo_matrix(adj_mx)
...@@ -51,7 +99,6 @@ n_his = args.window ...@@ -51,7 +99,6 @@ n_his = args.window
save_path = args.savemodelpath save_path = args.savemodelpath
n_pred = args.pred_len n_pred = args.pred_len
n_route = num_nodes n_route = num_nodes
blocks = args.channels blocks = args.channels
...@@ -67,9 +114,9 @@ lr = args.lr ...@@ -67,9 +114,9 @@ lr = args.lr
W = adj_mx W = adj_mx
len_val = round(num_samples * 0.1) len_val = round(num_samples * 0.1)
len_train = round(num_samples * 0.7) len_train = round(num_samples * 0.7)
train = df[: len_train] train = df[:len_train]
val = df[len_train: len_train + len_val] val = df[len_train : len_train + len_val]
test = df[len_train + len_val:] test = df[len_train + len_val :]
scaler = StandardScaler() scaler = StandardScaler()
train = scaler.fit_transform(train) train = scaler.fit_transform(train)
...@@ -91,7 +138,9 @@ test_iter = torch.utils.data.DataLoader(test_data, batch_size) ...@@ -91,7 +138,9 @@ test_iter = torch.utils.data.DataLoader(test_data, batch_size)
loss = nn.MSELoss() loss = nn.MSELoss()
G = G.to(device) G = G.to(device)
model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device) model = STGCN_WAVE(
blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str
).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr) optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7) scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
...@@ -113,10 +162,19 @@ for epoch in range(1, epochs + 1): ...@@ -113,10 +162,19 @@ for epoch in range(1, epochs + 1):
if val_loss < min_val_loss: if val_loss < min_val_loss:
min_val_loss = val_loss min_val_loss = val_loss
torch.save(model.state_dict(), save_path) torch.save(model.state_dict(), save_path)
print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss) print(
"epoch",
epoch,
best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device) ", train loss:",
l_sum / n,
", validation loss:",
val_loss,
)
best_model = STGCN_WAVE(
blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str
).to(device)
best_model.load_state_dict(torch.load(save_path)) best_model.load_state_dict(torch.load(save_path))
......
import math import math
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F import torch.nn.functional as F
import torch.nn.init as init
from dgl.nn.pytorch import GraphConv from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.conv import ChebConv from dgl.nn.pytorch.conv import ChebConv
class TemporalConvLayer(nn.Module): class TemporalConvLayer(nn.Module):
''' Temporal convolution layer. """Temporal convolution layer.
arguments arguments
--------- ---------
c_in : int c_in : int
...@@ -17,27 +20,29 @@ class TemporalConvLayer(nn.Module): ...@@ -17,27 +20,29 @@ class TemporalConvLayer(nn.Module):
The number of output channels (features) The number of output channels (features)
dia : int dia : int
The dilation size The dilation size
''' """
def __init__(self, c_in, c_out, dia = 1):
def __init__(self, c_in, c_out, dia=1):
super(TemporalConvLayer, self).__init__() super(TemporalConvLayer, self).__init__()
self.c_out = c_out self.c_out = c_out
self.c_in = c_in self.c_in = c_in
self.conv = nn.Conv2d(c_in, c_out, (2, 1), 1, dilation = dia, padding = (0,0)) self.conv = nn.Conv2d(
c_in, c_out, (2, 1), 1, dilation=dia, padding=(0, 0)
)
def forward(self, x): def forward(self, x):
return torch.relu(self.conv(x)) return torch.relu(self.conv(x))
class SpatioConvLayer(nn.Module): class SpatioConvLayer(nn.Module):
def __init__(self, c, Lk): # c : hidden dimension Lk: graph matrix def __init__(self, c, Lk): # c : hidden dimension Lk: graph matrix
super(SpatioConvLayer, self).__init__() super(SpatioConvLayer, self).__init__()
self.g = Lk self.g = Lk
self.gc = GraphConv(c, c, activation=F.relu) self.gc = GraphConv(c, c, activation=F.relu)
# self.gc = ChebConv(c, c, 3) # self.gc = ChebConv(c, c, 3)
def init(self): def init(self):
stdv = 1. / math.sqrt(self.W.weight.size(1)) stdv = 1.0 / math.sqrt(self.W.weight.size(1))
self.W.weight.data.uniform_(-stdv, stdv) self.W.weight.data.uniform_(-stdv, stdv)
def forward(self, x): def forward(self, x):
...@@ -48,6 +53,7 @@ class SpatioConvLayer(nn.Module): ...@@ -48,6 +53,7 @@ class SpatioConvLayer(nn.Module):
output = output.transpose(0, 3) output = output.transpose(0, 3)
return torch.relu(output) return torch.relu(output)
class FullyConvLayer(nn.Module): class FullyConvLayer(nn.Module):
def __init__(self, c): def __init__(self, c):
super(FullyConvLayer, self).__init__() super(FullyConvLayer, self).__init__()
...@@ -56,12 +62,13 @@ class FullyConvLayer(nn.Module): ...@@ -56,12 +62,13 @@ class FullyConvLayer(nn.Module):
def forward(self, x): def forward(self, x):
return self.conv(x) return self.conv(x)
class OutputLayer(nn.Module): class OutputLayer(nn.Module):
def __init__(self, c, T, n): def __init__(self, c, T, n):
super(OutputLayer, self).__init__() super(OutputLayer, self).__init__()
self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation = 1, padding = (0,0)) self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation=1, padding=(0, 0))
self.ln = nn.LayerNorm([n, c]) self.ln = nn.LayerNorm([n, c])
self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation = 1, padding = (0,0)) self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation=1, padding=(0, 0))
self.fc = FullyConvLayer(c) self.fc = FullyConvLayer(c)
def forward(self, x): def forward(self, x):
...@@ -70,32 +77,38 @@ class OutputLayer(nn.Module): ...@@ -70,32 +77,38 @@ class OutputLayer(nn.Module):
x_t2 = self.tconv2(x_ln) x_t2 = self.tconv2(x_ln)
return self.fc(x_t2) return self.fc(x_t2)
class STGCN_WAVE(nn.Module): class STGCN_WAVE(nn.Module):
def __init__(self, c, T, n, Lk, p, num_layers, device, control_str = 'TNTSTNTST'): def __init__(
self, c, T, n, Lk, p, num_layers, device, control_str="TNTSTNTST"
):
super(STGCN_WAVE, self).__init__() super(STGCN_WAVE, self).__init__()
self.control_str = control_str # model structure controller self.control_str = control_str # model structure controller
self.num_layers = len(control_str) self.num_layers = len(control_str)
self.layers = nn.ModuleList([]) self.layers = nn.ModuleList([])
cnt = 0 cnt = 0
diapower = 0 diapower = 0
for i in range(self.num_layers): for i in range(self.num_layers):
i_layer = control_str[i] i_layer = control_str[i]
if i_layer == 'T': # Temporal Layer if i_layer == "T": # Temporal Layer
self.layers.append(TemporalConvLayer(c[cnt], c[cnt + 1], dia = 2**diapower)) self.layers.append(
TemporalConvLayer(c[cnt], c[cnt + 1], dia=2**diapower)
)
diapower += 1 diapower += 1
cnt += 1 cnt += 1
if i_layer == 'S': # Spatio Layer if i_layer == "S": # Spatio Layer
self.layers.append(SpatioConvLayer(c[cnt], Lk)) self.layers.append(SpatioConvLayer(c[cnt], Lk))
if i_layer == 'N': # Norm Layer if i_layer == "N": # Norm Layer
self.layers.append(nn.LayerNorm([n,c[cnt]])) self.layers.append(nn.LayerNorm([n, c[cnt]]))
self.output = OutputLayer(c[cnt], T + 1 - 2**(diapower), n) self.output = OutputLayer(c[cnt], T + 1 - 2 ** (diapower), n)
for layer in self.layers: for layer in self.layers:
layer = layer.to(device) layer = layer.to(device)
def forward(self, x): def forward(self, x):
for i in range(self.num_layers): for i in range(self.num_layers):
i_layer = self.control_str[i] i_layer = self.control_str[i]
if i_layer == 'N': if i_layer == "N":
x = self.layers[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2) x = self.layers[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else: else:
x = self.layers[i](x) x = self.layers[i](x)
return self.output(x) return self.output(x)
import numpy as np import numpy as np
def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1): def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1):
""" """
:param distance_df: data frame with three columns: [from, to, distance]. :param distance_df: data frame with three columns: [from, to, distance].
...@@ -28,4 +30,4 @@ def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1): ...@@ -28,4 +30,4 @@ def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1):
# Sets entries that lower than a threshold, i.e., k, to zero for sparsity. # Sets entries that lower than a threshold, i.e., k, to zero for sparsity.
adj_mx[adj_mx < normalized_k] = 0 adj_mx[adj_mx < normalized_k] = 0
return adj_mx return adj_mx
\ No newline at end of file
import torch
import numpy as np import numpy as np
import torch
def evaluate_model(model, loss, data_iter): def evaluate_model(model, loss, data_iter):
...@@ -21,11 +20,13 @@ def evaluate_metric(model, data_iter, scaler): ...@@ -21,11 +20,13 @@ def evaluate_metric(model, data_iter, scaler):
mae, mape, mse = [], [], [] mae, mape, mse = [], [], []
for x, y in data_iter: for x, y in data_iter:
y = scaler.inverse_transform(y.cpu().numpy()).reshape(-1) y = scaler.inverse_transform(y.cpu().numpy()).reshape(-1)
y_pred = scaler.inverse_transform(model(x).view(len(x), -1).cpu().numpy()).reshape(-1) y_pred = scaler.inverse_transform(
model(x).view(len(x), -1).cpu().numpy()
).reshape(-1)
d = np.abs(y - y_pred) d = np.abs(y - y_pred)
mae += d.tolist() mae += d.tolist()
mape += (d / y).tolist() mape += (d / y).tolist()
mse += (d ** 2).tolist() mse += (d**2).tolist()
MAE = np.array(mae).mean() MAE = np.array(mae).mean()
MAPE = np.array(mape).mean() MAPE = np.array(mape).mean()
RMSE = np.sqrt(np.array(mse).mean()) RMSE = np.sqrt(np.array(mse).mean())
......
...@@ -6,17 +6,14 @@ References: ...@@ -6,17 +6,14 @@ References:
""" """
import torch import torch
import torch.nn as nn import torch.nn as nn
from dgl.nn.pytorch.conv import TAGConv from dgl.nn.pytorch.conv import TAGConv
class TAGCN(nn.Module): class TAGCN(nn.Module):
def __init__(self, def __init__(
g, self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
in_feats, ):
n_hidden,
n_classes,
n_layers,
activation,
dropout):
super(TAGCN, self).__init__() super(TAGCN, self).__init__()
self.g = g self.g = g
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
...@@ -24,9 +21,11 @@ class TAGCN(nn.Module): ...@@ -24,9 +21,11 @@ class TAGCN(nn.Module):
self.layers.append(TAGConv(in_feats, n_hidden, activation=activation)) self.layers.append(TAGConv(in_feats, n_hidden, activation=activation))
# hidden layers # hidden layers
for i in range(n_layers - 1): for i in range(n_layers - 1):
self.layers.append(TAGConv(n_hidden, n_hidden, activation=activation)) self.layers.append(
TAGConv(n_hidden, n_hidden, activation=activation)
)
# output layer # output layer
self.layers.append(TAGConv(n_hidden, n_classes)) #activation=None self.layers.append(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = nn.Dropout(p=dropout) self.dropout = nn.Dropout(p=dropout)
def forward(self, features): def forward(self, features):
......
import argparse, time import argparse
import numpy as np import time
import networkx as nx import networkx as nx
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from tagcn import TAGCN
from dgl import DGLGraph from dgl import DGLGraph
from dgl.data import register_data_args, load_data from dgl.data import load_data, register_data_args
from tagcn import TAGCN
def evaluate(model, features, labels, mask): def evaluate(model, features, labels, mask):
model.eval() model.eval()
...@@ -19,6 +22,7 @@ def evaluate(model, features, labels, mask): ...@@ -19,6 +22,7 @@ def evaluate(model, features, labels, mask):
correct = torch.sum(indices == labels) correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels) return correct.item() * 1.0 / len(labels)
def main(args): def main(args):
# load and preprocess dataset # load and preprocess dataset
data = load_data(args) data = load_data(args)
...@@ -28,24 +32,29 @@ def main(args): ...@@ -28,24 +32,29 @@ def main(args):
else: else:
cuda = True cuda = True
g = g.to(args.gpu) g = g.to(args.gpu)
features = g.ndata['feat'] features = g.ndata["feat"]
labels = g.ndata['label'] labels = g.ndata["label"]
train_mask = g.ndata['train_mask'] train_mask = g.ndata["train_mask"]
val_mask = g.ndata['val_mask'] val_mask = g.ndata["val_mask"]
test_mask = g.ndata['test_mask'] test_mask = g.ndata["test_mask"]
in_feats = features.shape[1] in_feats = features.shape[1]
n_classes = data.num_labels n_classes = data.num_labels
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
print("""----Data statistics------' print(
"""----Data statistics------'
#Edges %d #Edges %d
#Classes %d #Classes %d
#Train samples %d #Train samples %d
#Val samples %d #Val samples %d
#Test samples %d""" % #Test samples %d"""
(n_edges, n_classes, % (
train_mask.int().sum().item(), n_edges,
val_mask.int().sum().item(), n_classes,
test_mask.int().sum().item())) train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item(),
)
)
# graph preprocess and calculate normalization factor # graph preprocess and calculate normalization factor
# add self loop # add self loop
...@@ -54,22 +63,24 @@ def main(args): ...@@ -54,22 +63,24 @@ def main(args):
n_edges = g.number_of_edges() n_edges = g.number_of_edges()
# create TAGCN model # create TAGCN model
model = TAGCN(g, model = TAGCN(
in_feats, g,
args.n_hidden, in_feats,
n_classes, args.n_hidden,
args.n_layers, n_classes,
F.relu, args.n_layers,
args.dropout) F.relu,
args.dropout,
)
if cuda: if cuda:
model.cuda() model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss() loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer # use optimizer
optimizer = torch.optim.Adam(model.parameters(), optimizer = torch.optim.Adam(
lr=args.lr, model.parameters(), lr=args.lr, weight_decay=args.weight_decay
weight_decay=args.weight_decay) )
# initialize graph # initialize graph
dur = [] dur = []
...@@ -89,34 +100,47 @@ def main(args): ...@@ -89,34 +100,47 @@ def main(args):
dur.append(time.time() - t0) dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask) acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | " print(
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(), "Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
acc, n_edges / np.mean(dur) / 1000)) "ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print() print()
acc = evaluate(model, features, labels, test_mask) acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc)) print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__': if __name__ == "__main__":
parser = argparse.ArgumentParser(description='TAGCN') parser = argparse.ArgumentParser(description="TAGCN")
register_data_args(parser) register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5, parser.add_argument(
help="dropout probability") "--dropout", type=float, default=0.5, help="dropout probability"
parser.add_argument("--gpu", type=int, default=-1, )
help="gpu") parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
help="learning rate") parser.add_argument(
parser.add_argument("--n-epochs", type=int, default=200, "--n-epochs", type=int, default=200, help="number of training epochs"
help="number of training epochs") )
parser.add_argument("--n-hidden", type=int, default=16, parser.add_argument(
help="number of hidden tagcn units") "--n-hidden", type=int, default=16, help="number of hidden tagcn units"
parser.add_argument("--n-layers", type=int, default=1, )
help="number of hidden tagcn layers") parser.add_argument(
parser.add_argument("--weight-decay", type=float, default=5e-4, "--n-layers", type=int, default=1, help="number of hidden tagcn layers"
help="Weight for L2 loss") )
parser.add_argument("--self-loop", action='store_true', parser.add_argument(
help="graph self-loop (default=False)") "--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
)
parser.add_argument(
"--self-loop",
action="store_true",
help="graph self-loop (default=False)",
)
parser.set_defaults(self_loop=False) parser.set_defaults(self_loop=False)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
......
import os import os
import ssl import ssl
from six.moves import urllib
import pandas as pd
import numpy as np import numpy as np
import pandas as pd
import torch import torch
from six.moves import urllib
import dgl import dgl
# === Below data preprocessing code are based on # === Below data preprocessing code are based on
...@@ -13,6 +13,7 @@ import dgl ...@@ -13,6 +13,7 @@ import dgl
# Preprocess the raw data split each features # Preprocess the raw data split each features
def preprocess(data_name): def preprocess(data_name):
u_list, i_list, ts_list, label_list = [], [], [], [] u_list, i_list, ts_list, label_list = [], [], [], []
feat_l = [] feat_l = []
...@@ -21,7 +22,7 @@ def preprocess(data_name): ...@@ -21,7 +22,7 @@ def preprocess(data_name):
with open(data_name) as f: with open(data_name) as f:
s = next(f) s = next(f)
for idx, line in enumerate(f): for idx, line in enumerate(f):
e = line.strip().split(',') e = line.strip().split(",")
u = int(e[0]) u = int(e[0])
i = int(e[1]) i = int(e[1])
...@@ -37,18 +38,23 @@ def preprocess(data_name): ...@@ -37,18 +38,23 @@ def preprocess(data_name):
idx_list.append(idx) idx_list.append(idx)
feat_l.append(feat) feat_l.append(feat)
return pd.DataFrame({'u': u_list, return pd.DataFrame(
'i': i_list, {
'ts': ts_list, "u": u_list,
'label': label_list, "i": i_list,
'idx': idx_list}), np.array(feat_l) "ts": ts_list,
"label": label_list,
"idx": idx_list,
}
), np.array(feat_l)
# Re index nodes for DGL convience # Re index nodes for DGL convience
def reindex(df, bipartite=True): def reindex(df, bipartite=True):
new_df = df.copy() new_df = df.copy()
if bipartite: if bipartite:
assert (df.u.max() - df.u.min() + 1 == len(df.u.unique())) assert df.u.max() - df.u.min() + 1 == len(df.u.unique())
assert (df.i.max() - df.i.min() + 1 == len(df.i.unique())) assert df.i.max() - df.i.min() + 1 == len(df.i.unique())
upper_u = df.u.max() + 1 upper_u = df.u.max() + 1
new_i = df.i + upper_u new_i = df.i + upper_u
...@@ -64,12 +70,13 @@ def reindex(df, bipartite=True): ...@@ -64,12 +70,13 @@ def reindex(df, bipartite=True):
return new_df return new_df
# Save edge list, features in different file for data easy process data # Save edge list, features in different file for data easy process data
def run(data_name, bipartite=True): def run(data_name, bipartite=True):
PATH = './data/{}.csv'.format(data_name) PATH = "./data/{}.csv".format(data_name)
OUT_DF = './data/ml_{}.csv'.format(data_name) OUT_DF = "./data/ml_{}.csv".format(data_name)
OUT_FEAT = './data/ml_{}.npy'.format(data_name) OUT_FEAT = "./data/ml_{}.npy".format(data_name)
OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name) OUT_NODE_FEAT = "./data/ml_{}_node.npy".format(data_name)
df, feat = preprocess(PATH) df, feat = preprocess(PATH)
new_df = reindex(df, bipartite) new_df = reindex(df, bipartite)
...@@ -84,18 +91,20 @@ def run(data_name, bipartite=True): ...@@ -84,18 +91,20 @@ def run(data_name, bipartite=True):
np.save(OUT_FEAT, feat) np.save(OUT_FEAT, feat)
np.save(OUT_NODE_FEAT, rand_feat) np.save(OUT_NODE_FEAT, rand_feat)
# === code from twitter-research-tgn end === # === code from twitter-research-tgn end ===
# If you have new dataset follow by same format in Jodie, # If you have new dataset follow by same format in Jodie,
# you can directly use name to retrieve dataset # you can directly use name to retrieve dataset
def TemporalDataset(dataset): def TemporalDataset(dataset):
if not os.path.exists('./data/{}.bin'.format(dataset)): if not os.path.exists("./data/{}.bin".format(dataset)):
if not os.path.exists('./data/{}.csv'.format(dataset)): if not os.path.exists("./data/{}.csv".format(dataset)):
if not os.path.exists('./data'): if not os.path.exists("./data"):
os.mkdir('./data') os.mkdir("./data")
url = 'https://snap.stanford.edu/jodie/{}.csv'.format(dataset) url = "https://snap.stanford.edu/jodie/{}.csv".format(dataset)
print("Start Downloading File....") print("Start Downloading File....")
context = ssl._create_unverified_context() context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context) data = urllib.request.urlopen(url, context=context)
...@@ -104,27 +113,28 @@ def TemporalDataset(dataset): ...@@ -104,27 +113,28 @@ def TemporalDataset(dataset):
print("Start Process Data ...") print("Start Process Data ...")
run(dataset) run(dataset)
raw_connection = pd.read_csv('./data/ml_{}.csv'.format(dataset)) raw_connection = pd.read_csv("./data/ml_{}.csv".format(dataset))
raw_feature = np.load('./data/ml_{}.npy'.format(dataset)) raw_feature = np.load("./data/ml_{}.npy".format(dataset))
# -1 for re-index the node # -1 for re-index the node
src = raw_connection['u'].to_numpy()-1 src = raw_connection["u"].to_numpy() - 1
dst = raw_connection['i'].to_numpy()-1 dst = raw_connection["i"].to_numpy() - 1
# Create directed graph # Create directed graph
g = dgl.graph((src, dst)) g = dgl.graph((src, dst))
g.edata['timestamp'] = torch.from_numpy( g.edata["timestamp"] = torch.from_numpy(raw_connection["ts"].to_numpy())
raw_connection['ts'].to_numpy()) g.edata["label"] = torch.from_numpy(raw_connection["label"].to_numpy())
g.edata['label'] = torch.from_numpy(raw_connection['label'].to_numpy()) g.edata["feats"] = torch.from_numpy(raw_feature[1:, :]).float()
g.edata['feats'] = torch.from_numpy(raw_feature[1:, :]).float() dgl.save_graphs("./data/{}.bin".format(dataset), [g])
dgl.save_graphs('./data/{}.bin'.format(dataset), [g])
else: else:
print("Data is exist directly loaded.") print("Data is exist directly loaded.")
gs, _ = dgl.load_graphs('./data/{}.bin'.format(dataset)) gs, _ = dgl.load_graphs("./data/{}.bin".format(dataset))
g = gs[0] g = gs[0]
return g return g
def TemporalWikipediaDataset(): def TemporalWikipediaDataset():
# Download the dataset # Download the dataset
return TemporalDataset('wikipedia') return TemporalDataset("wikipedia")
def TemporalRedditDataset(): def TemporalRedditDataset():
return TemporalDataset('reddit') return TemporalDataset("reddit")
...@@ -4,9 +4,9 @@ import torch.nn as nn ...@@ -4,9 +4,9 @@ import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl import dgl
import dgl.function as fn
from dgl.base import DGLError from dgl.base import DGLError
from dgl.ops import edge_softmax from dgl.ops import edge_softmax
import dgl.function as fn
class Identity(nn.Module): class Identity(nn.Module):
...@@ -60,22 +60,22 @@ class MsgLinkPredictor(nn.Module): ...@@ -60,22 +60,22 @@ class MsgLinkPredictor(nn.Module):
self.out_fc = nn.Linear(emb_dim, 1) self.out_fc = nn.Linear(emb_dim, 1)
def link_pred(self, edges): def link_pred(self, edges):
src_hid = self.src_fc(edges.src['embedding']) src_hid = self.src_fc(edges.src["embedding"])
dst_hid = self.dst_fc(edges.dst['embedding']) dst_hid = self.dst_fc(edges.dst["embedding"])
score = F.relu(src_hid+dst_hid) score = F.relu(src_hid + dst_hid)
score = self.out_fc(score) score = self.out_fc(score)
return {'score': score} return {"score": score}
def forward(self, x, pos_g, neg_g): def forward(self, x, pos_g, neg_g):
# Local Scope? # Local Scope?
pos_g.ndata['embedding'] = x pos_g.ndata["embedding"] = x
neg_g.ndata['embedding'] = x neg_g.ndata["embedding"] = x
pos_g.apply_edges(self.link_pred) pos_g.apply_edges(self.link_pred)
neg_g.apply_edges(self.link_pred) neg_g.apply_edges(self.link_pred)
pos_escore = pos_g.edata['score'] pos_escore = pos_g.edata["score"]
neg_escore = neg_g.edata['score'] neg_escore = neg_g.edata["score"]
return pos_escore, neg_escore return pos_escore, neg_escore
...@@ -84,12 +84,12 @@ class TimeEncode(nn.Module): ...@@ -84,12 +84,12 @@ class TimeEncode(nn.Module):
time different between two event time different between two event
..math:: ..math::
\Phi(t) = [\cos(\omega_0t+\psi_0),\cos(\omega_1t+\psi_1),...,\cos(\omega_nt+\psi_n)] \Phi(t) = [\cos(\omega_0t+\psi_0),\cos(\omega_1t+\psi_1),...,\cos(\omega_nt+\psi_n)]
Parameter Parameter
---------- ----------
dimension : int dimension : int
Length of the fourier series. The longer it is , Length of the fourier series. The longer it is ,
the more timescale information it can capture the more timescale information it can capture
Example Example
...@@ -106,8 +106,11 @@ class TimeEncode(nn.Module): ...@@ -106,8 +106,11 @@ class TimeEncode(nn.Module):
self.dimension = dimension self.dimension = dimension
self.w = torch.nn.Linear(1, dimension) self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension))) self.w.weight = torch.nn.Parameter(
.double().reshape(dimension, -1)) (torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double()
.reshape(dimension, -1)
)
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double()) self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())
def forward(self, t): def forward(self, t):
...@@ -132,7 +135,7 @@ class MemoryModule(nn.Module): ...@@ -132,7 +135,7 @@ class MemoryModule(nn.Module):
Example Example
---------- ----------
Please refers to examples/pytorch/tgn/tgn.py; Please refers to examples/pytorch/tgn/tgn.py;
examples/pytorch/tgn/train.py examples/pytorch/tgn/train.py
""" """
...@@ -143,10 +146,13 @@ class MemoryModule(nn.Module): ...@@ -143,10 +146,13 @@ class MemoryModule(nn.Module):
self.reset_memory() self.reset_memory()
def reset_memory(self): def reset_memory(self):
self.last_update_t = nn.Parameter(torch.zeros( self.last_update_t = nn.Parameter(
self.n_node).float(), requires_grad=False) torch.zeros(self.n_node).float(), requires_grad=False
self.memory = nn.Parameter(torch.zeros( )
(self.n_node, self.hidden_dim)).float(), requires_grad=False) self.memory = nn.Parameter(
torch.zeros((self.n_node, self.hidden_dim)).float(),
requires_grad=False,
)
def backup_memory(self): def backup_memory(self):
""" """
...@@ -192,7 +198,7 @@ class MemoryModule(nn.Module): ...@@ -192,7 +198,7 @@ class MemoryModule(nn.Module):
class MemoryOperation(nn.Module): class MemoryOperation(nn.Module):
""" Memory update using message passing manner, update memory based on positive """Memory update using message passing manner, update memory based on positive
pair graph of each batch with recurrent module GRU or RNN pair graph of each batch with recurrent module GRU or RNN
Message function Message function
...@@ -235,41 +241,66 @@ class MemoryOperation(nn.Module): ...@@ -235,41 +241,66 @@ class MemoryOperation(nn.Module):
def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder): def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
super(MemoryOperation, self).__init__() super(MemoryOperation, self).__init__()
updater_dict = {'gru': nn.GRUCell, 'rnn': nn.RNNCell} updater_dict = {"gru": nn.GRUCell, "rnn": nn.RNNCell}
self.memory = memory self.memory = memory
memory_dim = self.memory.hidden_dim memory_dim = self.memory.hidden_dim
self.temporal_encoder = temporal_encoder self.temporal_encoder = temporal_encoder
self.message_dim = memory_dim+memory_dim + \ self.message_dim = (
e_feat_dim+self.temporal_encoder.dimension memory_dim
self.updater = updater_dict[updater_type](input_size=self.message_dim, + memory_dim
hidden_size=memory_dim) + e_feat_dim
+ self.temporal_encoder.dimension
)
self.updater = updater_dict[updater_type](
input_size=self.message_dim, hidden_size=memory_dim
)
self.memory = memory self.memory = memory
# Here assume g is a subgraph from each iteration # Here assume g is a subgraph from each iteration
def stick_feat_to_graph(self, g): def stick_feat_to_graph(self, g):
# How can I ensure order of the node ID # How can I ensure order of the node ID
g.ndata['timestamp'] = self.memory.last_update_t[g.ndata[dgl.NID]] g.ndata["timestamp"] = self.memory.last_update_t[g.ndata[dgl.NID]]
g.ndata['memory'] = self.memory.memory[g.ndata[dgl.NID]] g.ndata["memory"] = self.memory.memory[g.ndata[dgl.NID]]
def msg_fn_cat(self, edges): def msg_fn_cat(self, edges):
src_delta_time = edges.data['timestamp'] - edges.src['timestamp'] src_delta_time = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(src_delta_time.unsqueeze( time_encode = self.temporal_encoder(
dim=1)).view(len(edges.data['timestamp']), -1) src_delta_time.unsqueeze(dim=1)
ret = torch.cat([edges.src['memory'], edges.dst['memory'], ).view(len(edges.data["timestamp"]), -1)
edges.data['feats'], time_encode], dim=1) ret = torch.cat(
return {'message': ret, 'timestamp': edges.data['timestamp']} [
edges.src["memory"],
edges.dst["memory"],
edges.data["feats"],
time_encode,
],
dim=1,
)
return {"message": ret, "timestamp": edges.data["timestamp"]}
def agg_last(self, nodes): def agg_last(self, nodes):
timestamp, latest_idx = torch.max(nodes.mailbox['timestamp'], dim=1) timestamp, latest_idx = torch.max(nodes.mailbox["timestamp"], dim=1)
ret = nodes.mailbox['message'].gather(1, latest_idx.repeat( ret = (
self.message_dim).view(-1, 1, self.message_dim)).view(-1, self.message_dim) nodes.mailbox["message"]
return {'message_bar': ret.reshape(-1, self.message_dim), 'timestamp': timestamp} .gather(
1,
latest_idx.repeat(self.message_dim).view(
-1, 1, self.message_dim
),
)
.view(-1, self.message_dim)
)
return {
"message_bar": ret.reshape(-1, self.message_dim),
"timestamp": timestamp,
}
def update_memory(self, nodes): def update_memory(self, nodes):
# It should pass the feature through RNN # It should pass the feature through RNN
ret = self.updater( ret = self.updater(
nodes.data['message_bar'].float(), nodes.data['memory'].float()) nodes.data["message_bar"].float(), nodes.data["memory"].float()
return {'memory': ret} )
return {"memory": ret}
def forward(self, g): def forward(self, g):
self.stick_feat_to_graph(g) self.stick_feat_to_graph(g)
...@@ -278,7 +309,7 @@ class MemoryOperation(nn.Module): ...@@ -278,7 +309,7 @@ class MemoryOperation(nn.Module):
class EdgeGATConv(nn.Module): class EdgeGATConv(nn.Module):
'''Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and """Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
edge feature. edge feature.
Parameter Parameter
...@@ -314,19 +345,21 @@ class EdgeGATConv(nn.Module): ...@@ -314,19 +345,21 @@ class EdgeGATConv(nn.Module):
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check 0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``. and let the users handle it by themselves. Defaults: ``False``.
''' """
def __init__(self, def __init__(
node_feats, self,
edge_feats, node_feats,
out_feats, edge_feats,
num_heads, out_feats,
feat_drop=0., num_heads,
attn_drop=0., feat_drop=0.0,
negative_slope=0.2, attn_drop=0.0,
residual=False, negative_slope=0.2,
activation=None, residual=False,
allow_zero_in_degree=False): activation=None,
allow_zero_in_degree=False,
):
super(EdgeGATConv, self).__init__() super(EdgeGATConv, self).__init__()
self._num_heads = num_heads self._num_heads = num_heads
self._node_feats = node_feats self._node_feats = node_feats
...@@ -334,15 +367,20 @@ class EdgeGATConv(nn.Module): ...@@ -334,15 +367,20 @@ class EdgeGATConv(nn.Module):
self._out_feats = out_feats self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree self._allow_zero_in_degree = allow_zero_in_degree
self.fc_node = nn.Linear( self.fc_node = nn.Linear(
self._node_feats, self._out_feats*self._num_heads) self._node_feats, self._out_feats * self._num_heads
)
self.fc_edge = nn.Linear( self.fc_edge = nn.Linear(
self._edge_feats, self._out_feats*self._num_heads) self._edge_feats, self._out_feats * self._num_heads
self.attn_l = nn.Parameter(torch.FloatTensor( )
size=(1, self._num_heads, self._out_feats))) self.attn_l = nn.Parameter(
self.attn_r = nn.Parameter(torch.FloatTensor( torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
size=(1, self._num_heads, self._out_feats))) )
self.attn_e = nn.Parameter(torch.FloatTensor( self.attn_r = nn.Parameter(
size=(1, self._num_heads, self._out_feats))) torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_e = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.feat_drop = nn.Dropout(feat_drop) self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop) self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope) self.leaky_relu = nn.LeakyReLU(negative_slope)
...@@ -350,14 +388,17 @@ class EdgeGATConv(nn.Module): ...@@ -350,14 +388,17 @@ class EdgeGATConv(nn.Module):
if residual: if residual:
if self._node_feats != self._out_feats: if self._node_feats != self._out_feats:
self.res_fc = nn.Linear( self.res_fc = nn.Linear(
self._node_feats, self._out_feats*self._num_heads, bias=False) self._node_feats,
self._out_feats * self._num_heads,
bias=False,
)
else: else:
self.res_fc = Identity() self.res_fc = Identity()
self.reset_parameters() self.reset_parameters()
self.activation = activation self.activation = activation
def reset_parameters(self): def reset_parameters(self):
gain = nn.init.calculate_gain('relu') gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_node.weight, gain=gain) nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
nn.init.xavier_normal_(self.fc_edge.weight, gain=gain) nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain) nn.init.xavier_normal_(self.attn_l, gain=gain)
...@@ -367,62 +408,69 @@ class EdgeGATConv(nn.Module): ...@@ -367,62 +408,69 @@ class EdgeGATConv(nn.Module):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain) nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def msg_fn(self, edges): def msg_fn(self, edges):
ret = edges.data['a'].view(-1, self._num_heads, ret = (
1)*edges.data['el_prime'] edges.data["a"].view(-1, self._num_heads, 1)
return {'m': ret} * edges.data["el_prime"]
)
return {"m": ret}
def forward(self, graph, nfeat, efeat, get_attention=False): def forward(self, graph, nfeat, efeat, get_attention=False):
with graph.local_scope(): with graph.local_scope():
if not self._allow_zero_in_degree: if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any(): if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, ' raise DGLError(
'output for those nodes will be invalid. ' "There are 0-in-degree nodes in the graph, "
'This is harmful for some applications, ' "output for those nodes will be invalid. "
'causing silent performance regression. ' "This is harmful for some applications, "
'Adding self-loop on the input graph by ' "causing silent performance regression. "
'calling `g = dgl.add_self_loop(g)` will resolve ' "Adding self-loop on the input graph by "
'the issue. Setting ``allow_zero_in_degree`` ' "calling `g = dgl.add_self_loop(g)` will resolve "
'to be `True` when constructing this module will ' "the issue. Setting ``allow_zero_in_degree`` "
'suppress the check and let the code run.') "to be `True` when constructing this module will "
"suppress the check and let the code run."
)
nfeat = self.feat_drop(nfeat) nfeat = self.feat_drop(nfeat)
efeat = self.feat_drop(efeat) efeat = self.feat_drop(efeat)
node_feat = self.fc_node( node_feat = self.fc_node(nfeat).view(
nfeat).view(-1, self._num_heads, self._out_feats) -1, self._num_heads, self._out_feats
edge_feat = self.fc_edge( )
efeat).view(-1, self._num_heads, self._out_feats) edge_feat = self.fc_edge(efeat).view(
-1, self._num_heads, self._out_feats
el = (node_feat*self.attn_l).sum(dim=-1).unsqueeze(-1) )
er = (node_feat*self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat*self.attn_e).sum(dim=-1).unsqueeze(-1) el = (node_feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
graph.ndata['ft'] = node_feat er = (node_feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
graph.ndata['el'] = el ee = (edge_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata['er'] = er graph.ndata["ft"] = node_feat
graph.edata['ee'] = ee graph.ndata["el"] = el
graph.apply_edges(fn.u_add_e('el', 'ee', 'el_prime')) graph.ndata["er"] = er
graph.apply_edges(fn.e_add_v('el_prime', 'er', 'e')) graph.edata["ee"] = ee
e = self.leaky_relu(graph.edata['e']) graph.apply_edges(fn.u_add_e("el", "ee", "el_prime"))
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e)) graph.apply_edges(fn.e_add_v("el_prime", "er", "e"))
graph.edata['efeat'] = edge_feat e = self.leaky_relu(graph.edata["e"])
graph.update_all(self.msg_fn, fn.sum('m', 'ft')) graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
rst = graph.ndata['ft'] graph.edata["efeat"] = edge_feat
graph.update_all(self.msg_fn, fn.sum("m", "ft"))
rst = graph.ndata["ft"]
if self.residual: if self.residual:
resval = self.res_fc(nfeat).view( resval = self.res_fc(nfeat).view(
nfeat.shape[0], -1, self._out_feats) nfeat.shape[0], -1, self._out_feats
)
rst = rst + resval rst = rst + resval
if self.activation: if self.activation:
rst = self.activation(rst) rst = self.activation(rst)
if get_attention: if get_attention:
return rst, graph.edata['a'] return rst, graph.edata["a"]
else: else:
return rst return rst
class TemporalEdgePreprocess(nn.Module): class TemporalEdgePreprocess(nn.Module):
'''Preprocess layer, which finish time encoding and concatenate """Preprocess layer, which finish time encoding and concatenate
the time encoding to edge feature. the time encoding to edge feature.
Parameter Parameter
...@@ -432,7 +480,7 @@ class TemporalEdgePreprocess(nn.Module): ...@@ -432,7 +480,7 @@ class TemporalEdgePreprocess(nn.Module):
temporal_encoder : torch.nn.Module temporal_encoder : torch.nn.Module
time encoder model time encoder model
''' """
def __init__(self, edge_feats, temporal_encoder): def __init__(self, edge_feats, temporal_encoder):
super(TemporalEdgePreprocess, self).__init__() super(TemporalEdgePreprocess, self).__init__()
...@@ -440,29 +488,32 @@ class TemporalEdgePreprocess(nn.Module): ...@@ -440,29 +488,32 @@ class TemporalEdgePreprocess(nn.Module):
self.temporal_encoder = temporal_encoder self.temporal_encoder = temporal_encoder
def edge_fn(self, edges): def edge_fn(self, edges):
t0 = torch.zeros_like(edges.dst['timestamp']) t0 = torch.zeros_like(edges.dst["timestamp"])
time_diff = edges.data['timestamp'] - edges.src['timestamp'] time_diff = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder( time_encode = self.temporal_encoder(time_diff.unsqueeze(dim=1)).view(
time_diff.unsqueeze(dim=1)).view(t0.shape[0], -1) t0.shape[0], -1
edge_feat = torch.cat([edges.data['feats'], time_encode], dim=1) )
return {'efeat': edge_feat} edge_feat = torch.cat([edges.data["feats"], time_encode], dim=1)
return {"efeat": edge_feat}
def forward(self, graph): def forward(self, graph):
graph.apply_edges(self.edge_fn) graph.apply_edges(self.edge_fn)
efeat = graph.edata['efeat'] efeat = graph.edata["efeat"]
return efeat return efeat
class TemporalTransformerConv(nn.Module): class TemporalTransformerConv(nn.Module):
def __init__(self, def __init__(
edge_feats, self,
memory_feats, edge_feats,
temporal_encoder, memory_feats,
out_feats, temporal_encoder,
num_heads, out_feats,
allow_zero_in_degree=False, num_heads,
layers=1): allow_zero_in_degree=False,
'''Temporal Transformer model for TGN and TGAT layers=1,
):
"""Temporal Transformer model for TGN and TGAT
Parameter Parameter
========== ==========
...@@ -487,7 +538,7 @@ class TemporalTransformerConv(nn.Module): ...@@ -487,7 +538,7 @@ class TemporalTransformerConv(nn.Module):
causing silent performance regression. This module will raise a DGLError if it detects causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check 0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``. and let the users handle it by themselves. Defaults: ``False``.
''' """
super(TemporalTransformerConv, self).__init__() super(TemporalTransformerConv, self).__init__()
self._edge_feats = edge_feats self._edge_feats = edge_feats
self._memory_feats = memory_feats self._memory_feats = memory_feats
...@@ -498,32 +549,42 @@ class TemporalTransformerConv(nn.Module): ...@@ -498,32 +549,42 @@ class TemporalTransformerConv(nn.Module):
self.layers = layers self.layers = layers
self.preprocessor = TemporalEdgePreprocess( self.preprocessor = TemporalEdgePreprocess(
self._edge_feats, self.temporal_encoder) self._edge_feats, self.temporal_encoder
)
self.layer_list = nn.ModuleList() self.layer_list = nn.ModuleList()
self.layer_list.append(EdgeGATConv(node_feats=self._memory_feats, self.layer_list.append(
edge_feats=self._edge_feats+self.temporal_encoder.dimension, EdgeGATConv(
out_feats=self._out_feats, node_feats=self._memory_feats,
num_heads=self._num_heads, edge_feats=self._edge_feats + self.temporal_encoder.dimension,
feat_drop=0.6, out_feats=self._out_feats,
attn_drop=0.6, num_heads=self._num_heads,
residual=True, feat_drop=0.6,
allow_zero_in_degree=allow_zero_in_degree)) attn_drop=0.6,
for i in range(self.layers-1): residual=True,
self.layer_list.append(EdgeGATConv(node_feats=self._out_feats*self._num_heads, allow_zero_in_degree=allow_zero_in_degree,
edge_feats=self._edge_feats+self.temporal_encoder.dimension, )
out_feats=self._out_feats, )
num_heads=self._num_heads, for i in range(self.layers - 1):
feat_drop=0.6, self.layer_list.append(
attn_drop=0.6, EdgeGATConv(
residual=True, node_feats=self._out_feats * self._num_heads,
allow_zero_in_degree=allow_zero_in_degree)) edge_feats=self._edge_feats
+ self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree,
)
)
def forward(self, graph, memory, ts): def forward(self, graph, memory, ts):
graph = graph.local_var() graph = graph.local_var()
graph.ndata['timestamp'] = ts graph.ndata["timestamp"] = ts
efeat = self.preprocessor(graph).float() efeat = self.preprocessor(graph).float()
rst = memory rst = memory
for i in range(self.layers-1): for i in range(self.layers - 1):
rst = self.layer_list[i](graph, rst, efeat).flatten(1) rst = self.layer_list[i](graph, rst, efeat).flatten(1)
rst = self.layer_list[-1](graph, rst, efeat).mean(1) rst = self.layer_list[-1](graph, rst, efeat).mean(1)
return rst return rst
import copy import copy
import torch.nn as nn import torch.nn as nn
from modules import (
MemoryModule,
MemoryOperation,
MsgLinkPredictor,
TemporalTransformerConv,
TimeEncode,
)
import dgl import dgl
from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode
class TGN(nn.Module): class TGN(nn.Module):
def __init__(self, def __init__(
edge_feat_dim, self,
memory_dim, edge_feat_dim,
temporal_dim, memory_dim,
embedding_dim, temporal_dim,
num_heads, embedding_dim,
num_nodes, num_heads,
n_neighbors=10, num_nodes,
memory_updater_type='gru', n_neighbors=10,
layers=1): memory_updater_type="gru",
layers=1,
):
super(TGN, self).__init__() super(TGN, self).__init__()
self.memory_dim = memory_dim self.memory_dim = memory_dim
self.edge_feat_dim = edge_feat_dim self.edge_feat_dim = edge_feat_dim
...@@ -28,43 +38,49 @@ class TGN(nn.Module): ...@@ -28,43 +38,49 @@ class TGN(nn.Module):
self.temporal_encoder = TimeEncode(self.temporal_dim) self.temporal_encoder = TimeEncode(self.temporal_dim)
self.memory = MemoryModule(self.num_nodes, self.memory = MemoryModule(self.num_nodes, self.memory_dim)
self.memory_dim)
self.memory_ops = MemoryOperation(self.memory_updater_type, self.memory_ops = MemoryOperation(
self.memory, self.memory_updater_type,
self.edge_feat_dim, self.memory,
self.temporal_encoder) self.edge_feat_dim,
self.temporal_encoder,
)
self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim, self.embedding_attn = TemporalTransformerConv(
self.memory_dim, self.edge_feat_dim,
self.temporal_encoder, self.memory_dim,
self.embedding_dim, self.temporal_encoder,
self.num_heads, self.embedding_dim,
layers=self.layers, self.num_heads,
allow_zero_in_degree=True) layers=self.layers,
allow_zero_in_degree=True,
)
self.msg_linkpredictor = MsgLinkPredictor(embedding_dim) self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)
def embed(self, postive_graph, negative_graph, blocks): def embed(self, postive_graph, negative_graph, blocks):
emb_graph = blocks[0] emb_graph = blocks[0]
emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :] emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
emb_t = emb_graph.ndata['timestamp'] emb_t = emb_graph.ndata["timestamp"]
embedding = self.embedding_attn(emb_graph, emb_memory, emb_t) embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
emb2pred = dict( emb2pred = dict(
zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist())) zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist())
)
# Since postive graph and negative graph has same is mapping # Since postive graph and negative graph has same is mapping
feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]] feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
feat = embedding[feat_id] feat = embedding[feat_id]
pred_pos, pred_neg = self.msg_linkpredictor( pred_pos, pred_neg = self.msg_linkpredictor(
feat, postive_graph, negative_graph) feat, postive_graph, negative_graph
)
return pred_pos, pred_neg return pred_pos, pred_neg
def update_memory(self, subg): def update_memory(self, subg):
new_g = self.memory_ops(subg) new_g = self.memory_ops(subg)
self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata['memory']) self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata["memory"])
self.memory.set_last_update_t( self.memory.set_last_update_t(
new_g.ndata[dgl.NID], new_g.ndata['timestamp']) new_g.ndata[dgl.NID], new_g.ndata["timestamp"]
)
# Some memory operation wrappers # Some memory operation wrappers
def detach_memory(self): def detach_memory(self):
...@@ -75,10 +91,10 @@ class TGN(nn.Module): ...@@ -75,10 +91,10 @@ class TGN(nn.Module):
def store_memory(self): def store_memory(self):
memory_checkpoint = {} memory_checkpoint = {}
memory_checkpoint['memory'] = copy.deepcopy(self.memory.memory) memory_checkpoint["memory"] = copy.deepcopy(self.memory.memory)
memory_checkpoint['last_t'] = copy.deepcopy(self.memory.last_update_t) memory_checkpoint["last_t"] = copy.deepcopy(self.memory.last_update_t)
return memory_checkpoint return memory_checkpoint
def restore_memory(self, memory_checkpoint): def restore_memory(self, memory_checkpoint):
self.memory.memory = memory_checkpoint['memory'] self.memory.memory = memory_checkpoint["memory"]
self.memory.last_update_time = memory_checkpoint['last_t'] self.memory.last_update_time = memory_checkpoint["last_t"]
import argparse import argparse
import traceback
import time
import copy import copy
import time
import traceback
import numpy as np import numpy as np
import dgl
import torch import torch
from data_preprocess import (
from tgn import TGN TemporalDataset,
from data_preprocess import TemporalWikipediaDataset, TemporalRedditDataset, TemporalDataset TemporalRedditDataset,
from dataloading import (FastTemporalEdgeCollator, FastTemporalSampler, TemporalWikipediaDataset,
SimpleTemporalEdgeCollator, SimpleTemporalSampler, )
TemporalEdgeDataLoader, TemporalSampler, TemporalEdgeCollator) from dataloading import (
FastTemporalEdgeCollator,
FastTemporalSampler,
SimpleTemporalEdgeCollator,
SimpleTemporalSampler,
TemporalEdgeCollator,
TemporalEdgeDataLoader,
TemporalSampler,
)
from sklearn.metrics import average_precision_score, roc_auc_score from sklearn.metrics import average_precision_score, roc_auc_score
from tgn import TGN
import dgl
TRAIN_SPLIT = 0.7 TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.85 VALID_SPLIT = 0.85
...@@ -32,10 +40,11 @@ def train(model, dataloader, sampler, criterion, optimizer, args): ...@@ -32,10 +40,11 @@ def train(model, dataloader, sampler, criterion, optimizer, args):
for _, positive_pair_g, negative_pair_g, blocks in dataloader: for _, positive_pair_g, negative_pair_g, blocks in dataloader:
optimizer.zero_grad() optimizer.zero_grad()
pred_pos, pred_neg = model.embed( pred_pos, pred_neg = model.embed(
positive_pair_g, negative_pair_g, blocks) positive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos)) loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg)) loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)*args.batch_size total_loss += float(loss) * args.batch_size
retain_graph = True if batch_cnt == 0 and not args.fast_mode else False retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
loss.backward(retain_graph=retain_graph) loss.backward(retain_graph=retain_graph)
optimizer.step() optimizer.step()
...@@ -44,7 +53,7 @@ def train(model, dataloader, sampler, criterion, optimizer, args): ...@@ -44,7 +53,7 @@ def train(model, dataloader, sampler, criterion, optimizer, args):
model.update_memory(positive_pair_g) model.update_memory(positive_pair_g)
if args.fast_mode: if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t) sampler.attach_last_update(model.memory.last_update_t)
print("Batch: ", batch_cnt, "Time: ", time.time()-last_t) print("Batch: ", batch_cnt, "Time: ", time.time() - last_t)
last_t = time.time() last_t = time.time()
batch_cnt += 1 batch_cnt += 1
return total_loss return total_loss
...@@ -59,13 +68,16 @@ def test_val(model, dataloader, sampler, criterion, args): ...@@ -59,13 +68,16 @@ def test_val(model, dataloader, sampler, criterion, args):
with torch.no_grad(): with torch.no_grad():
for _, postive_pair_g, negative_pair_g, blocks in dataloader: for _, postive_pair_g, negative_pair_g, blocks in dataloader:
pred_pos, pred_neg = model.embed( pred_pos, pred_neg = model.embed(
postive_pair_g, negative_pair_g, blocks) postive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos)) loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg)) loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)*batch_size total_loss += float(loss) * batch_size
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu() y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat( y_true = torch.cat(
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0) [torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))],
dim=0,
)
if not args.not_use_memory: if not args.not_use_memory:
model.update_memory(postive_pair_g) model.update_memory(postive_pair_g)
if args.fast_mode: if args.fast_mode:
...@@ -79,52 +91,108 @@ def test_val(model, dataloader, sampler, criterion, args): ...@@ -79,52 +91,108 @@ def test_val(model, dataloader, sampler, criterion, args):
if __name__ == "__main__": if __name__ == "__main__":
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=50, parser.add_argument(
help='epochs for training on entire dataset') "--epochs",
parser.add_argument("--batch_size", type=int, type=int,
default=200, help="Size of each batch") default=50,
parser.add_argument("--embedding_dim", type=int, default=100, help="epochs for training on entire dataset",
help="Embedding dim for link prediction") )
parser.add_argument("--memory_dim", type=int, default=100, parser.add_argument(
help="dimension of memory") "--batch_size", type=int, default=200, help="Size of each batch"
parser.add_argument("--temporal_dim", type=int, default=100, )
help="Temporal dimension for time encoding") parser.add_argument(
parser.add_argument("--memory_updater", type=str, default='gru', "--embedding_dim",
help="Recurrent unit for memory update") type=int,
parser.add_argument("--aggregator", type=str, default='last', default=100,
help="Aggregation method for memory update") help="Embedding dim for link prediction",
parser.add_argument("--n_neighbors", type=int, default=10, )
help="number of neighbors while doing embedding") parser.add_argument(
parser.add_argument("--sampling_method", type=str, default='topk', "--memory_dim", type=int, default=100, help="dimension of memory"
help="In embedding how node aggregate from its neighor") )
parser.add_argument("--num_heads", type=int, default=8, parser.add_argument(
help="Number of heads for multihead attention mechanism") "--temporal_dim",
parser.add_argument("--fast_mode", action="store_true", default=False, type=int,
help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained") default=100,
parser.add_argument("--simple_mode", action="store_true", default=False, help="Temporal dimension for time encoding",
help="Simple Mode directly delete the temporal edges from the original static graph") )
parser.add_argument("--num_negative_samples", type=int, default=1, parser.add_argument(
help="number of negative samplers per positive samples") "--memory_updater",
parser.add_argument("--dataset", type=str, default="wikipedia", type=str,
help="dataset selection wikipedia/reddit") default="gru",
parser.add_argument("--k_hop", type=int, default=1, help="Recurrent unit for memory update",
help="sampling k-hop neighborhood") )
parser.add_argument("--not_use_memory", action="store_true", default=False, parser.add_argument(
help="Enable memory for TGN Model disable memory for TGN Model") "--aggregator",
type=str,
default="last",
help="Aggregation method for memory update",
)
parser.add_argument(
"--n_neighbors",
type=int,
default=10,
help="number of neighbors while doing embedding",
)
parser.add_argument(
"--sampling_method",
type=str,
default="topk",
help="In embedding how node aggregate from its neighor",
)
parser.add_argument(
"--num_heads",
type=int,
default=8,
help="Number of heads for multihead attention mechanism",
)
parser.add_argument(
"--fast_mode",
action="store_true",
default=False,
help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained",
)
parser.add_argument(
"--simple_mode",
action="store_true",
default=False,
help="Simple Mode directly delete the temporal edges from the original static graph",
)
parser.add_argument(
"--num_negative_samples",
type=int,
default=1,
help="number of negative samplers per positive samples",
)
parser.add_argument(
"--dataset",
type=str,
default="wikipedia",
help="dataset selection wikipedia/reddit",
)
parser.add_argument(
"--k_hop", type=int, default=1, help="sampling k-hop neighborhood"
)
parser.add_argument(
"--not_use_memory",
action="store_true",
default=False,
help="Enable memory for TGN Model disable memory for TGN Model",
)
args = parser.parse_args() args = parser.parse_args()
assert not ( assert not (
args.fast_mode and args.simple_mode), "you can only choose one sampling mode" args.fast_mode and args.simple_mode
), "you can only choose one sampling mode"
if args.k_hop != 1: if args.k_hop != 1:
assert args.simple_mode, "this k-hop parameter only support simple mode" assert args.simple_mode, "this k-hop parameter only support simple mode"
if args.dataset == 'wikipedia': if args.dataset == "wikipedia":
data = TemporalWikipediaDataset() data = TemporalWikipediaDataset()
elif args.dataset == 'reddit': elif args.dataset == "reddit":
data = TemporalRedditDataset() data = TemporalRedditDataset()
else: else:
print("Warning Using Untested Dataset: "+args.dataset) print("Warning Using Untested Dataset: " + args.dataset)
data = TemporalDataset(args.dataset) data = TemporalDataset(args.dataset)
# Pre-process data, mask new node in test set from original graph # Pre-process data, mask new node in test set from original graph
...@@ -132,22 +200,33 @@ if __name__ == "__main__": ...@@ -132,22 +200,33 @@ if __name__ == "__main__":
num_edges = data.num_edges() num_edges = data.num_edges()
num_edges = data.num_edges() num_edges = data.num_edges()
trainval_div = int(VALID_SPLIT*num_edges) trainval_div = int(VALID_SPLIT * num_edges)
# Select new node from test set and remove them from entire graph # Select new node from test set and remove them from entire graph
test_split_ts = data.edata['timestamp'][trainval_div] test_split_ts = data.edata["timestamp"][trainval_div]
test_nodes = torch.cat([data.edges()[0][trainval_div:], data.edges()[ test_nodes = (
1][trainval_div:]]).unique().numpy() torch.cat(
[data.edges()[0][trainval_div:], data.edges()[1][trainval_div:]]
)
.unique()
.numpy()
)
test_new_nodes = np.random.choice( test_new_nodes = np.random.choice(
test_nodes, int(0.1*len(test_nodes)), replace=False) test_nodes, int(0.1 * len(test_nodes)), replace=False
)
in_subg = dgl.in_subgraph(data, test_new_nodes) in_subg = dgl.in_subgraph(data, test_new_nodes)
out_subg = dgl.out_subgraph(data, test_new_nodes) out_subg = dgl.out_subgraph(data, test_new_nodes)
# Remove edge who happen before the test set to prevent from learning the connection info # Remove edge who happen before the test set to prevent from learning the connection info
new_node_in_eid_delete = in_subg.edata[dgl.EID][in_subg.edata['timestamp'] < test_split_ts] new_node_in_eid_delete = in_subg.edata[dgl.EID][
new_node_out_eid_delete = out_subg.edata[dgl.EID][out_subg.edata['timestamp'] < test_split_ts] in_subg.edata["timestamp"] < test_split_ts
]
new_node_out_eid_delete = out_subg.edata[dgl.EID][
out_subg.edata["timestamp"] < test_split_ts
]
new_node_eid_delete = torch.cat( new_node_eid_delete = torch.cat(
[new_node_in_eid_delete, new_node_out_eid_delete]).unique() [new_node_in_eid_delete, new_node_out_eid_delete]
).unique()
graph_new_node = copy.deepcopy(data) graph_new_node = copy.deepcopy(data)
# relative order preseved # relative order preseved
...@@ -178,118 +257,153 @@ if __name__ == "__main__": ...@@ -178,118 +257,153 @@ if __name__ == "__main__":
edge_collator = TemporalEdgeCollator edge_collator = TemporalEdgeCollator
neg_sampler = dgl.dataloading.negative_sampler.Uniform( neg_sampler = dgl.dataloading.negative_sampler.Uniform(
k=args.num_negative_samples) k=args.num_negative_samples
)
# Set Train, validation, test and new node test id # Set Train, validation, test and new node test id
train_seed = torch.arange(int(TRAIN_SPLIT*graph_no_new_node.num_edges())) train_seed = torch.arange(int(TRAIN_SPLIT * graph_no_new_node.num_edges()))
valid_seed = torch.arange(int( valid_seed = torch.arange(
TRAIN_SPLIT*graph_no_new_node.num_edges()), trainval_div-new_node_eid_delete.size(0)) int(TRAIN_SPLIT * graph_no_new_node.num_edges()),
trainval_div - new_node_eid_delete.size(0),
)
test_seed = torch.arange( test_seed = torch.arange(
trainval_div-new_node_eid_delete.size(0), graph_no_new_node.num_edges()) trainval_div - new_node_eid_delete.size(0),
graph_no_new_node.num_edges(),
)
test_new_node_seed = torch.arange( test_new_node_seed = torch.arange(
trainval_div-new_node_eid_delete.size(0), graph_new_node.num_edges()) trainval_div - new_node_eid_delete.size(0), graph_new_node.num_edges()
)
g_sampling = None if args.fast_mode else dgl.add_reverse_edges(
graph_no_new_node, copy_edata=True) g_sampling = (
new_node_g_sampling = None if args.fast_mode else dgl.add_reverse_edges( None
graph_new_node, copy_edata=True) if args.fast_mode
else dgl.add_reverse_edges(graph_no_new_node, copy_edata=True)
)
new_node_g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_new_node, copy_edata=True)
)
if not args.fast_mode: if not args.fast_mode:
new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes() new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes() g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
# we highly recommend that you always set the num_workers=0, otherwise the sampled subgraph may not be correct. # we highly recommend that you always set the num_workers=0, otherwise the sampled subgraph may not be correct.
train_dataloader = TemporalEdgeDataLoader(graph_no_new_node, train_dataloader = TemporalEdgeDataLoader(
train_seed, graph_no_new_node,
sampler, train_seed,
batch_size=args.batch_size, sampler,
negative_sampler=neg_sampler, batch_size=args.batch_size,
shuffle=False, negative_sampler=neg_sampler,
drop_last=False, shuffle=False,
num_workers=0, drop_last=False,
collator=edge_collator, num_workers=0,
g_sampling=g_sampling) collator=edge_collator,
g_sampling=g_sampling,
valid_dataloader = TemporalEdgeDataLoader(graph_no_new_node, )
valid_seed,
sampler, valid_dataloader = TemporalEdgeDataLoader(
batch_size=args.batch_size, graph_no_new_node,
negative_sampler=neg_sampler, valid_seed,
shuffle=False, sampler,
drop_last=False, batch_size=args.batch_size,
num_workers=0, negative_sampler=neg_sampler,
collator=edge_collator, shuffle=False,
g_sampling=g_sampling) drop_last=False,
num_workers=0,
test_dataloader = TemporalEdgeDataLoader(graph_no_new_node, collator=edge_collator,
test_seed, g_sampling=g_sampling,
sampler, )
batch_size=args.batch_size,
negative_sampler=neg_sampler, test_dataloader = TemporalEdgeDataLoader(
shuffle=False, graph_no_new_node,
drop_last=False, test_seed,
num_workers=0, sampler,
collator=edge_collator, batch_size=args.batch_size,
g_sampling=g_sampling) negative_sampler=neg_sampler,
shuffle=False,
test_new_node_dataloader = TemporalEdgeDataLoader(graph_new_node, drop_last=False,
test_new_node_seed, num_workers=0,
new_node_sampler if args.fast_mode else sampler, collator=edge_collator,
batch_size=args.batch_size, g_sampling=g_sampling,
negative_sampler=neg_sampler, )
shuffle=False,
drop_last=False, test_new_node_dataloader = TemporalEdgeDataLoader(
num_workers=0, graph_new_node,
collator=edge_collator, test_new_node_seed,
g_sampling=new_node_g_sampling) new_node_sampler if args.fast_mode else sampler,
batch_size=args.batch_size,
edge_dim = data.edata['feats'].shape[1] negative_sampler=neg_sampler,
shuffle=False,
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=new_node_g_sampling,
)
edge_dim = data.edata["feats"].shape[1]
num_node = data.num_nodes() num_node = data.num_nodes()
model = TGN(edge_feat_dim=edge_dim, model = TGN(
memory_dim=args.memory_dim, edge_feat_dim=edge_dim,
temporal_dim=args.temporal_dim, memory_dim=args.memory_dim,
embedding_dim=args.embedding_dim, temporal_dim=args.temporal_dim,
num_heads=args.num_heads, embedding_dim=args.embedding_dim,
num_nodes=num_node, num_heads=args.num_heads,
n_neighbors=args.n_neighbors, num_nodes=num_node,
memory_updater_type=args.memory_updater, n_neighbors=args.n_neighbors,
layers=args.k_hop) memory_updater_type=args.memory_updater,
layers=args.k_hop,
)
criterion = torch.nn.BCEWithLogitsLoss() criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001) optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# Implement Logging mechanism # Implement Logging mechanism
f = open("logging.txt", 'w') f = open("logging.txt", "w")
if args.fast_mode: if args.fast_mode:
sampler.reset() sampler.reset()
try: try:
for i in range(args.epochs): for i in range(args.epochs):
train_loss = train(model, train_dataloader, sampler, train_loss = train(
criterion, optimizer, args) model, train_dataloader, sampler, criterion, optimizer, args
)
val_ap, val_auc = test_val( val_ap, val_auc = test_val(
model, valid_dataloader, sampler, criterion, args) model, valid_dataloader, sampler, criterion, args
)
memory_checkpoint = model.store_memory() memory_checkpoint = model.store_memory()
if args.fast_mode: if args.fast_mode:
new_node_sampler.sync(sampler) new_node_sampler.sync(sampler)
test_ap, test_auc = test_val( test_ap, test_auc = test_val(
model, test_dataloader, sampler, criterion, args) model, test_dataloader, sampler, criterion, args
)
model.restore_memory(memory_checkpoint) model.restore_memory(memory_checkpoint)
if args.fast_mode: if args.fast_mode:
sample_nn = new_node_sampler sample_nn = new_node_sampler
else: else:
sample_nn = sampler sample_nn = sampler
nn_test_ap, nn_test_auc = test_val( nn_test_ap, nn_test_auc = test_val(
model, test_new_node_dataloader, sample_nn, criterion, args) model, test_new_node_dataloader, sample_nn, criterion, args
)
log_content = [] log_content = []
log_content.append("Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc))
log_content.append( log_content.append(
"Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(i, test_ap, test_auc)) "Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
log_content.append("Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format( i, train_loss, val_ap, val_auc
i, nn_test_ap, nn_test_auc)) )
)
log_content.append(
"Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(
i, test_ap, test_auc
)
)
log_content.append(
"Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
i, nn_test_ap, nn_test_auc
)
)
f.writelines(log_content) f.writelines(log_content)
model.reset_memory() model.reset_memory()
if i < args.epochs-1 and args.fast_mode: if i < args.epochs - 1 and args.fast_mode:
sampler.reset() sampler.reset()
print(log_content[0], log_content[1], log_content[2]) print(log_content[0], log_content[1], log_content[2])
except KeyboardInterrupt: except KeyboardInterrupt:
......
from .graph import *
from .fields import *
from .utils import prepare_dataset
import os import os
import random import random
from .fields import *
from .graph import *
from .utils import prepare_dataset
class ClassificationDataset(object): class ClassificationDataset(object):
"Dataset class for classification task." "Dataset class for classification task."
def __init__(self): def __init__(self):
raise NotImplementedError raise NotImplementedError
class TranslationDataset(object): class TranslationDataset(object):
''' """
Dataset class for translation task. Dataset class for translation task.
By default, the source language shares the same vocabulary with the target language. By default, the source language shares the same vocabulary with the target language.
''' """
INIT_TOKEN = '<sos>'
EOS_TOKEN = '<eos>' INIT_TOKEN = "<sos>"
PAD_TOKEN = '<pad>' EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
MAX_LENGTH = 50 MAX_LENGTH = 50
def __init__(self, path, exts, train='train', valid='valid', test='test', vocab='vocab.txt', replace_oov=None):
def __init__(
self,
path,
exts,
train="train",
valid="valid",
test="test",
vocab="vocab.txt",
replace_oov=None,
):
vocab_path = os.path.join(path, vocab) vocab_path = os.path.join(path, vocab)
self.src = {} self.src = {}
self.tgt = {} self.tgt = {}
with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f: with open(
self.src['train'] = f.readlines() os.path.join(path, train + "." + exts[0]), "r", encoding="utf-8"
with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f: ) as f:
self.tgt['train'] = f.readlines() self.src["train"] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f: with open(
self.src['valid'] = f.readlines() os.path.join(path, train + "." + exts[1]), "r", encoding="utf-8"
with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f: ) as f:
self.tgt['valid'] = f.readlines() self.tgt["train"] = f.readlines()
with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f: with open(
self.src['test'] = f.readlines() os.path.join(path, valid + "." + exts[0]), "r", encoding="utf-8"
with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f: ) as f:
self.tgt['test'] = f.readlines() self.src["valid"] = f.readlines()
with open(
os.path.join(path, valid + "." + exts[1]), "r", encoding="utf-8"
) as f:
self.tgt["valid"] = f.readlines()
with open(
os.path.join(path, test + "." + exts[0]), "r", encoding="utf-8"
) as f:
self.src["test"] = f.readlines()
with open(
os.path.join(path, test + "." + exts[1]), "r", encoding="utf-8"
) as f:
self.tgt["test"] = f.readlines()
if not os.path.exists(vocab_path): if not os.path.exists(vocab_path):
self._make_vocab(vocab_path) self._make_vocab(vocab_path)
vocab = Vocab(init_token=self.INIT_TOKEN, vocab = Vocab(
eos_token=self.EOS_TOKEN, init_token=self.INIT_TOKEN,
pad_token=self.PAD_TOKEN, eos_token=self.EOS_TOKEN,
unk_token=replace_oov) pad_token=self.PAD_TOKEN,
unk_token=replace_oov,
)
vocab.load(vocab_path) vocab.load(vocab_path)
self.vocab = vocab self.vocab = vocab
strip_func = lambda x: x[:self.MAX_LENGTH] strip_func = lambda x: x[: self.MAX_LENGTH]
self.src_field = Field(vocab, self.src_field = Field(
preprocessing=None, vocab, preprocessing=None, postprocessing=strip_func
postprocessing=strip_func) )
self.tgt_field = Field(vocab, self.tgt_field = Field(
preprocessing=lambda seq: [self.INIT_TOKEN] + seq + [self.EOS_TOKEN], vocab,
postprocessing=strip_func) preprocessing=lambda seq: [self.INIT_TOKEN]
+ seq
def get_seq_by_id(self, idx, mode='train', field='src'): + [self.EOS_TOKEN],
postprocessing=strip_func,
)
def get_seq_by_id(self, idx, mode="train", field="src"):
"get raw sequence in dataset by specifying index, mode(train/valid/test), field(src/tgt)" "get raw sequence in dataset by specifying index, mode(train/valid/test), field(src/tgt)"
if field == 'src': if field == "src":
return self.src[mode][idx].strip().split() return self.src[mode][idx].strip().split()
else: else:
return [self.INIT_TOKEN] + self.tgt[mode][idx].strip().split() + [self.EOS_TOKEN] return (
[self.INIT_TOKEN]
+ self.tgt[mode][idx].strip().split()
+ [self.EOS_TOKEN]
)
def _make_vocab(self, path, thres=2): def _make_vocab(self, path, thres=2):
word_dict = {} word_dict = {}
for mode in ['train', 'valid', 'test']: for mode in ["train", "valid", "test"]:
for line in self.src[mode] + self.tgt[mode]: for line in self.src[mode] + self.tgt[mode]:
for token in line.strip().split(): for token in line.strip().split():
if token not in word_dict: if token not in word_dict:
...@@ -69,7 +106,7 @@ class TranslationDataset(object): ...@@ -69,7 +106,7 @@ class TranslationDataset(object):
else: else:
word_dict[token] += 1 word_dict[token] += 1
with open(path, 'w') as f: with open(path, "w") as f:
for k, v in word_dict.items(): for k, v in word_dict.items():
if v > 2: if v > 2:
print(k, file=f) print(k, file=f)
...@@ -90,9 +127,17 @@ class TranslationDataset(object): ...@@ -90,9 +127,17 @@ class TranslationDataset(object):
def eos_id(self): def eos_id(self):
return self.vocab[self.EOS_TOKEN] return self.vocab[self.EOS_TOKEN]
def __call__(self, graph_pool, mode='train', batch_size=32, k=1, def __call__(
device='cpu', dev_rank=0, ndev=1): self,
''' graph_pool,
mode="train",
batch_size=32,
k=1,
device="cpu",
dev_rank=0,
ndev=1,
):
"""
Create a batched graph correspond to the mini-batch of the dataset. Create a batched graph correspond to the mini-batch of the dataset.
args: args:
graph_pool: a GraphPool object for accelerating. graph_pool: a GraphPool object for accelerating.
...@@ -102,7 +147,7 @@ class TranslationDataset(object): ...@@ -102,7 +147,7 @@ class TranslationDataset(object):
device: str or torch.device device: str or torch.device
dev_rank: rank (id) of current device dev_rank: rank (id) of current device
ndev: number of devices ndev: number of devices
''' """
src_data, tgt_data = self.src[mode], self.tgt[mode] src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data) n = len(src_data)
# make sure all devices have the same number of batch # make sure all devices have the same number of batch
...@@ -111,28 +156,30 @@ class TranslationDataset(object): ...@@ -111,28 +156,30 @@ class TranslationDataset(object):
# XXX: partition then shuffle may not be equivalent to shuffle then # XXX: partition then shuffle may not be equivalent to shuffle then
# partition # partition
order = list(range(dev_rank, n, ndev)) order = list(range(dev_rank, n, ndev))
if mode == 'train': if mode == "train":
random.shuffle(order) random.shuffle(order)
src_buf, tgt_buf = [], [] src_buf, tgt_buf = [], []
for idx in order: for idx in order:
src_sample = self.src_field( src_sample = self.src_field(src_data[idx].strip().split())
src_data[idx].strip().split()) tgt_sample = self.tgt_field(tgt_data[idx].strip().split())
tgt_sample = self.tgt_field(
tgt_data[idx].strip().split())
src_buf.append(src_sample) src_buf.append(src_sample)
tgt_buf.append(tgt_sample) tgt_buf.append(tgt_sample)
if len(src_buf) == batch_size: if len(src_buf) == batch_size:
if mode == 'test': if mode == "test":
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) yield graph_pool.beam(
src_buf, self.sos_id, self.MAX_LENGTH, k, device=device
)
else: else:
yield graph_pool(src_buf, tgt_buf, device=device) yield graph_pool(src_buf, tgt_buf, device=device)
src_buf, tgt_buf = [], [] src_buf, tgt_buf = [], []
if len(src_buf) != 0: if len(src_buf) != 0:
if mode == 'test': if mode == "test":
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device) yield graph_pool.beam(
src_buf, self.sos_id, self.MAX_LENGTH, k, device=device
)
else: else:
yield graph_pool(src_buf, tgt_buf, device=device) yield graph_pool(src_buf, tgt_buf, device=device)
...@@ -145,38 +192,46 @@ class TranslationDataset(object): ...@@ -145,38 +192,46 @@ class TranslationDataset(object):
l = seq.index(self.eos_id) l = seq.index(self.eos_id)
except: except:
l = len(seq) l = len(seq)
ret.append(' '.join(self.vocab[token] for token in seq[:l] if not token in filter_list)) ret.append(
" ".join(
self.vocab[token]
for token in seq[:l]
if not token in filter_list
)
)
return ret return ret
def get_dataset(dataset): def get_dataset(dataset):
"we wrapped a set of datasets as example" "we wrapped a set of datasets as example"
prepare_dataset(dataset) prepare_dataset(dataset)
if dataset == 'babi': if dataset == "babi":
raise NotImplementedError raise NotImplementedError
elif dataset == 'copy' or dataset == 'sort': elif dataset == "copy" or dataset == "sort":
return TranslationDataset( return TranslationDataset(
'data/{}'.format(dataset), "data/{}".format(dataset),
('in', 'out'), ("in", "out"),
train='train', train="train",
valid='valid', valid="valid",
test='test', test="test",
) )
elif dataset == 'multi30k': elif dataset == "multi30k":
return TranslationDataset( return TranslationDataset(
'data/multi30k', "data/multi30k",
('en.atok', 'de.atok'), ("en.atok", "de.atok"),
train='train', train="train",
valid='val', valid="val",
test='test2016', test="test2016",
replace_oov='<unk>' replace_oov="<unk>",
) )
elif dataset == 'wmt14': elif dataset == "wmt14":
return TranslationDataset( return TranslationDataset(
'data/wmt14', "data/wmt14",
('en', 'de'), ("en", "de"),
train='train.tok.clean.bpe.32000', train="train.tok.clean.bpe.32000",
valid='newstest2013.tok.bpe.32000', valid="newstest2013.tok.bpe.32000",
test='newstest2014.tok.bpe.32000.ende', test="newstest2014.tok.bpe.32000.ende",
vocab='vocab.bpe.32000') vocab="vocab.bpe.32000",
)
else: else:
raise KeyError() raise KeyError()
class Vocab: class Vocab:
def __init__(self, init_token=None, eos_token=None, pad_token=None, unk_token=None): def __init__(
self, init_token=None, eos_token=None, pad_token=None, unk_token=None
):
self.init_token = init_token self.init_token = init_token
self.eos_token = eos_token self.eos_token = eos_token
self.pad_token = pad_token self.pad_token = pad_token
...@@ -16,13 +18,11 @@ class Vocab: ...@@ -16,13 +18,11 @@ class Vocab:
self.vocab_lst.append(self.pad_token) self.vocab_lst.append(self.pad_token)
if self.unk_token is not None: if self.unk_token is not None:
self.vocab_lst.append(self.unk_token) self.vocab_lst.append(self.unk_token)
with open(path, 'r', encoding='utf-8') as f: with open(path, "r", encoding="utf-8") as f:
for token in f.readlines(): for token in f.readlines():
token = token.strip() token = token.strip()
self.vocab_lst.append(token) self.vocab_lst.append(token)
self.vocab_dict = { self.vocab_dict = {v: k for k, v in enumerate(self.vocab_lst)}
v: k for k, v in enumerate(self.vocab_lst)
}
def __len__(self): def __len__(self):
return len(self.vocab_lst) return len(self.vocab_lst)
...@@ -36,6 +36,7 @@ class Vocab: ...@@ -36,6 +36,7 @@ class Vocab:
else: else:
return self.vocab_lst[key] return self.vocab_lst[key]
class Field: class Field:
def __init__(self, vocab, preprocessing=None, postprocessing=None): def __init__(self, vocab, preprocessing=None, postprocessing=None):
self.vocab = vocab self.vocab = vocab
...@@ -56,8 +57,4 @@ class Field: ...@@ -56,8 +57,4 @@ class Field:
return [self.vocab[token] for token in x] return [self.vocab[token] for token in x]
def __call__(self, x): def __call__(self, x):
return self.postprocess( return self.postprocess(self.numericalize(self.preprocess(x)))
self.numericalize(
self.preprocess(x)
)
)
import dgl
import torch as th
import numpy as np
import itertools import itertools
import time import time
from collections import * from collections import *
Graph = namedtuple('Graph', import numpy as np
['g', 'src', 'tgt', 'tgt_y', 'nids', 'eids', 'nid_arr', 'n_nodes', 'n_edges', 'n_tokens']) import torch as th
import dgl
Graph = namedtuple(
"Graph",
[
"g",
"src",
"tgt",
"tgt_y",
"nids",
"eids",
"nid_arr",
"n_nodes",
"n_edges",
"n_tokens",
],
)
class GraphPool: class GraphPool:
"Create a graph pool in advance to accelerate graph building phase in Transformer." "Create a graph pool in advance to accelerate graph building phase in Transformer."
def __init__(self, n=50, m=50): def __init__(self, n=50, m=50):
''' """
args: args:
n: maximum length of input sequence. n: maximum length of input sequence.
m: maximum length of output sequence. m: maximum length of output sequence.
''' """
print('start creating graph pool...') print("start creating graph pool...")
tic = time.time() tic = time.time()
self.n, self.m = n, m self.n, self.m = n, m
g_pool = [[dgl.graph(([], [])) for _ in range(m)] for _ in range(n)] g_pool = [[dgl.graph(([], [])) for _ in range(m)] for _ in range(n)]
num_edges = { num_edges = {
'ee': np.zeros((n, n)).astype(int), "ee": np.zeros((n, n)).astype(int),
'ed': np.zeros((n, m)).astype(int), "ed": np.zeros((n, m)).astype(int),
'dd': np.zeros((m, m)).astype(int) "dd": np.zeros((m, m)).astype(int),
} }
for i, j in itertools.product(range(n), range(m)): for i, j in itertools.product(range(n), range(m)):
src_length = i + 1 src_length = i + 1
...@@ -37,42 +54,46 @@ class GraphPool: ...@@ -37,42 +54,46 @@ class GraphPool:
us = enc_nodes.unsqueeze(-1).repeat(1, src_length).view(-1) us = enc_nodes.unsqueeze(-1).repeat(1, src_length).view(-1)
vs = enc_nodes.repeat(src_length) vs = enc_nodes.repeat(src_length)
g_pool[i][j].add_edges(us, vs) g_pool[i][j].add_edges(us, vs)
num_edges['ee'][i][j] = len(us) num_edges["ee"][i][j] = len(us)
# enc -> dec # enc -> dec
us = enc_nodes.unsqueeze(-1).repeat(1, tgt_length).view(-1) us = enc_nodes.unsqueeze(-1).repeat(1, tgt_length).view(-1)
vs = dec_nodes.repeat(src_length) vs = dec_nodes.repeat(src_length)
g_pool[i][j].add_edges(us, vs) g_pool[i][j].add_edges(us, vs)
num_edges['ed'][i][j] = len(us) num_edges["ed"][i][j] = len(us)
# dec -> dec # dec -> dec
indices = th.triu(th.ones(tgt_length, tgt_length)) == 1 indices = th.triu(th.ones(tgt_length, tgt_length)) == 1
us = dec_nodes.unsqueeze(-1).repeat(1, tgt_length)[indices] us = dec_nodes.unsqueeze(-1).repeat(1, tgt_length)[indices]
vs = dec_nodes.unsqueeze(0).repeat(tgt_length, 1)[indices] vs = dec_nodes.unsqueeze(0).repeat(tgt_length, 1)[indices]
g_pool[i][j].add_edges(us, vs) g_pool[i][j].add_edges(us, vs)
num_edges['dd'][i][j] = len(us) num_edges["dd"][i][j] = len(us)
print('successfully created graph pool, time: {0:0.3f}s'.format(time.time() - tic)) print(
"successfully created graph pool, time: {0:0.3f}s".format(
time.time() - tic
)
)
self.g_pool = g_pool self.g_pool = g_pool
self.num_edges = num_edges self.num_edges = num_edges
def beam(self, src_buf, start_sym, max_len, k, device='cpu'): def beam(self, src_buf, start_sym, max_len, k, device="cpu"):
''' """
Return a batched graph for beam search during inference of Transformer. Return a batched graph for beam search during inference of Transformer.
args: args:
src_buf: a list of input sequence src_buf: a list of input sequence
start_sym: the index of start-of-sequence symbol start_sym: the index of start-of-sequence symbol
max_len: maximum length for decoding max_len: maximum length for decoding
k: beam size k: beam size
device: 'cpu' or 'cuda:*' device: 'cpu' or 'cuda:*'
''' """
g_list = [] g_list = []
src_lens = [len(_) for _ in src_buf] src_lens = [len(_) for _ in src_buf]
tgt_lens = [max_len] * len(src_buf) tgt_lens = [max_len] * len(src_buf)
num_edges = {'ee': [], 'ed': [], 'dd': []} num_edges = {"ee": [], "ed": [], "dd": []}
for src_len, tgt_len in zip(src_lens, tgt_lens): for src_len, tgt_len in zip(src_lens, tgt_lens):
i, j = src_len - 1, tgt_len - 1 i, j = src_len - 1, tgt_len - 1
for _ in range(k): for _ in range(k):
g_list.append(self.g_pool[i][j]) g_list.append(self.g_pool[i][j])
for key in ['ee', 'ed', 'dd']: for key in ["ee", "ed", "dd"]:
num_edges[key].append(int(self.num_edges[key][i][j])) num_edges[key].append(int(self.num_edges[key][i][j]))
g = dgl.batch(g_list) g = dgl.batch(g_list)
...@@ -81,57 +102,85 @@ class GraphPool: ...@@ -81,57 +102,85 @@ class GraphPool:
enc_ids, dec_ids = [], [] enc_ids, dec_ids = [], []
e2e_eids, e2d_eids, d2d_eids = [], [], [] e2e_eids, e2d_eids, d2d_eids = [], [], []
n_nodes, n_edges, n_tokens = 0, 0, 0 n_nodes, n_edges, n_tokens = 0, 0, 0
for src_sample, n, n_ee, n_ed, n_dd in zip(src_buf, src_lens, num_edges['ee'], num_edges['ed'], num_edges['dd']): for src_sample, n, n_ee, n_ed, n_dd in zip(
src_buf, src_lens, num_edges["ee"], num_edges["ed"], num_edges["dd"]
):
for _ in range(k): for _ in range(k):
src.append(th.tensor(src_sample, dtype=th.long, device=device)) src.append(th.tensor(src_sample, dtype=th.long, device=device))
src_pos.append(th.arange(n, dtype=th.long, device=device)) src_pos.append(th.arange(n, dtype=th.long, device=device))
enc_ids.append(th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)) enc_ids.append(
th.arange(
n_nodes, n_nodes + n, dtype=th.long, device=device
)
)
n_nodes += n n_nodes += n
e2e_eids.append(th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)) e2e_eids.append(
th.arange(
n_edges, n_edges + n_ee, dtype=th.long, device=device
)
)
n_edges += n_ee n_edges += n_ee
tgt_seq = th.zeros(max_len, dtype=th.long, device=device) tgt_seq = th.zeros(max_len, dtype=th.long, device=device)
tgt_seq[0] = start_sym tgt_seq[0] = start_sym
tgt.append(tgt_seq) tgt.append(tgt_seq)
tgt_pos.append(th.arange(max_len, dtype=th.long, device=device)) tgt_pos.append(th.arange(max_len, dtype=th.long, device=device))
dec_ids.append(th.arange(n_nodes, n_nodes + max_len, dtype=th.long, device=device)) dec_ids.append(
th.arange(
n_nodes, n_nodes + max_len, dtype=th.long, device=device
)
)
n_nodes += max_len n_nodes += max_len
e2d_eids.append(th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)) e2d_eids.append(
th.arange(
n_edges, n_edges + n_ed, dtype=th.long, device=device
)
)
n_edges += n_ed n_edges += n_ed
d2d_eids.append(th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)) d2d_eids.append(
th.arange(
n_edges, n_edges + n_dd, dtype=th.long, device=device
)
)
n_edges += n_dd n_edges += n_dd
g.set_n_initializer(dgl.init.zero_initializer) g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer)
g = g.to(device).long() g = g.to(device).long()
return Graph(g=g, return Graph(
src=(th.cat(src), th.cat(src_pos)), g=g,
tgt=(th.cat(tgt), th.cat(tgt_pos)), src=(th.cat(src), th.cat(src_pos)),
tgt_y=None, tgt=(th.cat(tgt), th.cat(tgt_pos)),
nids = {'enc': th.cat(enc_ids), 'dec': th.cat(dec_ids)}, tgt_y=None,
eids = {'ee': th.cat(e2e_eids), 'ed': th.cat(e2d_eids), 'dd': th.cat(d2d_eids)}, nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
nid_arr = {'enc': enc_ids, 'dec': dec_ids}, eids={
n_nodes=n_nodes, "ee": th.cat(e2e_eids),
n_edges=n_edges, "ed": th.cat(e2d_eids),
n_tokens=n_tokens) "dd": th.cat(d2d_eids),
},
def __call__(self, src_buf, tgt_buf, device='cpu'): nid_arr={"enc": enc_ids, "dec": dec_ids},
''' n_nodes=n_nodes,
n_edges=n_edges,
n_tokens=n_tokens,
)
def __call__(self, src_buf, tgt_buf, device="cpu"):
"""
Return a batched graph for the training phase of Transformer. Return a batched graph for the training phase of Transformer.
args: args:
src_buf: a set of input sequence arrays. src_buf: a set of input sequence arrays.
tgt_buf: a set of output sequence arrays. tgt_buf: a set of output sequence arrays.
device: 'cpu' or 'cuda:*' device: 'cpu' or 'cuda:*'
''' """
g_list = [] g_list = []
src_lens = [len(_) for _ in src_buf] src_lens = [len(_) for _ in src_buf]
tgt_lens = [len(_) - 1 for _ in tgt_buf] tgt_lens = [len(_) - 1 for _ in tgt_buf]
num_edges = {'ee': [], 'ed': [], 'dd': []} num_edges = {"ee": [], "ed": [], "dd": []}
for src_len, tgt_len in zip(src_lens, tgt_lens): for src_len, tgt_len in zip(src_lens, tgt_lens):
i, j = src_len - 1, tgt_len - 1 i, j = src_len - 1, tgt_len - 1
g_list.append(self.g_pool[i][j]) g_list.append(self.g_pool[i][j])
for key in ['ee', 'ed', 'dd']: for key in ["ee", "ed", "dd"]:
num_edges[key].append(int(self.num_edges[key][i][j])) num_edges[key].append(int(self.num_edges[key][i][j]))
g = dgl.batch(g_list) g = dgl.batch(g_list)
...@@ -140,36 +189,61 @@ class GraphPool: ...@@ -140,36 +189,61 @@ class GraphPool:
enc_ids, dec_ids = [], [] enc_ids, dec_ids = [], []
e2e_eids, d2d_eids, e2d_eids = [], [], [] e2e_eids, d2d_eids, e2d_eids = [], [], []
n_nodes, n_edges, n_tokens = 0, 0, 0 n_nodes, n_edges, n_tokens = 0, 0, 0
for src_sample, tgt_sample, n, m, n_ee, n_ed, n_dd in zip(src_buf, tgt_buf, src_lens, tgt_lens, num_edges['ee'], num_edges['ed'], num_edges['dd']): for src_sample, tgt_sample, n, m, n_ee, n_ed, n_dd in zip(
src_buf,
tgt_buf,
src_lens,
tgt_lens,
num_edges["ee"],
num_edges["ed"],
num_edges["dd"],
):
src.append(th.tensor(src_sample, dtype=th.long, device=device)) src.append(th.tensor(src_sample, dtype=th.long, device=device))
tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long, device=device)) tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long, device=device))
tgt_y.append(th.tensor(tgt_sample[1:], dtype=th.long, device=device)) tgt_y.append(
th.tensor(tgt_sample[1:], dtype=th.long, device=device)
)
src_pos.append(th.arange(n, dtype=th.long, device=device)) src_pos.append(th.arange(n, dtype=th.long, device=device))
tgt_pos.append(th.arange(m, dtype=th.long, device=device)) tgt_pos.append(th.arange(m, dtype=th.long, device=device))
enc_ids.append(th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)) enc_ids.append(
th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)
)
n_nodes += n n_nodes += n
dec_ids.append(th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device)) dec_ids.append(
th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device)
)
n_nodes += m n_nodes += m
e2e_eids.append(th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)) e2e_eids.append(
th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)
)
n_edges += n_ee n_edges += n_ee
e2d_eids.append(th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)) e2d_eids.append(
th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)
)
n_edges += n_ed n_edges += n_ed
d2d_eids.append(th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)) d2d_eids.append(
th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)
)
n_edges += n_dd n_edges += n_dd
n_tokens += m n_tokens += m
g.set_n_initializer(dgl.init.zero_initializer) g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer) g.set_e_initializer(dgl.init.zero_initializer)
g = g.to(device).long() g = g.to(device).long()
return Graph(g=g, return Graph(
src=(th.cat(src), th.cat(src_pos)), g=g,
tgt=(th.cat(tgt), th.cat(tgt_pos)), src=(th.cat(src), th.cat(src_pos)),
tgt_y=th.cat(tgt_y), tgt=(th.cat(tgt), th.cat(tgt_pos)),
nids = {'enc': th.cat(enc_ids), 'dec': th.cat(dec_ids)}, tgt_y=th.cat(tgt_y),
eids = {'ee': th.cat(e2e_eids), 'ed': th.cat(e2d_eids), 'dd': th.cat(d2d_eids)}, nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
nid_arr = {'enc': enc_ids, 'dec': dec_ids}, eids={
n_nodes=n_nodes, "ee": th.cat(e2e_eids),
n_edges=n_edges, "ed": th.cat(e2d_eids),
n_tokens=n_tokens) "dd": th.cat(d2d_eids),
},
nid_arr={"enc": enc_ids, "dec": dec_ids},
n_nodes=n_nodes,
n_edges=n_edges,
n_tokens=n_tokens,
)
import os
import numpy as np import numpy as np
import torch as th import torch as th
import os
from dgl.data.utils import * from dgl.data.utils import *
_urls = { _urls = {
'wmt': 'https://data.dgl.ai/dataset/wmt14bpe_de_en.zip', "wmt": "https://data.dgl.ai/dataset/wmt14bpe_de_en.zip",
'scripts': 'https://data.dgl.ai/dataset/transformer_scripts.zip', "scripts": "https://data.dgl.ai/dataset/transformer_scripts.zip",
} }
def prepare_dataset(dataset_name): def prepare_dataset(dataset_name):
"download and generate datasets" "download and generate datasets"
script_dir = os.path.join('scripts') script_dir = os.path.join("scripts")
if not os.path.exists(script_dir): if not os.path.exists(script_dir):
download(_urls['scripts'], path='scripts.zip') download(_urls["scripts"], path="scripts.zip")
extract_archive('scripts.zip', 'scripts') extract_archive("scripts.zip", "scripts")
directory = os.path.join('data', dataset_name) directory = os.path.join("data", dataset_name)
if not os.path.exists(directory): if not os.path.exists(directory):
os.makedirs(directory) os.makedirs(directory)
else: else:
return return
if dataset_name == 'multi30k': if dataset_name == "multi30k":
os.system('bash scripts/prepare-multi30k.sh') os.system("bash scripts/prepare-multi30k.sh")
elif dataset_name == 'wmt14': elif dataset_name == "wmt14":
download(_urls['wmt'], path='wmt14.zip') download(_urls["wmt"], path="wmt14.zip")
os.system('bash scripts/prepare-wmt14.sh') os.system("bash scripts/prepare-wmt14.sh")
elif dataset_name == 'copy' or dataset_name == 'tiny_copy': elif dataset_name == "copy" or dataset_name == "tiny_copy":
train_size = 9000 train_size = 9000
valid_size = 1000 valid_size = 1000
test_size = 1000 test_size = 1000
char_list = [chr(i) for i in range(ord('a'), ord('z') + 1)] char_list = [chr(i) for i in range(ord("a"), ord("z") + 1)]
with open(os.path.join(directory, 'train.in'), 'w') as f_in,\ with open(os.path.join(directory, "train.in"), "w") as f_in, open(
open(os.path.join(directory, 'train.out'), 'w') as f_out: os.path.join(directory, "train.out"), "w"
for i, l in zip(range(train_size), np.random.normal(15, 3, train_size).astype(int)): ) as f_out:
for i, l in zip(
range(train_size),
np.random.normal(15, 3, train_size).astype(int),
):
l = max(l, 1) l = max(l, 1)
line = ' '.join(np.random.choice(char_list, l)) + '\n' line = " ".join(np.random.choice(char_list, l)) + "\n"
f_in.write(line) f_in.write(line)
f_out.write(line) f_out.write(line)
with open(os.path.join(directory, 'valid.in'), 'w') as f_in,\ with open(os.path.join(directory, "valid.in"), "w") as f_in, open(
open(os.path.join(directory, 'valid.out'), 'w') as f_out: os.path.join(directory, "valid.out"), "w"
for i, l in zip(range(valid_size), np.random.normal(15, 3, valid_size).astype(int)): ) as f_out:
for i, l in zip(
range(valid_size),
np.random.normal(15, 3, valid_size).astype(int),
):
l = max(l, 1) l = max(l, 1)
line = ' '.join(np.random.choice(char_list, l)) + '\n' line = " ".join(np.random.choice(char_list, l)) + "\n"
f_in.write(line) f_in.write(line)
f_out.write(line) f_out.write(line)
with open(os.path.join(directory, 'test.in'), 'w') as f_in,\ with open(os.path.join(directory, "test.in"), "w") as f_in, open(
open(os.path.join(directory, 'test.out'), 'w') as f_out: os.path.join(directory, "test.out"), "w"
for i, l in zip(range(test_size), np.random.normal(15, 3, test_size).astype(int)): ) as f_out:
for i, l in zip(
range(test_size), np.random.normal(15, 3, test_size).astype(int)
):
l = max(l, 1) l = max(l, 1)
line = ' '.join(np.random.choice(char_list, l)) + '\n' line = " ".join(np.random.choice(char_list, l)) + "\n"
f_in.write(line) f_in.write(line)
f_out.write(line) f_out.write(line)
with open(os.path.join(directory, 'vocab.txt'), 'w') as f: with open(os.path.join(directory, "vocab.txt"), "w") as f:
for c in char_list: for c in char_list:
f.write(c + '\n') f.write(c + "\n")
elif dataset_name == 'sort' or dataset_name == 'tiny_sort': elif dataset_name == "sort" or dataset_name == "tiny_sort":
train_size = 9000 train_size = 9000
valid_size = 1000 valid_size = 1000
test_size = 1000 test_size = 1000
char_list = [chr(i) for i in range(ord('a'), ord('z') + 1)] char_list = [chr(i) for i in range(ord("a"), ord("z") + 1)]
with open(os.path.join(directory, 'train.in'), 'w') as f_in,\ with open(os.path.join(directory, "train.in"), "w") as f_in, open(
open(os.path.join(directory, 'train.out'), 'w') as f_out: os.path.join(directory, "train.out"), "w"
for i, l in zip(range(train_size), np.random.normal(15, 3, train_size).astype(int)): ) as f_out:
for i, l in zip(
range(train_size),
np.random.normal(15, 3, train_size).astype(int),
):
l = max(l, 1) l = max(l, 1)
seq = np.random.choice(char_list, l) seq = np.random.choice(char_list, l)
f_in.write(' '.join(seq) + '\n') f_in.write(" ".join(seq) + "\n")
f_out.write(' '.join(np.sort(seq)) + '\n') f_out.write(" ".join(np.sort(seq)) + "\n")
with open(os.path.join(directory, 'valid.in'), 'w') as f_in,\ with open(os.path.join(directory, "valid.in"), "w") as f_in, open(
open(os.path.join(directory, 'valid.out'), 'w') as f_out: os.path.join(directory, "valid.out"), "w"
for i, l in zip(range(valid_size), np.random.normal(15, 3, valid_size).astype(int)): ) as f_out:
for i, l in zip(
range(valid_size),
np.random.normal(15, 3, valid_size).astype(int),
):
l = max(l, 1) l = max(l, 1)
seq = np.random.choice(char_list, l) seq = np.random.choice(char_list, l)
f_in.write(' '.join(seq) + '\n') f_in.write(" ".join(seq) + "\n")
f_out.write(' '.join(np.sort(seq)) + '\n') f_out.write(" ".join(np.sort(seq)) + "\n")
with open(os.path.join(directory, 'test.in'), 'w') as f_in,\ with open(os.path.join(directory, "test.in"), "w") as f_in, open(
open(os.path.join(directory, 'test.out'), 'w') as f_out: os.path.join(directory, "test.out"), "w"
for i, l in zip(range(test_size), np.random.normal(15, 3, test_size).astype(int)): ) as f_out:
for i, l in zip(
range(test_size), np.random.normal(15, 3, test_size).astype(int)
):
l = max(l, 1) l = max(l, 1)
seq = np.random.choice(char_list, l) seq = np.random.choice(char_list, l)
f_in.write(' '.join(seq) + '\n') f_in.write(" ".join(seq) + "\n")
f_out.write(' '.join(np.sort(seq)) + '\n') f_out.write(" ".join(np.sort(seq)) + "\n")
with open(os.path.join(directory, 'vocab.txt'), 'w') as f: with open(os.path.join(directory, "vocab.txt"), "w") as f:
for c in char_list: for c in char_list:
f.write(c + '\n') f.write(c + "\n")
import torch as T import torch as T
import torch.distributed as dist
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import torch.distributed as dist
class LabelSmoothing(nn.Module): class LabelSmoothing(nn.Module):
""" """
Computer loss at one time step. Computer loss at one time step.
""" """
def __init__(self, size, padding_idx, smoothing=0.0): def __init__(self, size, padding_idx, smoothing=0.0):
"""Label Smoothing module """Label Smoothing module
args: args:
...@@ -15,7 +17,7 @@ class LabelSmoothing(nn.Module): ...@@ -15,7 +17,7 @@ class LabelSmoothing(nn.Module):
smoothing: smoothing ratio smoothing: smoothing ratio
""" """
super(LabelSmoothing, self).__init__() super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(reduction='sum') self.criterion = nn.KLDivLoss(reduction="sum")
self.size = size self.size = size
self.padding_idx = padding_idx self.padding_idx = padding_idx
self.smoothing = smoothing self.smoothing = smoothing
...@@ -26,7 +28,9 @@ class LabelSmoothing(nn.Module): ...@@ -26,7 +28,9 @@ class LabelSmoothing(nn.Module):
assert x.size(1) == self.size assert x.size(1) == self.size
with T.no_grad(): with T.no_grad():
tgt_dist = T.zeros_like(x, dtype=T.float) tgt_dist = T.zeros_like(x, dtype=T.float)
tgt_dist.fill_(self.smoothing / (self.size - 2)) # one for padding, another for label tgt_dist.fill_(
self.smoothing / (self.size - 2)
) # one for padding, another for label
tgt_dist[:, self.padding_idx] = 0 tgt_dist[:, self.padding_idx] = 0
tgt_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing) tgt_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)
...@@ -36,8 +40,10 @@ class LabelSmoothing(nn.Module): ...@@ -36,8 +40,10 @@ class LabelSmoothing(nn.Module):
return self.criterion(x, tgt_dist) return self.criterion(x, tgt_dist)
class SimpleLossCompute(nn.Module): class SimpleLossCompute(nn.Module):
eps=1e-8 eps = 1e-8
def __init__(self, criterion, grad_accum, opt=None): def __init__(self, criterion, grad_accum, opt=None):
"""Loss function and optimizer for single device """Loss function and optimizer for single device
...@@ -96,11 +102,16 @@ class SimpleLossCompute(nn.Module): ...@@ -96,11 +102,16 @@ class SimpleLossCompute(nn.Module):
self.loss = self.criterion(y_pred, y) / norm self.loss = self.criterion(y_pred, y) / norm
if self.opt is not None: if self.opt is not None:
self.backward_and_step() self.backward_and_step()
self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item() self.n_correct += (
((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx))
.sum()
.item()
)
self.acc_loss += self.loss.item() * norm self.acc_loss += self.loss.item() * norm
self.norm_term += norm self.norm_term += norm
return self.loss.item() * norm return self.loss.item() * norm
class MultiGPULossCompute(SimpleLossCompute): class MultiGPULossCompute(SimpleLossCompute):
def __init__(self, criterion, ndev, grad_accum, model, opt=None): def __init__(self, criterion, ndev, grad_accum, model, opt=None):
"""Loss function and optimizer for multiple devices """Loss function and optimizer for multiple devices
...@@ -119,7 +130,9 @@ class MultiGPULossCompute(SimpleLossCompute): ...@@ -119,7 +130,9 @@ class MultiGPULossCompute(SimpleLossCompute):
Model optimizer to use. If None, then no backward and update will be Model optimizer to use. If None, then no backward and update will be
performed performed
""" """
super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt) super(MultiGPULossCompute, self).__init__(
criterion, grad_accum, opt=opt
)
self.ndev = ndev self.ndev = ndev
self.model = model self.model = model
......
import numpy as np
import torch as th import torch as th
import torch.nn as nn import torch.nn as nn
import numpy as np
from .layers import clones from .layers import clones
class MultiHeadAttention(nn.Module): class MultiHeadAttention(nn.Module):
"Multi-Head Attention" "Multi-Head Attention"
def __init__(self, h, dim_model): def __init__(self, h, dim_model):
"h: number of heads; dim_model: hidden dimension" "h: number of heads; dim_model: hidden dimension"
super(MultiHeadAttention, self).__init__() super(MultiHeadAttention, self).__init__()
self.d_k = dim_model // h self.d_k = dim_model // h
self.h = h self.h = h
# W_q, W_k, W_v, W_o # W_q, W_k, W_v, W_o
self.linears = clones( self.linears = clones(nn.Linear(dim_model, dim_model, bias=False), 4)
nn.Linear(dim_model, dim_model, bias=False), 4
)
def get(self, x, fields='qkv'): def get(self, x, fields="qkv"):
"Return a dict of queries / keys / values." "Return a dict of queries / keys / values."
batch_size = x.shape[0] batch_size = x.shape[0]
ret = {} ret = {}
if 'q' in fields: if "q" in fields:
ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k) ret["q"] = self.linears[0](x).view(batch_size, self.h, self.d_k)
if 'k' in fields: if "k" in fields:
ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k) ret["k"] = self.linears[1](x).view(batch_size, self.h, self.d_k)
if 'v' in fields: if "v" in fields:
ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k) ret["v"] = self.linears[2](x).view(batch_size, self.h, self.d_k)
return ret return ret
def get_o(self, x): def get_o(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment