"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "9ff72433fa5a4d9f9e2f2c599e394480b581c614"
Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import os
from time import time
import torch import torch
import torch.optim as optim import torch.optim as optim
from model import NGCF from model import NGCF
from utility.batch_test import * from utility.batch_test import *
from utility.helper import early_stopping from utility.helper import early_stopping
from time import time
import os
def main(args): def main(args):
# Step 1: Prepare graph data and device ================================================================= # # Step 1: Prepare graph data and device ================================================================= #
if args.gpu >= 0 and torch.cuda.is_available(): if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu) device = "cuda:{}".format(args.gpu)
else: else:
device = 'cpu' device = "cpu"
g=data_generator.g g = data_generator.g
g=g.to(device) g = g.to(device)
# Step 2: Create model and training components=========================================================== # # Step 2: Create model and training components=========================================================== #
model = NGCF(g, args.embed_size, args.layer_size, args.mess_dropout, args.regs[0]).to(device) model = NGCF(
g, args.embed_size, args.layer_size, args.mess_dropout, args.regs[0]
).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr) optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Step 3: training epoches ============================================================================== # # Step 3: training epoches ============================================================================== #
...@@ -27,16 +31,16 @@ def main(args): ...@@ -27,16 +31,16 @@ def main(args):
loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], [] loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
for epoch in range(args.epoch): for epoch in range(args.epoch):
t1 = time() t1 = time()
loss, mf_loss, emb_loss = 0., 0., 0. loss, mf_loss, emb_loss = 0.0, 0.0, 0.0
for idx in range(n_batch): for idx in range(n_batch):
users, pos_items, neg_items = data_generator.sample() users, pos_items, neg_items = data_generator.sample()
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(g, 'user', 'item', users, u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(
pos_items, g, "user", "item", users, pos_items, neg_items
neg_items) )
batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(u_g_embeddings, batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(
pos_i_g_embeddings, u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
neg_i_g_embeddings) )
optimizer.zero_grad() optimizer.zero_grad()
batch_loss.backward() batch_loss.backward()
optimizer.step() optimizer.step()
...@@ -45,44 +49,71 @@ def main(args): ...@@ -45,44 +49,71 @@ def main(args):
mf_loss += batch_mf_loss mf_loss += batch_mf_loss
emb_loss += batch_emb_loss emb_loss += batch_emb_loss
if (epoch + 1) % 10 != 0: if (epoch + 1) % 10 != 0:
if args.verbose > 0 and epoch % args.verbose == 0: if args.verbose > 0 and epoch % args.verbose == 0:
perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % ( perf_str = "Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]" % (
epoch, time() - t1, loss, mf_loss, emb_loss) epoch,
time() - t1,
loss,
mf_loss,
emb_loss,
)
print(perf_str) print(perf_str)
continue #end the current epoch and move to the next epoch, let the following evaluation run every 10 epoches continue # end the current epoch and move to the next epoch, let the following evaluation run every 10 epoches
#evaluate the model every 10 epoches # evaluate the model every 10 epoches
t2 = time() t2 = time()
users_to_test = list(data_generator.test_set.keys()) users_to_test = list(data_generator.test_set.keys())
ret = test(model, g, users_to_test) ret = test(model, g, users_to_test)
t3 = time() t3 = time()
loss_loger.append(loss) loss_loger.append(loss)
rec_loger.append(ret['recall']) rec_loger.append(ret["recall"])
pre_loger.append(ret['precision']) pre_loger.append(ret["precision"])
ndcg_loger.append(ret['ndcg']) ndcg_loger.append(ret["ndcg"])
hit_loger.append(ret['hit_ratio']) hit_loger.append(ret["hit_ratio"])
if args.verbose > 0: if args.verbose > 0:
perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], ' \ perf_str = (
'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \ "Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], "
(epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, ret['recall'][0], ret['recall'][-1], "precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]"
ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1], % (
ret['ndcg'][0], ret['ndcg'][-1]) epoch,
t2 - t1,
t3 - t2,
loss,
mf_loss,
emb_loss,
ret["recall"][0],
ret["recall"][-1],
ret["precision"][0],
ret["precision"][-1],
ret["hit_ratio"][0],
ret["hit_ratio"][-1],
ret["ndcg"][0],
ret["ndcg"][-1],
)
)
print(perf_str) print(perf_str)
cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0, cur_best_pre_0, stopping_step, should_stop = early_stopping(
stopping_step, expected_order='acc', flag_step=5) ret["recall"][0],
cur_best_pre_0,
stopping_step,
expected_order="acc",
flag_step=5,
)
# early stop # early stop
if should_stop == True: if should_stop == True:
break break
if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1: if ret["recall"][0] == cur_best_pre_0 and args.save_flag == 1:
torch.save(model.state_dict(), args.weights_path + args.model_name) torch.save(model.state_dict(), args.weights_path + args.model_name)
print('save the weights in path: ', args.weights_path + args.model_name) print(
"save the weights in path: ",
args.weights_path + args.model_name,
)
recs = np.array(rec_loger) recs = np.array(rec_loger)
pres = np.array(pre_loger) pres = np.array(pre_loger)
...@@ -92,14 +123,21 @@ def main(args): ...@@ -92,14 +123,21 @@ def main(args):
best_rec_0 = max(recs[:, 0]) best_rec_0 = max(recs[:, 0])
idx = list(recs[:, 0]).index(best_rec_0) idx = list(recs[:, 0]).index(best_rec_0)
final_perf = "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]" % \ final_perf = (
(idx, time() - t0, '\t'.join(['%.5f' % r for r in recs[idx]]), "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]"
'\t'.join(['%.5f' % r for r in pres[idx]]), % (
'\t'.join(['%.5f' % r for r in hit[idx]]), idx,
'\t'.join(['%.5f' % r for r in ndcgs[idx]])) time() - t0,
"\t".join(["%.5f" % r for r in recs[idx]]),
"\t".join(["%.5f" % r for r in pres[idx]]),
"\t".join(["%.5f" % r for r in hit[idx]]),
"\t".join(["%.5f" % r for r in ndcgs[idx]]),
)
)
print(final_perf) print(final_perf)
if __name__ == '__main__':
if __name__ == "__main__":
if not os.path.exists(args.weights_path): if not os.path.exists(args.weights_path):
os.mkdir(args.weights_path) os.mkdir(args.weights_path)
args.mess_dropout = eval(args.mess_dropout) args.mess_dropout = eval(args.mess_dropout)
...@@ -107,4 +145,3 @@ if __name__ == '__main__': ...@@ -107,4 +145,3 @@ if __name__ == '__main__':
args.regs = eval(args.regs) args.regs = eval(args.regs)
print(args) print(args)
main(args) main(args)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
class NGCFLayer(nn.Module): class NGCFLayer(nn.Module):
def __init__(self, in_size, out_size, norm_dict, dropout): def __init__(self, in_size, out_size, norm_dict, dropout):
super(NGCFLayer, self).__init__() super(NGCFLayer, self).__init__()
self.in_size = in_size self.in_size = in_size
self.out_size = out_size self.out_size = out_size
#weights for different types of messages # weights for different types of messages
self.W1 = nn.Linear(in_size, out_size, bias = True) self.W1 = nn.Linear(in_size, out_size, bias=True)
self.W2 = nn.Linear(in_size, out_size, bias = True) self.W2 = nn.Linear(in_size, out_size, bias=True)
#leaky relu # leaky relu
self.leaky_relu = nn.LeakyReLU(0.2) self.leaky_relu = nn.LeakyReLU(0.2)
#dropout layer # dropout layer
self.dropout = nn.Dropout(dropout) self.dropout = nn.Dropout(dropout)
#initialization # initialization
torch.nn.init.xavier_uniform_(self.W1.weight) torch.nn.init.xavier_uniform_(self.W1.weight)
torch.nn.init.constant_(self.W1.bias, 0) torch.nn.init.constant_(self.W1.bias, 0)
torch.nn.init.xavier_uniform_(self.W2.weight) torch.nn.init.xavier_uniform_(self.W2.weight)
torch.nn.init.constant_(self.W2.bias, 0) torch.nn.init.constant_(self.W2.bias, 0)
#norm # norm
self.norm_dict = norm_dict self.norm_dict = norm_dict
def forward(self, g, feat_dict): def forward(self, g, feat_dict):
funcs = {} #message and reduce functions dict funcs = {} # message and reduce functions dict
#for each type of edges, compute messages and reduce them all # for each type of edges, compute messages and reduce them all
for srctype, etype, dsttype in g.canonical_etypes: for srctype, etype, dsttype in g.canonical_etypes:
if srctype == dsttype: #for self loops if srctype == dsttype: # for self loops
messages = self.W1(feat_dict[srctype]) messages = self.W1(feat_dict[srctype])
g.nodes[srctype].data[etype] = messages #store in ndata g.nodes[srctype].data[etype] = messages # store in ndata
funcs[(srctype, etype, dsttype)] = (fn.copy_u(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions funcs[(srctype, etype, dsttype)] = (
fn.copy_u(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions
else: else:
src, dst = g.edges(etype=(srctype, etype, dsttype)) src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)] norm = self.norm_dict[(srctype, etype, dsttype)]
messages = norm * (self.W1(feat_dict[srctype][src]) + self.W2(feat_dict[srctype][src]*feat_dict[dsttype][dst])) #compute messages messages = norm * (
g.edges[(srctype, etype, dsttype)].data[etype] = messages #store in edata self.W1(feat_dict[srctype][src])
funcs[(srctype, etype, dsttype)] = (fn.copy_e(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions + self.W2(feat_dict[srctype][src] * feat_dict[dsttype][dst])
) # compute messages
g.multi_update_all(funcs, 'sum') #update all, reduce by first type-wisely then across different types g.edges[(srctype, etype, dsttype)].data[
feature_dict={} etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions
g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes: for ntype in g.ntypes:
h = self.leaky_relu(g.nodes[ntype].data['h']) #leaky relu h = self.leaky_relu(g.nodes[ntype].data["h"]) # leaky relu
h = self.dropout(h) #dropout h = self.dropout(h) # dropout
h = F.normalize(h,dim=1,p=2) #l2 normalize h = F.normalize(h, dim=1, p=2) # l2 normalize
feature_dict[ntype] = h feature_dict[ntype] = h
return feature_dict return feature_dict
class NGCF(nn.Module): class NGCF(nn.Module):
def __init__(self, g, in_size, layer_size, dropout, lmbd=1e-5): def __init__(self, g, in_size, layer_size, dropout, lmbd=1e-5):
super(NGCF, self).__init__() super(NGCF, self).__init__()
...@@ -60,9 +76,15 @@ class NGCF(nn.Module): ...@@ -60,9 +76,15 @@ class NGCF(nn.Module):
self.norm_dict = dict() self.norm_dict = dict()
for srctype, etype, dsttype in g.canonical_etypes: for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype)) src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(dst, etype=(srctype, etype, dsttype)).float() #obtain degrees dst_degree = g.in_degrees(
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float() dst, etype=(srctype, etype, dsttype)
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) #compute norm ).float() # obtain degrees
src_degree = g.out_degrees(
src, etype=(srctype, etype, dsttype)
).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(
1
) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm self.norm_dict[(srctype, etype, dsttype)] = norm
self.layers = nn.ModuleList() self.layers = nn.ModuleList()
...@@ -70,16 +92,26 @@ class NGCF(nn.Module): ...@@ -70,16 +92,26 @@ class NGCF(nn.Module):
NGCFLayer(in_size, layer_size[0], self.norm_dict, dropout[0]) NGCFLayer(in_size, layer_size[0], self.norm_dict, dropout[0])
) )
self.num_layers = len(layer_size) self.num_layers = len(layer_size)
for i in range(self.num_layers-1): for i in range(self.num_layers - 1):
self.layers.append( self.layers.append(
NGCFLayer(layer_size[i], layer_size[i+1], self.norm_dict, dropout[i+1]) NGCFLayer(
layer_size[i],
layer_size[i + 1],
self.norm_dict,
dropout[i + 1],
)
) )
self.initializer = nn.init.xavier_uniform_ self.initializer = nn.init.xavier_uniform_
#embeddings for different types of nodes # embeddings for different types of nodes
self.feature_dict = nn.ParameterDict({ self.feature_dict = nn.ParameterDict(
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes {
}) ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
def create_bpr_loss(self, users, pos_items, neg_items): def create_bpr_loss(self, users, pos_items, neg_items):
pos_scores = (users * pos_items).sum(1) pos_scores = (users * pos_items).sum(1)
...@@ -88,7 +120,11 @@ class NGCF(nn.Module): ...@@ -88,7 +120,11 @@ class NGCF(nn.Module):
mf_loss = nn.LogSigmoid()(pos_scores - neg_scores).mean() mf_loss = nn.LogSigmoid()(pos_scores - neg_scores).mean()
mf_loss = -1 * mf_loss mf_loss = -1 * mf_loss
regularizer = (torch.norm(users) ** 2 + torch.norm(pos_items) ** 2 + torch.norm(neg_items) ** 2) / 2 regularizer = (
torch.norm(users) ** 2
+ torch.norm(pos_items) ** 2
+ torch.norm(neg_items) ** 2
) / 2
emb_loss = self.lmbd * regularizer / users.shape[0] emb_loss = self.lmbd * regularizer / users.shape[0]
return mf_loss + emb_loss, mf_loss, emb_loss return mf_loss + emb_loss, mf_loss, emb_loss
...@@ -96,9 +132,9 @@ class NGCF(nn.Module): ...@@ -96,9 +132,9 @@ class NGCF(nn.Module):
def rating(self, u_g_embeddings, pos_i_g_embeddings): def rating(self, u_g_embeddings, pos_i_g_embeddings):
return torch.matmul(u_g_embeddings, pos_i_g_embeddings.t()) return torch.matmul(u_g_embeddings, pos_i_g_embeddings.t())
def forward(self, g,user_key, item_key, users, pos_items, neg_items): def forward(self, g, user_key, item_key, users, pos_items, neg_items):
h_dict = {ntype : self.feature_dict[ntype] for ntype in g.ntypes} h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
#obtain features of each layer and concatenate them all # obtain features of each layer and concatenate them all
user_embeds = [] user_embeds = []
item_embeds = [] item_embeds = []
user_embeds.append(h_dict[user_key]) user_embeds.append(h_dict[user_key])
......
...@@ -2,22 +2,26 @@ ...@@ -2,22 +2,26 @@
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/batch_test.py>. # <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/batch_test.py>.
# It implements the batch test. # It implements the batch test.
import heapq
import multiprocessing
import utility.metrics as metrics import utility.metrics as metrics
from utility.parser import parse_args
from utility.load_data import * from utility.load_data import *
import multiprocessing from utility.parser import parse_args
import heapq
cores = multiprocessing.cpu_count() cores = multiprocessing.cpu_count()
args = parse_args() args = parse_args()
Ks = eval(args.Ks) Ks = eval(args.Ks)
data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size) data_generator = Data(
path=args.data_path + args.dataset, batch_size=args.batch_size
)
USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
BATCH_SIZE = args.batch_size BATCH_SIZE = args.batch_size
def ranklist_by_heapq(user_pos_test, test_items, rating, Ks): def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
item_score = {} item_score = {}
for i in test_items: for i in test_items:
...@@ -32,9 +36,10 @@ def ranklist_by_heapq(user_pos_test, test_items, rating, Ks): ...@@ -32,9 +36,10 @@ def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
r.append(1) r.append(1)
else: else:
r.append(0) r.append(0)
auc = 0. auc = 0.0
return r, auc return r, auc
def get_auc(item_score, user_pos_test): def get_auc(item_score, user_pos_test):
item_score = sorted(item_score.items(), key=lambda kv: kv[1]) item_score = sorted(item_score.items(), key=lambda kv: kv[1])
item_score.reverse() item_score.reverse()
...@@ -50,6 +55,7 @@ def get_auc(item_score, user_pos_test): ...@@ -50,6 +55,7 @@ def get_auc(item_score, user_pos_test):
auc = metrics.auc(ground_truth=r, prediction=posterior) auc = metrics.auc(ground_truth=r, prediction=posterior)
return auc return auc
def ranklist_by_sorted(user_pos_test, test_items, rating, Ks): def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
item_score = {} item_score = {}
for i in test_items: for i in test_items:
...@@ -67,6 +73,7 @@ def ranklist_by_sorted(user_pos_test, test_items, rating, Ks): ...@@ -67,6 +73,7 @@ def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
auc = get_auc(item_score, user_pos_test) auc = get_auc(item_score, user_pos_test)
return r, auc return r, auc
def get_performance(user_pos_test, r, auc, Ks): def get_performance(user_pos_test, r, auc, Ks):
precision, recall, ndcg, hit_ratio = [], [], [], [] precision, recall, ndcg, hit_ratio = [], [], [], []
...@@ -76,28 +83,33 @@ def get_performance(user_pos_test, r, auc, Ks): ...@@ -76,28 +83,33 @@ def get_performance(user_pos_test, r, auc, Ks):
ndcg.append(metrics.ndcg_at_k(r, K)) ndcg.append(metrics.ndcg_at_k(r, K))
hit_ratio.append(metrics.hit_at_k(r, K)) hit_ratio.append(metrics.hit_at_k(r, K))
return {'recall': np.array(recall), 'precision': np.array(precision), return {
'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc} "recall": np.array(recall),
"precision": np.array(precision),
"ndcg": np.array(ndcg),
"hit_ratio": np.array(hit_ratio),
"auc": auc,
}
def test_one_user(x): def test_one_user(x):
# user u's ratings for user u # user u's ratings for user u
rating = x[0] rating = x[0]
#uid # uid
u = x[1] u = x[1]
#user u's items in the training set # user u's items in the training set
try: try:
training_items = data_generator.train_items[u] training_items = data_generator.train_items[u]
except Exception: except Exception:
training_items = [] training_items = []
#user u's items in the test set # user u's items in the test set
user_pos_test = data_generator.test_set[u] user_pos_test = data_generator.test_set[u]
all_items = set(range(ITEM_NUM)) all_items = set(range(ITEM_NUM))
test_items = list(all_items - set(training_items)) test_items = list(all_items - set(training_items))
if args.test_flag == 'part': if args.test_flag == "part":
r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks) r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)
else: else:
r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks) r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)
...@@ -106,8 +118,13 @@ def test_one_user(x): ...@@ -106,8 +118,13 @@ def test_one_user(x):
def test(model, g, users_to_test, batch_test_flag=False): def test(model, g, users_to_test, batch_test_flag=False):
result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)), result = {
'hit_ratio': np.zeros(len(Ks)), 'auc': 0.} "precision": np.zeros(len(Ks)),
"recall": np.zeros(len(Ks)),
"ndcg": np.zeros(len(Ks)),
"hit_ratio": np.zeros(len(Ks)),
"auc": 0.0,
}
pool = multiprocessing.Pool(cores) pool = multiprocessing.Pool(cores)
...@@ -124,7 +141,7 @@ def test(model, g, users_to_test, batch_test_flag=False): ...@@ -124,7 +141,7 @@ def test(model, g, users_to_test, batch_test_flag=False):
start = u_batch_id * u_batch_size start = u_batch_id * u_batch_size
end = (u_batch_id + 1) * u_batch_size end = (u_batch_id + 1) * u_batch_size
user_batch = test_users[start: end] user_batch = test_users[start:end]
if batch_test_flag: if batch_test_flag:
# batch-item test # batch-item test
...@@ -138,10 +155,16 @@ def test(model, g, users_to_test, batch_test_flag=False): ...@@ -138,10 +155,16 @@ def test(model, g, users_to_test, batch_test_flag=False):
item_batch = range(i_start, i_end) item_batch = range(i_start, i_end)
u_g_embeddings, pos_i_g_embeddings, _ = model(g, 'user', 'item',user_batch, item_batch, []) u_g_embeddings, pos_i_g_embeddings, _ = model(
i_rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu() g, "user", "item", user_batch, item_batch, []
)
i_rate_batch = (
model.rating(u_g_embeddings, pos_i_g_embeddings)
.detach()
.cpu()
)
rate_batch[:, i_start: i_end] = i_rate_batch rate_batch[:, i_start:i_end] = i_rate_batch
i_count += i_rate_batch.shape[1] i_count += i_rate_batch.shape[1]
assert i_count == ITEM_NUM assert i_count == ITEM_NUM
...@@ -149,20 +172,23 @@ def test(model, g, users_to_test, batch_test_flag=False): ...@@ -149,20 +172,23 @@ def test(model, g, users_to_test, batch_test_flag=False):
else: else:
# all-item test # all-item test
item_batch = range(ITEM_NUM) item_batch = range(ITEM_NUM)
u_g_embeddings, pos_i_g_embeddings, _ = model(g, 'user', 'item',user_batch, item_batch, []) u_g_embeddings, pos_i_g_embeddings, _ = model(
rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu() g, "user", "item", user_batch, item_batch, []
)
rate_batch = (
model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()
)
user_batch_rating_uid = zip(rate_batch.numpy(), user_batch) user_batch_rating_uid = zip(rate_batch.numpy(), user_batch)
batch_result = pool.map(test_one_user, user_batch_rating_uid) batch_result = pool.map(test_one_user, user_batch_rating_uid)
count += len(batch_result) count += len(batch_result)
for re in batch_result: for re in batch_result:
result['precision'] += re['precision']/n_test_users result["precision"] += re["precision"] / n_test_users
result['recall'] += re['recall']/n_test_users result["recall"] += re["recall"] / n_test_users
result['ndcg'] += re['ndcg']/n_test_users result["ndcg"] += re["ndcg"] / n_test_users
result['hit_ratio'] += re['hit_ratio']/n_test_users result["hit_ratio"] += re["hit_ratio"] / n_test_users
result['auc'] += re['auc']/n_test_users result["auc"] += re["auc"] / n_test_users
assert count == n_test_users assert count == n_test_users
pool.close() pool.close()
......
# This file is copied from the NGCF author's implementation # This file is copied from the NGCF author's implementation
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/helper.py>. # <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/helper.py>.
# It implements the helper functions. # It implements the helper functions.
''' """
Created on Aug 19, 2016 Created on Aug 19, 2016
@author: Xiang Wang (xiangwang@u.nus.edu) @author: Xiang Wang (xiangwang@u.nus.edu)
''' """
__author__ = "xiangwang" __author__ = "xiangwang"
import os import os
import re import re
def txt2list(file_src): def txt2list(file_src):
orig_file = open(file_src, "r") orig_file = open(file_src, "r")
lines = orig_file.readlines() lines = orig_file.readlines()
return lines return lines
def ensureDir(dir_path): def ensureDir(dir_path):
d = os.path.dirname(dir_path) d = os.path.dirname(dir_path)
if not os.path.exists(d): if not os.path.exists(d):
os.makedirs(d) os.makedirs(d)
def uni2str(unicode_str): def uni2str(unicode_str):
return str(unicode_str.encode('ascii', 'ignore')).replace('\n', '').strip() return str(unicode_str.encode("ascii", "ignore")).replace("\n", "").strip()
def hasNumbers(inputString): def hasNumbers(inputString):
return bool(re.search(r'\d', inputString)) return bool(re.search(r"\d", inputString))
def delMultiChar(inputString, chars): def delMultiChar(inputString, chars):
for ch in chars: for ch in chars:
inputString = inputString.replace(ch, '') inputString = inputString.replace(ch, "")
return inputString return inputString
def merge_two_dicts(x, y): def merge_two_dicts(x, y):
z = x.copy() # start with x's keys and values z = x.copy() # start with x's keys and values
z.update(y) # modifies z with y's keys and values & returns None z.update(y) # modifies z with y's keys and values & returns None
return z return z
def early_stopping(log_value, best_value, stopping_step, expected_order='acc', flag_step=100):
def early_stopping(
log_value, best_value, stopping_step, expected_order="acc", flag_step=100
):
# early stopping strategy: # early stopping strategy:
assert expected_order in ['acc', 'dec'] assert expected_order in ["acc", "dec"]
if (expected_order == 'acc' and log_value >= best_value) or (expected_order == 'dec' and log_value <= best_value): if (expected_order == "acc" and log_value >= best_value) or (
expected_order == "dec" and log_value <= best_value
):
stopping_step = 0 stopping_step = 0
best_value = log_value best_value = log_value
else: else:
stopping_step += 1 stopping_step += 1
if stopping_step >= flag_step: if stopping_step >= flag_step:
print("Early stopping is trigger at step: {} log:{}".format(flag_step, log_value)) print(
"Early stopping is trigger at step: {} log:{}".format(
flag_step, log_value
)
)
should_stop = True should_stop = True
else: else:
should_stop = False should_stop = False
......
# This file is based on the NGCF author's implementation # This file is based on the NGCF author's implementation
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/load_data.py>. # <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/load_data.py>.
# It implements the data processing and graph construction. # It implements the data processing and graph construction.
import numpy as np
import random as rd import random as rd
import numpy as np
import dgl import dgl
class Data(object): class Data(object):
def __init__(self, path, batch_size): def __init__(self, path, batch_size):
self.path = path self.path = path
self.batch_size = batch_size self.batch_size = batch_size
train_file = path + '/train.txt' train_file = path + "/train.txt"
test_file = path + '/test.txt' test_file = path + "/test.txt"
#get number of users and items # get number of users and items
self.n_users, self.n_items = 0, 0 self.n_users, self.n_items = 0, 0
self.n_train, self.n_test = 0, 0 self.n_train, self.n_test = 0, 0
self.exist_users = [] self.exist_users = []
...@@ -24,7 +27,7 @@ class Data(object): ...@@ -24,7 +27,7 @@ class Data(object):
with open(train_file) as f: with open(train_file) as f:
for l in f.readlines(): for l in f.readlines():
if len(l) > 0: if len(l) > 0:
l = l.strip('\n').split(' ') l = l.strip("\n").split(" ")
items = [int(i) for i in l[1:]] items = [int(i) for i in l[1:]]
uid = int(l[0]) uid = int(l[0])
self.exist_users.append(uid) self.exist_users.append(uid)
...@@ -38,9 +41,9 @@ class Data(object): ...@@ -38,9 +41,9 @@ class Data(object):
with open(test_file) as f: with open(test_file) as f:
for l in f.readlines(): for l in f.readlines():
if len(l) > 0: if len(l) > 0:
l = l.strip('\n') l = l.strip("\n")
try: try:
items = [int(i) for i in l.split(' ')[1:]] items = [int(i) for i in l.split(" ")[1:]]
except Exception: except Exception:
continue continue
self.n_items = max(self.n_items, max(items)) self.n_items = max(self.n_items, max(items))
...@@ -50,51 +53,51 @@ class Data(object): ...@@ -50,51 +53,51 @@ class Data(object):
self.print_statistics() self.print_statistics()
#training positive items corresponding to each user; testing positive items corresponding to each user # training positive items corresponding to each user; testing positive items corresponding to each user
self.train_items, self.test_set = {}, {} self.train_items, self.test_set = {}, {}
with open(train_file) as f_train: with open(train_file) as f_train:
with open(test_file) as f_test: with open(test_file) as f_test:
for l in f_train.readlines(): for l in f_train.readlines():
if len(l) == 0: if len(l) == 0:
break break
l = l.strip('\n') l = l.strip("\n")
items = [int(i) for i in l.split(' ')] items = [int(i) for i in l.split(" ")]
uid, train_items = items[0], items[1:] uid, train_items = items[0], items[1:]
self.train_items[uid] = train_items self.train_items[uid] = train_items
for l in f_test.readlines(): for l in f_test.readlines():
if len(l) == 0: break if len(l) == 0:
l = l.strip('\n') break
l = l.strip("\n")
try: try:
items = [int(i) for i in l.split(' ')] items = [int(i) for i in l.split(" ")]
except Exception: except Exception:
continue continue
uid, test_items = items[0], items[1:] uid, test_items = items[0], items[1:]
self.test_set[uid] = test_items self.test_set[uid] = test_items
#construct graph from the train data and add self-loops # construct graph from the train data and add self-loops
user_selfs = [ i for i in range(self.n_users)] user_selfs = [i for i in range(self.n_users)]
item_selfs = [ i for i in range(self.n_items)] item_selfs = [i for i in range(self.n_items)]
data_dict = { data_dict = {
('user', 'user_self', 'user') : (user_selfs, user_selfs), ("user", "user_self", "user"): (user_selfs, user_selfs),
('item', 'item_self', 'item') : (item_selfs, item_selfs), ("item", "item_self", "item"): (item_selfs, item_selfs),
('user', 'ui', 'item') : (user_item_src, user_item_dst), ("user", "ui", "item"): (user_item_src, user_item_dst),
('item', 'iu', 'user') : (user_item_dst, user_item_src) ("item", "iu", "user"): (user_item_dst, user_item_src),
}
num_dict = {
'user': self.n_users, 'item': self.n_items
} }
num_dict = {"user": self.n_users, "item": self.n_items}
self.g = dgl.heterograph(data_dict, num_nodes_dict=num_dict) self.g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)
def sample(self): def sample(self):
if self.batch_size <= self.n_users: if self.batch_size <= self.n_users:
users = rd.sample(self.exist_users, self.batch_size) users = rd.sample(self.exist_users, self.batch_size)
else: else:
users = [rd.choice(self.exist_users) for _ in range(self.batch_size)] users = [
rd.choice(self.exist_users) for _ in range(self.batch_size)
]
def sample_pos_items_for_u(u, num): def sample_pos_items_for_u(u, num):
# sample num pos items for u-th user # sample num pos items for u-th user
...@@ -117,12 +120,14 @@ class Data(object): ...@@ -117,12 +120,14 @@ class Data(object):
while True: while True:
if len(neg_items) == num: if len(neg_items) == num:
break break
neg_id = np.random.randint(low=0, high=self.n_items,size=1)[0] neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
if neg_id not in self.train_items[u] and neg_id not in neg_items: if (
neg_id not in self.train_items[u]
and neg_id not in neg_items
):
neg_items.append(neg_id) neg_items.append(neg_id)
return neg_items return neg_items
pos_items, neg_items = [], [] pos_items, neg_items = [], []
for u in users: for u in users:
pos_items += sample_pos_items_for_u(u, 1) pos_items += sample_pos_items_for_u(u, 1)
...@@ -134,10 +139,13 @@ class Data(object): ...@@ -134,10 +139,13 @@ class Data(object):
return self.n_users, self.n_items return self.n_users, self.n_items
def print_statistics(self): def print_statistics(self):
print('n_users=%d, n_items=%d' % (self.n_users, self.n_items)) print("n_users=%d, n_items=%d" % (self.n_users, self.n_items))
print('n_interactions=%d' % (self.n_train + self.n_test)) print("n_interactions=%d" % (self.n_train + self.n_test))
print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items))) print(
"n_train=%d, n_test=%d, sparsity=%.5f"
% (
self.n_train,
self.n_test,
(self.n_train + self.n_test) / (self.n_users * self.n_items),
)
)
# This file is copied from the NGCF author's implementation # This file is copied from the NGCF author's implementation
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/metrics.py>. # <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/metrics.py>.
# It implements the metrics. # It implements the metrics.
''' """
Created on Oct 10, 2018 Created on Oct 10, 2018
Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in: Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in:
Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019. Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019.
@author: Xiang Wang (xiangwang@u.nus.edu) @author: Xiang Wang (xiangwang@u.nus.edu)
''' """
import numpy as np import numpy as np
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
def recall(rank, ground_truth, N): def recall(rank, ground_truth, N):
return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth))) return len(set(rank[:N]) & set(ground_truth)) / float(
len(set(ground_truth))
)
def precision_at_k(r, k): def precision_at_k(r, k):
...@@ -28,7 +31,7 @@ def precision_at_k(r, k): ...@@ -28,7 +31,7 @@ def precision_at_k(r, k):
return np.mean(r) return np.mean(r)
def average_precision(r,cut): def average_precision(r, cut):
"""Score is average precision (area under PR curve) """Score is average precision (area under PR curve)
Relevance is binary (nonzero is relevant). Relevance is binary (nonzero is relevant).
Returns: Returns:
...@@ -37,8 +40,8 @@ def average_precision(r,cut): ...@@ -37,8 +40,8 @@ def average_precision(r,cut):
r = np.asarray(r) r = np.asarray(r)
out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]] out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
if not out: if not out:
return 0. return 0.0
return np.sum(out)/float(min(cut, np.sum(r))) return np.sum(out) / float(min(cut, np.sum(r)))
def mean_average_precision(rs): def mean_average_precision(rs):
...@@ -64,8 +67,8 @@ def dcg_at_k(r, k, method=1): ...@@ -64,8 +67,8 @@ def dcg_at_k(r, k, method=1):
elif method == 1: elif method == 1:
return np.sum(r / np.log2(np.arange(2, r.size + 2))) return np.sum(r / np.log2(np.arange(2, r.size + 2)))
else: else:
raise ValueError('method must be 0 or 1.') raise ValueError("method must be 0 or 1.")
return 0. return 0.0
def ndcg_at_k(r, k, method=1): def ndcg_at_k(r, k, method=1):
...@@ -77,7 +80,7 @@ def ndcg_at_k(r, k, method=1): ...@@ -77,7 +80,7 @@ def ndcg_at_k(r, k, method=1):
""" """
dcg_max = dcg_at_k(sorted(r, reverse=True), k, method) dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
if not dcg_max: if not dcg_max:
return 0. return 0.0
return dcg_at_k(r, k, method) / dcg_max return dcg_at_k(r, k, method) / dcg_max
...@@ -89,19 +92,21 @@ def recall_at_k(r, k, all_pos_num): ...@@ -89,19 +92,21 @@ def recall_at_k(r, k, all_pos_num):
def hit_at_k(r, k): def hit_at_k(r, k):
r = np.array(r)[:k] r = np.array(r)[:k]
if np.sum(r) > 0: if np.sum(r) > 0:
return 1. return 1.0
else: else:
return 0. return 0.0
def F1(pre, rec): def F1(pre, rec):
if pre + rec > 0: if pre + rec > 0:
return (2.0 * pre * rec) / (pre + rec) return (2.0 * pre * rec) / (pre + rec)
else: else:
return 0. return 0.0
def auc(ground_truth, prediction): def auc(ground_truth, prediction):
try: try:
res = roc_auc_score(y_true=ground_truth, y_score=prediction) res = roc_auc_score(y_true=ground_truth, y_score=prediction)
except Exception: except Exception:
res = 0. res = 0.0
return res return res
...@@ -3,51 +3,88 @@ ...@@ -3,51 +3,88 @@
import argparse import argparse
def parse_args(): def parse_args():
parser = argparse.ArgumentParser(description="Run NGCF.") parser = argparse.ArgumentParser(description="Run NGCF.")
parser.add_argument('--weights_path', nargs='?', default='model/', parser.add_argument(
help='Store model path.') "--weights_path", nargs="?", default="model/", help="Store model path."
parser.add_argument('--data_path', nargs='?', default='../Data/', )
help='Input data path.') parser.add_argument(
parser.add_argument('--model_name', type=str, default='NGCF.pkl', "--data_path", nargs="?", default="../Data/", help="Input data path."
help='Saved model name.') )
parser.add_argument(
"--model_name", type=str, default="NGCF.pkl", help="Saved model name."
parser.add_argument('--dataset', nargs='?', default='gowalla', )
help='Choose a dataset from {gowalla, yelp2018, amazon-book}')
parser.add_argument('--verbose', type=int, default=1,
help='Interval of evaluation.')
parser.add_argument('--epoch', type=int, default=400,
help='Number of epoch.')
parser.add_argument('--embed_size', type=int, default=64, parser.add_argument(
help='Embedding size.') "--dataset",
parser.add_argument('--layer_size', nargs='?', default='[64,64,64]', nargs="?",
help='Output sizes of every layer') default="gowalla",
parser.add_argument('--batch_size', type=int, default=1024, help="Choose a dataset from {gowalla, yelp2018, amazon-book}",
help='Batch size.') )
parser.add_argument(
"--verbose", type=int, default=1, help="Interval of evaluation."
)
parser.add_argument(
"--epoch", type=int, default=400, help="Number of epoch."
)
parser.add_argument('--regs', nargs='?', default='[1e-5]', parser.add_argument(
help='Regularizations.') "--embed_size", type=int, default=64, help="Embedding size."
parser.add_argument('--lr', type=float, default=0.0001, )
help='Learning rate.') parser.add_argument(
"--layer_size",
nargs="?",
default="[64,64,64]",
help="Output sizes of every layer",
)
parser.add_argument(
"--batch_size", type=int, default=1024, help="Batch size."
)
parser.add_argument(
"--regs", nargs="?", default="[1e-5]", help="Regularizations."
)
parser.add_argument(
"--lr", type=float, default=0.0001, help="Learning rate."
)
parser.add_argument('--gpu', type=int, default=0, parser.add_argument(
help='0 for NAIS_prod, 1 for NAIS_concat') "--gpu", type=int, default=0, help="0 for NAIS_prod, 1 for NAIS_concat"
)
parser.add_argument('--mess_dropout', nargs='?', default='[0.1,0.1,0.1]', parser.add_argument(
help='Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.') "--mess_dropout",
nargs="?",
default="[0.1,0.1,0.1]",
help="Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.",
)
parser.add_argument('--Ks', nargs='?', default='[20, 40]', parser.add_argument(
help='Output sizes of every layer') "--Ks",
nargs="?",
default="[20, 40]",
help="Output sizes of every layer",
)
parser.add_argument('--save_flag', type=int, default=1, parser.add_argument(
help='0: Disable model saver, 1: Activate model saver') "--save_flag",
type=int,
default=1,
help="0: Disable model saver, 1: Activate model saver",
)
parser.add_argument('--test_flag', nargs='?', default='part', parser.add_argument(
help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch') "--test_flag",
nargs="?",
default="part",
help="Specify the test type from {part, full}, indicating whether the reference is done in mini-batch",
)
parser.add_argument('--report', type=int, default=0, parser.add_argument(
help='0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels') "--report",
type=int,
default=0,
help="0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels",
)
return parser.parse_args() return parser.parse_args()
import os import os
import dgl import warnings
import torch
import numpy as np import numpy as np
import torch
import torch.nn as nn import torch.nn as nn
from model import PGNN from model import PGNN
from sklearn.metrics import roc_auc_score from sklearn.metrics import roc_auc_score
from utils import get_dataset, preselect_anchor from utils import get_dataset, preselect_anchor
import warnings import dgl
warnings.filterwarnings('ignore')
warnings.filterwarnings("ignore")
def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate((data['positive_edges_{}'.format(p)], data['negative_edges_{}'.format(p)]), axis=-1)
nodes_first = torch.index_select(out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)) def get_loss(p, data, out, loss_func, device, get_auc=True):
nodes_second = torch.index_select(out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)) edge_mask = np.concatenate(
(
data["positive_edges_{}".format(p)],
data["negative_edges_{}".format(p)],
),
axis=-1,
)
nodes_first = torch.index_select(
out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)
)
nodes_second = torch.index_select(
out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)
)
pred = torch.sum(nodes_first * nodes_second, dim=-1) pred = torch.sum(nodes_first * nodes_second, dim=-1)
label_positive = torch.ones([data['positive_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype) label_positive = torch.ones(
label_negative = torch.zeros([data['negative_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype) [
data["positive_edges_{}".format(p)].shape[1],
],
dtype=pred.dtype,
)
label_negative = torch.zeros(
[
data["negative_edges_{}".format(p)].shape[1],
],
dtype=pred.dtype,
)
label = torch.cat((label_positive, label_negative)).to(device) label = torch.cat((label_positive, label_negative)).to(device)
loss = loss_func(pred, label) loss = loss_func(pred, label)
if get_auc: if get_auc:
auc = roc_auc_score(label.flatten().cpu().numpy(), torch.sigmoid(pred).flatten().data.cpu().numpy()) auc = roc_auc_score(
label.flatten().cpu().numpy(),
torch.sigmoid(pred).flatten().data.cpu().numpy(),
)
return loss, auc return loss, auc
else: else:
return loss return loss
def train_model(data, model, loss_func, optimizer, device, g_data): def train_model(data, model, loss_func, optimizer, device, g_data):
model.train() model.train()
out = model(g_data) out = model(g_data)
loss = get_loss('train', data, out, loss_func, device, get_auc=False) loss = get_loss("train", data, out, loss_func, device, get_auc=False)
optimizer.zero_grad() optimizer.zero_grad()
loss.backward() loss.backward()
...@@ -42,35 +69,41 @@ def train_model(data, model, loss_func, optimizer, device, g_data): ...@@ -42,35 +69,41 @@ def train_model(data, model, loss_func, optimizer, device, g_data):
return g_data return g_data
def eval_model(data, g_data, model, loss_func, device): def eval_model(data, g_data, model, loss_func, device):
model.eval() model.eval()
out = model(g_data) out = model(g_data)
# train loss and auc # train loss and auc
tmp_loss, auc_train = get_loss('train', data, out, loss_func, device) tmp_loss, auc_train = get_loss("train", data, out, loss_func, device)
loss_train = tmp_loss.cpu().data.numpy() loss_train = tmp_loss.cpu().data.numpy()
# val loss and auc # val loss and auc
_, auc_val = get_loss('val', data, out, loss_func, device) _, auc_val = get_loss("val", data, out, loss_func, device)
# test loss and auc # test loss and auc
_, auc_test = get_loss('test', data, out, loss_func, device) _, auc_test = get_loss("test", data, out, loss_func, device)
return loss_train, auc_train, auc_val, auc_test return loss_train, auc_train, auc_val, auc_test
def main(args): def main(args):
# The mean and standard deviation of the experiment results # The mean and standard deviation of the experiment results
# are stored in the 'results' folder # are stored in the 'results' folder
if not os.path.isdir('results'): if not os.path.isdir("results"):
os.mkdir('results') os.mkdir("results")
if torch.cuda.is_available(): if torch.cuda.is_available():
device = 'cuda:0' device = "cuda:0"
else: else:
device = 'cpu' device = "cpu"
print('Learning Type: {}'.format(['Transductive', 'Inductive'][args.inductive]), print(
'Task: {}'.format(args.task)) "Learning Type: {}".format(
["Transductive", "Inductive"][args.inductive]
),
"Task: {}".format(args.task),
)
results = [] results = []
...@@ -78,13 +111,20 @@ def main(args): ...@@ -78,13 +111,20 @@ def main(args):
data = get_dataset(args) data = get_dataset(args)
# pre-sample anchor nodes and compute shortest distance values for all epochs # pre-sample anchor nodes and compute shortest distance values for all epochs
g_list, anchor_eid_list, dist_max_list, edge_weight_list = preselect_anchor(data, args) (
g_list,
anchor_eid_list,
dist_max_list,
edge_weight_list,
) = preselect_anchor(data, args)
# model # model
model = PGNN(input_dim=data['feature'].shape[1]).to(device) model = PGNN(input_dim=data["feature"].shape[1]).to(device)
# loss # loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4) optimizer = torch.optim.Adam(
model.parameters(), lr=1e-2, weight_decay=5e-4
)
loss_func = nn.BCEWithLogitsLoss() loss_func = nn.BCEWithLogitsLoss()
best_auc_val = -1 best_auc_val = -1
...@@ -93,55 +133,79 @@ def main(args): ...@@ -93,55 +133,79 @@ def main(args):
for epoch in range(args.epoch_num): for epoch in range(args.epoch_num):
if epoch == 200: if epoch == 200:
for param_group in optimizer.param_groups: for param_group in optimizer.param_groups:
param_group['lr'] /= 10 param_group["lr"] /= 10
g = dgl.graph(g_list[epoch]) g = dgl.graph(g_list[epoch])
g.ndata['feat'] = torch.FloatTensor(data['feature']) g.ndata["feat"] = torch.FloatTensor(data["feature"])
g.edata['sp_dist'] = torch.FloatTensor(edge_weight_list[epoch]) g.edata["sp_dist"] = torch.FloatTensor(edge_weight_list[epoch])
g_data = { g_data = {
'graph': g.to(device), "graph": g.to(device),
'anchor_eid': anchor_eid_list[epoch], "anchor_eid": anchor_eid_list[epoch],
'dists_max': dist_max_list[epoch] "dists_max": dist_max_list[epoch],
} }
train_model(data, model, loss_func, optimizer, device, g_data) train_model(data, model, loss_func, optimizer, device, g_data)
loss_train, auc_train, auc_val, auc_test = eval_model( loss_train, auc_train, auc_val, auc_test = eval_model(
data, g_data, model, loss_func, device) data, g_data, model, loss_func, device
)
if auc_val > best_auc_val: if auc_val > best_auc_val:
best_auc_val = auc_val best_auc_val = auc_val
best_auc_test = auc_test best_auc_test = auc_test
if epoch % args.epoch_log == 0: if epoch % args.epoch_log == 0:
print(repeat, epoch, 'Loss {:.4f}'.format(loss_train), 'Train AUC: {:.4f}'.format(auc_train), print(
'Val AUC: {:.4f}'.format(auc_val), 'Test AUC: {:.4f}'.format(auc_test), repeat,
'Best Val AUC: {:.4f}'.format(best_auc_val), 'Best Test AUC: {:.4f}'.format(best_auc_test)) epoch,
"Loss {:.4f}".format(loss_train),
"Train AUC: {:.4f}".format(auc_train),
"Val AUC: {:.4f}".format(auc_val),
"Test AUC: {:.4f}".format(auc_test),
"Best Val AUC: {:.4f}".format(best_auc_val),
"Best Test AUC: {:.4f}".format(best_auc_test),
)
results.append(best_auc_test) results.append(best_auc_test)
results = np.array(results) results = np.array(results)
results_mean = np.mean(results).round(6) results_mean = np.mean(results).round(6)
results_std = np.std(results).round(6) results_std = np.std(results).round(6)
print('-----------------Final-------------------') print("-----------------Final-------------------")
print(results_mean, results_std) print(results_mean, results_std)
with open('results/{}_{}_{}.txt'.format(['Transductive', 'Inductive'][args.inductive], args.task, with open(
args.k_hop_dist), 'w') as f: "results/{}_{}_{}.txt".format(
f.write('{}, {}\n'.format(results_mean, results_std)) ["Transductive", "Inductive"][args.inductive],
args.task,
args.k_hop_dist,
),
"w",
) as f:
f.write("{}, {}\n".format(results_mean, results_std))
if __name__ == '__main__': if __name__ == "__main__":
from argparse import ArgumentParser from argparse import ArgumentParser
parser = ArgumentParser() parser = ArgumentParser()
parser.add_argument('--task', type=str, default='link', choices=['link', 'link_pair']) parser.add_argument(
parser.add_argument('--inductive', action='store_true', "--task", type=str, default="link", choices=["link", "link_pair"]
help='Inductive learning or transductive learning') )
parser.add_argument('--k_hop_dist', default=-1, type=int, parser.add_argument(
help='K-hop shortest path distance, -1 means exact shortest path.') "--inductive",
action="store_true",
parser.add_argument('--epoch_num', type=int, default=2000) help="Inductive learning or transductive learning",
parser.add_argument('--repeat_num', type=int, default=10) )
parser.add_argument('--epoch_log', type=int, default=100) parser.add_argument(
"--k_hop_dist",
default=-1,
type=int,
help="K-hop shortest path distance, -1 means exact shortest path.",
)
parser.add_argument("--epoch_num", type=int, default=2000)
parser.add_argument("--repeat_num", type=int, default=10)
parser.add_argument("--epoch_log", type=int, default=100)
args = parser.parse_args() args = parser.parse_args()
main(args) main(args)
import torch import torch
import torch.nn as nn import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn
class PGNN_layer(nn.Module): class PGNN_layer(nn.Module):
def __init__(self, input_dim, output_dim): def __init__(self, input_dim, output_dim):
super(PGNN_layer, self).__init__() super(PGNN_layer, self).__init__()
...@@ -17,23 +19,31 @@ class PGNN_layer(nn.Module): ...@@ -17,23 +19,31 @@ class PGNN_layer(nn.Module):
with graph.local_scope(): with graph.local_scope():
u_feat = self.linear_hidden_u(feature) u_feat = self.linear_hidden_u(feature)
v_feat = self.linear_hidden_v(feature) v_feat = self.linear_hidden_v(feature)
graph.srcdata.update({'u_feat': u_feat}) graph.srcdata.update({"u_feat": u_feat})
graph.dstdata.update({'v_feat': v_feat}) graph.dstdata.update({"v_feat": v_feat})
graph.apply_edges(fn.u_mul_e('u_feat', 'sp_dist', 'u_message')) graph.apply_edges(fn.u_mul_e("u_feat", "sp_dist", "u_message"))
graph.apply_edges(fn.v_add_e('v_feat', 'u_message', 'message')) graph.apply_edges(fn.v_add_e("v_feat", "u_message", "message"))
messages = torch.index_select(graph.edata['message'], 0, messages = torch.index_select(
torch.LongTensor(anchor_eid).to(feature.device)) graph.edata["message"],
messages = messages.reshape(dists_max.shape[0], dists_max.shape[1], messages.shape[-1]) 0,
torch.LongTensor(anchor_eid).to(feature.device),
)
messages = messages.reshape(
dists_max.shape[0], dists_max.shape[1], messages.shape[-1]
)
messages = self.act(messages) # n*m*d messages = self.act(messages) # n*m*d
out_position = self.linear_out_position(messages).squeeze(-1) # n*m_out out_position = self.linear_out_position(messages).squeeze(
-1
) # n*m_out
out_structure = torch.mean(messages, dim=1) # n*d out_structure = torch.mean(messages, dim=1) # n*d
return out_position, out_structure return out_position, out_structure
class PGNN(nn.Module): class PGNN(nn.Module):
def __init__(self, input_dim, feature_dim=32, dropout=0.5): def __init__(self, input_dim, feature_dim=32, dropout=0.5):
super(PGNN, self).__init__() super(PGNN, self).__init__()
...@@ -44,12 +54,16 @@ class PGNN(nn.Module): ...@@ -44,12 +54,16 @@ class PGNN(nn.Module):
self.conv_out = PGNN_layer(feature_dim, feature_dim) self.conv_out = PGNN_layer(feature_dim, feature_dim)
def forward(self, data): def forward(self, data):
x = data['graph'].ndata['feat'] x = data["graph"].ndata["feat"]
graph = data['graph'] graph = data["graph"]
x = self.linear_pre(x) x = self.linear_pre(x)
x_position, x = self.conv_first(graph, x, data['anchor_eid'], data['dists_max']) x_position, x = self.conv_first(
graph, x, data["anchor_eid"], data["dists_max"]
)
x = self.dropout(x) x = self.dropout(x)
x_position, x = self.conv_out(graph, x, data['anchor_eid'], data['dists_max']) x_position, x = self.conv_out(
graph, x, data["anchor_eid"], data["dists_max"]
)
x_position = F.normalize(x_position, p=2, dim=-1) x_position = F.normalize(x_position, p=2, dim=-1)
return x_position return x_position
import torch import multiprocessing as mp
import random import random
import numpy as np from multiprocessing import get_context
import networkx as nx import networkx as nx
import numpy as np
import torch
from tqdm.auto import tqdm from tqdm.auto import tqdm
import multiprocessing as mp
from multiprocessing import get_context
def get_communities(remove_feature): def get_communities(remove_feature):
community_size = 20 community_size = 20
...@@ -45,14 +47,15 @@ def get_communities(remove_feature): ...@@ -45,14 +47,15 @@ def get_communities(remove_feature):
feature = np.identity(n)[:, rand_order] feature = np.identity(n)[:, rand_order]
data = { data = {
'edge_index': edge_index, "edge_index": edge_index,
'feature': feature, "feature": feature,
'positive_edges': np.stack(np.nonzero(label)), "positive_edges": np.stack(np.nonzero(label)),
'num_nodes': feature.shape[0] "num_nodes": feature.shape[0],
} }
return data return data
def to_single_directed(edges): def to_single_directed(edges):
edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int) edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int)
j = 0 j = 0
...@@ -63,6 +66,7 @@ def to_single_directed(edges): ...@@ -63,6 +66,7 @@ def to_single_directed(edges):
return edges_new return edges_new
# each node at least remain in the new graph # each node at least remain in the new graph
def split_edges(p, edges, data, non_train_ratio=0.2): def split_edges(p, edges, data, non_train_ratio=0.2):
e = edges.shape[1] e = edges.shape[1]
...@@ -70,15 +74,19 @@ def split_edges(p, edges, data, non_train_ratio=0.2): ...@@ -70,15 +74,19 @@ def split_edges(p, edges, data, non_train_ratio=0.2):
split1 = int((1 - non_train_ratio) * e) split1 = int((1 - non_train_ratio) * e)
split2 = int((1 - non_train_ratio / 2) * e) split2 = int((1 - non_train_ratio / 2) * e)
data.update({ data.update(
'{}_edges_train'.format(p): edges[:, :split1], # 80% {
'{}_edges_val'.format(p): edges[:, split1:split2], # 10% "{}_edges_train".format(p): edges[:, :split1], # 80%
'{}_edges_test'.format(p): edges[:, split2:] # 10% "{}_edges_val".format(p): edges[:, split1:split2], # 10%
}) "{}_edges_test".format(p): edges[:, split2:], # 10%
}
)
def to_bidirected(edges): def to_bidirected(edges):
return np.concatenate((edges, edges[::-1, :]), axis=-1) return np.concatenate((edges, edges[::-1, :]), axis=-1)
def get_negative_edges(positive_edges, num_nodes, num_negative_edges): def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
positive_edge_set = [] positive_edge_set = []
positive_edges = to_bidirected(positive_edges) positive_edges = to_bidirected(positive_edges)
...@@ -86,54 +94,76 @@ def get_negative_edges(positive_edges, num_nodes, num_negative_edges): ...@@ -86,54 +94,76 @@ def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
positive_edge_set.append(tuple(positive_edges[:, i])) positive_edge_set.append(tuple(positive_edges[:, i]))
positive_edge_set = set(positive_edge_set) positive_edge_set = set(positive_edge_set)
negative_edges = np.zeros((2, num_negative_edges), dtype=positive_edges.dtype) negative_edges = np.zeros(
(2, num_negative_edges), dtype=positive_edges.dtype
)
for i in range(num_negative_edges): for i in range(num_negative_edges):
while True: while True:
mask_temp = tuple(np.random.choice(num_nodes, size=(2,), replace=False)) mask_temp = tuple(
np.random.choice(num_nodes, size=(2,), replace=False)
)
if mask_temp not in positive_edge_set: if mask_temp not in positive_edge_set:
negative_edges[:, i] = mask_temp negative_edges[:, i] = mask_temp
break break
return negative_edges return negative_edges
def get_pos_neg_edges(data, infer_link_positive=True): def get_pos_neg_edges(data, infer_link_positive=True):
if infer_link_positive: if infer_link_positive:
data['positive_edges'] = to_single_directed(data['edge_index'].numpy()) data["positive_edges"] = to_single_directed(data["edge_index"].numpy())
split_edges('positive', data['positive_edges'], data) split_edges("positive", data["positive_edges"], data)
# resample edge mask link negative # resample edge mask link negative
negative_edges = get_negative_edges(data['positive_edges'], data['num_nodes'], negative_edges = get_negative_edges(
num_negative_edges=data['positive_edges'].shape[1]) data["positive_edges"],
split_edges('negative', negative_edges, data) data["num_nodes"],
num_negative_edges=data["positive_edges"].shape[1],
)
split_edges("negative", negative_edges, data)
return data return data
def shortest_path(graph, node_range, cutoff): def shortest_path(graph, node_range, cutoff):
dists_dict = {} dists_dict = {}
for node in tqdm(node_range, leave=False): for node in tqdm(node_range, leave=False):
dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff) dists_dict[node] = nx.single_source_shortest_path_length(
graph, node, cutoff
)
return dists_dict return dists_dict
def merge_dicts(dicts): def merge_dicts(dicts):
result = {} result = {}
for dictionary in dicts: for dictionary in dicts:
result.update(dictionary) result.update(dictionary)
return result return result
def all_pairs_shortest_path(graph, cutoff=None, num_workers=4): def all_pairs_shortest_path(graph, cutoff=None, num_workers=4):
nodes = list(graph.nodes) nodes = list(graph.nodes)
random.shuffle(nodes) random.shuffle(nodes)
pool = mp.Pool(processes=num_workers) pool = mp.Pool(processes=num_workers)
interval_size = len(nodes) / num_workers interval_size = len(nodes) / num_workers
results = [pool.apply_async(shortest_path, args=( results = [
graph, nodes[int(interval_size * i): int(interval_size * (i + 1))], cutoff)) pool.apply_async(
for i in range(num_workers)] shortest_path,
args=(
graph,
nodes[int(interval_size * i) : int(interval_size * (i + 1))],
cutoff,
),
)
for i in range(num_workers)
]
output = [p.get() for p in results] output = [p.get() for p in results]
dists_dict = merge_dicts(output) dists_dict = merge_dicts(output)
pool.close() pool.close()
pool.join() pool.join()
return dists_dict return dists_dict
def precompute_dist_data(edge_index, num_nodes, approximate=0): def precompute_dist_data(edge_index, num_nodes, approximate=0):
""" """
Here dist is 1/real_dist, higher actually means closer, 0 means disconnected Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
...@@ -145,7 +175,9 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0): ...@@ -145,7 +175,9 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0):
n = num_nodes n = num_nodes
dists_array = np.zeros((n, n)) dists_array = np.zeros((n, n))
dists_dict = all_pairs_shortest_path(graph, cutoff=approximate if approximate > 0 else None) dists_dict = all_pairs_shortest_path(
graph, cutoff=approximate if approximate > 0 else None
)
node_list = graph.nodes() node_list = graph.nodes()
for node_i in node_list: for node_i in node_list:
shortest_dist = dists_dict[node_i] shortest_dist = dists_dict[node_i]
...@@ -155,24 +187,36 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0): ...@@ -155,24 +187,36 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0):
dists_array[node_i, node_j] = 1 / (dist + 1) dists_array[node_i, node_j] = 1 / (dist + 1)
return dists_array return dists_array
def get_dataset(args): def get_dataset(args):
# Generate graph data # Generate graph data
data_info = get_communities(args.inductive) data_info = get_communities(args.inductive)
# Get positive and negative edges # Get positive and negative edges
data = get_pos_neg_edges(data_info, infer_link_positive=True if args.task == 'link' else False) data = get_pos_neg_edges(
data_info, infer_link_positive=True if args.task == "link" else False
)
# Pre-compute shortest path length # Pre-compute shortest path length
if args.task == 'link': if args.task == "link":
dists_removed = precompute_dist_data(data['positive_edges_train'], data['num_nodes'], dists_removed = precompute_dist_data(
approximate=args.k_hop_dist) data["positive_edges_train"],
data['dists'] = torch.from_numpy(dists_removed).float() data["num_nodes"],
data['edge_index'] = torch.from_numpy(to_bidirected(data['positive_edges_train'])).long() approximate=args.k_hop_dist,
)
data["dists"] = torch.from_numpy(dists_removed).float()
data["edge_index"] = torch.from_numpy(
to_bidirected(data["positive_edges_train"])
).long()
else: else:
dists = precompute_dist_data(data['edge_index'].numpy(), data['num_nodes'], dists = precompute_dist_data(
approximate=args.k_hop_dist) data["edge_index"].numpy(),
data['dists'] = torch.from_numpy(dists).float() data["num_nodes"],
approximate=args.k_hop_dist,
)
data["dists"] = torch.from_numpy(dists).float()
return data return data
def get_anchors(n): def get_anchors(n):
"""Get a list of NumPy arrays, each of them is an anchor node set""" """Get a list of NumPy arrays, each of them is an anchor node set"""
m = int(np.log2(n)) m = int(np.log2(n))
...@@ -180,9 +224,12 @@ def get_anchors(n): ...@@ -180,9 +224,12 @@ def get_anchors(n):
for i in range(m): for i in range(m):
anchor_size = int(n / np.exp2(i + 1)) anchor_size = int(n / np.exp2(i + 1))
for _ in range(m): for _ in range(m):
anchor_set_id.append(np.random.choice(n, size=anchor_size, replace=False)) anchor_set_id.append(
np.random.choice(n, size=anchor_size, replace=False)
)
return anchor_set_id return anchor_set_id
def get_dist_max(anchor_set_id, dist): def get_dist_max(anchor_set_id, dist):
# N x K, N is number of nodes, K is the number of anchor sets # N x K, N is number of nodes, K is the number of anchor sets
dist_max = torch.zeros((dist.shape[0], len(anchor_set_id))) dist_max = torch.zeros((dist.shape[0], len(anchor_set_id)))
...@@ -198,6 +245,7 @@ def get_dist_max(anchor_set_id, dist): ...@@ -198,6 +245,7 @@ def get_dist_max(anchor_set_id, dist):
dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp) dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp)
return dist_max, dist_argmax return dist_max, dist_argmax
def get_a_graph(dists_max, dists_argmax): def get_a_graph(dists_max, dists_argmax):
src = [] src = []
dst = [] dst = []
...@@ -207,7 +255,9 @@ def get_a_graph(dists_max, dists_argmax): ...@@ -207,7 +255,9 @@ def get_a_graph(dists_max, dists_argmax):
dists_max = dists_max.numpy() dists_max = dists_max.numpy()
for i in range(dists_max.shape[0]): for i in range(dists_max.shape[0]):
# Get unique closest anchor nodes for node i across all anchor sets # Get unique closest anchor nodes for node i across all anchor sets
tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(dists_argmax[i, :], True) tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(
dists_argmax[i, :], True
)
src.extend([i] * tmp_dists_argmax.shape[0]) src.extend([i] * tmp_dists_argmax.shape[0])
real_src.extend([i] * dists_argmax[i, :].shape[0]) real_src.extend([i] * dists_argmax[i, :].shape[0])
real_dst.extend(list(dists_argmax[i, :].numpy())) real_dst.extend(list(dists_argmax[i, :].numpy()))
...@@ -218,13 +268,14 @@ def get_a_graph(dists_max, dists_argmax): ...@@ -218,13 +268,14 @@ def get_a_graph(dists_max, dists_argmax):
g = (dst, src) g = (dst, src)
return g, anchor_eid, edge_weight return g, anchor_eid, edge_weight
def get_graphs(data, anchor_sets): def get_graphs(data, anchor_sets):
graphs = [] graphs = []
anchor_eids = [] anchor_eids = []
dists_max_list = [] dists_max_list = []
edge_weights = [] edge_weights = []
for anchor_set in tqdm(anchor_sets, leave=False): for anchor_set in tqdm(anchor_sets, leave=False):
dists_max, dists_argmax = get_dist_max(anchor_set, data['dists']) dists_max, dists_argmax = get_dist_max(anchor_set, data["dists"])
g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax) g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax)
graphs.append(g) graphs.append(g)
anchor_eids.append(anchor_eid) anchor_eids.append(anchor_eid)
...@@ -233,6 +284,7 @@ def get_graphs(data, anchor_sets): ...@@ -233,6 +284,7 @@ def get_graphs(data, anchor_sets):
return graphs, anchor_eids, dists_max_list, edge_weights return graphs, anchor_eids, dists_max_list, edge_weights
def merge_result(outputs): def merge_result(outputs):
graphs = [] graphs = []
anchor_eids = [] anchor_eids = []
...@@ -247,14 +299,26 @@ def merge_result(outputs): ...@@ -247,14 +299,26 @@ def merge_result(outputs):
return graphs, anchor_eids, dists_max_list, edge_weights return graphs, anchor_eids, dists_max_list, edge_weights
def preselect_anchor(data, args, num_workers=4): def preselect_anchor(data, args, num_workers=4):
pool = get_context("spawn").Pool(processes=num_workers) pool = get_context("spawn").Pool(processes=num_workers)
# Pre-compute anchor sets, a collection of anchor sets per epoch # Pre-compute anchor sets, a collection of anchor sets per epoch
anchor_set_ids = [get_anchors(data['num_nodes']) for _ in range(args.epoch_num)] anchor_set_ids = [
get_anchors(data["num_nodes"]) for _ in range(args.epoch_num)
]
interval_size = len(anchor_set_ids) / num_workers interval_size = len(anchor_set_ids) / num_workers
results = [pool.apply_async(get_graphs, args=( results = [
data, anchor_set_ids[int(interval_size * i):int(interval_size * (i + 1))],)) pool.apply_async(
for i in range(num_workers)] get_graphs,
args=(
data,
anchor_set_ids[
int(interval_size * i) : int(interval_size * (i + 1))
],
),
)
for i in range(num_workers)
]
output = [p.get() for p in results] output = [p.get() for p in results]
graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output) graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output)
......
''' Credit: https://github.com/fanyun-sun/InfoGraph ''' """ Credit: https://github.com/fanyun-sun/InfoGraph """
import math
import torch as th import torch as th
import torch.nn.functional as F import torch.nn.functional as F
import math
def get_positive_expectation(p_samples, average=True): def get_positive_expectation(p_samples, average=True):
"""Computes the positive part of a JS Divergence. """Computes the positive part of a JS Divergence.
...@@ -13,8 +14,8 @@ def get_positive_expectation(p_samples, average=True): ...@@ -13,8 +14,8 @@ def get_positive_expectation(p_samples, average=True):
Returns: Returns:
th.Tensor th.Tensor
""" """
log_2 = math.log(2.) log_2 = math.log(2.0)
Ep = log_2 - F.softplus(- p_samples) Ep = log_2 - F.softplus(-p_samples)
if average: if average:
return Ep.mean() return Ep.mean()
...@@ -30,7 +31,7 @@ def get_negative_expectation(q_samples, average=True): ...@@ -30,7 +31,7 @@ def get_negative_expectation(q_samples, average=True):
Returns: Returns:
th.Tensor th.Tensor
""" """
log_2 = math.log(2.) log_2 = math.log(2.0)
Eq = F.softplus(-q_samples) + q_samples - log_2 Eq = F.softplus(-q_samples) + q_samples - log_2
if average: if average:
...@@ -51,8 +52,8 @@ def local_global_loss_(l_enc, g_enc, graph_id): ...@@ -51,8 +52,8 @@ def local_global_loss_(l_enc, g_enc, graph_id):
for nodeidx, graphidx in enumerate(graph_id): for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1. pos_mask[nodeidx][graphidx] = 1.0
neg_mask[nodeidx][graphidx] = 0. neg_mask[nodeidx][graphidx] = 0.0
res = th.mm(l_enc, g_enc.t()) res = th.mm(l_enc, g_enc.t())
......
...@@ -2,39 +2,45 @@ ...@@ -2,39 +2,45 @@
import argparse import argparse
import copy import copy
import numpy as np
import torch import torch
import torch.optim as optim
import torch.nn as nn import torch.nn as nn
import numpy as np import torch.optim as optim
from dgl.data import CoraGraphDataset, CiteseerGraphDataset
from tqdm import trange
from sklearn.model_selection import train_test_split
from model import JKNet from model import JKNet
from sklearn.model_selection import train_test_split
from tqdm import trange
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def main(args): def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= # # Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset # Load from DGL dataset
if args.dataset == 'Cora': if args.dataset == "Cora":
dataset = CoraGraphDataset() dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer': elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset() dataset = CiteseerGraphDataset()
else: else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset)) raise ValueError("Dataset {} is invalid.".format(args.dataset))
graph = dataset[0] graph = dataset[0]
# check cuda # check cuda
device = f'cuda:{args.gpu}' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu' device = (
f"cuda:{args.gpu}"
if args.gpu >= 0 and torch.cuda.is_available()
else "cpu"
)
# retrieve the number of classes # retrieve the number of classes
n_classes = dataset.num_classes n_classes = dataset.num_classes
# retrieve labels of ground truth # retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long() labels = graph.ndata.pop("label").to(device).long()
# Extract node features # Extract node features
feats = graph.ndata.pop('feat').to(device) feats = graph.ndata.pop("feat").to(device)
n_features = feats.shape[-1] n_features = feats.shape[-1]
# create masks for train / validation / test # create masks for train / validation / test
...@@ -47,12 +53,14 @@ def main(args): ...@@ -47,12 +53,14 @@ def main(args):
graph = graph.to(device) graph = graph.to(device)
# Step 2: Create model =================================================================== # # Step 2: Create model =================================================================== #
model = JKNet(in_dim=n_features, model = JKNet(
in_dim=n_features,
hid_dim=args.hid_dim, hid_dim=args.hid_dim,
out_dim=n_classes, out_dim=n_classes,
num_layers=args.num_layers, num_layers=args.num_layers,
mode=args.mode, mode=args.mode,
dropout=args.dropout).to(device) dropout=args.dropout,
).to(device)
best_model = copy.deepcopy(model) best_model = copy.deepcopy(model)
...@@ -62,7 +70,7 @@ def main(args): ...@@ -62,7 +70,7 @@ def main(args):
# Step 4: training epochs =============================================================== # # Step 4: training epochs =============================================================== #
acc = 0 acc = 0
epochs = trange(args.epochs, desc='Accuracy & Loss') epochs = trange(args.epochs, desc="Accuracy & Loss")
for _ in epochs: for _ in epochs:
# Training using a full graph # Training using a full graph
...@@ -72,7 +80,9 @@ def main(args): ...@@ -72,7 +80,9 @@ def main(args):
# compute loss # compute loss
train_loss = loss_fn(logits[train_idx], labels[train_idx]) train_loss = loss_fn(logits[train_idx], labels[train_idx])
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx) train_acc = torch.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
# backward # backward
opt.zero_grad() opt.zero_grad()
...@@ -84,11 +94,16 @@ def main(args): ...@@ -84,11 +94,16 @@ def main(args):
with torch.no_grad(): with torch.no_grad():
valid_loss = loss_fn(logits[val_idx], labels[val_idx]) valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx) valid_acc = torch.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
# Print out performance # Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format( epochs.set_description(
train_acc, train_loss.item(), valid_acc, valid_loss.item())) "Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()
)
)
if valid_acc > acc: if valid_acc > acc:
acc = valid_acc acc = valid_acc
...@@ -96,31 +111,57 @@ def main(args): ...@@ -96,31 +111,57 @@ def main(args):
best_model.eval() best_model.eval()
logits = best_model(graph, feats) logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx) test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc)) print("Test Acc {:.4f}".format(test_acc))
return test_acc return test_acc
if __name__ == "__main__": if __name__ == "__main__":
""" """
JKNet Hyperparameters JKNet Hyperparameters
""" """
parser = argparse.ArgumentParser(description='JKNet') parser = argparse.ArgumentParser(description="JKNet")
# data source params # data source params
parser.add_argument('--dataset', type=str, default='Cora', help='Name of dataset.') parser.add_argument(
"--dataset", type=str, default="Cora", help="Name of dataset."
)
# cuda params # cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.') parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
# training params # training params
parser.add_argument('--run', type=int, default=10, help='Running times.') parser.add_argument("--run", type=int, default=10, help="Running times.")
parser.add_argument('--epochs', type=int, default=500, help='Training epochs.') parser.add_argument(
parser.add_argument('--lr', type=float, default=0.005, help='Learning rate.') "--epochs", type=int, default=500, help="Training epochs."
parser.add_argument('--lamb', type=float, default=0.0005, help='L2 reg.') )
parser.add_argument(
"--lr", type=float, default=0.005, help="Learning rate."
)
parser.add_argument("--lamb", type=float, default=0.0005, help="L2 reg.")
# model params # model params
parser.add_argument("--hid-dim", type=int, default=32, help='Hidden layer dimensionalities.') parser.add_argument(
parser.add_argument("--num-layers", type=int, default=5, help='Number of GCN layers.') "--hid-dim", type=int, default=32, help="Hidden layer dimensionalities."
parser.add_argument("--mode", type=str, default='cat', help="Type of aggregation.", choices=['cat', 'max', 'lstm']) )
parser.add_argument("--dropout", type=float, default=0.5, help='Dropout applied at all layers.') parser.add_argument(
"--num-layers", type=int, default=5, help="Number of GCN layers."
)
parser.add_argument(
"--mode",
type=str,
default="cat",
help="Type of aggregation.",
choices=["cat", "max", "lstm"],
)
parser.add_argument(
"--dropout",
type=float,
default=0.5,
help="Dropout applied at all layers.",
)
args = parser.parse_args() args = parser.parse_args()
print(args) print(args)
...@@ -132,6 +173,6 @@ if __name__ == "__main__": ...@@ -132,6 +173,6 @@ if __name__ == "__main__":
mean = np.around(np.mean(acc_lists, axis=0), decimals=3) mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
std = np.around(np.std(acc_lists, axis=0), decimals=3) std = np.around(np.std(acc_lists, axis=0), decimals=3)
print('total acc: ', acc_lists) print("total acc: ", acc_lists)
print('mean', mean) print("mean", mean)
print('std', std) print("std", std)
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
import dgl.function as fn import dgl.function as fn
from dgl.nn import GraphConv, JumpingKnowledge from dgl.nn import GraphConv, JumpingKnowledge
class JKNet(nn.Module): class JKNet(nn.Module):
def __init__(self, def __init__(
in_dim, self, in_dim, hid_dim, out_dim, num_layers=1, mode="cat", dropout=0.0
hid_dim, ):
out_dim,
num_layers=1,
mode='cat',
dropout=0.):
super(JKNet, self).__init__() super(JKNet, self).__init__()
self.mode = mode self.mode = mode
...@@ -21,12 +19,12 @@ class JKNet(nn.Module): ...@@ -21,12 +19,12 @@ class JKNet(nn.Module):
for _ in range(num_layers): for _ in range(num_layers):
self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu)) self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu))
if self.mode == 'lstm': if self.mode == "lstm":
self.jump = JumpingKnowledge(mode, hid_dim, num_layers) self.jump = JumpingKnowledge(mode, hid_dim, num_layers)
else: else:
self.jump = JumpingKnowledge(mode) self.jump = JumpingKnowledge(mode)
if self.mode == 'cat': if self.mode == "cat":
hid_dim = hid_dim * (num_layers + 1) hid_dim = hid_dim * (num_layers + 1)
self.output = nn.Linear(hid_dim, out_dim) self.output = nn.Linear(hid_dim, out_dim)
...@@ -44,7 +42,7 @@ class JKNet(nn.Module): ...@@ -44,7 +42,7 @@ class JKNet(nn.Module):
feats = self.dropout(layer(g, feats)) feats = self.dropout(layer(g, feats))
feat_lst.append(feats) feat_lst.append(feats)
g.ndata['h'] = self.jump(feat_lst) g.ndata["h"] = self.jump(feat_lst)
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h')) g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
return self.output(g.ndata['h']) return self.output(g.ndata["h"])
from .mol_tree import Vocab from .chemutils import decode_stereo
from .datautils import JTNNCollator, JTNNDataset
from .jtnn_vae import DGLJTNNVAE from .jtnn_vae import DGLJTNNVAE
from .mol_tree import Vocab
from .mpn import DGLMPN from .mpn import DGLMPN
from .nnutils import cuda from .nnutils import cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
from collections import defaultdict
import rdkit.Chem as Chem import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
MST_MAX_WEIGHT = 100 MST_MAX_WEIGHT = 100
MAX_NCAND = 2000 MAX_NCAND = 2000
def set_atommap(mol, num=0): def set_atommap(mol, num=0):
for atom in mol.GetAtoms(): for atom in mol.GetAtoms():
atom.SetAtomMapNum(num) atom.SetAtomMapNum(num)
def get_mol(smiles): def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
if mol is None: if mol is None:
...@@ -18,26 +21,39 @@ def get_mol(smiles): ...@@ -18,26 +21,39 @@ def get_mol(smiles):
Chem.Kekulize(mol) Chem.Kekulize(mol)
return mol return mol
def get_smiles(mol): def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True) return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D): def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D) mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol)) dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers] dec_isomers = [
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers] Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True))
for mol in dec_isomers
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms() ]
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"] smiles3D = [
Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers
]
chiralN = [
atom.GetIdx()
for atom in dec_isomers[0].GetAtoms()
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"
]
if len(chiralN) > 0: if len(chiralN) > 0:
for mol in dec_isomers: for mol in dec_isomers:
for idx in chiralN: for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED) mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED
)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True)) smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D return smiles3D
def sanitize(mol): def sanitize(mol):
try: try:
smiles = get_smiles(mol) smiles = get_smiles(mol)
...@@ -46,14 +62,16 @@ def sanitize(mol): ...@@ -46,14 +62,16 @@ def sanitize(mol):
return None return None
return mol return mol
def copy_atom(atom): def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol()) new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge()) new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum()) new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom return new_atom
def copy_edit_mol(mol): def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles('')) new_mol = Chem.RWMol(Chem.MolFromSmiles(""))
for atom in mol.GetAtoms(): for atom in mol.GetAtoms():
new_atom = copy_atom(atom) new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom) new_mol.AddAtom(new_atom)
...@@ -64,13 +82,15 @@ def copy_edit_mol(mol): ...@@ -64,13 +82,15 @@ def copy_edit_mol(mol):
new_mol.AddBond(a1, a2, bt) new_mol.AddBond(a1, a2, bt)
return new_mol return new_mol
def get_clique_mol(mol, atoms): def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True) smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False) new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol() new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) #We assume this is not None new_mol = sanitize(new_mol) # We assume this is not None
return new_mol return new_mol
def tree_decomp(mol): def tree_decomp(mol):
n_atoms = mol.GetNumAtoms() n_atoms = mol.GetNumAtoms()
if n_atoms == 1: if n_atoms == 1:
...@@ -81,7 +101,7 @@ def tree_decomp(mol): ...@@ -81,7 +101,7 @@ def tree_decomp(mol):
a1 = bond.GetBeginAtom().GetIdx() a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx() a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing(): if not bond.IsInRing():
cliques.append([a1,a2]) cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)] ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr) cliques.extend(ssr)
...@@ -91,12 +111,14 @@ def tree_decomp(mol): ...@@ -91,12 +111,14 @@ def tree_decomp(mol):
for atom in cliques[i]: for atom in cliques[i]:
nei_list[atom].append(i) nei_list[atom].append(i)
#Merge Rings with intersection > 2 atoms # Merge Rings with intersection > 2 atoms
for i in range(len(cliques)): for i in range(len(cliques)):
if len(cliques[i]) <= 2: continue if len(cliques[i]) <= 2:
continue
for atom in cliques[i]: for atom in cliques[i]:
for j in nei_list[atom]: for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2: continue if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j]) inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2: if len(inter) > 2:
cliques[i].extend(cliques[j]) cliques[i].extend(cliques[j])
...@@ -109,7 +131,7 @@ def tree_decomp(mol): ...@@ -109,7 +131,7 @@ def tree_decomp(mol):
for atom in cliques[i]: for atom in cliques[i]:
nei_list[atom].append(i) nei_list[atom].append(i)
#Build edges and add singleton cliques # Build edges and add singleton cliques
edges = defaultdict(int) edges = defaultdict(int)
for atom in range(n_atoms): for atom in range(n_atoms):
if len(nei_list[atom]) <= 1: if len(nei_list[atom]) <= 1:
...@@ -122,37 +144,44 @@ def tree_decomp(mol): ...@@ -122,37 +144,44 @@ def tree_decomp(mol):
cliques.append([atom]) cliques.append([atom])
c2 = len(cliques) - 1 c2 = len(cliques) - 1
for c1 in cnei: for c1 in cnei:
edges[(c1,c2)] = 1 edges[(c1, c2)] = 1
elif len(rings) > 2: #Multiple (n>2) complex rings elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom]) cliques.append([atom])
c2 = len(cliques) - 1 c2 = len(cliques) - 1
for c1 in cnei: for c1 in cnei:
edges[(c1,c2)] = MST_MAX_WEIGHT - 1 edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else: else:
for i in range(len(cnei)): for i in range(len(cnei)):
for j in range(i + 1, len(cnei)): for j in range(i + 1, len(cnei)):
c1,c2 = cnei[i],cnei[j] c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2]) inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1,c2)] < len(inter): if edges[(c1, c2)] < len(inter):
edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction edges[(c1, c2)] = len(
inter
) # cnei[i] < cnei[j] by construction
edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()] edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
if len(edges) == 0: if len(edges) == 0:
return cliques, edges return cliques, edges
#Compute Maximum Spanning Tree # Compute Maximum Spanning Tree
row,col,data = list(zip(*edges)) row, col, data = list(zip(*edges))
n_clique = len(cliques) n_clique = len(cliques)
clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) ) clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph) junc_tree = minimum_spanning_tree(clique_graph)
row,col = junc_tree.nonzero() row, col = junc_tree.nonzero()
edges = [(row[i],col[i]) for i in range(len(row))] edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges) return (cliques, edges)
def atom_equal(a1, a2): def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge() return (
a1.GetSymbol() == a2.GetSymbol()
and a1.GetFormalCharge() == a2.GetFormalCharge()
)
#Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False): def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom()) b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse: if reverse:
...@@ -161,10 +190,11 @@ def ring_bond_equal(b1, b2, reverse=False): ...@@ -161,10 +190,11 @@ def ring_bond_equal(b1, b2, reverse=False):
b2 = (b2.GetBeginAtom(), b2.GetEndAtom()) b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1]) return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap): def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes] prev_nids = [node["nid"] for node in prev_nodes]
for nei_node in prev_nodes + neighbors: for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol'] nei_id, nei_mol = nei_node["nid"], nei_node["mol"]
amap = nei_amap[nei_id] amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms(): for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap: if atom.GetIdx() not in amap:
...@@ -181,82 +211,116 @@ def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap): ...@@ -181,82 +211,116 @@ def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
a2 = amap[bond.GetEndAtom().GetIdx()] a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None: if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType()) ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: #father node overrides elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2) ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType()) ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list): def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol) ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors} nei_amap = {nei["nid"]: {} for nei in prev_nodes + neighbors}
for nei_id,ctr_atom,nei_atom in amap_list: for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap) ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol() return ctr_mol.GetMol()
#This version records idx mapping between ctr_mol and nei_mol
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons): def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol,nei_idx = nei_node['mol'], nei_node['nid'] nei_mol, nei_idx = nei_node["mol"], nei_node["nid"]
att_confs = [] att_confs = []
black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons] black_list = [
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list] atom_idx for nei_id, atom_idx, _ in amap if nei_id in singletons
]
ctr_atoms = [
atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list
]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()] ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: #neighbor singleton if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0) nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _,atom_idx,_ in amap] used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms: for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list: if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)] new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append( new_amap ) att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: #neighbor is a bond elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0) bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble()) bond_val = int(bond.GetBondTypeAsDouble())
b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom() b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms: for atom in ctr_atoms:
#Optimize if atom is carbon (other atoms may change valence) # Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val: if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue continue
if atom_equal(atom, b1): if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())] new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append( new_amap ) att_confs.append(new_amap)
elif atom_equal(atom, b2): elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())] new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append( new_amap ) att_confs.append(new_amap)
else: else:
#intersection is an atom # intersection is an atom
for a1 in ctr_atoms: for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms(): for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2): if atom_equal(a1, a2):
#Optimize if atom is carbon (other atoms may change valence) # Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4: if (
a1.GetAtomicNum() == 6
and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4
):
continue continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())] new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append( new_amap ) att_confs.append(new_amap)
#intersection is an bond # intersection is an bond
if ctr_mol.GetNumBonds() > 1: if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds: for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds(): for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2): if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()), new_amap = amap + [
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())] (
att_confs.append( new_amap ) nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetBeginAtom().GetIdx(),
),
(
nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetEndAtom().GetIdx(),
),
]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True): if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()), new_amap = amap + [
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())] (
att_confs.append( new_amap ) nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetEndAtom().GetIdx(),
),
(
nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetBeginAtom().GetIdx(),
),
]
att_confs.append(new_amap)
return att_confs return att_confs
#Try rings first: Speed-Up
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]): def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = [] all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1] singletons = [
nei_node["nid"]
for nei_node in neighbors + prev_nodes
if nei_node["mol"].GetNumAtoms() == 1
]
def search(cur_amap, depth): def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND: if len(all_attach_confs) > MAX_NCAND:
...@@ -266,11 +330,13 @@ def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]): ...@@ -266,11 +330,13 @@ def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
return return
nei_node = neighbors[depth] nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons) cand_amap = enum_attach_nx(node["mol"], nei_node, cur_amap, singletons)
cand_smiles = set() cand_smiles = set()
candidates = [] candidates = []
for amap in cand_amap: for amap in cand_amap:
cand_mol = local_attach_nx(node['mol'], neighbors[:depth+1], prev_nodes, amap) cand_mol = local_attach_nx(
node["mol"], neighbors[: depth + 1], prev_nodes, amap
)
cand_mol = sanitize(cand_mol) cand_mol = sanitize(cand_mol)
if cand_mol is None: if cand_mol is None:
continue continue
...@@ -290,47 +356,69 @@ def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]): ...@@ -290,47 +356,69 @@ def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
cand_smiles = set() cand_smiles = set()
candidates = [] candidates = []
for amap in all_attach_confs: for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap) cand_mol = local_attach_nx(node["mol"], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol)) cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol) smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles: if smiles in cand_smiles:
continue continue
cand_smiles.add(smiles) cand_smiles.add(smiles)
Chem.Kekulize(cand_mol) Chem.Kekulize(cand_mol)
candidates.append( (smiles,cand_mol,amap) ) candidates.append((smiles, cand_mol, amap))
return candidates return candidates
#Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id): # Only used for debugging purpose
def dfs_assemble_nx(
graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id
):
cur_node = graph.nodes_dict[cur_node_id] cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1 fa_nid = fa_node["nid"] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else [] prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id] if graph.nodes_dict[nei]['nid'] != fa_nid] children_id = [
nei
for nei in graph[cur_node_id]
if graph.nodes_dict[nei]["nid"] != fa_nid
]
children = [graph.nodes_dict[nei] for nei in children_id] children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = [nei for nei in children if nei["mol"].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True) neighbors = sorted(
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [nei for nei in children if nei["mol"].GetNumAtoms() == 1]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']] cur_amap = [
cands = enum_assemble_nx(graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap) (fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node["nid"]
]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap
)
if len(cands) == 0: if len(cands) == 0:
return return
cand_smiles, _, cand_amap = zip(*cands) cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label']) label_idx = cand_smiles.index(cur_node["label"])
label_amap = cand_amap[label_idx] label_amap = cand_amap[label_idx]
for nei_id,ctr_atom,nei_atom in label_amap: for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid: if nei_id == fa_nid:
continue continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom] global_amap[nei_id][nei_atom] = global_amap[cur_node["nid"]][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap) #father is already attached cur_mol = attach_mols_nx(
cur_mol, children, [], global_amap
) # father is already attached
for nei_node_id, nei_node in zip(children_id, children): for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']: if not nei_node["is_leaf"]:
dfs_assemble_nx(graph, cur_mol, global_amap, label_amap, nei_node_id, cur_node_id) dfs_assemble_nx(
graph,
cur_mol,
global_amap,
label_amap,
nei_node_id,
cur_node_id,
)
...@@ -2,41 +2,49 @@ import torch ...@@ -2,41 +2,49 @@ import torch
from torch.utils.data import Dataset from torch.utils.data import Dataset
import dgl import dgl
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url from dgl.data.utils import (
from .mol_tree_nx import DGLMolTree _get_dgl_url,
from .mol_tree import Vocab download,
extract_archive,
get_download_dir,
)
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC from .jtmpn import BOND_FDIM as BOND_FDIM_DEC
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .mpn import mol2dgl_single as mol2dgl_enc
def _unpack_field(examples, field): def _unpack_field(examples, field):
return [e[field] for e in examples] return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab): def _set_node_id(mol_tree, vocab):
wid = [] wid = []
for i, node in enumerate(mol_tree.nodes_dict): for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i mol_tree.nodes_dict[node]["idx"] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles'])) wid.append(vocab.get_index(mol_tree.nodes_dict[node]["smiles"]))
return wid return wid
class JTNNDataset(Dataset): class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True): def __init__(self, data, vocab, training=True):
self.dir = get_download_dir() self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir) self.zip_file_path = "{}/jtnn.zip".format(self.dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=self.zip_file_path) download(_get_dgl_url("dgllife/jtnn.zip"), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir)) extract_archive(self.zip_file_path, "{}/jtnn".format(self.dir))
print('Loading data...') print("Loading data...")
data_file = '{}/jtnn/{}.txt'.format(self.dir, data) data_file = "{}/jtnn/{}.txt".format(self.dir, data)
with open(data_file) as f: with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f] self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab) self.vocab_file = "{}/jtnn/{}.txt".format(self.dir, vocab)
print('Loading finished.') print("Loading finished.")
print('\tNum samples:', len(self.data)) print("\tNum samples:", len(self.data))
print('\tVocab file:', self.vocab_file) print("\tVocab file:", self.vocab_file)
self.training = training self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)]) self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
...@@ -55,11 +63,11 @@ class JTNNDataset(Dataset): ...@@ -55,11 +63,11 @@ class JTNNDataset(Dataset):
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles) mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = { result = {
'mol_tree': mol_tree, "mol_tree": mol_tree,
'mol_graph': mol_graph, "mol_graph": mol_graph,
'atom_x_enc': atom_x_enc, "atom_x_enc": atom_x_enc,
'bond_x_enc': bond_x_enc, "bond_x_enc": bond_x_enc,
'wid': wid, "wid": wid,
} }
if not self.training: if not self.training:
...@@ -69,17 +77,24 @@ class JTNNDataset(Dataset): ...@@ -69,17 +77,24 @@ class JTNNDataset(Dataset):
cands = [] cands = []
for node_id, node in mol_tree.nodes_dict.items(): for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth # fill in ground truth
if node['label'] not in node['cands']: if node["label"] not in node["cands"]:
node['cands'].append(node['label']) node["cands"].append(node["label"])
node['cand_mols'].append(node['label_mol']) node["cand_mols"].append(node["label_mol"])
if node['is_leaf'] or len(node['cands']) == 1: if node["is_leaf"] or len(node["cands"]) == 1:
continue continue
cands.extend([(cand, mol_tree, node_id) cands.extend(
for cand in node['cand_mols']]) [(cand, mol_tree, node_id) for cand in node["cand_mols"]]
)
if len(cands) > 0: if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \ (
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands) cand_graphs,
atom_x_dec,
bond_x_dec,
tree_mess_src_e,
tree_mess_tgt_e,
tree_mess_tgt_n,
) = mol2dgl_dec(cands)
else: else:
cand_graphs = [] cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC) atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
...@@ -95,8 +110,9 @@ class JTNNDataset(Dataset): ...@@ -95,8 +110,9 @@ class JTNNDataset(Dataset):
cands.append(mol_tree.smiles3D) cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands] stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \ stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = zip(
zip(*stereo_graphs) *stereo_graphs
)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc) stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc) stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))] stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
...@@ -106,21 +122,24 @@ class JTNNDataset(Dataset): ...@@ -106,21 +122,24 @@ class JTNNDataset(Dataset):
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1]) stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = [] stereo_cand_label = []
result.update({ result.update(
'cand_graphs': cand_graphs, {
'atom_x_dec': atom_x_dec, "cand_graphs": cand_graphs,
'bond_x_dec': bond_x_dec, "atom_x_dec": atom_x_dec,
'tree_mess_src_e': tree_mess_src_e, "bond_x_dec": bond_x_dec,
'tree_mess_tgt_e': tree_mess_tgt_e, "tree_mess_src_e": tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n, "tree_mess_tgt_e": tree_mess_tgt_e,
'stereo_cand_graphs': stereo_cand_graphs, "tree_mess_tgt_n": tree_mess_tgt_n,
'stereo_atom_x_enc': stereo_atom_x_enc, "stereo_cand_graphs": stereo_cand_graphs,
'stereo_bond_x_enc': stereo_bond_x_enc, "stereo_atom_x_enc": stereo_atom_x_enc,
'stereo_cand_label': stereo_cand_label, "stereo_bond_x_enc": stereo_bond_x_enc,
}) "stereo_cand_label": stereo_cand_label,
}
)
return result return result
class JTNNCollator(object): class JTNNCollator(object):
def __init__(self, vocab, training): def __init__(self, vocab, training):
self.vocab = vocab self.vocab = vocab
...@@ -131,43 +150,45 @@ class JTNNCollator(object): ...@@ -131,43 +150,45 @@ class JTNNCollator(object):
if flatten: if flatten:
graphs = [g for f in graphs for g in f] graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs) graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x graph_batch.ndata["x"] = atom_x
graph_batch.edata.update({ graph_batch.edata.update(
'x': bond_x, {
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(), "x": bond_x,
}) "src_x": atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
}
)
return graph_batch return graph_batch
def __call__(self, examples): def __call__(self, examples):
# get list of trees # get list of trees
mol_trees = _unpack_field(examples, 'mol_tree') mol_trees = _unpack_field(examples, "mol_tree")
wid = _unpack_field(examples, 'wid') wid = _unpack_field(examples, "wid")
for _wid, mol_tree in zip(wid, mol_trees): for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.graph.ndata['wid'] = torch.LongTensor(_wid) mol_tree.graph.ndata["wid"] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy # TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs # batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph') mol_graphs = _unpack_field(examples, "mol_graph")
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc')) atom_x = torch.cat(_unpack_field(examples, "atom_x_enc"))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc')) bond_x = torch.cat(_unpack_field(examples, "bond_x_enc"))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False) mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = { result = {
'mol_trees': mol_trees, "mol_trees": mol_trees,
'mol_graph_batch': mol_graph_batch, "mol_graph_batch": mol_graph_batch,
} }
if not self.training: if not self.training:
return result return result
# batch candidate graphs # batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs') cand_graphs = _unpack_field(examples, "cand_graphs")
cand_batch_idx = [] cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec')) atom_x = torch.cat(_unpack_field(examples, "atom_x_dec"))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec')) bond_x = torch.cat(_unpack_field(examples, "bond_x_dec"))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e') tree_mess_src_e = _unpack_field(examples, "tree_mess_src_e")
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e') tree_mess_tgt_e = _unpack_field(examples, "tree_mess_tgt_e")
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n') tree_mess_tgt_n = _unpack_field(examples, "tree_mess_tgt_n")
n_graph_nodes = 0 n_graph_nodes = 0
n_tree_nodes = 0 n_tree_nodes = 0
...@@ -182,12 +203,14 @@ class JTNNCollator(object): ...@@ -182,12 +203,14 @@ class JTNNCollator(object):
tree_mess_src_e = torch.cat(tree_mess_src_e) tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n) tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True) cand_graph_batch = self._batch_and_set(
cand_graphs, atom_x, bond_x, True
)
# batch stereoisomers # batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs') stereo_cand_graphs = _unpack_field(examples, "stereo_cand_graphs")
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc')) atom_x = torch.cat(_unpack_field(examples, "stereo_atom_x_enc"))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc')) bond_x = torch.cat(_unpack_field(examples, "stereo_bond_x_enc"))
stereo_cand_batch_idx = [] stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)): for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i])) stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
...@@ -195,28 +218,31 @@ class JTNNCollator(object): ...@@ -195,28 +218,31 @@ class JTNNCollator(object):
if len(stereo_cand_batch_idx) > 0: if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [ stereo_cand_labels = [
(label, length) (label, length)
for ex in _unpack_field(examples, 'stereo_cand_label') for ex in _unpack_field(examples, "stereo_cand_label")
for label, length in ex for label, length in ex
] ]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels) stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set( stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True) stereo_cand_graphs, atom_x, bond_x, True
)
else: else:
stereo_cand_labels = [] stereo_cand_labels = []
stereo_cand_lengths = [] stereo_cand_lengths = []
stereo_cand_graph_batch = None stereo_cand_graph_batch = None
stereo_cand_batch_idx = [] stereo_cand_batch_idx = []
result.update({ result.update(
'cand_graph_batch': cand_graph_batch, {
'cand_batch_idx': cand_batch_idx, "cand_graph_batch": cand_graph_batch,
'tree_mess_tgt_e': tree_mess_tgt_e, "cand_batch_idx": cand_batch_idx,
'tree_mess_src_e': tree_mess_src_e, "tree_mess_tgt_e": tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n, "tree_mess_src_e": tree_mess_src_e,
'stereo_cand_graph_batch': stereo_cand_graph_batch, "tree_mess_tgt_n": tree_mess_tgt_n,
'stereo_cand_batch_idx': stereo_cand_batch_idx, "stereo_cand_graph_batch": stereo_cand_graph_batch,
'stereo_cand_labels': stereo_cand_labels, "stereo_cand_batch_idx": stereo_cand_batch_idx,
'stereo_cand_lengths': stereo_cand_lengths, "stereo_cand_labels": stereo_cand_labels,
}) "stereo_cand_lengths": stereo_cand_lengths,
}
)
return result return result
import numpy as np
import torch import torch
import torch.nn as nn import torch.nn as nn
from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, bfs_edges_generator, line_graph
import dgl.function as DGLF import dgl.function as DGLF
import numpy as np from dgl import batch, bfs_edges_generator, line_graph
from .nnutils import GRUUpdate, cuda, tocpu
MAX_NB = 8 MAX_NB = 8
def level_order(forest, roots): def level_order(forest, roots):
forest = tocpu(forest) forest = tocpu(forest)
edges = bfs_edges_generator(forest, roots) edges = bfs_edges_generator(forest, roots)
...@@ -18,6 +21,7 @@ def level_order(forest, roots): ...@@ -18,6 +21,7 @@ def level_order(forest, roots):
yield from reversed(edges_back) yield from reversed(edges_back)
yield from edges yield from edges
class EncoderGatherUpdate(nn.Module): class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size): def __init__(self, hidden_size):
nn.Module.__init__(self) nn.Module.__init__(self)
...@@ -26,10 +30,10 @@ class EncoderGatherUpdate(nn.Module): ...@@ -26,10 +30,10 @@ class EncoderGatherUpdate(nn.Module):
self.W = nn.Linear(2 * hidden_size, hidden_size) self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, nodes): def forward(self, nodes):
x = nodes.data['x'] x = nodes.data["x"]
m = nodes.data['m'] m = nodes.data["m"]
return { return {
'h': torch.relu(self.W(torch.cat([x, m], 1))), "h": torch.relu(self.W(torch.cat([x, m], 1))),
} }
...@@ -52,42 +56,53 @@ class DGLJTNNEncoder(nn.Module): ...@@ -52,42 +56,53 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch = batch(mol_trees) mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation # Build line graph to prepare for belief propagation
mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True) mol_tree_batch_lg = line_graph(
mol_tree_batch, backtracking=False, shared=True
)
return self.run(mol_tree_batch, mol_tree_batch_lg) return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg): def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can # Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset # simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)) node_offset = np.cumsum(
np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)
)
root_ids = node_offset[:-1] root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes() n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges() n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes # Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({ mol_tree_batch.ndata.update(
'x': self.embedding(mol_tree_batch.ndata['wid']), {
'm': cuda(torch.zeros(n_nodes, self.hidden_size)), "x": self.embedding(mol_tree_batch.ndata["wid"]),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)), "m": cuda(torch.zeros(n_nodes, self.hidden_size)),
}) "h": cuda(torch.zeros(n_nodes, self.hidden_size)),
}
)
# Initialize the intermediate variables according to Eq (4)-(8). # Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields. # Also initialize the src_x and dst_x fields.
# TODO: context? # TODO: context?
mol_tree_batch.edata.update({ mol_tree_batch.edata.update(
's': cuda(torch.zeros(n_edges, self.hidden_size)), {
'm': cuda(torch.zeros(n_edges, self.hidden_size)), "s": cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)), "m": cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)), "r": cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)), "z": cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)), "src_x": cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)), "dst_x": cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)), "rm": cuda(torch.zeros(n_edges, self.hidden_size)),
}) "accum_rm": cuda(torch.zeros(n_edges, self.hidden_size)),
}
)
# Send the source/destination node features to edges # Send the source/destination node features to edges
mol_tree_batch.apply_edges( mol_tree_batch.apply_edges(
func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']}, func=lambda edges: {
"src_x": edges.src["x"],
"dst_x": edges.dst["x"],
},
) )
# Message passing # Message passing
...@@ -98,15 +113,19 @@ class DGLJTNNEncoder(nn.Module): ...@@ -98,15 +113,19 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata) mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
for eid in level_order(mol_tree_batch, root_ids): for eid in level_order(mol_tree_batch, root_ids):
eid = eid.to(mol_tree_batch_lg.device) eid = eid.to(mol_tree_batch_lg.device)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's')) mol_tree_batch_lg.pull(
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm')) eid, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_batch_lg.pull(
eid, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_batch_lg.apply_nodes(self.enc_tree_update, v=eid) mol_tree_batch_lg.apply_nodes(self.enc_tree_update, v=eid)
# Readout # Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata) mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
mol_tree_batch.update_all(DGLF.copy_e('m', 'm'), DGLF.sum('m', 'm')) mol_tree_batch.update_all(DGLF.copy_e("m", "m"), DGLF.sum("m", "m"))
mol_tree_batch.apply_nodes(self.enc_tree_gather_update) mol_tree_batch.apply_nodes(self.enc_tree_gather_update)
root_vecs = mol_tree_batch.nodes[root_ids].data['h'] root_vecs = mol_tree_batch.nodes[root_ids].data["h"]
return mol_tree_batch, root_vecs return mol_tree_batch, root_vecs
import copy
import rdkit.Chem as Chem
import torch import torch
import torch.nn as nn import torch.nn as nn
import torch.nn.functional as F import torch.nn.functional as F
from .nnutils import cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \ from dgl import batch, unbatch
attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder from .chemutils import (
attach_mols_nx,
copy_edit_mol,
decode_stereo,
enum_assemble_nx,
set_atommap,
)
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .mpn import DGLMPN from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN from .nnutils import cuda
from .jtmpn import mol2dgl_single as mol2dgl_dec
import rdkit.Chem as Chem
import copy
from dgl import batch, unbatch
class DGLJTNNVAE(nn.Module): class DGLJTNNVAE(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, depth): def __init__(self, vocab, hidden_size, latent_size, depth):
super(DGLJTNNVAE, self).__init__() super(DGLJTNNVAE, self).__init__()
self.vocab = vocab self.vocab = vocab
...@@ -29,7 +35,8 @@ class DGLJTNNVAE(nn.Module): ...@@ -29,7 +35,8 @@ class DGLJTNNVAE(nn.Module):
self.mpn = DGLMPN(hidden_size, depth) self.mpn = DGLMPN(hidden_size, depth)
self.jtnn = DGLJTNNEncoder(vocab, hidden_size, self.embedding) self.jtnn = DGLJTNNEncoder(vocab, hidden_size, self.embedding)
self.decoder = DGLJTNNDecoder( self.decoder = DGLJTNNDecoder(
vocab, hidden_size, latent_size // 2, self.embedding) vocab, hidden_size, latent_size // 2, self.embedding
)
self.jtmpn = DGLJTMPN(hidden_size, depth) self.jtmpn = DGLJTMPN(hidden_size, depth)
self.T_mean = nn.Linear(hidden_size, latent_size // 2) self.T_mean = nn.Linear(hidden_size, latent_size // 2)
...@@ -44,24 +51,32 @@ class DGLJTNNVAE(nn.Module): ...@@ -44,24 +51,32 @@ class DGLJTNNVAE(nn.Module):
@staticmethod @staticmethod
def move_to_cuda(mol_batch): def move_to_cuda(mol_batch):
for i in range(len(mol_batch['mol_trees'])): for i in range(len(mol_batch["mol_trees"])):
mol_batch['mol_trees'][i].graph = cuda(mol_batch['mol_trees'][i].graph) mol_batch["mol_trees"][i].graph = cuda(
mol_batch["mol_trees"][i].graph
)
mol_batch['mol_graph_batch'] = cuda(mol_batch['mol_graph_batch']) mol_batch["mol_graph_batch"] = cuda(mol_batch["mol_graph_batch"])
if 'cand_graph_batch' in mol_batch: if "cand_graph_batch" in mol_batch:
mol_batch['cand_graph_batch'] = cuda(mol_batch['cand_graph_batch']) mol_batch["cand_graph_batch"] = cuda(mol_batch["cand_graph_batch"])
if mol_batch.get('stereo_cand_graph_batch') is not None: if mol_batch.get("stereo_cand_graph_batch") is not None:
mol_batch['stereo_cand_graph_batch'] = cuda(mol_batch['stereo_cand_graph_batch']) mol_batch["stereo_cand_graph_batch"] = cuda(
mol_batch["stereo_cand_graph_batch"]
)
def encode(self, mol_batch): def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch'] mol_graphs = mol_batch["mol_graph_batch"]
mol_vec = self.mpn(mol_graphs) mol_vec = self.mpn(mol_graphs)
mol_tree_batch, tree_vec = self.jtnn([t.graph for t in mol_batch['mol_trees']]) mol_tree_batch, tree_vec = self.jtnn(
[t.graph for t in mol_batch["mol_trees"]]
)
self.n_nodes_total += mol_graphs.number_of_nodes() self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges() self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.graph.number_of_nodes() for t in mol_batch['mol_trees']) self.n_tree_nodes_total += sum(
t.graph.number_of_nodes() for t in mol_batch["mol_trees"]
)
self.n_passes += 1 self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec return mol_tree_batch, tree_vec, mol_vec
...@@ -85,31 +100,45 @@ class DGLJTNNVAE(nn.Module): ...@@ -85,31 +100,45 @@ class DGLJTNNVAE(nn.Module):
def forward(self, mol_batch, beta=0, e1=None, e2=None): def forward(self, mol_batch, beta=0, e1=None, e2=None):
self.move_to_cuda(mol_batch) self.move_to_cuda(mol_batch)
mol_trees = mol_batch['mol_trees'] mol_trees = mol_batch["mol_trees"]
batch_size = len(mol_trees) batch_size = len(mol_trees)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch) mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2) tree_vec, mol_vec, z_mean, z_log_var = self.sample(
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size tree_vec, mol_vec, e1, e2
)
kl_loss = (
-0.5
* torch.sum(
1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)
)
/ batch_size
)
word_loss, topo_loss, word_acc, topo_acc = self.decoder([t.graph for t in mol_trees], tree_vec) word_loss, topo_loss, word_acc, topo_acc = self.decoder(
[t.graph for t in mol_trees], tree_vec
)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec) assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec) stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss loss = (
word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
)
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, mol_tree_batch, mol_vec): def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [mol_batch['cand_graph_batch'], cands = [
cuda(mol_batch['tree_mess_src_e']), mol_batch["cand_graph_batch"],
cuda(mol_batch['tree_mess_tgt_e']), cuda(mol_batch["tree_mess_src_e"]),
cuda(mol_batch['tree_mess_tgt_n'])] cuda(mol_batch["tree_mess_tgt_e"]),
cuda(mol_batch["tree_mess_tgt_n"]),
]
cand_vec = self.jtmpn(cands, mol_tree_batch) cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec) cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx'])) batch_idx = cuda(torch.LongTensor(mol_batch["cand_batch_idx"]))
mol_vec = mol_vec[batch_idx] mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2) mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
...@@ -118,16 +147,19 @@ class DGLJTNNVAE(nn.Module): ...@@ -118,16 +147,19 @@ class DGLJTNNVAE(nn.Module):
cnt, tot, acc = 0, 0, 0 cnt, tot, acc = 0, 0, 0
all_loss = [] all_loss = []
for i, mol_tree in enumerate(mol_batch['mol_trees']): for i, mol_tree in enumerate(mol_batch["mol_trees"]):
comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items() comp_nodes = [
if len(node['cands']) > 1 and not node['is_leaf']] node_id
for node_id, node in mol_tree.nodes_dict.items()
if len(node["cands"]) > 1 and not node["is_leaf"]
]
cnt += len(comp_nodes) cnt += len(comp_nodes)
# segmented accuracy and cross entropy # segmented accuracy and cross entropy
for node_id in comp_nodes: for node_id in comp_nodes:
node = mol_tree.nodes_dict[node_id] node = mol_tree.nodes_dict[node_id]
label = node['cands'].index(node['label']) label = node["cands"].index(node["label"])
ncand = len(node['cands']) ncand = len(node["cands"])
cur_score = scores[tot:tot+ncand] cur_score = scores[tot : tot + ncand]
tot += ncand tot += ncand
if cur_score[label].item() >= cur_score.max().item(): if cur_score[label].item() >= cur_score.max().item():
...@@ -135,20 +167,23 @@ class DGLJTNNVAE(nn.Module): ...@@ -135,20 +167,23 @@ class DGLJTNNVAE(nn.Module):
label = cuda(torch.LongTensor([label])) label = cuda(torch.LongTensor([label]))
all_loss.append( all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False)) F.cross_entropy(
cur_score.view(1, -1), label, size_average=False
)
)
all_loss = sum(all_loss) / len(mol_batch['mol_trees']) all_loss = sum(all_loss) / len(mol_batch["mol_trees"])
return all_loss, acc / cnt return all_loss, acc / cnt
def stereo(self, mol_batch, mol_vec): def stereo(self, mol_batch, mol_vec):
stereo_cands = mol_batch['stereo_cand_graph_batch'] stereo_cands = mol_batch["stereo_cand_graph_batch"]
batch_idx = mol_batch['stereo_cand_batch_idx'] batch_idx = mol_batch["stereo_cand_batch_idx"]
labels = mol_batch['stereo_cand_labels'] labels = mol_batch["stereo_cand_labels"]
lengths = mol_batch['stereo_cand_lengths'] lengths = mol_batch["stereo_cand_lengths"]
if len(labels) == 0: if len(labels) == 0:
# Only one stereoisomer exists; do nothing # Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1. return cuda(torch.tensor(0.0)), 1.0
batch_idx = cuda(torch.LongTensor(batch_idx)) batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(stereo_cands) stereo_cands = self.mpn(stereo_cands)
...@@ -159,12 +194,15 @@ class DGLJTNNVAE(nn.Module): ...@@ -159,12 +194,15 @@ class DGLJTNNVAE(nn.Module):
st, acc = 0, 0 st, acc = 0, 0
all_loss = [] all_loss = []
for label, le in zip(labels, lengths): for label, le in zip(labels, lengths):
cur_scores = scores[st:st+le] cur_scores = scores[st : st + le]
if cur_scores.data[label].item() >= cur_scores.max().item(): if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1 acc += 1
label = cuda(torch.LongTensor([label])) label = cuda(torch.LongTensor([label]))
all_loss.append( all_loss.append(
F.cross_entropy(cur_scores.view(1, -1), label, size_average=False)) F.cross_entropy(
cur_scores.view(1, -1), label, size_average=False
)
)
st += le st += le
all_loss = sum(all_loss) / len(labels) all_loss = sum(all_loss) / len(labels)
...@@ -175,24 +213,32 @@ class DGLJTNNVAE(nn.Module): ...@@ -175,24 +213,32 @@ class DGLJTNNVAE(nn.Module):
effective_nodes_list = effective_nodes.tolist() effective_nodes_list = effective_nodes.tolist()
nodes_dict = [nodes_dict[v] for v in effective_nodes_list] nodes_dict = [nodes_dict[v] for v in effective_nodes_list]
for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)): for i, (node_id, node) in enumerate(
node['idx'] = i zip(effective_nodes_list, nodes_dict)
node['nid'] = i + 1 ):
node['is_leaf'] = True node["idx"] = i
node["nid"] = i + 1
node["is_leaf"] = True
if mol_tree.graph.in_degrees(node_id) > 1: if mol_tree.graph.in_degrees(node_id) > 1:
node['is_leaf'] = False node["is_leaf"] = False
set_atommap(node['mol'], node['nid']) set_atommap(node["mol"], node["nid"])
mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.to(tree_vec.device)) mol_tree_sg = mol_tree.graph.subgraph(
effective_nodes.to(tree_vec.device)
)
mol_tree_msg, _ = self.jtnn([mol_tree_sg]) mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0] mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict mol_tree_msg.nodes_dict = nodes_dict
cur_mol = copy_edit_mol(nodes_dict[0]['mol']) cur_mol = copy_edit_mol(nodes_dict[0]["mol"])
global_amap = [{}] + [{} for node in nodes_dict] global_amap = [{}] + [{} for node in nodes_dict]
global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()} global_amap[1] = {
atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()
}
cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None) cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None
)
if cur_mol is None: if cur_mol is None:
return None return None
...@@ -207,56 +253,86 @@ class DGLJTNNVAE(nn.Module): ...@@ -207,56 +253,86 @@ class DGLJTNNVAE(nn.Module):
if len(stereo_cands) == 1: if len(stereo_cands) == 1:
return stereo_cands[0] return stereo_cands[0]
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands] stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \ stereo_cand_graphs, atom_x, bond_x = zip(*stereo_graphs)
zip(*stereo_graphs)
stereo_cand_graphs = cuda(batch(stereo_cand_graphs)) stereo_cand_graphs = cuda(batch(stereo_cand_graphs))
atom_x = cuda(torch.cat(atom_x)) atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x)) bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x stereo_cand_graphs.ndata["x"] = atom_x
stereo_cand_graphs.edata['x'] = bond_x stereo_cand_graphs.edata["x"] = bond_x
stereo_cand_graphs.edata['src_x'] = atom_x.new( stereo_cand_graphs.edata["src_x"] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_() bond_x.shape[0], atom_x.shape[1]
).zero_()
stereo_vecs = self.mpn(stereo_cand_graphs) stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs) stereo_vecs = self.G_mean(stereo_vecs)
scores = F.cosine_similarity(stereo_vecs, mol_vec) scores = F.cosine_similarity(stereo_vecs, mol_vec)
_, max_id = scores.max(0) _, max_id = scores.max(0)
return stereo_cands[max_id.item()] return stereo_cands[max_id.item()]
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol, def dfs_assemble(
global_amap, fa_amap, cur_node_id, fa_node_id): self,
mol_tree_msg,
mol_vec,
cur_mol,
global_amap,
fa_amap,
cur_node_id,
fa_node_id,
):
nodes_dict = mol_tree_msg.nodes_dict nodes_dict = mol_tree_msg.nodes_dict
fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
cur_node = nodes_dict[cur_node_id] cur_node = nodes_dict[cur_node_id]
fa_nid = fa_node['nid'] if fa_node is not None else -1 fa_nid = fa_node["nid"] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else [] prev_nodes = [fa_node] if fa_node is not None else []
children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist() children_node_id = [
if nodes_dict[v]['nid'] != fa_nid] v
for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]["nid"] != fa_nid
]
children = [nodes_dict[v] for v in children_node_id] children = [nodes_dict[v] for v in children_node_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1] neighbors = [nei for nei in children if nei["mol"].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True) neighbors = sorted(
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1] neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [nei for nei in children if nei["mol"].GetNumAtoms() == 1]
neighbors = singletons + neighbors neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']] cur_amap = [
(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap
if nid == cur_node["nid"]
]
cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap) cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0: if len(cands) == 0:
return None return None
cand_smiles, cand_mols, cand_amap = list(zip(*cands)) cand_smiles, cand_mols, cand_amap = list(zip(*cands))
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols] cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \ (
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(cands) cand_graphs,
atom_x,
bond_x,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
) = mol2dgl_dec(cands)
cand_graphs = batch([g.to(mol_vec.device) for g in cand_graphs]) cand_graphs = batch([g.to(mol_vec.device) for g in cand_graphs])
atom_x = cuda(atom_x) atom_x = cuda(atom_x)
bond_x = cuda(bond_x) bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x cand_graphs.ndata["x"] = atom_x
cand_graphs.edata['x'] = bond_x cand_graphs.edata["x"] = bond_x
cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_() cand_graphs.edata["src_x"] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]
).zero_()
cand_vecs = self.jtmpn( cand_vecs = self.jtmpn(
(cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes), (
cand_graphs,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
),
mol_tree_msg, mol_tree_msg,
) )
cand_vecs = self.G_mean(cand_vecs) cand_vecs = self.G_mean(cand_vecs)
...@@ -274,7 +350,9 @@ class DGLJTNNVAE(nn.Module): ...@@ -274,7 +350,9 @@ class DGLJTNNVAE(nn.Module):
for nei_id, ctr_atom, nei_atom in pred_amap: for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid: if nei_id == fa_nid:
continue continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom] new_global_amap[nei_id][nei_atom] = new_global_amap[
cur_node["nid"]
][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap) cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol() new_mol = cur_mol.GetMol()
...@@ -285,11 +363,17 @@ class DGLJTNNVAE(nn.Module): ...@@ -285,11 +363,17 @@ class DGLJTNNVAE(nn.Module):
result = True result = True
for nei_node_id, nei_node in zip(children_node_id, children): for nei_node_id, nei_node in zip(children_node_id, children):
if nei_node['is_leaf']: if nei_node["is_leaf"]:
continue continue
cur_mol = self.dfs_assemble( cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap, mol_tree_msg,
nei_node_id, cur_node_id) mol_vec,
cur_mol,
new_global_amap,
pred_amap,
nei_node_id,
cur_node_id,
)
if cur_mol is None: if cur_mol is None:
result = False result = False
break break
......
''' """
line_profiler integration line_profiler integration
''' """
import os import os
if os.getenv('PROFILE', 0): if os.getenv("PROFILE", 0):
import line_profiler
import atexit import atexit
import line_profiler
profile = line_profiler.LineProfiler() profile = line_profiler.LineProfiler()
profile_output = os.getenv('PROFILE_OUTPUT', None) profile_output = os.getenv("PROFILE_OUTPUT", None)
if profile_output: if profile_output:
from functools import partial from functools import partial
atexit.register(partial(profile.dump_stats, profile_output)) atexit.register(partial(profile.dump_stats, profile_output))
else: else:
atexit.register(profile.print_stats) atexit.register(profile.print_stats)
else: else:
def profile(f): def profile(f):
return f return f
import rdkit.Chem as Chem
import copy import copy
import rdkit.Chem as Chem
def get_slots(smiles): def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles) mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()] return [
(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs())
for atom in mol.GetAtoms()
]
class Vocab(object):
class Vocab(object):
def __init__(self, smiles_list): def __init__(self, smiles_list):
self.vocab = smiles_list self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)} self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab] self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles): def get_index(self, smiles):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment