Unverified Commit be8763fa authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4679)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent eae6ce2a
......@@ -5,18 +5,24 @@ Paper: https://arxiv.org/abs/1902.07153
Code: https://github.com/Tiiiger/SGC
SGC implementation in DGL.
"""
import argparse, time, math
import argparse
import math
import time
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data import load_data, register_data_args
from dgl.nn.pytorch.conv import SGConv
def normalize(h):
return (h-h.mean(0))/h.std(0)
return (h - h.mean(0)) / h.std(0)
def evaluate(model, features, graph, labels, mask):
model.eval()
......@@ -27,6 +33,7 @@ def evaluate(model, features, graph, labels, mask):
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
args.dataset = "reddit-self-loop"
......@@ -38,24 +45,29 @@ def main(args):
cuda = True
g = g.int().to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = g.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
g.ndata['train_mask'].int().sum().item(),
g.ndata['val_mask'].int().sum().item(),
g.ndata['test_mask'].int().sum().item()))
#Test samples %d"""
% (
n_edges,
n_classes,
g.ndata["train_mask"].int().sum().item(),
g.ndata["val_mask"].int().sum().item(),
g.ndata["test_mask"].int().sum().item(),
)
)
# graph preprocess and calculate normalization factor
n_edges = g.number_of_edges()
......@@ -63,10 +75,12 @@ def main(args):
degs = g.in_degrees().float()
norm = torch.pow(degs, -0.5)
norm[torch.isinf(norm)] = 0
g.ndata['norm'] = norm.unsqueeze(1)
g.ndata["norm"] = norm.unsqueeze(1)
# create SGC model
model = SGConv(in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize)
model = SGConv(
in_feats, n_classes, k=2, cached=True, bias=True, norm=normalize
)
if args.gpu >= 0:
model = model.cuda()
......@@ -90,15 +104,16 @@ def main(args):
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='SGC')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SGC")
register_data_args(parser)
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--bias", action='store_true', default=False,
help="flag to use bias")
parser.add_argument("--n-epochs", type=int, default=2,
help="number of training epochs")
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument(
"--bias", action="store_true", default=False, help="flag to use bias"
)
parser.add_argument(
"--n-epochs", type=int, default=2, help="number of training epochs"
)
args = parser.parse_args()
print(args)
......
import torch
import numpy as np
import torch
import dgl
......@@ -7,6 +8,7 @@ def load_dataset(name):
dataset = name.lower()
if dataset == "amazon":
from ogb.nodeproppred.dataset_dgl import DglNodePropPredDataset
dataset = DglNodePropPredDataset(name="ogbn-products")
splitted_idx = dataset.get_idx_split()
train_nid = splitted_idx["train"]
......@@ -14,26 +16,28 @@ def load_dataset(name):
test_nid = splitted_idx["test"]
g, labels = dataset[0]
n_classes = int(labels.max() - labels.min() + 1)
g.ndata['label'] = labels.squeeze()
g.ndata['feat'] = g.ndata['feat'].float()
g.ndata["label"] = labels.squeeze()
g.ndata["feat"] = g.ndata["feat"].float()
elif dataset in ["reddit", "cora"]:
if dataset == "reddit":
from dgl.data import RedditDataset
data = RedditDataset(self_loop=True)
g = data[0]
else:
from dgl.data import CitationGraphDataset
data = CitationGraphDataset('cora')
data = CitationGraphDataset("cora")
g = data[0]
n_classes = data.num_labels
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
train_nid = torch.LongTensor(train_mask.nonzero().squeeze())
val_nid = torch.LongTensor(val_mask.nonzero().squeeze())
test_nid = torch.LongTensor(test_mask.nonzero().squeeze())
else:
print("Dataset {} is not supported".format(name))
assert(0)
assert 0
return g, n_classes, train_nid, val_nid, test_nid
import argparse
import os
import time
import torch
import torch.nn as nn
import torch.nn.functional as F
from dataset import load_dataset
import dgl
import dgl.function as fn
from dataset import load_dataset
class FeedForwardNet(nn.Module):
......@@ -48,10 +50,12 @@ class Model(nn.Module):
self.inception_ffs = nn.ModuleList()
for hop in range(R + 1):
self.inception_ffs.append(
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout))
FeedForwardNet(in_feats, hidden, hidden, n_layers, dropout)
)
# self.linear = nn.Linear(hidden * (R + 1), out_feats)
self.project = FeedForwardNet((R + 1) * hidden, hidden, out_feats,
n_layers, dropout)
self.project = FeedForwardNet(
(R + 1) * hidden, hidden, out_feats, n_layers, dropout
)
def forward(self, feats):
hidden = []
......@@ -84,8 +88,10 @@ def preprocess(g, features, args):
g.edata["weight"] = calc_weight(g)
g.ndata["feat_0"] = features
for hop in range(1, args.R + 1):
g.update_all(fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"),
fn.sum("msg", f"feat_{hop}"))
g.update_all(
fn.u_mul_e(f"feat_{hop-1}", "weight", "msg"),
fn.sum("msg", f"feat_{hop}"),
)
res = []
for hop in range(args.R + 1):
res.append(g.ndata.pop(f"feat_{hop}"))
......@@ -96,17 +102,26 @@ def prepare_data(device, args):
data = load_dataset(args.dataset)
g, n_classes, train_nid, val_nid, test_nid = data
g = g.to(device)
in_feats = g.ndata['feat'].shape[1]
feats = preprocess(g, g.ndata['feat'], args)
labels = g.ndata['label']
in_feats = g.ndata["feat"].shape[1]
feats = preprocess(g, g.ndata["feat"], args)
labels = g.ndata["label"]
# move to device
train_nid = train_nid.to(device)
val_nid = val_nid.to(device)
test_nid = test_nid.to(device)
train_feats = [x[train_nid] for x in feats]
train_labels = labels[train_nid]
return feats, labels, train_feats, train_labels, in_feats, \
n_classes, train_nid, val_nid, test_nid
return (
feats,
labels,
train_feats,
train_labels,
in_feats,
n_classes,
train_nid,
val_nid,
test_nid,
)
def evaluate(epoch, args, model, feats, labels, train, val, test):
......@@ -121,7 +136,7 @@ def evaluate(epoch, args, model, feats, labels, train, val, test):
for i in range(n_batch):
batch_start = i * batch_size
batch_end = min((i + 1) * batch_size, num_nodes)
batch_feats = [feat[batch_start: batch_end] for feat in feats]
batch_feats = [feat[batch_start:batch_end] for feat in feats]
pred.append(model(batch_feats))
pred = torch.cat(pred)
......@@ -140,15 +155,31 @@ def main(args):
device = "cuda:{}".format(args.gpu)
data = prepare_data(device, args)
feats, labels, train_feats, train_labels, in_size, num_classes, \
train_nid, val_nid, test_nid = data
model = Model(in_size, args.num_hidden, num_classes, args.R, args.ff_layer,
args.dropout)
(
feats,
labels,
train_feats,
train_labels,
in_size,
num_classes,
train_nid,
val_nid,
test_nid,
) = data
model = Model(
in_size,
args.num_hidden,
num_classes,
args.R,
args.ff_layer,
args.dropout,
)
model = model.to(device)
loss_fcn = nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr,
weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
best_epoch = 0
best_val = 0
......@@ -164,38 +195,47 @@ def main(args):
if epoch % args.eval_every == 0:
model.eval()
acc = evaluate(epoch, args, model, feats, labels,
train_nid, val_nid, test_nid)
acc = evaluate(
epoch, args, model, feats, labels, train_nid, val_nid, test_nid
)
end = time.time()
log = "Epoch {}, Times(s): {:.4f}".format(epoch, end - start)
log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}" \
.format(*acc)
log += ", Accuracy: Train {:.4f}, Val {:.4f}, Test {:.4f}".format(
*acc
)
print(log)
if acc[1] > best_val:
best_val = acc[1]
best_epoch = epoch
best_test = acc[2]
print("Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
best_epoch, best_val, best_test))
print(
"Best Epoch {}, Val {:.4f}, Test {:.4f}".format(
best_epoch, best_val, best_test
)
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="SIGN")
parser.add_argument("--num-epochs", type=int, default=1000)
parser.add_argument("--num-hidden", type=int, default=256)
parser.add_argument("--R", type=int, default=3,
help="number of hops")
parser.add_argument("--R", type=int, default=3, help="number of hops")
parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--dataset", type=str, default="amazon")
parser.add_argument("--dropout", type=float, default=0.5)
parser.add_argument("--gpu", type=int, default=0)
parser.add_argument("--weight-decay", type=float, default=0)
parser.add_argument("--eval-every", type=int, default=50)
parser.add_argument("--eval-batch-size", type=int, default=250000,
help="evaluation batch size, -1 for full batch")
parser.add_argument("--ff-layer", type=int, default=2,
help="number of feed-forward layers")
parser.add_argument(
"--eval-batch-size",
type=int,
default=250000,
help="evaluation batch size, -1 for full batch",
)
parser.add_argument(
"--ff-layer", type=int, default=2, help="number of feed-forward layers"
)
args = parser.parse_args()
print(args)
......
import torch
import numpy as np
import pandas as pd
import torch
def load_data(file_path, len_train, len_val):
df = pd.read_csv(file_path, header=None).values.astype(float)
train = df[: len_train]
val = df[len_train: len_train + len_val]
test = df[len_train + len_val:]
train = df[:len_train]
val = df[len_train : len_train + len_val]
test = df[len_train + len_val :]
return train, val, test
......@@ -16,15 +15,15 @@ def data_transform(data, n_his, n_pred, device):
# produce data slices for training and testing
n_route = data.shape[1]
l = len(data)
num = l-n_his-n_pred
num = l - n_his - n_pred
x = np.zeros([num, 1, n_his, n_route])
y = np.zeros([num, n_route])
cnt = 0
for i in range(l-n_his-n_pred):
for i in range(l - n_his - n_pred):
head = i
tail = i + n_his
x[cnt, :, :, :] = data[head: tail].reshape(1, n_his, n_route)
x[cnt, :, :, :] = data[head:tail].reshape(1, n_his, n_route)
y[cnt] = data[tail + n_pred - 1]
cnt += 1
return torch.Tensor(x).to(device), torch.Tensor(y).to(device)
import dgl
import argparse
import random
import torch
import numpy as np
import pandas as pd
from sklearn.preprocessing import StandardScaler
import scipy.sparse as sp
import torch
import torch.nn as nn
from load_data import *
from utils import *
from model import *
from sensors2graph import *
import torch.nn as nn
import argparse
import scipy.sparse as sp
from sklearn.preprocessing import StandardScaler
from utils import *
parser = argparse.ArgumentParser(description='STGCN_WAVE')
parser.add_argument('--lr', default=0.001, type=float, help='learning rate')
parser.add_argument('--disablecuda', action='store_true', help='Disable CUDA')
parser.add_argument('--batch_size', type=int, default=50, help='batch size for training and validation (default: 50)')
parser.add_argument('--epochs', type=int, default=50, help='epochs for training (default: 50)')
parser.add_argument('--num_layers', type=int, default=9, help='number of layers')
parser.add_argument('--window', type=int, default=144, help='window length')
parser.add_argument('--sensorsfilepath', type=str, default='./data/sensor_graph/graph_sensor_ids.txt', help='sensors file path')
parser.add_argument('--disfilepath', type=str, default='./data/sensor_graph/distances_la_2012.csv', help='distance file path')
parser.add_argument('--tsfilepath', type=str, default='./data/metr-la.h5', help='ts file path')
parser.add_argument('--savemodelpath', type=str, default='stgcnwavemodel.pt', help='save model path')
parser.add_argument('--pred_len', type=int, default=5, help='how many steps away we want to predict')
parser.add_argument('--control_str', type=str, default='TNTSTNTST', help='model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer')
parser.add_argument('--channels', type=int, nargs='+', default=[1, 16, 32, 64, 32, 128], help='model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer')
import dgl
parser = argparse.ArgumentParser(description="STGCN_WAVE")
parser.add_argument("--lr", default=0.001, type=float, help="learning rate")
parser.add_argument("--disablecuda", action="store_true", help="Disable CUDA")
parser.add_argument(
"--batch_size",
type=int,
default=50,
help="batch size for training and validation (default: 50)",
)
parser.add_argument(
"--epochs", type=int, default=50, help="epochs for training (default: 50)"
)
parser.add_argument(
"--num_layers", type=int, default=9, help="number of layers"
)
parser.add_argument("--window", type=int, default=144, help="window length")
parser.add_argument(
"--sensorsfilepath",
type=str,
default="./data/sensor_graph/graph_sensor_ids.txt",
help="sensors file path",
)
parser.add_argument(
"--disfilepath",
type=str,
default="./data/sensor_graph/distances_la_2012.csv",
help="distance file path",
)
parser.add_argument(
"--tsfilepath", type=str, default="./data/metr-la.h5", help="ts file path"
)
parser.add_argument(
"--savemodelpath",
type=str,
default="stgcnwavemodel.pt",
help="save model path",
)
parser.add_argument(
"--pred_len",
type=int,
default=5,
help="how many steps away we want to predict",
)
parser.add_argument(
"--control_str",
type=str,
default="TNTSTNTST",
help="model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer",
)
parser.add_argument(
"--channels",
type=int,
nargs="+",
default=[1, 16, 32, 64, 32, 128],
help="model strcture controller, T: Temporal Layer, S: Spatio Layer, N: Norm Layer",
)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() and not args.disablecuda else torch.device("cpu")
device = (
torch.device("cuda")
if torch.cuda.is_available() and not args.disablecuda
else torch.device("cpu")
)
with open(args.sensorsfilepath) as f:
sensor_ids = f.read().strip().split(',')
sensor_ids = f.read().strip().split(",")
distance_df = pd.read_csv(args.disfilepath, dtype={'from': 'str', 'to': 'str'})
distance_df = pd.read_csv(args.disfilepath, dtype={"from": "str", "to": "str"})
adj_mx = get_adjacency_matrix(distance_df, sensor_ids)
sp_mx = sp.coo_matrix(adj_mx)
......@@ -51,7 +99,6 @@ n_his = args.window
save_path = args.savemodelpath
n_pred = args.pred_len
n_route = num_nodes
blocks = args.channels
......@@ -67,9 +114,9 @@ lr = args.lr
W = adj_mx
len_val = round(num_samples * 0.1)
len_train = round(num_samples * 0.7)
train = df[: len_train]
val = df[len_train: len_train + len_val]
test = df[len_train + len_val:]
train = df[:len_train]
val = df[len_train : len_train + len_val]
test = df[len_train + len_val :]
scaler = StandardScaler()
train = scaler.fit_transform(train)
......@@ -91,7 +138,9 @@ test_iter = torch.utils.data.DataLoader(test_data, batch_size)
loss = nn.MSELoss()
G = G.to(device)
model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device)
model = STGCN_WAVE(
blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str
).to(device)
optimizer = torch.optim.RMSprop(model.parameters(), lr=lr)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=5, gamma=0.7)
......@@ -113,10 +162,19 @@ for epoch in range(1, epochs + 1):
if val_loss < min_val_loss:
min_val_loss = val_loss
torch.save(model.state_dict(), save_path)
print("epoch", epoch, ", train loss:", l_sum / n, ", validation loss:", val_loss)
best_model = STGCN_WAVE(blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str).to(device)
print(
"epoch",
epoch,
", train loss:",
l_sum / n,
", validation loss:",
val_loss,
)
best_model = STGCN_WAVE(
blocks, n_his, n_route, G, drop_prob, num_layers, device, args.control_str
).to(device)
best_model.load_state_dict(torch.load(save_path))
......
import math
import torch
import torch.nn as nn
import torch.nn.init as init
import torch.nn.functional as F
import torch.nn.init as init
from dgl.nn.pytorch import GraphConv
from dgl.nn.pytorch.conv import ChebConv
class TemporalConvLayer(nn.Module):
''' Temporal convolution layer.
"""Temporal convolution layer.
arguments
---------
......@@ -17,13 +20,15 @@ class TemporalConvLayer(nn.Module):
The number of output channels (features)
dia : int
The dilation size
'''
def __init__(self, c_in, c_out, dia = 1):
"""
def __init__(self, c_in, c_out, dia=1):
super(TemporalConvLayer, self).__init__()
self.c_out = c_out
self.c_in = c_in
self.conv = nn.Conv2d(c_in, c_out, (2, 1), 1, dilation = dia, padding = (0,0))
self.conv = nn.Conv2d(
c_in, c_out, (2, 1), 1, dilation=dia, padding=(0, 0)
)
def forward(self, x):
return torch.relu(self.conv(x))
......@@ -37,7 +42,7 @@ class SpatioConvLayer(nn.Module):
# self.gc = ChebConv(c, c, 3)
def init(self):
stdv = 1. / math.sqrt(self.W.weight.size(1))
stdv = 1.0 / math.sqrt(self.W.weight.size(1))
self.W.weight.data.uniform_(-stdv, stdv)
def forward(self, x):
......@@ -48,6 +53,7 @@ class SpatioConvLayer(nn.Module):
output = output.transpose(0, 3)
return torch.relu(output)
class FullyConvLayer(nn.Module):
def __init__(self, c):
super(FullyConvLayer, self).__init__()
......@@ -56,12 +62,13 @@ class FullyConvLayer(nn.Module):
def forward(self, x):
return self.conv(x)
class OutputLayer(nn.Module):
def __init__(self, c, T, n):
super(OutputLayer, self).__init__()
self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation = 1, padding = (0,0))
self.tconv1 = nn.Conv2d(c, c, (T, 1), 1, dilation=1, padding=(0, 0))
self.ln = nn.LayerNorm([n, c])
self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation = 1, padding = (0,0))
self.tconv2 = nn.Conv2d(c, c, (1, 1), 1, dilation=1, padding=(0, 0))
self.fc = FullyConvLayer(c)
def forward(self, x):
......@@ -70,8 +77,11 @@ class OutputLayer(nn.Module):
x_t2 = self.tconv2(x_ln)
return self.fc(x_t2)
class STGCN_WAVE(nn.Module):
def __init__(self, c, T, n, Lk, p, num_layers, device, control_str = 'TNTSTNTST'):
def __init__(
self, c, T, n, Lk, p, num_layers, device, control_str="TNTSTNTST"
):
super(STGCN_WAVE, self).__init__()
self.control_str = control_str # model structure controller
self.num_layers = len(control_str)
......@@ -80,21 +90,24 @@ class STGCN_WAVE(nn.Module):
diapower = 0
for i in range(self.num_layers):
i_layer = control_str[i]
if i_layer == 'T': # Temporal Layer
self.layers.append(TemporalConvLayer(c[cnt], c[cnt + 1], dia = 2**diapower))
if i_layer == "T": # Temporal Layer
self.layers.append(
TemporalConvLayer(c[cnt], c[cnt + 1], dia=2**diapower)
)
diapower += 1
cnt += 1
if i_layer == 'S': # Spatio Layer
if i_layer == "S": # Spatio Layer
self.layers.append(SpatioConvLayer(c[cnt], Lk))
if i_layer == 'N': # Norm Layer
self.layers.append(nn.LayerNorm([n,c[cnt]]))
self.output = OutputLayer(c[cnt], T + 1 - 2**(diapower), n)
if i_layer == "N": # Norm Layer
self.layers.append(nn.LayerNorm([n, c[cnt]]))
self.output = OutputLayer(c[cnt], T + 1 - 2 ** (diapower), n)
for layer in self.layers:
layer = layer.to(device)
def forward(self, x):
for i in range(self.num_layers):
i_layer = self.control_str[i]
if i_layer == 'N':
if i_layer == "N":
x = self.layers[i](x.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
else:
x = self.layers[i](x)
......
import numpy as np
def get_adjacency_matrix(distance_df, sensor_ids, normalized_k=0.1):
"""
:param distance_df: data frame with three columns: [from, to, distance].
......
import torch
import numpy as np
import torch
def evaluate_model(model, loss, data_iter):
......@@ -21,11 +20,13 @@ def evaluate_metric(model, data_iter, scaler):
mae, mape, mse = [], [], []
for x, y in data_iter:
y = scaler.inverse_transform(y.cpu().numpy()).reshape(-1)
y_pred = scaler.inverse_transform(model(x).view(len(x), -1).cpu().numpy()).reshape(-1)
y_pred = scaler.inverse_transform(
model(x).view(len(x), -1).cpu().numpy()
).reshape(-1)
d = np.abs(y - y_pred)
mae += d.tolist()
mape += (d / y).tolist()
mse += (d ** 2).tolist()
mse += (d**2).tolist()
MAE = np.array(mae).mean()
MAPE = np.array(mape).mean()
RMSE = np.sqrt(np.array(mse).mean())
......
......@@ -6,17 +6,14 @@ References:
"""
import torch
import torch.nn as nn
from dgl.nn.pytorch.conv import TAGConv
class TAGCN(nn.Module):
def __init__(self,
g,
in_feats,
n_hidden,
n_classes,
n_layers,
activation,
dropout):
def __init__(
self, g, in_feats, n_hidden, n_classes, n_layers, activation, dropout
):
super(TAGCN, self).__init__()
self.g = g
self.layers = nn.ModuleList()
......@@ -24,9 +21,11 @@ class TAGCN(nn.Module):
self.layers.append(TAGConv(in_feats, n_hidden, activation=activation))
# hidden layers
for i in range(n_layers - 1):
self.layers.append(TAGConv(n_hidden, n_hidden, activation=activation))
self.layers.append(
TAGConv(n_hidden, n_hidden, activation=activation)
)
# output layer
self.layers.append(TAGConv(n_hidden, n_classes)) #activation=None
self.layers.append(TAGConv(n_hidden, n_classes)) # activation=None
self.dropout = nn.Dropout(p=dropout)
def forward(self, features):
......
import argparse, time
import numpy as np
import argparse
import time
import networkx as nx
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from tagcn import TAGCN
from dgl import DGLGraph
from dgl.data import register_data_args, load_data
from dgl.data import load_data, register_data_args
from tagcn import TAGCN
def evaluate(model, features, labels, mask):
model.eval()
......@@ -19,6 +22,7 @@ def evaluate(model, features, labels, mask):
correct = torch.sum(indices == labels)
return correct.item() * 1.0 / len(labels)
def main(args):
# load and preprocess dataset
data = load_data(args)
......@@ -28,24 +32,29 @@ def main(args):
else:
cuda = True
g = g.to(args.gpu)
features = g.ndata['feat']
labels = g.ndata['label']
train_mask = g.ndata['train_mask']
val_mask = g.ndata['val_mask']
test_mask = g.ndata['test_mask']
features = g.ndata["feat"]
labels = g.ndata["label"]
train_mask = g.ndata["train_mask"]
val_mask = g.ndata["val_mask"]
test_mask = g.ndata["test_mask"]
in_feats = features.shape[1]
n_classes = data.num_labels
n_edges = g.number_of_edges()
print("""----Data statistics------'
print(
"""----Data statistics------'
#Edges %d
#Classes %d
#Train samples %d
#Val samples %d
#Test samples %d""" %
(n_edges, n_classes,
#Test samples %d"""
% (
n_edges,
n_classes,
train_mask.int().sum().item(),
val_mask.int().sum().item(),
test_mask.int().sum().item()))
test_mask.int().sum().item(),
)
)
# graph preprocess and calculate normalization factor
# add self loop
......@@ -54,22 +63,24 @@ def main(args):
n_edges = g.number_of_edges()
# create TAGCN model
model = TAGCN(g,
model = TAGCN(
g,
in_feats,
args.n_hidden,
n_classes,
args.n_layers,
F.relu,
args.dropout)
args.dropout,
)
if cuda:
model.cuda()
loss_fcn = torch.nn.CrossEntropyLoss()
# use optimizer
optimizer = torch.optim.Adam(model.parameters(),
lr=args.lr,
weight_decay=args.weight_decay)
optimizer = torch.optim.Adam(
model.parameters(), lr=args.lr, weight_decay=args.weight_decay
)
# initialize graph
dur = []
......@@ -89,34 +100,47 @@ def main(args):
dur.append(time.time() - t0)
acc = evaluate(model, features, labels, val_mask)
print("Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}". format(epoch, np.mean(dur), loss.item(),
acc, n_edges / np.mean(dur) / 1000))
print(
"Epoch {:05d} | Time(s) {:.4f} | Loss {:.4f} | Accuracy {:.4f} | "
"ETputs(KTEPS) {:.2f}".format(
epoch,
np.mean(dur),
loss.item(),
acc,
n_edges / np.mean(dur) / 1000,
)
)
print()
acc = evaluate(model, features, labels, test_mask)
print("Test Accuracy {:.4f}".format(acc))
if __name__ == '__main__':
parser = argparse.ArgumentParser(description='TAGCN')
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="TAGCN")
register_data_args(parser)
parser.add_argument("--dropout", type=float, default=0.5,
help="dropout probability")
parser.add_argument("--gpu", type=int, default=-1,
help="gpu")
parser.add_argument("--lr", type=float, default=1e-2,
help="learning rate")
parser.add_argument("--n-epochs", type=int, default=200,
help="number of training epochs")
parser.add_argument("--n-hidden", type=int, default=16,
help="number of hidden tagcn units")
parser.add_argument("--n-layers", type=int, default=1,
help="number of hidden tagcn layers")
parser.add_argument("--weight-decay", type=float, default=5e-4,
help="Weight for L2 loss")
parser.add_argument("--self-loop", action='store_true',
help="graph self-loop (default=False)")
parser.add_argument(
"--dropout", type=float, default=0.5, help="dropout probability"
)
parser.add_argument("--gpu", type=int, default=-1, help="gpu")
parser.add_argument("--lr", type=float, default=1e-2, help="learning rate")
parser.add_argument(
"--n-epochs", type=int, default=200, help="number of training epochs"
)
parser.add_argument(
"--n-hidden", type=int, default=16, help="number of hidden tagcn units"
)
parser.add_argument(
"--n-layers", type=int, default=1, help="number of hidden tagcn layers"
)
parser.add_argument(
"--weight-decay", type=float, default=5e-4, help="Weight for L2 loss"
)
parser.add_argument(
"--self-loop",
action="store_true",
help="graph self-loop (default=False)",
)
parser.set_defaults(self_loop=False)
args = parser.parse_args()
print(args)
......
import os
import ssl
from six.moves import urllib
import pandas as pd
import numpy as np
import pandas as pd
import torch
from six.moves import urllib
import dgl
# === Below data preprocessing code are based on
......@@ -13,6 +13,7 @@ import dgl
# Preprocess the raw data split each features
def preprocess(data_name):
u_list, i_list, ts_list, label_list = [], [], [], []
feat_l = []
......@@ -21,7 +22,7 @@ def preprocess(data_name):
with open(data_name) as f:
s = next(f)
for idx, line in enumerate(f):
e = line.strip().split(',')
e = line.strip().split(",")
u = int(e[0])
i = int(e[1])
......@@ -37,18 +38,23 @@ def preprocess(data_name):
idx_list.append(idx)
feat_l.append(feat)
return pd.DataFrame({'u': u_list,
'i': i_list,
'ts': ts_list,
'label': label_list,
'idx': idx_list}), np.array(feat_l)
return pd.DataFrame(
{
"u": u_list,
"i": i_list,
"ts": ts_list,
"label": label_list,
"idx": idx_list,
}
), np.array(feat_l)
# Re index nodes for DGL convience
def reindex(df, bipartite=True):
new_df = df.copy()
if bipartite:
assert (df.u.max() - df.u.min() + 1 == len(df.u.unique()))
assert (df.i.max() - df.i.min() + 1 == len(df.i.unique()))
assert df.u.max() - df.u.min() + 1 == len(df.u.unique())
assert df.i.max() - df.i.min() + 1 == len(df.i.unique())
upper_u = df.u.max() + 1
new_i = df.i + upper_u
......@@ -64,12 +70,13 @@ def reindex(df, bipartite=True):
return new_df
# Save edge list, features in different file for data easy process data
def run(data_name, bipartite=True):
PATH = './data/{}.csv'.format(data_name)
OUT_DF = './data/ml_{}.csv'.format(data_name)
OUT_FEAT = './data/ml_{}.npy'.format(data_name)
OUT_NODE_FEAT = './data/ml_{}_node.npy'.format(data_name)
PATH = "./data/{}.csv".format(data_name)
OUT_DF = "./data/ml_{}.csv".format(data_name)
OUT_FEAT = "./data/ml_{}.npy".format(data_name)
OUT_NODE_FEAT = "./data/ml_{}_node.npy".format(data_name)
df, feat = preprocess(PATH)
new_df = reindex(df, bipartite)
......@@ -84,18 +91,20 @@ def run(data_name, bipartite=True):
np.save(OUT_FEAT, feat)
np.save(OUT_NODE_FEAT, rand_feat)
# === code from twitter-research-tgn end ===
# If you have new dataset follow by same format in Jodie,
# you can directly use name to retrieve dataset
def TemporalDataset(dataset):
if not os.path.exists('./data/{}.bin'.format(dataset)):
if not os.path.exists('./data/{}.csv'.format(dataset)):
if not os.path.exists('./data'):
os.mkdir('./data')
if not os.path.exists("./data/{}.bin".format(dataset)):
if not os.path.exists("./data/{}.csv".format(dataset)):
if not os.path.exists("./data"):
os.mkdir("./data")
url = 'https://snap.stanford.edu/jodie/{}.csv'.format(dataset)
url = "https://snap.stanford.edu/jodie/{}.csv".format(dataset)
print("Start Downloading File....")
context = ssl._create_unverified_context()
data = urllib.request.urlopen(url, context=context)
......@@ -104,27 +113,28 @@ def TemporalDataset(dataset):
print("Start Process Data ...")
run(dataset)
raw_connection = pd.read_csv('./data/ml_{}.csv'.format(dataset))
raw_feature = np.load('./data/ml_{}.npy'.format(dataset))
raw_connection = pd.read_csv("./data/ml_{}.csv".format(dataset))
raw_feature = np.load("./data/ml_{}.npy".format(dataset))
# -1 for re-index the node
src = raw_connection['u'].to_numpy()-1
dst = raw_connection['i'].to_numpy()-1
src = raw_connection["u"].to_numpy() - 1
dst = raw_connection["i"].to_numpy() - 1
# Create directed graph
g = dgl.graph((src, dst))
g.edata['timestamp'] = torch.from_numpy(
raw_connection['ts'].to_numpy())
g.edata['label'] = torch.from_numpy(raw_connection['label'].to_numpy())
g.edata['feats'] = torch.from_numpy(raw_feature[1:, :]).float()
dgl.save_graphs('./data/{}.bin'.format(dataset), [g])
g.edata["timestamp"] = torch.from_numpy(raw_connection["ts"].to_numpy())
g.edata["label"] = torch.from_numpy(raw_connection["label"].to_numpy())
g.edata["feats"] = torch.from_numpy(raw_feature[1:, :]).float()
dgl.save_graphs("./data/{}.bin".format(dataset), [g])
else:
print("Data is exist directly loaded.")
gs, _ = dgl.load_graphs('./data/{}.bin'.format(dataset))
gs, _ = dgl.load_graphs("./data/{}.bin".format(dataset))
g = gs[0]
return g
def TemporalWikipediaDataset():
# Download the dataset
return TemporalDataset('wikipedia')
return TemporalDataset("wikipedia")
def TemporalRedditDataset():
return TemporalDataset('reddit')
return TemporalDataset("reddit")
......@@ -4,9 +4,9 @@ import torch.nn as nn
import torch.nn.functional as F
import dgl
import dgl.function as fn
from dgl.base import DGLError
from dgl.ops import edge_softmax
import dgl.function as fn
class Identity(nn.Module):
......@@ -60,22 +60,22 @@ class MsgLinkPredictor(nn.Module):
self.out_fc = nn.Linear(emb_dim, 1)
def link_pred(self, edges):
src_hid = self.src_fc(edges.src['embedding'])
dst_hid = self.dst_fc(edges.dst['embedding'])
score = F.relu(src_hid+dst_hid)
src_hid = self.src_fc(edges.src["embedding"])
dst_hid = self.dst_fc(edges.dst["embedding"])
score = F.relu(src_hid + dst_hid)
score = self.out_fc(score)
return {'score': score}
return {"score": score}
def forward(self, x, pos_g, neg_g):
# Local Scope?
pos_g.ndata['embedding'] = x
neg_g.ndata['embedding'] = x
pos_g.ndata["embedding"] = x
neg_g.ndata["embedding"] = x
pos_g.apply_edges(self.link_pred)
neg_g.apply_edges(self.link_pred)
pos_escore = pos_g.edata['score']
neg_escore = neg_g.edata['score']
pos_escore = pos_g.edata["score"]
neg_escore = neg_g.edata["score"]
return pos_escore, neg_escore
......@@ -106,8 +106,11 @@ class TimeEncode(nn.Module):
self.dimension = dimension
self.w = torch.nn.Linear(1, dimension)
self.w.weight = torch.nn.Parameter((torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double().reshape(dimension, -1))
self.w.weight = torch.nn.Parameter(
(torch.from_numpy(1 / 10 ** np.linspace(0, 9, dimension)))
.double()
.reshape(dimension, -1)
)
self.w.bias = torch.nn.Parameter(torch.zeros(dimension).double())
def forward(self, t):
......@@ -143,10 +146,13 @@ class MemoryModule(nn.Module):
self.reset_memory()
def reset_memory(self):
self.last_update_t = nn.Parameter(torch.zeros(
self.n_node).float(), requires_grad=False)
self.memory = nn.Parameter(torch.zeros(
(self.n_node, self.hidden_dim)).float(), requires_grad=False)
self.last_update_t = nn.Parameter(
torch.zeros(self.n_node).float(), requires_grad=False
)
self.memory = nn.Parameter(
torch.zeros((self.n_node, self.hidden_dim)).float(),
requires_grad=False,
)
def backup_memory(self):
"""
......@@ -192,7 +198,7 @@ class MemoryModule(nn.Module):
class MemoryOperation(nn.Module):
""" Memory update using message passing manner, update memory based on positive
"""Memory update using message passing manner, update memory based on positive
pair graph of each batch with recurrent module GRU or RNN
Message function
......@@ -235,41 +241,66 @@ class MemoryOperation(nn.Module):
def __init__(self, updater_type, memory, e_feat_dim, temporal_encoder):
super(MemoryOperation, self).__init__()
updater_dict = {'gru': nn.GRUCell, 'rnn': nn.RNNCell}
updater_dict = {"gru": nn.GRUCell, "rnn": nn.RNNCell}
self.memory = memory
memory_dim = self.memory.hidden_dim
self.temporal_encoder = temporal_encoder
self.message_dim = memory_dim+memory_dim + \
e_feat_dim+self.temporal_encoder.dimension
self.updater = updater_dict[updater_type](input_size=self.message_dim,
hidden_size=memory_dim)
self.message_dim = (
memory_dim
+ memory_dim
+ e_feat_dim
+ self.temporal_encoder.dimension
)
self.updater = updater_dict[updater_type](
input_size=self.message_dim, hidden_size=memory_dim
)
self.memory = memory
# Here assume g is a subgraph from each iteration
def stick_feat_to_graph(self, g):
# How can I ensure order of the node ID
g.ndata['timestamp'] = self.memory.last_update_t[g.ndata[dgl.NID]]
g.ndata['memory'] = self.memory.memory[g.ndata[dgl.NID]]
g.ndata["timestamp"] = self.memory.last_update_t[g.ndata[dgl.NID]]
g.ndata["memory"] = self.memory.memory[g.ndata[dgl.NID]]
def msg_fn_cat(self, edges):
src_delta_time = edges.data['timestamp'] - edges.src['timestamp']
time_encode = self.temporal_encoder(src_delta_time.unsqueeze(
dim=1)).view(len(edges.data['timestamp']), -1)
ret = torch.cat([edges.src['memory'], edges.dst['memory'],
edges.data['feats'], time_encode], dim=1)
return {'message': ret, 'timestamp': edges.data['timestamp']}
src_delta_time = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(
src_delta_time.unsqueeze(dim=1)
).view(len(edges.data["timestamp"]), -1)
ret = torch.cat(
[
edges.src["memory"],
edges.dst["memory"],
edges.data["feats"],
time_encode,
],
dim=1,
)
return {"message": ret, "timestamp": edges.data["timestamp"]}
def agg_last(self, nodes):
timestamp, latest_idx = torch.max(nodes.mailbox['timestamp'], dim=1)
ret = nodes.mailbox['message'].gather(1, latest_idx.repeat(
self.message_dim).view(-1, 1, self.message_dim)).view(-1, self.message_dim)
return {'message_bar': ret.reshape(-1, self.message_dim), 'timestamp': timestamp}
timestamp, latest_idx = torch.max(nodes.mailbox["timestamp"], dim=1)
ret = (
nodes.mailbox["message"]
.gather(
1,
latest_idx.repeat(self.message_dim).view(
-1, 1, self.message_dim
),
)
.view(-1, self.message_dim)
)
return {
"message_bar": ret.reshape(-1, self.message_dim),
"timestamp": timestamp,
}
def update_memory(self, nodes):
# It should pass the feature through RNN
ret = self.updater(
nodes.data['message_bar'].float(), nodes.data['memory'].float())
return {'memory': ret}
nodes.data["message_bar"].float(), nodes.data["memory"].float()
)
return {"memory": ret}
def forward(self, g):
self.stick_feat_to_graph(g)
......@@ -278,7 +309,7 @@ class MemoryOperation(nn.Module):
class EdgeGATConv(nn.Module):
'''Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
"""Edge Graph attention compute the graph attention from node and edge feature then aggregate both node and
edge feature.
Parameter
......@@ -314,19 +345,21 @@ class EdgeGATConv(nn.Module):
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
'''
"""
def __init__(self,
def __init__(
self,
node_feats,
edge_feats,
out_feats,
num_heads,
feat_drop=0.,
attn_drop=0.,
feat_drop=0.0,
attn_drop=0.0,
negative_slope=0.2,
residual=False,
activation=None,
allow_zero_in_degree=False):
allow_zero_in_degree=False,
):
super(EdgeGATConv, self).__init__()
self._num_heads = num_heads
self._node_feats = node_feats
......@@ -334,15 +367,20 @@ class EdgeGATConv(nn.Module):
self._out_feats = out_feats
self._allow_zero_in_degree = allow_zero_in_degree
self.fc_node = nn.Linear(
self._node_feats, self._out_feats*self._num_heads)
self._node_feats, self._out_feats * self._num_heads
)
self.fc_edge = nn.Linear(
self._edge_feats, self._out_feats*self._num_heads)
self.attn_l = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.attn_r = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self.attn_e = nn.Parameter(torch.FloatTensor(
size=(1, self._num_heads, self._out_feats)))
self._edge_feats, self._out_feats * self._num_heads
)
self.attn_l = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_r = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.attn_e = nn.Parameter(
torch.FloatTensor(size=(1, self._num_heads, self._out_feats))
)
self.feat_drop = nn.Dropout(feat_drop)
self.attn_drop = nn.Dropout(attn_drop)
self.leaky_relu = nn.LeakyReLU(negative_slope)
......@@ -350,14 +388,17 @@ class EdgeGATConv(nn.Module):
if residual:
if self._node_feats != self._out_feats:
self.res_fc = nn.Linear(
self._node_feats, self._out_feats*self._num_heads, bias=False)
self._node_feats,
self._out_feats * self._num_heads,
bias=False,
)
else:
self.res_fc = Identity()
self.reset_parameters()
self.activation = activation
def reset_parameters(self):
gain = nn.init.calculate_gain('relu')
gain = nn.init.calculate_gain("relu")
nn.init.xavier_normal_(self.fc_node.weight, gain=gain)
nn.init.xavier_normal_(self.fc_edge.weight, gain=gain)
nn.init.xavier_normal_(self.attn_l, gain=gain)
......@@ -367,62 +408,69 @@ class EdgeGATConv(nn.Module):
nn.init.xavier_normal_(self.res_fc.weight, gain=gain)
def msg_fn(self, edges):
ret = edges.data['a'].view(-1, self._num_heads,
1)*edges.data['el_prime']
return {'m': ret}
ret = (
edges.data["a"].view(-1, self._num_heads, 1)
* edges.data["el_prime"]
)
return {"m": ret}
def forward(self, graph, nfeat, efeat, get_attention=False):
with graph.local_scope():
if not self._allow_zero_in_degree:
if (graph.in_degrees() == 0).any():
raise DGLError('There are 0-in-degree nodes in the graph, '
'output for those nodes will be invalid. '
'This is harmful for some applications, '
'causing silent performance regression. '
'Adding self-loop on the input graph by '
'calling `g = dgl.add_self_loop(g)` will resolve '
'the issue. Setting ``allow_zero_in_degree`` '
'to be `True` when constructing this module will '
'suppress the check and let the code run.')
raise DGLError(
"There are 0-in-degree nodes in the graph, "
"output for those nodes will be invalid. "
"This is harmful for some applications, "
"causing silent performance regression. "
"Adding self-loop on the input graph by "
"calling `g = dgl.add_self_loop(g)` will resolve "
"the issue. Setting ``allow_zero_in_degree`` "
"to be `True` when constructing this module will "
"suppress the check and let the code run."
)
nfeat = self.feat_drop(nfeat)
efeat = self.feat_drop(efeat)
node_feat = self.fc_node(
nfeat).view(-1, self._num_heads, self._out_feats)
edge_feat = self.fc_edge(
efeat).view(-1, self._num_heads, self._out_feats)
el = (node_feat*self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (node_feat*self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat*self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata['ft'] = node_feat
graph.ndata['el'] = el
graph.ndata['er'] = er
graph.edata['ee'] = ee
graph.apply_edges(fn.u_add_e('el', 'ee', 'el_prime'))
graph.apply_edges(fn.e_add_v('el_prime', 'er', 'e'))
e = self.leaky_relu(graph.edata['e'])
graph.edata['a'] = self.attn_drop(edge_softmax(graph, e))
graph.edata['efeat'] = edge_feat
graph.update_all(self.msg_fn, fn.sum('m', 'ft'))
rst = graph.ndata['ft']
node_feat = self.fc_node(nfeat).view(
-1, self._num_heads, self._out_feats
)
edge_feat = self.fc_edge(efeat).view(
-1, self._num_heads, self._out_feats
)
el = (node_feat * self.attn_l).sum(dim=-1).unsqueeze(-1)
er = (node_feat * self.attn_r).sum(dim=-1).unsqueeze(-1)
ee = (edge_feat * self.attn_e).sum(dim=-1).unsqueeze(-1)
graph.ndata["ft"] = node_feat
graph.ndata["el"] = el
graph.ndata["er"] = er
graph.edata["ee"] = ee
graph.apply_edges(fn.u_add_e("el", "ee", "el_prime"))
graph.apply_edges(fn.e_add_v("el_prime", "er", "e"))
e = self.leaky_relu(graph.edata["e"])
graph.edata["a"] = self.attn_drop(edge_softmax(graph, e))
graph.edata["efeat"] = edge_feat
graph.update_all(self.msg_fn, fn.sum("m", "ft"))
rst = graph.ndata["ft"]
if self.residual:
resval = self.res_fc(nfeat).view(
nfeat.shape[0], -1, self._out_feats)
nfeat.shape[0], -1, self._out_feats
)
rst = rst + resval
if self.activation:
rst = self.activation(rst)
if get_attention:
return rst, graph.edata['a']
return rst, graph.edata["a"]
else:
return rst
class TemporalEdgePreprocess(nn.Module):
'''Preprocess layer, which finish time encoding and concatenate
"""Preprocess layer, which finish time encoding and concatenate
the time encoding to edge feature.
Parameter
......@@ -432,7 +480,7 @@ class TemporalEdgePreprocess(nn.Module):
temporal_encoder : torch.nn.Module
time encoder model
'''
"""
def __init__(self, edge_feats, temporal_encoder):
super(TemporalEdgePreprocess, self).__init__()
......@@ -440,29 +488,32 @@ class TemporalEdgePreprocess(nn.Module):
self.temporal_encoder = temporal_encoder
def edge_fn(self, edges):
t0 = torch.zeros_like(edges.dst['timestamp'])
time_diff = edges.data['timestamp'] - edges.src['timestamp']
time_encode = self.temporal_encoder(
time_diff.unsqueeze(dim=1)).view(t0.shape[0], -1)
edge_feat = torch.cat([edges.data['feats'], time_encode], dim=1)
return {'efeat': edge_feat}
t0 = torch.zeros_like(edges.dst["timestamp"])
time_diff = edges.data["timestamp"] - edges.src["timestamp"]
time_encode = self.temporal_encoder(time_diff.unsqueeze(dim=1)).view(
t0.shape[0], -1
)
edge_feat = torch.cat([edges.data["feats"], time_encode], dim=1)
return {"efeat": edge_feat}
def forward(self, graph):
graph.apply_edges(self.edge_fn)
efeat = graph.edata['efeat']
efeat = graph.edata["efeat"]
return efeat
class TemporalTransformerConv(nn.Module):
def __init__(self,
def __init__(
self,
edge_feats,
memory_feats,
temporal_encoder,
out_feats,
num_heads,
allow_zero_in_degree=False,
layers=1):
'''Temporal Transformer model for TGN and TGAT
layers=1,
):
"""Temporal Transformer model for TGN and TGAT
Parameter
==========
......@@ -487,7 +538,7 @@ class TemporalTransformerConv(nn.Module):
causing silent performance regression. This module will raise a DGLError if it detects
0-in-degree nodes in input graph. By setting ``True``, it will suppress the check
and let the users handle it by themselves. Defaults: ``False``.
'''
"""
super(TemporalTransformerConv, self).__init__()
self._edge_feats = edge_feats
self._memory_feats = memory_feats
......@@ -498,32 +549,42 @@ class TemporalTransformerConv(nn.Module):
self.layers = layers
self.preprocessor = TemporalEdgePreprocess(
self._edge_feats, self.temporal_encoder)
self._edge_feats, self.temporal_encoder
)
self.layer_list = nn.ModuleList()
self.layer_list.append(EdgeGATConv(node_feats=self._memory_feats,
edge_feats=self._edge_feats+self.temporal_encoder.dimension,
self.layer_list.append(
EdgeGATConv(
node_feats=self._memory_feats,
edge_feats=self._edge_feats + self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree))
for i in range(self.layers-1):
self.layer_list.append(EdgeGATConv(node_feats=self._out_feats*self._num_heads,
edge_feats=self._edge_feats+self.temporal_encoder.dimension,
allow_zero_in_degree=allow_zero_in_degree,
)
)
for i in range(self.layers - 1):
self.layer_list.append(
EdgeGATConv(
node_feats=self._out_feats * self._num_heads,
edge_feats=self._edge_feats
+ self.temporal_encoder.dimension,
out_feats=self._out_feats,
num_heads=self._num_heads,
feat_drop=0.6,
attn_drop=0.6,
residual=True,
allow_zero_in_degree=allow_zero_in_degree))
allow_zero_in_degree=allow_zero_in_degree,
)
)
def forward(self, graph, memory, ts):
graph = graph.local_var()
graph.ndata['timestamp'] = ts
graph.ndata["timestamp"] = ts
efeat = self.preprocessor(graph).float()
rst = memory
for i in range(self.layers-1):
for i in range(self.layers - 1):
rst = self.layer_list[i](graph, rst, efeat).flatten(1)
rst = self.layer_list[-1](graph, rst, efeat).mean(1)
return rst
import copy
import torch.nn as nn
from modules import (
MemoryModule,
MemoryOperation,
MsgLinkPredictor,
TemporalTransformerConv,
TimeEncode,
)
import dgl
from modules import MemoryModule, MemoryOperation, MsgLinkPredictor, TemporalTransformerConv, TimeEncode
class TGN(nn.Module):
def __init__(self,
def __init__(
self,
edge_feat_dim,
memory_dim,
temporal_dim,
......@@ -13,8 +22,9 @@ class TGN(nn.Module):
num_heads,
num_nodes,
n_neighbors=10,
memory_updater_type='gru',
layers=1):
memory_updater_type="gru",
layers=1,
):
super(TGN, self).__init__()
self.memory_dim = memory_dim
self.edge_feat_dim = edge_feat_dim
......@@ -28,43 +38,49 @@ class TGN(nn.Module):
self.temporal_encoder = TimeEncode(self.temporal_dim)
self.memory = MemoryModule(self.num_nodes,
self.memory_dim)
self.memory = MemoryModule(self.num_nodes, self.memory_dim)
self.memory_ops = MemoryOperation(self.memory_updater_type,
self.memory_ops = MemoryOperation(
self.memory_updater_type,
self.memory,
self.edge_feat_dim,
self.temporal_encoder)
self.temporal_encoder,
)
self.embedding_attn = TemporalTransformerConv(self.edge_feat_dim,
self.embedding_attn = TemporalTransformerConv(
self.edge_feat_dim,
self.memory_dim,
self.temporal_encoder,
self.embedding_dim,
self.num_heads,
layers=self.layers,
allow_zero_in_degree=True)
allow_zero_in_degree=True,
)
self.msg_linkpredictor = MsgLinkPredictor(embedding_dim)
def embed(self, postive_graph, negative_graph, blocks):
emb_graph = blocks[0]
emb_memory = self.memory.memory[emb_graph.ndata[dgl.NID], :]
emb_t = emb_graph.ndata['timestamp']
emb_t = emb_graph.ndata["timestamp"]
embedding = self.embedding_attn(emb_graph, emb_memory, emb_t)
emb2pred = dict(
zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist()))
zip(emb_graph.ndata[dgl.NID].tolist(), emb_graph.nodes().tolist())
)
# Since postive graph and negative graph has same is mapping
feat_id = [emb2pred[int(n)] for n in postive_graph.ndata[dgl.NID]]
feat = embedding[feat_id]
pred_pos, pred_neg = self.msg_linkpredictor(
feat, postive_graph, negative_graph)
feat, postive_graph, negative_graph
)
return pred_pos, pred_neg
def update_memory(self, subg):
new_g = self.memory_ops(subg)
self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata['memory'])
self.memory.set_memory(new_g.ndata[dgl.NID], new_g.ndata["memory"])
self.memory.set_last_update_t(
new_g.ndata[dgl.NID], new_g.ndata['timestamp'])
new_g.ndata[dgl.NID], new_g.ndata["timestamp"]
)
# Some memory operation wrappers
def detach_memory(self):
......@@ -75,10 +91,10 @@ class TGN(nn.Module):
def store_memory(self):
memory_checkpoint = {}
memory_checkpoint['memory'] = copy.deepcopy(self.memory.memory)
memory_checkpoint['last_t'] = copy.deepcopy(self.memory.last_update_t)
memory_checkpoint["memory"] = copy.deepcopy(self.memory.memory)
memory_checkpoint["last_t"] = copy.deepcopy(self.memory.last_update_t)
return memory_checkpoint
def restore_memory(self, memory_checkpoint):
self.memory.memory = memory_checkpoint['memory']
self.memory.last_update_time = memory_checkpoint['last_t']
self.memory.memory = memory_checkpoint["memory"]
self.memory.last_update_time = memory_checkpoint["last_t"]
import argparse
import traceback
import time
import copy
import time
import traceback
import numpy as np
import dgl
import torch
from tgn import TGN
from data_preprocess import TemporalWikipediaDataset, TemporalRedditDataset, TemporalDataset
from dataloading import (FastTemporalEdgeCollator, FastTemporalSampler,
SimpleTemporalEdgeCollator, SimpleTemporalSampler,
TemporalEdgeDataLoader, TemporalSampler, TemporalEdgeCollator)
from data_preprocess import (
TemporalDataset,
TemporalRedditDataset,
TemporalWikipediaDataset,
)
from dataloading import (
FastTemporalEdgeCollator,
FastTemporalSampler,
SimpleTemporalEdgeCollator,
SimpleTemporalSampler,
TemporalEdgeCollator,
TemporalEdgeDataLoader,
TemporalSampler,
)
from sklearn.metrics import average_precision_score, roc_auc_score
from tgn import TGN
import dgl
TRAIN_SPLIT = 0.7
VALID_SPLIT = 0.85
......@@ -32,10 +40,11 @@ def train(model, dataloader, sampler, criterion, optimizer, args):
for _, positive_pair_g, negative_pair_g, blocks in dataloader:
optimizer.zero_grad()
pred_pos, pred_neg = model.embed(
positive_pair_g, negative_pair_g, blocks)
positive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)*args.batch_size
total_loss += float(loss) * args.batch_size
retain_graph = True if batch_cnt == 0 and not args.fast_mode else False
loss.backward(retain_graph=retain_graph)
optimizer.step()
......@@ -44,7 +53,7 @@ def train(model, dataloader, sampler, criterion, optimizer, args):
model.update_memory(positive_pair_g)
if args.fast_mode:
sampler.attach_last_update(model.memory.last_update_t)
print("Batch: ", batch_cnt, "Time: ", time.time()-last_t)
print("Batch: ", batch_cnt, "Time: ", time.time() - last_t)
last_t = time.time()
batch_cnt += 1
return total_loss
......@@ -59,13 +68,16 @@ def test_val(model, dataloader, sampler, criterion, args):
with torch.no_grad():
for _, postive_pair_g, negative_pair_g, blocks in dataloader:
pred_pos, pred_neg = model.embed(
postive_pair_g, negative_pair_g, blocks)
postive_pair_g, negative_pair_g, blocks
)
loss = criterion(pred_pos, torch.ones_like(pred_pos))
loss += criterion(pred_neg, torch.zeros_like(pred_neg))
total_loss += float(loss)*batch_size
total_loss += float(loss) * batch_size
y_pred = torch.cat([pred_pos, pred_neg], dim=0).sigmoid().cpu()
y_true = torch.cat(
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))], dim=0)
[torch.ones(pred_pos.size(0)), torch.zeros(pred_neg.size(0))],
dim=0,
)
if not args.not_use_memory:
model.update_memory(postive_pair_g)
if args.fast_mode:
......@@ -79,52 +91,108 @@ def test_val(model, dataloader, sampler, criterion, args):
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument("--epochs", type=int, default=50,
help='epochs for training on entire dataset')
parser.add_argument("--batch_size", type=int,
default=200, help="Size of each batch")
parser.add_argument("--embedding_dim", type=int, default=100,
help="Embedding dim for link prediction")
parser.add_argument("--memory_dim", type=int, default=100,
help="dimension of memory")
parser.add_argument("--temporal_dim", type=int, default=100,
help="Temporal dimension for time encoding")
parser.add_argument("--memory_updater", type=str, default='gru',
help="Recurrent unit for memory update")
parser.add_argument("--aggregator", type=str, default='last',
help="Aggregation method for memory update")
parser.add_argument("--n_neighbors", type=int, default=10,
help="number of neighbors while doing embedding")
parser.add_argument("--sampling_method", type=str, default='topk',
help="In embedding how node aggregate from its neighor")
parser.add_argument("--num_heads", type=int, default=8,
help="Number of heads for multihead attention mechanism")
parser.add_argument("--fast_mode", action="store_true", default=False,
help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained")
parser.add_argument("--simple_mode", action="store_true", default=False,
help="Simple Mode directly delete the temporal edges from the original static graph")
parser.add_argument("--num_negative_samples", type=int, default=1,
help="number of negative samplers per positive samples")
parser.add_argument("--dataset", type=str, default="wikipedia",
help="dataset selection wikipedia/reddit")
parser.add_argument("--k_hop", type=int, default=1,
help="sampling k-hop neighborhood")
parser.add_argument("--not_use_memory", action="store_true", default=False,
help="Enable memory for TGN Model disable memory for TGN Model")
parser.add_argument(
"--epochs",
type=int,
default=50,
help="epochs for training on entire dataset",
)
parser.add_argument(
"--batch_size", type=int, default=200, help="Size of each batch"
)
parser.add_argument(
"--embedding_dim",
type=int,
default=100,
help="Embedding dim for link prediction",
)
parser.add_argument(
"--memory_dim", type=int, default=100, help="dimension of memory"
)
parser.add_argument(
"--temporal_dim",
type=int,
default=100,
help="Temporal dimension for time encoding",
)
parser.add_argument(
"--memory_updater",
type=str,
default="gru",
help="Recurrent unit for memory update",
)
parser.add_argument(
"--aggregator",
type=str,
default="last",
help="Aggregation method for memory update",
)
parser.add_argument(
"--n_neighbors",
type=int,
default=10,
help="number of neighbors while doing embedding",
)
parser.add_argument(
"--sampling_method",
type=str,
default="topk",
help="In embedding how node aggregate from its neighor",
)
parser.add_argument(
"--num_heads",
type=int,
default=8,
help="Number of heads for multihead attention mechanism",
)
parser.add_argument(
"--fast_mode",
action="store_true",
default=False,
help="Fast Mode uses batch temporal sampling, history within same batch cannot be obtained",
)
parser.add_argument(
"--simple_mode",
action="store_true",
default=False,
help="Simple Mode directly delete the temporal edges from the original static graph",
)
parser.add_argument(
"--num_negative_samples",
type=int,
default=1,
help="number of negative samplers per positive samples",
)
parser.add_argument(
"--dataset",
type=str,
default="wikipedia",
help="dataset selection wikipedia/reddit",
)
parser.add_argument(
"--k_hop", type=int, default=1, help="sampling k-hop neighborhood"
)
parser.add_argument(
"--not_use_memory",
action="store_true",
default=False,
help="Enable memory for TGN Model disable memory for TGN Model",
)
args = parser.parse_args()
assert not (
args.fast_mode and args.simple_mode), "you can only choose one sampling mode"
args.fast_mode and args.simple_mode
), "you can only choose one sampling mode"
if args.k_hop != 1:
assert args.simple_mode, "this k-hop parameter only support simple mode"
if args.dataset == 'wikipedia':
if args.dataset == "wikipedia":
data = TemporalWikipediaDataset()
elif args.dataset == 'reddit':
elif args.dataset == "reddit":
data = TemporalRedditDataset()
else:
print("Warning Using Untested Dataset: "+args.dataset)
print("Warning Using Untested Dataset: " + args.dataset)
data = TemporalDataset(args.dataset)
# Pre-process data, mask new node in test set from original graph
......@@ -132,22 +200,33 @@ if __name__ == "__main__":
num_edges = data.num_edges()
num_edges = data.num_edges()
trainval_div = int(VALID_SPLIT*num_edges)
trainval_div = int(VALID_SPLIT * num_edges)
# Select new node from test set and remove them from entire graph
test_split_ts = data.edata['timestamp'][trainval_div]
test_nodes = torch.cat([data.edges()[0][trainval_div:], data.edges()[
1][trainval_div:]]).unique().numpy()
test_split_ts = data.edata["timestamp"][trainval_div]
test_nodes = (
torch.cat(
[data.edges()[0][trainval_div:], data.edges()[1][trainval_div:]]
)
.unique()
.numpy()
)
test_new_nodes = np.random.choice(
test_nodes, int(0.1*len(test_nodes)), replace=False)
test_nodes, int(0.1 * len(test_nodes)), replace=False
)
in_subg = dgl.in_subgraph(data, test_new_nodes)
out_subg = dgl.out_subgraph(data, test_new_nodes)
# Remove edge who happen before the test set to prevent from learning the connection info
new_node_in_eid_delete = in_subg.edata[dgl.EID][in_subg.edata['timestamp'] < test_split_ts]
new_node_out_eid_delete = out_subg.edata[dgl.EID][out_subg.edata['timestamp'] < test_split_ts]
new_node_in_eid_delete = in_subg.edata[dgl.EID][
in_subg.edata["timestamp"] < test_split_ts
]
new_node_out_eid_delete = out_subg.edata[dgl.EID][
out_subg.edata["timestamp"] < test_split_ts
]
new_node_eid_delete = torch.cat(
[new_node_in_eid_delete, new_node_out_eid_delete]).unique()
[new_node_in_eid_delete, new_node_out_eid_delete]
).unique()
graph_new_node = copy.deepcopy(data)
# relative order preseved
......@@ -178,26 +257,39 @@ if __name__ == "__main__":
edge_collator = TemporalEdgeCollator
neg_sampler = dgl.dataloading.negative_sampler.Uniform(
k=args.num_negative_samples)
k=args.num_negative_samples
)
# Set Train, validation, test and new node test id
train_seed = torch.arange(int(TRAIN_SPLIT*graph_no_new_node.num_edges()))
valid_seed = torch.arange(int(
TRAIN_SPLIT*graph_no_new_node.num_edges()), trainval_div-new_node_eid_delete.size(0))
train_seed = torch.arange(int(TRAIN_SPLIT * graph_no_new_node.num_edges()))
valid_seed = torch.arange(
int(TRAIN_SPLIT * graph_no_new_node.num_edges()),
trainval_div - new_node_eid_delete.size(0),
)
test_seed = torch.arange(
trainval_div-new_node_eid_delete.size(0), graph_no_new_node.num_edges())
trainval_div - new_node_eid_delete.size(0),
graph_no_new_node.num_edges(),
)
test_new_node_seed = torch.arange(
trainval_div-new_node_eid_delete.size(0), graph_new_node.num_edges())
g_sampling = None if args.fast_mode else dgl.add_reverse_edges(
graph_no_new_node, copy_edata=True)
new_node_g_sampling = None if args.fast_mode else dgl.add_reverse_edges(
graph_new_node, copy_edata=True)
trainval_div - new_node_eid_delete.size(0), graph_new_node.num_edges()
)
g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_no_new_node, copy_edata=True)
)
new_node_g_sampling = (
None
if args.fast_mode
else dgl.add_reverse_edges(graph_new_node, copy_edata=True)
)
if not args.fast_mode:
new_node_g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
g_sampling.ndata[dgl.NID] = new_node_g_sampling.nodes()
# we highly recommend that you always set the num_workers=0, otherwise the sampled subgraph may not be correct.
train_dataloader = TemporalEdgeDataLoader(graph_no_new_node,
train_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
train_seed,
sampler,
batch_size=args.batch_size,
......@@ -206,9 +298,11 @@ if __name__ == "__main__":
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling)
g_sampling=g_sampling,
)
valid_dataloader = TemporalEdgeDataLoader(graph_no_new_node,
valid_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
valid_seed,
sampler,
batch_size=args.batch_size,
......@@ -217,9 +311,11 @@ if __name__ == "__main__":
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling)
g_sampling=g_sampling,
)
test_dataloader = TemporalEdgeDataLoader(graph_no_new_node,
test_dataloader = TemporalEdgeDataLoader(
graph_no_new_node,
test_seed,
sampler,
batch_size=args.batch_size,
......@@ -228,9 +324,11 @@ if __name__ == "__main__":
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=g_sampling)
g_sampling=g_sampling,
)
test_new_node_dataloader = TemporalEdgeDataLoader(graph_new_node,
test_new_node_dataloader = TemporalEdgeDataLoader(
graph_new_node,
test_new_node_seed,
new_node_sampler if args.fast_mode else sampler,
batch_size=args.batch_size,
......@@ -239,12 +337,14 @@ if __name__ == "__main__":
drop_last=False,
num_workers=0,
collator=edge_collator,
g_sampling=new_node_g_sampling)
g_sampling=new_node_g_sampling,
)
edge_dim = data.edata['feats'].shape[1]
edge_dim = data.edata["feats"].shape[1]
num_node = data.num_nodes()
model = TGN(edge_feat_dim=edge_dim,
model = TGN(
edge_feat_dim=edge_dim,
memory_dim=args.memory_dim,
temporal_dim=args.temporal_dim,
embedding_dim=args.embedding_dim,
......@@ -252,44 +352,58 @@ if __name__ == "__main__":
num_nodes=num_node,
n_neighbors=args.n_neighbors,
memory_updater_type=args.memory_updater,
layers=args.k_hop)
layers=args.k_hop,
)
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=0.0001)
# Implement Logging mechanism
f = open("logging.txt", 'w')
f = open("logging.txt", "w")
if args.fast_mode:
sampler.reset()
try:
for i in range(args.epochs):
train_loss = train(model, train_dataloader, sampler,
criterion, optimizer, args)
train_loss = train(
model, train_dataloader, sampler, criterion, optimizer, args
)
val_ap, val_auc = test_val(
model, valid_dataloader, sampler, criterion, args)
model, valid_dataloader, sampler, criterion, args
)
memory_checkpoint = model.store_memory()
if args.fast_mode:
new_node_sampler.sync(sampler)
test_ap, test_auc = test_val(
model, test_dataloader, sampler, criterion, args)
model, test_dataloader, sampler, criterion, args
)
model.restore_memory(memory_checkpoint)
if args.fast_mode:
sample_nn = new_node_sampler
else:
sample_nn = sampler
nn_test_ap, nn_test_auc = test_val(
model, test_new_node_dataloader, sample_nn, criterion, args)
model, test_new_node_dataloader, sample_nn, criterion, args
)
log_content = []
log_content.append("Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc))
log_content.append(
"Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(i, test_ap, test_auc))
log_content.append("Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
i, nn_test_ap, nn_test_auc))
"Epoch: {}; Training Loss: {} | Validation AP: {:.3f} AUC: {:.3f}\n".format(
i, train_loss, val_ap, val_auc
)
)
log_content.append(
"Epoch: {}; Test AP: {:.3f} AUC: {:.3f}\n".format(
i, test_ap, test_auc
)
)
log_content.append(
"Epoch: {}; Test New Node AP: {:.3f} AUC: {:.3f}\n".format(
i, nn_test_ap, nn_test_auc
)
)
f.writelines(log_content)
model.reset_memory()
if i < args.epochs-1 and args.fast_mode:
if i < args.epochs - 1 and args.fast_mode:
sampler.reset()
print(log_content[0], log_content[1], log_content[2])
except KeyboardInterrupt:
......
from .graph import *
from .fields import *
from .utils import prepare_dataset
import os
import random
from .fields import *
from .graph import *
from .utils import prepare_dataset
class ClassificationDataset(object):
"Dataset class for classification task."
def __init__(self):
raise NotImplementedError
class TranslationDataset(object):
'''
"""
Dataset class for translation task.
By default, the source language shares the same vocabulary with the target language.
'''
INIT_TOKEN = '<sos>'
EOS_TOKEN = '<eos>'
PAD_TOKEN = '<pad>'
"""
INIT_TOKEN = "<sos>"
EOS_TOKEN = "<eos>"
PAD_TOKEN = "<pad>"
MAX_LENGTH = 50
def __init__(self, path, exts, train='train', valid='valid', test='test', vocab='vocab.txt', replace_oov=None):
def __init__(
self,
path,
exts,
train="train",
valid="valid",
test="test",
vocab="vocab.txt",
replace_oov=None,
):
vocab_path = os.path.join(path, vocab)
self.src = {}
self.tgt = {}
with open(os.path.join(path, train + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['train'] = f.readlines()
with open(os.path.join(path, train + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['train'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['valid'] = f.readlines()
with open(os.path.join(path, valid + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['valid'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[0]), 'r', encoding='utf-8') as f:
self.src['test'] = f.readlines()
with open(os.path.join(path, test + '.' + exts[1]), 'r', encoding='utf-8') as f:
self.tgt['test'] = f.readlines()
with open(
os.path.join(path, train + "." + exts[0]), "r", encoding="utf-8"
) as f:
self.src["train"] = f.readlines()
with open(
os.path.join(path, train + "." + exts[1]), "r", encoding="utf-8"
) as f:
self.tgt["train"] = f.readlines()
with open(
os.path.join(path, valid + "." + exts[0]), "r", encoding="utf-8"
) as f:
self.src["valid"] = f.readlines()
with open(
os.path.join(path, valid + "." + exts[1]), "r", encoding="utf-8"
) as f:
self.tgt["valid"] = f.readlines()
with open(
os.path.join(path, test + "." + exts[0]), "r", encoding="utf-8"
) as f:
self.src["test"] = f.readlines()
with open(
os.path.join(path, test + "." + exts[1]), "r", encoding="utf-8"
) as f:
self.tgt["test"] = f.readlines()
if not os.path.exists(vocab_path):
self._make_vocab(vocab_path)
vocab = Vocab(init_token=self.INIT_TOKEN,
vocab = Vocab(
init_token=self.INIT_TOKEN,
eos_token=self.EOS_TOKEN,
pad_token=self.PAD_TOKEN,
unk_token=replace_oov)
unk_token=replace_oov,
)
vocab.load(vocab_path)
self.vocab = vocab
strip_func = lambda x: x[:self.MAX_LENGTH]
self.src_field = Field(vocab,
preprocessing=None,
postprocessing=strip_func)
self.tgt_field = Field(vocab,
preprocessing=lambda seq: [self.INIT_TOKEN] + seq + [self.EOS_TOKEN],
postprocessing=strip_func)
def get_seq_by_id(self, idx, mode='train', field='src'):
strip_func = lambda x: x[: self.MAX_LENGTH]
self.src_field = Field(
vocab, preprocessing=None, postprocessing=strip_func
)
self.tgt_field = Field(
vocab,
preprocessing=lambda seq: [self.INIT_TOKEN]
+ seq
+ [self.EOS_TOKEN],
postprocessing=strip_func,
)
def get_seq_by_id(self, idx, mode="train", field="src"):
"get raw sequence in dataset by specifying index, mode(train/valid/test), field(src/tgt)"
if field == 'src':
if field == "src":
return self.src[mode][idx].strip().split()
else:
return [self.INIT_TOKEN] + self.tgt[mode][idx].strip().split() + [self.EOS_TOKEN]
return (
[self.INIT_TOKEN]
+ self.tgt[mode][idx].strip().split()
+ [self.EOS_TOKEN]
)
def _make_vocab(self, path, thres=2):
word_dict = {}
for mode in ['train', 'valid', 'test']:
for mode in ["train", "valid", "test"]:
for line in self.src[mode] + self.tgt[mode]:
for token in line.strip().split():
if token not in word_dict:
......@@ -69,7 +106,7 @@ class TranslationDataset(object):
else:
word_dict[token] += 1
with open(path, 'w') as f:
with open(path, "w") as f:
for k, v in word_dict.items():
if v > 2:
print(k, file=f)
......@@ -90,9 +127,17 @@ class TranslationDataset(object):
def eos_id(self):
return self.vocab[self.EOS_TOKEN]
def __call__(self, graph_pool, mode='train', batch_size=32, k=1,
device='cpu', dev_rank=0, ndev=1):
'''
def __call__(
self,
graph_pool,
mode="train",
batch_size=32,
k=1,
device="cpu",
dev_rank=0,
ndev=1,
):
"""
Create a batched graph correspond to the mini-batch of the dataset.
args:
graph_pool: a GraphPool object for accelerating.
......@@ -102,7 +147,7 @@ class TranslationDataset(object):
device: str or torch.device
dev_rank: rank (id) of current device
ndev: number of devices
'''
"""
src_data, tgt_data = self.src[mode], self.tgt[mode]
n = len(src_data)
# make sure all devices have the same number of batch
......@@ -111,28 +156,30 @@ class TranslationDataset(object):
# XXX: partition then shuffle may not be equivalent to shuffle then
# partition
order = list(range(dev_rank, n, ndev))
if mode == 'train':
if mode == "train":
random.shuffle(order)
src_buf, tgt_buf = [], []
for idx in order:
src_sample = self.src_field(
src_data[idx].strip().split())
tgt_sample = self.tgt_field(
tgt_data[idx].strip().split())
src_sample = self.src_field(src_data[idx].strip().split())
tgt_sample = self.tgt_field(tgt_data[idx].strip().split())
src_buf.append(src_sample)
tgt_buf.append(tgt_sample)
if len(src_buf) == batch_size:
if mode == 'test':
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
if mode == "test":
yield graph_pool.beam(
src_buf, self.sos_id, self.MAX_LENGTH, k, device=device
)
else:
yield graph_pool(src_buf, tgt_buf, device=device)
src_buf, tgt_buf = [], []
if len(src_buf) != 0:
if mode == 'test':
yield graph_pool.beam(src_buf, self.sos_id, self.MAX_LENGTH, k, device=device)
if mode == "test":
yield graph_pool.beam(
src_buf, self.sos_id, self.MAX_LENGTH, k, device=device
)
else:
yield graph_pool(src_buf, tgt_buf, device=device)
......@@ -145,38 +192,46 @@ class TranslationDataset(object):
l = seq.index(self.eos_id)
except:
l = len(seq)
ret.append(' '.join(self.vocab[token] for token in seq[:l] if not token in filter_list))
ret.append(
" ".join(
self.vocab[token]
for token in seq[:l]
if not token in filter_list
)
)
return ret
def get_dataset(dataset):
"we wrapped a set of datasets as example"
prepare_dataset(dataset)
if dataset == 'babi':
if dataset == "babi":
raise NotImplementedError
elif dataset == 'copy' or dataset == 'sort':
elif dataset == "copy" or dataset == "sort":
return TranslationDataset(
'data/{}'.format(dataset),
('in', 'out'),
train='train',
valid='valid',
test='test',
"data/{}".format(dataset),
("in", "out"),
train="train",
valid="valid",
test="test",
)
elif dataset == 'multi30k':
elif dataset == "multi30k":
return TranslationDataset(
'data/multi30k',
('en.atok', 'de.atok'),
train='train',
valid='val',
test='test2016',
replace_oov='<unk>'
"data/multi30k",
("en.atok", "de.atok"),
train="train",
valid="val",
test="test2016",
replace_oov="<unk>",
)
elif dataset == 'wmt14':
elif dataset == "wmt14":
return TranslationDataset(
'data/wmt14',
('en', 'de'),
train='train.tok.clean.bpe.32000',
valid='newstest2013.tok.bpe.32000',
test='newstest2014.tok.bpe.32000.ende',
vocab='vocab.bpe.32000')
"data/wmt14",
("en", "de"),
train="train.tok.clean.bpe.32000",
valid="newstest2013.tok.bpe.32000",
test="newstest2014.tok.bpe.32000.ende",
vocab="vocab.bpe.32000",
)
else:
raise KeyError()
class Vocab:
def __init__(self, init_token=None, eos_token=None, pad_token=None, unk_token=None):
def __init__(
self, init_token=None, eos_token=None, pad_token=None, unk_token=None
):
self.init_token = init_token
self.eos_token = eos_token
self.pad_token = pad_token
......@@ -16,13 +18,11 @@ class Vocab:
self.vocab_lst.append(self.pad_token)
if self.unk_token is not None:
self.vocab_lst.append(self.unk_token)
with open(path, 'r', encoding='utf-8') as f:
with open(path, "r", encoding="utf-8") as f:
for token in f.readlines():
token = token.strip()
self.vocab_lst.append(token)
self.vocab_dict = {
v: k for k, v in enumerate(self.vocab_lst)
}
self.vocab_dict = {v: k for k, v in enumerate(self.vocab_lst)}
def __len__(self):
return len(self.vocab_lst)
......@@ -36,6 +36,7 @@ class Vocab:
else:
return self.vocab_lst[key]
class Field:
def __init__(self, vocab, preprocessing=None, postprocessing=None):
self.vocab = vocab
......@@ -56,8 +57,4 @@ class Field:
return [self.vocab[token] for token in x]
def __call__(self, x):
return self.postprocess(
self.numericalize(
self.preprocess(x)
)
)
return self.postprocess(self.numericalize(self.preprocess(x)))
import dgl
import torch as th
import numpy as np
import itertools
import time
from collections import *
Graph = namedtuple('Graph',
['g', 'src', 'tgt', 'tgt_y', 'nids', 'eids', 'nid_arr', 'n_nodes', 'n_edges', 'n_tokens'])
import numpy as np
import torch as th
import dgl
Graph = namedtuple(
"Graph",
[
"g",
"src",
"tgt",
"tgt_y",
"nids",
"eids",
"nid_arr",
"n_nodes",
"n_edges",
"n_tokens",
],
)
class GraphPool:
"Create a graph pool in advance to accelerate graph building phase in Transformer."
def __init__(self, n=50, m=50):
'''
"""
args:
n: maximum length of input sequence.
m: maximum length of output sequence.
'''
print('start creating graph pool...')
"""
print("start creating graph pool...")
tic = time.time()
self.n, self.m = n, m
g_pool = [[dgl.graph(([], [])) for _ in range(m)] for _ in range(n)]
num_edges = {
'ee': np.zeros((n, n)).astype(int),
'ed': np.zeros((n, m)).astype(int),
'dd': np.zeros((m, m)).astype(int)
"ee": np.zeros((n, n)).astype(int),
"ed": np.zeros((n, m)).astype(int),
"dd": np.zeros((m, m)).astype(int),
}
for i, j in itertools.product(range(n), range(m)):
src_length = i + 1
......@@ -37,25 +54,29 @@ class GraphPool:
us = enc_nodes.unsqueeze(-1).repeat(1, src_length).view(-1)
vs = enc_nodes.repeat(src_length)
g_pool[i][j].add_edges(us, vs)
num_edges['ee'][i][j] = len(us)
num_edges["ee"][i][j] = len(us)
# enc -> dec
us = enc_nodes.unsqueeze(-1).repeat(1, tgt_length).view(-1)
vs = dec_nodes.repeat(src_length)
g_pool[i][j].add_edges(us, vs)
num_edges['ed'][i][j] = len(us)
num_edges["ed"][i][j] = len(us)
# dec -> dec
indices = th.triu(th.ones(tgt_length, tgt_length)) == 1
us = dec_nodes.unsqueeze(-1).repeat(1, tgt_length)[indices]
vs = dec_nodes.unsqueeze(0).repeat(tgt_length, 1)[indices]
g_pool[i][j].add_edges(us, vs)
num_edges['dd'][i][j] = len(us)
num_edges["dd"][i][j] = len(us)
print('successfully created graph pool, time: {0:0.3f}s'.format(time.time() - tic))
print(
"successfully created graph pool, time: {0:0.3f}s".format(
time.time() - tic
)
)
self.g_pool = g_pool
self.num_edges = num_edges
def beam(self, src_buf, start_sym, max_len, k, device='cpu'):
'''
def beam(self, src_buf, start_sym, max_len, k, device="cpu"):
"""
Return a batched graph for beam search during inference of Transformer.
args:
src_buf: a list of input sequence
......@@ -63,16 +84,16 @@ class GraphPool:
max_len: maximum length for decoding
k: beam size
device: 'cpu' or 'cuda:*'
'''
"""
g_list = []
src_lens = [len(_) for _ in src_buf]
tgt_lens = [max_len] * len(src_buf)
num_edges = {'ee': [], 'ed': [], 'dd': []}
num_edges = {"ee": [], "ed": [], "dd": []}
for src_len, tgt_len in zip(src_lens, tgt_lens):
i, j = src_len - 1, tgt_len - 1
for _ in range(k):
g_list.append(self.g_pool[i][j])
for key in ['ee', 'ed', 'dd']:
for key in ["ee", "ed", "dd"]:
num_edges[key].append(int(self.num_edges[key][i][j]))
g = dgl.batch(g_list)
......@@ -81,57 +102,85 @@ class GraphPool:
enc_ids, dec_ids = [], []
e2e_eids, e2d_eids, d2d_eids = [], [], []
n_nodes, n_edges, n_tokens = 0, 0, 0
for src_sample, n, n_ee, n_ed, n_dd in zip(src_buf, src_lens, num_edges['ee'], num_edges['ed'], num_edges['dd']):
for src_sample, n, n_ee, n_ed, n_dd in zip(
src_buf, src_lens, num_edges["ee"], num_edges["ed"], num_edges["dd"]
):
for _ in range(k):
src.append(th.tensor(src_sample, dtype=th.long, device=device))
src_pos.append(th.arange(n, dtype=th.long, device=device))
enc_ids.append(th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device))
enc_ids.append(
th.arange(
n_nodes, n_nodes + n, dtype=th.long, device=device
)
)
n_nodes += n
e2e_eids.append(th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device))
e2e_eids.append(
th.arange(
n_edges, n_edges + n_ee, dtype=th.long, device=device
)
)
n_edges += n_ee
tgt_seq = th.zeros(max_len, dtype=th.long, device=device)
tgt_seq[0] = start_sym
tgt.append(tgt_seq)
tgt_pos.append(th.arange(max_len, dtype=th.long, device=device))
dec_ids.append(th.arange(n_nodes, n_nodes + max_len, dtype=th.long, device=device))
dec_ids.append(
th.arange(
n_nodes, n_nodes + max_len, dtype=th.long, device=device
)
)
n_nodes += max_len
e2d_eids.append(th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device))
e2d_eids.append(
th.arange(
n_edges, n_edges + n_ed, dtype=th.long, device=device
)
)
n_edges += n_ed
d2d_eids.append(th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device))
d2d_eids.append(
th.arange(
n_edges, n_edges + n_dd, dtype=th.long, device=device
)
)
n_edges += n_dd
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
g = g.to(device).long()
return Graph(g=g,
return Graph(
g=g,
src=(th.cat(src), th.cat(src_pos)),
tgt=(th.cat(tgt), th.cat(tgt_pos)),
tgt_y=None,
nids = {'enc': th.cat(enc_ids), 'dec': th.cat(dec_ids)},
eids = {'ee': th.cat(e2e_eids), 'ed': th.cat(e2d_eids), 'dd': th.cat(d2d_eids)},
nid_arr = {'enc': enc_ids, 'dec': dec_ids},
nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
eids={
"ee": th.cat(e2e_eids),
"ed": th.cat(e2d_eids),
"dd": th.cat(d2d_eids),
},
nid_arr={"enc": enc_ids, "dec": dec_ids},
n_nodes=n_nodes,
n_edges=n_edges,
n_tokens=n_tokens)
n_tokens=n_tokens,
)
def __call__(self, src_buf, tgt_buf, device='cpu'):
'''
def __call__(self, src_buf, tgt_buf, device="cpu"):
"""
Return a batched graph for the training phase of Transformer.
args:
src_buf: a set of input sequence arrays.
tgt_buf: a set of output sequence arrays.
device: 'cpu' or 'cuda:*'
'''
"""
g_list = []
src_lens = [len(_) for _ in src_buf]
tgt_lens = [len(_) - 1 for _ in tgt_buf]
num_edges = {'ee': [], 'ed': [], 'dd': []}
num_edges = {"ee": [], "ed": [], "dd": []}
for src_len, tgt_len in zip(src_lens, tgt_lens):
i, j = src_len - 1, tgt_len - 1
g_list.append(self.g_pool[i][j])
for key in ['ee', 'ed', 'dd']:
for key in ["ee", "ed", "dd"]:
num_edges[key].append(int(self.num_edges[key][i][j]))
g = dgl.batch(g_list)
......@@ -140,36 +189,61 @@ class GraphPool:
enc_ids, dec_ids = [], []
e2e_eids, d2d_eids, e2d_eids = [], [], []
n_nodes, n_edges, n_tokens = 0, 0, 0
for src_sample, tgt_sample, n, m, n_ee, n_ed, n_dd in zip(src_buf, tgt_buf, src_lens, tgt_lens, num_edges['ee'], num_edges['ed'], num_edges['dd']):
for src_sample, tgt_sample, n, m, n_ee, n_ed, n_dd in zip(
src_buf,
tgt_buf,
src_lens,
tgt_lens,
num_edges["ee"],
num_edges["ed"],
num_edges["dd"],
):
src.append(th.tensor(src_sample, dtype=th.long, device=device))
tgt.append(th.tensor(tgt_sample[:-1], dtype=th.long, device=device))
tgt_y.append(th.tensor(tgt_sample[1:], dtype=th.long, device=device))
tgt_y.append(
th.tensor(tgt_sample[1:], dtype=th.long, device=device)
)
src_pos.append(th.arange(n, dtype=th.long, device=device))
tgt_pos.append(th.arange(m, dtype=th.long, device=device))
enc_ids.append(th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device))
enc_ids.append(
th.arange(n_nodes, n_nodes + n, dtype=th.long, device=device)
)
n_nodes += n
dec_ids.append(th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device))
dec_ids.append(
th.arange(n_nodes, n_nodes + m, dtype=th.long, device=device)
)
n_nodes += m
e2e_eids.append(th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device))
e2e_eids.append(
th.arange(n_edges, n_edges + n_ee, dtype=th.long, device=device)
)
n_edges += n_ee
e2d_eids.append(th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device))
e2d_eids.append(
th.arange(n_edges, n_edges + n_ed, dtype=th.long, device=device)
)
n_edges += n_ed
d2d_eids.append(th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device))
d2d_eids.append(
th.arange(n_edges, n_edges + n_dd, dtype=th.long, device=device)
)
n_edges += n_dd
n_tokens += m
g.set_n_initializer(dgl.init.zero_initializer)
g.set_e_initializer(dgl.init.zero_initializer)
g = g.to(device).long()
return Graph(g=g,
return Graph(
g=g,
src=(th.cat(src), th.cat(src_pos)),
tgt=(th.cat(tgt), th.cat(tgt_pos)),
tgt_y=th.cat(tgt_y),
nids = {'enc': th.cat(enc_ids), 'dec': th.cat(dec_ids)},
eids = {'ee': th.cat(e2e_eids), 'ed': th.cat(e2d_eids), 'dd': th.cat(d2d_eids)},
nid_arr = {'enc': enc_ids, 'dec': dec_ids},
nids={"enc": th.cat(enc_ids), "dec": th.cat(dec_ids)},
eids={
"ee": th.cat(e2e_eids),
"ed": th.cat(e2d_eids),
"dd": th.cat(d2d_eids),
},
nid_arr={"enc": enc_ids, "dec": dec_ids},
n_nodes=n_nodes,
n_edges=n_edges,
n_tokens=n_tokens)
n_tokens=n_tokens,
)
import os
import numpy as np
import torch as th
import os
from dgl.data.utils import *
_urls = {
'wmt': 'https://data.dgl.ai/dataset/wmt14bpe_de_en.zip',
'scripts': 'https://data.dgl.ai/dataset/transformer_scripts.zip',
"wmt": "https://data.dgl.ai/dataset/wmt14bpe_de_en.zip",
"scripts": "https://data.dgl.ai/dataset/transformer_scripts.zip",
}
def prepare_dataset(dataset_name):
"download and generate datasets"
script_dir = os.path.join('scripts')
script_dir = os.path.join("scripts")
if not os.path.exists(script_dir):
download(_urls['scripts'], path='scripts.zip')
extract_archive('scripts.zip', 'scripts')
download(_urls["scripts"], path="scripts.zip")
extract_archive("scripts.zip", "scripts")
directory = os.path.join('data', dataset_name)
directory = os.path.join("data", dataset_name)
if not os.path.exists(directory):
os.makedirs(directory)
else:
return
if dataset_name == 'multi30k':
os.system('bash scripts/prepare-multi30k.sh')
elif dataset_name == 'wmt14':
download(_urls['wmt'], path='wmt14.zip')
os.system('bash scripts/prepare-wmt14.sh')
elif dataset_name == 'copy' or dataset_name == 'tiny_copy':
if dataset_name == "multi30k":
os.system("bash scripts/prepare-multi30k.sh")
elif dataset_name == "wmt14":
download(_urls["wmt"], path="wmt14.zip")
os.system("bash scripts/prepare-wmt14.sh")
elif dataset_name == "copy" or dataset_name == "tiny_copy":
train_size = 9000
valid_size = 1000
test_size = 1000
char_list = [chr(i) for i in range(ord('a'), ord('z') + 1)]
with open(os.path.join(directory, 'train.in'), 'w') as f_in,\
open(os.path.join(directory, 'train.out'), 'w') as f_out:
for i, l in zip(range(train_size), np.random.normal(15, 3, train_size).astype(int)):
char_list = [chr(i) for i in range(ord("a"), ord("z") + 1)]
with open(os.path.join(directory, "train.in"), "w") as f_in, open(
os.path.join(directory, "train.out"), "w"
) as f_out:
for i, l in zip(
range(train_size),
np.random.normal(15, 3, train_size).astype(int),
):
l = max(l, 1)
line = ' '.join(np.random.choice(char_list, l)) + '\n'
line = " ".join(np.random.choice(char_list, l)) + "\n"
f_in.write(line)
f_out.write(line)
with open(os.path.join(directory, 'valid.in'), 'w') as f_in,\
open(os.path.join(directory, 'valid.out'), 'w') as f_out:
for i, l in zip(range(valid_size), np.random.normal(15, 3, valid_size).astype(int)):
with open(os.path.join(directory, "valid.in"), "w") as f_in, open(
os.path.join(directory, "valid.out"), "w"
) as f_out:
for i, l in zip(
range(valid_size),
np.random.normal(15, 3, valid_size).astype(int),
):
l = max(l, 1)
line = ' '.join(np.random.choice(char_list, l)) + '\n'
line = " ".join(np.random.choice(char_list, l)) + "\n"
f_in.write(line)
f_out.write(line)
with open(os.path.join(directory, 'test.in'), 'w') as f_in,\
open(os.path.join(directory, 'test.out'), 'w') as f_out:
for i, l in zip(range(test_size), np.random.normal(15, 3, test_size).astype(int)):
with open(os.path.join(directory, "test.in"), "w") as f_in, open(
os.path.join(directory, "test.out"), "w"
) as f_out:
for i, l in zip(
range(test_size), np.random.normal(15, 3, test_size).astype(int)
):
l = max(l, 1)
line = ' '.join(np.random.choice(char_list, l)) + '\n'
line = " ".join(np.random.choice(char_list, l)) + "\n"
f_in.write(line)
f_out.write(line)
with open(os.path.join(directory, 'vocab.txt'), 'w') as f:
with open(os.path.join(directory, "vocab.txt"), "w") as f:
for c in char_list:
f.write(c + '\n')
f.write(c + "\n")
elif dataset_name == 'sort' or dataset_name == 'tiny_sort':
elif dataset_name == "sort" or dataset_name == "tiny_sort":
train_size = 9000
valid_size = 1000
test_size = 1000
char_list = [chr(i) for i in range(ord('a'), ord('z') + 1)]
with open(os.path.join(directory, 'train.in'), 'w') as f_in,\
open(os.path.join(directory, 'train.out'), 'w') as f_out:
for i, l in zip(range(train_size), np.random.normal(15, 3, train_size).astype(int)):
char_list = [chr(i) for i in range(ord("a"), ord("z") + 1)]
with open(os.path.join(directory, "train.in"), "w") as f_in, open(
os.path.join(directory, "train.out"), "w"
) as f_out:
for i, l in zip(
range(train_size),
np.random.normal(15, 3, train_size).astype(int),
):
l = max(l, 1)
seq = np.random.choice(char_list, l)
f_in.write(' '.join(seq) + '\n')
f_out.write(' '.join(np.sort(seq)) + '\n')
f_in.write(" ".join(seq) + "\n")
f_out.write(" ".join(np.sort(seq)) + "\n")
with open(os.path.join(directory, 'valid.in'), 'w') as f_in,\
open(os.path.join(directory, 'valid.out'), 'w') as f_out:
for i, l in zip(range(valid_size), np.random.normal(15, 3, valid_size).astype(int)):
with open(os.path.join(directory, "valid.in"), "w") as f_in, open(
os.path.join(directory, "valid.out"), "w"
) as f_out:
for i, l in zip(
range(valid_size),
np.random.normal(15, 3, valid_size).astype(int),
):
l = max(l, 1)
seq = np.random.choice(char_list, l)
f_in.write(' '.join(seq) + '\n')
f_out.write(' '.join(np.sort(seq)) + '\n')
f_in.write(" ".join(seq) + "\n")
f_out.write(" ".join(np.sort(seq)) + "\n")
with open(os.path.join(directory, 'test.in'), 'w') as f_in,\
open(os.path.join(directory, 'test.out'), 'w') as f_out:
for i, l in zip(range(test_size), np.random.normal(15, 3, test_size).astype(int)):
with open(os.path.join(directory, "test.in"), "w") as f_in, open(
os.path.join(directory, "test.out"), "w"
) as f_out:
for i, l in zip(
range(test_size), np.random.normal(15, 3, test_size).astype(int)
):
l = max(l, 1)
seq = np.random.choice(char_list, l)
f_in.write(' '.join(seq) + '\n')
f_out.write(' '.join(np.sort(seq)) + '\n')
f_in.write(" ".join(seq) + "\n")
f_out.write(" ".join(np.sort(seq)) + "\n")
with open(os.path.join(directory, 'vocab.txt'), 'w') as f:
with open(os.path.join(directory, "vocab.txt"), "w") as f:
for c in char_list:
f.write(c + '\n')
f.write(c + "\n")
import torch as T
import torch.distributed as dist
import torch.nn as nn
import torch.nn.functional as F
import torch.distributed as dist
class LabelSmoothing(nn.Module):
"""
Computer loss at one time step.
"""
def __init__(self, size, padding_idx, smoothing=0.0):
"""Label Smoothing module
args:
......@@ -15,7 +17,7 @@ class LabelSmoothing(nn.Module):
smoothing: smoothing ratio
"""
super(LabelSmoothing, self).__init__()
self.criterion = nn.KLDivLoss(reduction='sum')
self.criterion = nn.KLDivLoss(reduction="sum")
self.size = size
self.padding_idx = padding_idx
self.smoothing = smoothing
......@@ -26,7 +28,9 @@ class LabelSmoothing(nn.Module):
assert x.size(1) == self.size
with T.no_grad():
tgt_dist = T.zeros_like(x, dtype=T.float)
tgt_dist.fill_(self.smoothing / (self.size - 2)) # one for padding, another for label
tgt_dist.fill_(
self.smoothing / (self.size - 2)
) # one for padding, another for label
tgt_dist[:, self.padding_idx] = 0
tgt_dist.scatter_(1, target.unsqueeze(1), 1 - self.smoothing)
......@@ -36,8 +40,10 @@ class LabelSmoothing(nn.Module):
return self.criterion(x, tgt_dist)
class SimpleLossCompute(nn.Module):
eps=1e-8
eps = 1e-8
def __init__(self, criterion, grad_accum, opt=None):
"""Loss function and optimizer for single device
......@@ -96,11 +102,16 @@ class SimpleLossCompute(nn.Module):
self.loss = self.criterion(y_pred, y) / norm
if self.opt is not None:
self.backward_and_step()
self.n_correct += ((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx)).sum().item()
self.n_correct += (
((y_pred.max(dim=-1)[1] == y) & (y != self.criterion.padding_idx))
.sum()
.item()
)
self.acc_loss += self.loss.item() * norm
self.norm_term += norm
return self.loss.item() * norm
class MultiGPULossCompute(SimpleLossCompute):
def __init__(self, criterion, ndev, grad_accum, model, opt=None):
"""Loss function and optimizer for multiple devices
......@@ -119,7 +130,9 @@ class MultiGPULossCompute(SimpleLossCompute):
Model optimizer to use. If None, then no backward and update will be
performed
"""
super(MultiGPULossCompute, self).__init__(criterion, grad_accum, opt=opt)
super(MultiGPULossCompute, self).__init__(
criterion, grad_accum, opt=opt
)
self.ndev = ndev
self.model = model
......
import numpy as np
import torch as th
import torch.nn as nn
import numpy as np
from .layers import clones
class MultiHeadAttention(nn.Module):
"Multi-Head Attention"
def __init__(self, h, dim_model):
"h: number of heads; dim_model: hidden dimension"
super(MultiHeadAttention, self).__init__()
self.d_k = dim_model // h
self.h = h
# W_q, W_k, W_v, W_o
self.linears = clones(
nn.Linear(dim_model, dim_model, bias=False), 4
)
self.linears = clones(nn.Linear(dim_model, dim_model, bias=False), 4)
def get(self, x, fields='qkv'):
def get(self, x, fields="qkv"):
"Return a dict of queries / keys / values."
batch_size = x.shape[0]
ret = {}
if 'q' in fields:
ret['q'] = self.linears[0](x).view(batch_size, self.h, self.d_k)
if 'k' in fields:
ret['k'] = self.linears[1](x).view(batch_size, self.h, self.d_k)
if 'v' in fields:
ret['v'] = self.linears[2](x).view(batch_size, self.h, self.d_k)
if "q" in fields:
ret["q"] = self.linears[0](x).view(batch_size, self.h, self.d_k)
if "k" in fields:
ret["k"] = self.linears[1](x).view(batch_size, self.h, self.d_k)
if "v" in fields:
ret["v"] = self.linears[2](x).view(batch_size, self.h, self.d_k)
return ret
def get_o(self, x):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment