Unverified Commit 0b9df9d7 authored by Hongzhi (Steve), Chen's avatar Hongzhi (Steve), Chen Committed by GitHub
Browse files

[Misc] Black auto fix. (#4652)


Co-authored-by: default avatarSteve <ubuntu@ip-172-31-34-29.ap-northeast-1.compute.internal>
parent f19f05ce
import os
from time import time
import torch
import torch.optim as optim
from model import NGCF
from utility.batch_test import *
from utility.helper import early_stopping
from time import time
import os
def main(args):
# Step 1: Prepare graph data and device ================================================================= #
if args.gpu >= 0 and torch.cuda.is_available():
device = 'cuda:{}'.format(args.gpu)
device = "cuda:{}".format(args.gpu)
else:
device = 'cpu'
device = "cpu"
g=data_generator.g
g=g.to(device)
g = data_generator.g
g = g.to(device)
# Step 2: Create model and training components=========================================================== #
model = NGCF(g, args.embed_size, args.layer_size, args.mess_dropout, args.regs[0]).to(device)
model = NGCF(
g, args.embed_size, args.layer_size, args.mess_dropout, args.regs[0]
).to(device)
optimizer = optim.Adam(model.parameters(), lr=args.lr)
# Step 3: training epoches ============================================================================== #
......@@ -27,16 +31,16 @@ def main(args):
loss_loger, pre_loger, rec_loger, ndcg_loger, hit_loger = [], [], [], [], []
for epoch in range(args.epoch):
t1 = time()
loss, mf_loss, emb_loss = 0., 0., 0.
loss, mf_loss, emb_loss = 0.0, 0.0, 0.0
for idx in range(n_batch):
users, pos_items, neg_items = data_generator.sample()
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(g, 'user', 'item', users,
pos_items,
neg_items)
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings = model(
g, "user", "item", users, pos_items, neg_items
)
batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(u_g_embeddings,
pos_i_g_embeddings,
neg_i_g_embeddings)
batch_loss, batch_mf_loss, batch_emb_loss = model.create_bpr_loss(
u_g_embeddings, pos_i_g_embeddings, neg_i_g_embeddings
)
optimizer.zero_grad()
batch_loss.backward()
optimizer.step()
......@@ -45,44 +49,71 @@ def main(args):
mf_loss += batch_mf_loss
emb_loss += batch_emb_loss
if (epoch + 1) % 10 != 0:
if args.verbose > 0 and epoch % args.verbose == 0:
perf_str = 'Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]' % (
epoch, time() - t1, loss, mf_loss, emb_loss)
perf_str = "Epoch %d [%.1fs]: train==[%.5f=%.5f + %.5f]" % (
epoch,
time() - t1,
loss,
mf_loss,
emb_loss,
)
print(perf_str)
continue #end the current epoch and move to the next epoch, let the following evaluation run every 10 epoches
continue # end the current epoch and move to the next epoch, let the following evaluation run every 10 epoches
#evaluate the model every 10 epoches
# evaluate the model every 10 epoches
t2 = time()
users_to_test = list(data_generator.test_set.keys())
ret = test(model, g, users_to_test)
t3 = time()
loss_loger.append(loss)
rec_loger.append(ret['recall'])
pre_loger.append(ret['precision'])
ndcg_loger.append(ret['ndcg'])
hit_loger.append(ret['hit_ratio'])
rec_loger.append(ret["recall"])
pre_loger.append(ret["precision"])
ndcg_loger.append(ret["ndcg"])
hit_loger.append(ret["hit_ratio"])
if args.verbose > 0:
perf_str = 'Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], ' \
'precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]' % \
(epoch, t2 - t1, t3 - t2, loss, mf_loss, emb_loss, ret['recall'][0], ret['recall'][-1],
ret['precision'][0], ret['precision'][-1], ret['hit_ratio'][0], ret['hit_ratio'][-1],
ret['ndcg'][0], ret['ndcg'][-1])
perf_str = (
"Epoch %d [%.1fs + %.1fs]: train==[%.5f=%.5f + %.5f], recall=[%.5f, %.5f], "
"precision=[%.5f, %.5f], hit=[%.5f, %.5f], ndcg=[%.5f, %.5f]"
% (
epoch,
t2 - t1,
t3 - t2,
loss,
mf_loss,
emb_loss,
ret["recall"][0],
ret["recall"][-1],
ret["precision"][0],
ret["precision"][-1],
ret["hit_ratio"][0],
ret["hit_ratio"][-1],
ret["ndcg"][0],
ret["ndcg"][-1],
)
)
print(perf_str)
cur_best_pre_0, stopping_step, should_stop = early_stopping(ret['recall'][0], cur_best_pre_0,
stopping_step, expected_order='acc', flag_step=5)
cur_best_pre_0, stopping_step, should_stop = early_stopping(
ret["recall"][0],
cur_best_pre_0,
stopping_step,
expected_order="acc",
flag_step=5,
)
# early stop
if should_stop == True:
break
if ret['recall'][0] == cur_best_pre_0 and args.save_flag == 1:
if ret["recall"][0] == cur_best_pre_0 and args.save_flag == 1:
torch.save(model.state_dict(), args.weights_path + args.model_name)
print('save the weights in path: ', args.weights_path + args.model_name)
print(
"save the weights in path: ",
args.weights_path + args.model_name,
)
recs = np.array(rec_loger)
pres = np.array(pre_loger)
......@@ -92,14 +123,21 @@ def main(args):
best_rec_0 = max(recs[:, 0])
idx = list(recs[:, 0]).index(best_rec_0)
final_perf = "Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]" % \
(idx, time() - t0, '\t'.join(['%.5f' % r for r in recs[idx]]),
'\t'.join(['%.5f' % r for r in pres[idx]]),
'\t'.join(['%.5f' % r for r in hit[idx]]),
'\t'.join(['%.5f' % r for r in ndcgs[idx]]))
final_perf = (
"Best Iter=[%d]@[%.1f]\trecall=[%s], precision=[%s], hit=[%s], ndcg=[%s]"
% (
idx,
time() - t0,
"\t".join(["%.5f" % r for r in recs[idx]]),
"\t".join(["%.5f" % r for r in pres[idx]]),
"\t".join(["%.5f" % r for r in hit[idx]]),
"\t".join(["%.5f" % r for r in ndcgs[idx]]),
)
)
print(final_perf)
if __name__ == '__main__':
if __name__ == "__main__":
if not os.path.exists(args.weights_path):
os.mkdir(args.weights_path)
args.mess_dropout = eval(args.mess_dropout)
......@@ -107,4 +145,3 @@ if __name__ == '__main__':
args.regs = eval(args.regs)
print(args)
main(args)
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
class NGCFLayer(nn.Module):
def __init__(self, in_size, out_size, norm_dict, dropout):
super(NGCFLayer, self).__init__()
self.in_size = in_size
self.out_size = out_size
#weights for different types of messages
self.W1 = nn.Linear(in_size, out_size, bias = True)
self.W2 = nn.Linear(in_size, out_size, bias = True)
# weights for different types of messages
self.W1 = nn.Linear(in_size, out_size, bias=True)
self.W2 = nn.Linear(in_size, out_size, bias=True)
#leaky relu
# leaky relu
self.leaky_relu = nn.LeakyReLU(0.2)
#dropout layer
# dropout layer
self.dropout = nn.Dropout(dropout)
#initialization
# initialization
torch.nn.init.xavier_uniform_(self.W1.weight)
torch.nn.init.constant_(self.W1.bias, 0)
torch.nn.init.xavier_uniform_(self.W2.weight)
torch.nn.init.constant_(self.W2.bias, 0)
#norm
# norm
self.norm_dict = norm_dict
def forward(self, g, feat_dict):
funcs = {} #message and reduce functions dict
#for each type of edges, compute messages and reduce them all
funcs = {} # message and reduce functions dict
# for each type of edges, compute messages and reduce them all
for srctype, etype, dsttype in g.canonical_etypes:
if srctype == dsttype: #for self loops
if srctype == dsttype: # for self loops
messages = self.W1(feat_dict[srctype])
g.nodes[srctype].data[etype] = messages #store in ndata
funcs[(srctype, etype, dsttype)] = (fn.copy_u(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions
g.nodes[srctype].data[etype] = messages # store in ndata
funcs[(srctype, etype, dsttype)] = (
fn.copy_u(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions
else:
src, dst = g.edges(etype=(srctype, etype, dsttype))
norm = self.norm_dict[(srctype, etype, dsttype)]
messages = norm * (self.W1(feat_dict[srctype][src]) + self.W2(feat_dict[srctype][src]*feat_dict[dsttype][dst])) #compute messages
g.edges[(srctype, etype, dsttype)].data[etype] = messages #store in edata
funcs[(srctype, etype, dsttype)] = (fn.copy_e(etype, 'm'), fn.sum('m', 'h')) #define message and reduce functions
g.multi_update_all(funcs, 'sum') #update all, reduce by first type-wisely then across different types
feature_dict={}
messages = norm * (
self.W1(feat_dict[srctype][src])
+ self.W2(feat_dict[srctype][src] * feat_dict[dsttype][dst])
) # compute messages
g.edges[(srctype, etype, dsttype)].data[
etype
] = messages # store in edata
funcs[(srctype, etype, dsttype)] = (
fn.copy_e(etype, "m"),
fn.sum("m", "h"),
) # define message and reduce functions
g.multi_update_all(
funcs, "sum"
) # update all, reduce by first type-wisely then across different types
feature_dict = {}
for ntype in g.ntypes:
h = self.leaky_relu(g.nodes[ntype].data['h']) #leaky relu
h = self.dropout(h) #dropout
h = F.normalize(h,dim=1,p=2) #l2 normalize
h = self.leaky_relu(g.nodes[ntype].data["h"]) # leaky relu
h = self.dropout(h) # dropout
h = F.normalize(h, dim=1, p=2) # l2 normalize
feature_dict[ntype] = h
return feature_dict
class NGCF(nn.Module):
def __init__(self, g, in_size, layer_size, dropout, lmbd=1e-5):
super(NGCF, self).__init__()
......@@ -60,9 +76,15 @@ class NGCF(nn.Module):
self.norm_dict = dict()
for srctype, etype, dsttype in g.canonical_etypes:
src, dst = g.edges(etype=(srctype, etype, dsttype))
dst_degree = g.in_degrees(dst, etype=(srctype, etype, dsttype)).float() #obtain degrees
src_degree = g.out_degrees(src, etype=(srctype, etype, dsttype)).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(1) #compute norm
dst_degree = g.in_degrees(
dst, etype=(srctype, etype, dsttype)
).float() # obtain degrees
src_degree = g.out_degrees(
src, etype=(srctype, etype, dsttype)
).float()
norm = torch.pow(src_degree * dst_degree, -0.5).unsqueeze(
1
) # compute norm
self.norm_dict[(srctype, etype, dsttype)] = norm
self.layers = nn.ModuleList()
......@@ -70,16 +92,26 @@ class NGCF(nn.Module):
NGCFLayer(in_size, layer_size[0], self.norm_dict, dropout[0])
)
self.num_layers = len(layer_size)
for i in range(self.num_layers-1):
for i in range(self.num_layers - 1):
self.layers.append(
NGCFLayer(layer_size[i], layer_size[i+1], self.norm_dict, dropout[i+1])
NGCFLayer(
layer_size[i],
layer_size[i + 1],
self.norm_dict,
dropout[i + 1],
)
)
self.initializer = nn.init.xavier_uniform_
#embeddings for different types of nodes
self.feature_dict = nn.ParameterDict({
ntype: nn.Parameter(self.initializer(torch.empty(g.num_nodes(ntype), in_size))) for ntype in g.ntypes
})
# embeddings for different types of nodes
self.feature_dict = nn.ParameterDict(
{
ntype: nn.Parameter(
self.initializer(torch.empty(g.num_nodes(ntype), in_size))
)
for ntype in g.ntypes
}
)
def create_bpr_loss(self, users, pos_items, neg_items):
pos_scores = (users * pos_items).sum(1)
......@@ -88,7 +120,11 @@ class NGCF(nn.Module):
mf_loss = nn.LogSigmoid()(pos_scores - neg_scores).mean()
mf_loss = -1 * mf_loss
regularizer = (torch.norm(users) ** 2 + torch.norm(pos_items) ** 2 + torch.norm(neg_items) ** 2) / 2
regularizer = (
torch.norm(users) ** 2
+ torch.norm(pos_items) ** 2
+ torch.norm(neg_items) ** 2
) / 2
emb_loss = self.lmbd * regularizer / users.shape[0]
return mf_loss + emb_loss, mf_loss, emb_loss
......@@ -96,9 +132,9 @@ class NGCF(nn.Module):
def rating(self, u_g_embeddings, pos_i_g_embeddings):
return torch.matmul(u_g_embeddings, pos_i_g_embeddings.t())
def forward(self, g,user_key, item_key, users, pos_items, neg_items):
h_dict = {ntype : self.feature_dict[ntype] for ntype in g.ntypes}
#obtain features of each layer and concatenate them all
def forward(self, g, user_key, item_key, users, pos_items, neg_items):
h_dict = {ntype: self.feature_dict[ntype] for ntype in g.ntypes}
# obtain features of each layer and concatenate them all
user_embeds = []
item_embeds = []
user_embeds.append(h_dict[user_key])
......
......@@ -2,22 +2,26 @@
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/batch_test.py>.
# It implements the batch test.
import heapq
import multiprocessing
import utility.metrics as metrics
from utility.parser import parse_args
from utility.load_data import *
import multiprocessing
import heapq
from utility.parser import parse_args
cores = multiprocessing.cpu_count()
args = parse_args()
Ks = eval(args.Ks)
data_generator = Data(path=args.data_path + args.dataset, batch_size=args.batch_size)
data_generator = Data(
path=args.data_path + args.dataset, batch_size=args.batch_size
)
USR_NUM, ITEM_NUM = data_generator.n_users, data_generator.n_items
N_TRAIN, N_TEST = data_generator.n_train, data_generator.n_test
BATCH_SIZE = args.batch_size
def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
item_score = {}
for i in test_items:
......@@ -32,9 +36,10 @@ def ranklist_by_heapq(user_pos_test, test_items, rating, Ks):
r.append(1)
else:
r.append(0)
auc = 0.
auc = 0.0
return r, auc
def get_auc(item_score, user_pos_test):
item_score = sorted(item_score.items(), key=lambda kv: kv[1])
item_score.reverse()
......@@ -50,6 +55,7 @@ def get_auc(item_score, user_pos_test):
auc = metrics.auc(ground_truth=r, prediction=posterior)
return auc
def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
item_score = {}
for i in test_items:
......@@ -67,6 +73,7 @@ def ranklist_by_sorted(user_pos_test, test_items, rating, Ks):
auc = get_auc(item_score, user_pos_test)
return r, auc
def get_performance(user_pos_test, r, auc, Ks):
precision, recall, ndcg, hit_ratio = [], [], [], []
......@@ -76,28 +83,33 @@ def get_performance(user_pos_test, r, auc, Ks):
ndcg.append(metrics.ndcg_at_k(r, K))
hit_ratio.append(metrics.hit_at_k(r, K))
return {'recall': np.array(recall), 'precision': np.array(precision),
'ndcg': np.array(ndcg), 'hit_ratio': np.array(hit_ratio), 'auc': auc}
return {
"recall": np.array(recall),
"precision": np.array(precision),
"ndcg": np.array(ndcg),
"hit_ratio": np.array(hit_ratio),
"auc": auc,
}
def test_one_user(x):
# user u's ratings for user u
rating = x[0]
#uid
# uid
u = x[1]
#user u's items in the training set
# user u's items in the training set
try:
training_items = data_generator.train_items[u]
except Exception:
training_items = []
#user u's items in the test set
# user u's items in the test set
user_pos_test = data_generator.test_set[u]
all_items = set(range(ITEM_NUM))
test_items = list(all_items - set(training_items))
if args.test_flag == 'part':
if args.test_flag == "part":
r, auc = ranklist_by_heapq(user_pos_test, test_items, rating, Ks)
else:
r, auc = ranklist_by_sorted(user_pos_test, test_items, rating, Ks)
......@@ -106,8 +118,13 @@ def test_one_user(x):
def test(model, g, users_to_test, batch_test_flag=False):
result = {'precision': np.zeros(len(Ks)), 'recall': np.zeros(len(Ks)), 'ndcg': np.zeros(len(Ks)),
'hit_ratio': np.zeros(len(Ks)), 'auc': 0.}
result = {
"precision": np.zeros(len(Ks)),
"recall": np.zeros(len(Ks)),
"ndcg": np.zeros(len(Ks)),
"hit_ratio": np.zeros(len(Ks)),
"auc": 0.0,
}
pool = multiprocessing.Pool(cores)
......@@ -124,7 +141,7 @@ def test(model, g, users_to_test, batch_test_flag=False):
start = u_batch_id * u_batch_size
end = (u_batch_id + 1) * u_batch_size
user_batch = test_users[start: end]
user_batch = test_users[start:end]
if batch_test_flag:
# batch-item test
......@@ -138,10 +155,16 @@ def test(model, g, users_to_test, batch_test_flag=False):
item_batch = range(i_start, i_end)
u_g_embeddings, pos_i_g_embeddings, _ = model(g, 'user', 'item',user_batch, item_batch, [])
i_rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()
u_g_embeddings, pos_i_g_embeddings, _ = model(
g, "user", "item", user_batch, item_batch, []
)
i_rate_batch = (
model.rating(u_g_embeddings, pos_i_g_embeddings)
.detach()
.cpu()
)
rate_batch[:, i_start: i_end] = i_rate_batch
rate_batch[:, i_start:i_end] = i_rate_batch
i_count += i_rate_batch.shape[1]
assert i_count == ITEM_NUM
......@@ -149,20 +172,23 @@ def test(model, g, users_to_test, batch_test_flag=False):
else:
# all-item test
item_batch = range(ITEM_NUM)
u_g_embeddings, pos_i_g_embeddings, _ = model(g, 'user', 'item',user_batch, item_batch, [])
rate_batch = model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()
u_g_embeddings, pos_i_g_embeddings, _ = model(
g, "user", "item", user_batch, item_batch, []
)
rate_batch = (
model.rating(u_g_embeddings, pos_i_g_embeddings).detach().cpu()
)
user_batch_rating_uid = zip(rate_batch.numpy(), user_batch)
batch_result = pool.map(test_one_user, user_batch_rating_uid)
count += len(batch_result)
for re in batch_result:
result['precision'] += re['precision']/n_test_users
result['recall'] += re['recall']/n_test_users
result['ndcg'] += re['ndcg']/n_test_users
result['hit_ratio'] += re['hit_ratio']/n_test_users
result['auc'] += re['auc']/n_test_users
result["precision"] += re["precision"] / n_test_users
result["recall"] += re["recall"] / n_test_users
result["ndcg"] += re["ndcg"] / n_test_users
result["hit_ratio"] += re["hit_ratio"] / n_test_users
result["auc"] += re["auc"] / n_test_users
assert count == n_test_users
pool.close()
......
# This file is copied from the NGCF author's implementation
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/helper.py>.
# It implements the helper functions.
'''
"""
Created on Aug 19, 2016
@author: Xiang Wang (xiangwang@u.nus.edu)
'''
"""
__author__ = "xiangwang"
import os
import re
def txt2list(file_src):
orig_file = open(file_src, "r")
lines = orig_file.readlines()
return lines
def ensureDir(dir_path):
d = os.path.dirname(dir_path)
if not os.path.exists(d):
os.makedirs(d)
def uni2str(unicode_str):
return str(unicode_str.encode('ascii', 'ignore')).replace('\n', '').strip()
return str(unicode_str.encode("ascii", "ignore")).replace("\n", "").strip()
def hasNumbers(inputString):
return bool(re.search(r'\d', inputString))
return bool(re.search(r"\d", inputString))
def delMultiChar(inputString, chars):
for ch in chars:
inputString = inputString.replace(ch, '')
inputString = inputString.replace(ch, "")
return inputString
def merge_two_dicts(x, y):
z = x.copy() # start with x's keys and values
z.update(y) # modifies z with y's keys and values & returns None
return z
def early_stopping(log_value, best_value, stopping_step, expected_order='acc', flag_step=100):
def early_stopping(
log_value, best_value, stopping_step, expected_order="acc", flag_step=100
):
# early stopping strategy:
assert expected_order in ['acc', 'dec']
assert expected_order in ["acc", "dec"]
if (expected_order == 'acc' and log_value >= best_value) or (expected_order == 'dec' and log_value <= best_value):
if (expected_order == "acc" and log_value >= best_value) or (
expected_order == "dec" and log_value <= best_value
):
stopping_step = 0
best_value = log_value
else:
stopping_step += 1
if stopping_step >= flag_step:
print("Early stopping is trigger at step: {} log:{}".format(flag_step, log_value))
print(
"Early stopping is trigger at step: {} log:{}".format(
flag_step, log_value
)
)
should_stop = True
else:
should_stop = False
......
# This file is based on the NGCF author's implementation
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/load_data.py>.
# It implements the data processing and graph construction.
import numpy as np
import random as rd
import numpy as np
import dgl
class Data(object):
def __init__(self, path, batch_size):
self.path = path
self.batch_size = batch_size
train_file = path + '/train.txt'
test_file = path + '/test.txt'
train_file = path + "/train.txt"
test_file = path + "/test.txt"
#get number of users and items
# get number of users and items
self.n_users, self.n_items = 0, 0
self.n_train, self.n_test = 0, 0
self.exist_users = []
......@@ -24,7 +27,7 @@ class Data(object):
with open(train_file) as f:
for l in f.readlines():
if len(l) > 0:
l = l.strip('\n').split(' ')
l = l.strip("\n").split(" ")
items = [int(i) for i in l[1:]]
uid = int(l[0])
self.exist_users.append(uid)
......@@ -38,9 +41,9 @@ class Data(object):
with open(test_file) as f:
for l in f.readlines():
if len(l) > 0:
l = l.strip('\n')
l = l.strip("\n")
try:
items = [int(i) for i in l.split(' ')[1:]]
items = [int(i) for i in l.split(" ")[1:]]
except Exception:
continue
self.n_items = max(self.n_items, max(items))
......@@ -50,51 +53,51 @@ class Data(object):
self.print_statistics()
#training positive items corresponding to each user; testing positive items corresponding to each user
# training positive items corresponding to each user; testing positive items corresponding to each user
self.train_items, self.test_set = {}, {}
with open(train_file) as f_train:
with open(test_file) as f_test:
for l in f_train.readlines():
if len(l) == 0:
break
l = l.strip('\n')
items = [int(i) for i in l.split(' ')]
l = l.strip("\n")
items = [int(i) for i in l.split(" ")]
uid, train_items = items[0], items[1:]
self.train_items[uid] = train_items
for l in f_test.readlines():
if len(l) == 0: break
l = l.strip('\n')
if len(l) == 0:
break
l = l.strip("\n")
try:
items = [int(i) for i in l.split(' ')]
items = [int(i) for i in l.split(" ")]
except Exception:
continue
uid, test_items = items[0], items[1:]
self.test_set[uid] = test_items
#construct graph from the train data and add self-loops
user_selfs = [ i for i in range(self.n_users)]
item_selfs = [ i for i in range(self.n_items)]
# construct graph from the train data and add self-loops
user_selfs = [i for i in range(self.n_users)]
item_selfs = [i for i in range(self.n_items)]
data_dict = {
('user', 'user_self', 'user') : (user_selfs, user_selfs),
('item', 'item_self', 'item') : (item_selfs, item_selfs),
('user', 'ui', 'item') : (user_item_src, user_item_dst),
('item', 'iu', 'user') : (user_item_dst, user_item_src)
}
num_dict = {
'user': self.n_users, 'item': self.n_items
("user", "user_self", "user"): (user_selfs, user_selfs),
("item", "item_self", "item"): (item_selfs, item_selfs),
("user", "ui", "item"): (user_item_src, user_item_dst),
("item", "iu", "user"): (user_item_dst, user_item_src),
}
num_dict = {"user": self.n_users, "item": self.n_items}
self.g = dgl.heterograph(data_dict, num_nodes_dict=num_dict)
def sample(self):
if self.batch_size <= self.n_users:
users = rd.sample(self.exist_users, self.batch_size)
else:
users = [rd.choice(self.exist_users) for _ in range(self.batch_size)]
users = [
rd.choice(self.exist_users) for _ in range(self.batch_size)
]
def sample_pos_items_for_u(u, num):
# sample num pos items for u-th user
......@@ -117,12 +120,14 @@ class Data(object):
while True:
if len(neg_items) == num:
break
neg_id = np.random.randint(low=0, high=self.n_items,size=1)[0]
if neg_id not in self.train_items[u] and neg_id not in neg_items:
neg_id = np.random.randint(low=0, high=self.n_items, size=1)[0]
if (
neg_id not in self.train_items[u]
and neg_id not in neg_items
):
neg_items.append(neg_id)
return neg_items
pos_items, neg_items = [], []
for u in users:
pos_items += sample_pos_items_for_u(u, 1)
......@@ -134,10 +139,13 @@ class Data(object):
return self.n_users, self.n_items
def print_statistics(self):
print('n_users=%d, n_items=%d' % (self.n_users, self.n_items))
print('n_interactions=%d' % (self.n_train + self.n_test))
print('n_train=%d, n_test=%d, sparsity=%.5f' % (self.n_train, self.n_test, (self.n_train + self.n_test)/(self.n_users * self.n_items)))
print("n_users=%d, n_items=%d" % (self.n_users, self.n_items))
print("n_interactions=%d" % (self.n_train + self.n_test))
print(
"n_train=%d, n_test=%d, sparsity=%.5f"
% (
self.n_train,
self.n_test,
(self.n_train + self.n_test) / (self.n_users * self.n_items),
)
)
# This file is copied from the NGCF author's implementation
# <https://github.com/xiangwang1223/neural_graph_collaborative_filtering/blob/master/NGCF/utility/metrics.py>.
# It implements the metrics.
'''
"""
Created on Oct 10, 2018
Tensorflow Implementation of Neural Graph Collaborative Filtering (NGCF) model in:
Wang Xiang et al. Neural Graph Collaborative Filtering. In SIGIR 2019.
@author: Xiang Wang (xiangwang@u.nus.edu)
'''
"""
import numpy as np
from sklearn.metrics import roc_auc_score
def recall(rank, ground_truth, N):
return len(set(rank[:N]) & set(ground_truth)) / float(len(set(ground_truth)))
return len(set(rank[:N]) & set(ground_truth)) / float(
len(set(ground_truth))
)
def precision_at_k(r, k):
......@@ -28,7 +31,7 @@ def precision_at_k(r, k):
return np.mean(r)
def average_precision(r,cut):
def average_precision(r, cut):
"""Score is average precision (area under PR curve)
Relevance is binary (nonzero is relevant).
Returns:
......@@ -37,8 +40,8 @@ def average_precision(r,cut):
r = np.asarray(r)
out = [precision_at_k(r, k + 1) for k in range(cut) if r[k]]
if not out:
return 0.
return np.sum(out)/float(min(cut, np.sum(r)))
return 0.0
return np.sum(out) / float(min(cut, np.sum(r)))
def mean_average_precision(rs):
......@@ -64,8 +67,8 @@ def dcg_at_k(r, k, method=1):
elif method == 1:
return np.sum(r / np.log2(np.arange(2, r.size + 2)))
else:
raise ValueError('method must be 0 or 1.')
return 0.
raise ValueError("method must be 0 or 1.")
return 0.0
def ndcg_at_k(r, k, method=1):
......@@ -77,7 +80,7 @@ def ndcg_at_k(r, k, method=1):
"""
dcg_max = dcg_at_k(sorted(r, reverse=True), k, method)
if not dcg_max:
return 0.
return 0.0
return dcg_at_k(r, k, method) / dcg_max
......@@ -89,19 +92,21 @@ def recall_at_k(r, k, all_pos_num):
def hit_at_k(r, k):
r = np.array(r)[:k]
if np.sum(r) > 0:
return 1.
return 1.0
else:
return 0.
return 0.0
def F1(pre, rec):
if pre + rec > 0:
return (2.0 * pre * rec) / (pre + rec)
else:
return 0.
return 0.0
def auc(ground_truth, prediction):
try:
res = roc_auc_score(y_true=ground_truth, y_score=prediction)
except Exception:
res = 0.
res = 0.0
return res
......@@ -3,51 +3,88 @@
import argparse
def parse_args():
parser = argparse.ArgumentParser(description="Run NGCF.")
parser.add_argument('--weights_path', nargs='?', default='model/',
help='Store model path.')
parser.add_argument('--data_path', nargs='?', default='../Data/',
help='Input data path.')
parser.add_argument('--model_name', type=str, default='NGCF.pkl',
help='Saved model name.')
parser.add_argument('--dataset', nargs='?', default='gowalla',
help='Choose a dataset from {gowalla, yelp2018, amazon-book}')
parser.add_argument('--verbose', type=int, default=1,
help='Interval of evaluation.')
parser.add_argument('--epoch', type=int, default=400,
help='Number of epoch.')
parser.add_argument(
"--weights_path", nargs="?", default="model/", help="Store model path."
)
parser.add_argument(
"--data_path", nargs="?", default="../Data/", help="Input data path."
)
parser.add_argument(
"--model_name", type=str, default="NGCF.pkl", help="Saved model name."
)
parser.add_argument('--embed_size', type=int, default=64,
help='Embedding size.')
parser.add_argument('--layer_size', nargs='?', default='[64,64,64]',
help='Output sizes of every layer')
parser.add_argument('--batch_size', type=int, default=1024,
help='Batch size.')
parser.add_argument(
"--dataset",
nargs="?",
default="gowalla",
help="Choose a dataset from {gowalla, yelp2018, amazon-book}",
)
parser.add_argument(
"--verbose", type=int, default=1, help="Interval of evaluation."
)
parser.add_argument(
"--epoch", type=int, default=400, help="Number of epoch."
)
parser.add_argument('--regs', nargs='?', default='[1e-5]',
help='Regularizations.')
parser.add_argument('--lr', type=float, default=0.0001,
help='Learning rate.')
parser.add_argument(
"--embed_size", type=int, default=64, help="Embedding size."
)
parser.add_argument(
"--layer_size",
nargs="?",
default="[64,64,64]",
help="Output sizes of every layer",
)
parser.add_argument(
"--batch_size", type=int, default=1024, help="Batch size."
)
parser.add_argument(
"--regs", nargs="?", default="[1e-5]", help="Regularizations."
)
parser.add_argument(
"--lr", type=float, default=0.0001, help="Learning rate."
)
parser.add_argument('--gpu', type=int, default=0,
help='0 for NAIS_prod, 1 for NAIS_concat')
parser.add_argument(
"--gpu", type=int, default=0, help="0 for NAIS_prod, 1 for NAIS_concat"
)
parser.add_argument('--mess_dropout', nargs='?', default='[0.1,0.1,0.1]',
help='Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.')
parser.add_argument(
"--mess_dropout",
nargs="?",
default="[0.1,0.1,0.1]",
help="Keep probability w.r.t. message dropout (i.e., 1-dropout_ratio) for each deep layer. 1: no dropout.",
)
parser.add_argument('--Ks', nargs='?', default='[20, 40]',
help='Output sizes of every layer')
parser.add_argument(
"--Ks",
nargs="?",
default="[20, 40]",
help="Output sizes of every layer",
)
parser.add_argument('--save_flag', type=int, default=1,
help='0: Disable model saver, 1: Activate model saver')
parser.add_argument(
"--save_flag",
type=int,
default=1,
help="0: Disable model saver, 1: Activate model saver",
)
parser.add_argument('--test_flag', nargs='?', default='part',
help='Specify the test type from {part, full}, indicating whether the reference is done in mini-batch')
parser.add_argument(
"--test_flag",
nargs="?",
default="part",
help="Specify the test type from {part, full}, indicating whether the reference is done in mini-batch",
)
parser.add_argument('--report', type=int, default=0,
help='0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels')
parser.add_argument(
"--report",
type=int,
default=0,
help="0: Disable performance report w.r.t. sparsity levels, 1: Show performance report w.r.t. sparsity levels",
)
return parser.parse_args()
import os
import dgl
import torch
import warnings
import numpy as np
import torch
import torch.nn as nn
from model import PGNN
from sklearn.metrics import roc_auc_score
from utils import get_dataset, preselect_anchor
import warnings
warnings.filterwarnings('ignore')
import dgl
warnings.filterwarnings("ignore")
def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate((data['positive_edges_{}'.format(p)], data['negative_edges_{}'.format(p)]), axis=-1)
nodes_first = torch.index_select(out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device))
nodes_second = torch.index_select(out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device))
def get_loss(p, data, out, loss_func, device, get_auc=True):
edge_mask = np.concatenate(
(
data["positive_edges_{}".format(p)],
data["negative_edges_{}".format(p)],
),
axis=-1,
)
nodes_first = torch.index_select(
out, 0, torch.from_numpy(edge_mask[0, :]).long().to(out.device)
)
nodes_second = torch.index_select(
out, 0, torch.from_numpy(edge_mask[1, :]).long().to(out.device)
)
pred = torch.sum(nodes_first * nodes_second, dim=-1)
label_positive = torch.ones([data['positive_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype)
label_negative = torch.zeros([data['negative_edges_{}'.format(p)].shape[1], ], dtype=pred.dtype)
label_positive = torch.ones(
[
data["positive_edges_{}".format(p)].shape[1],
],
dtype=pred.dtype,
)
label_negative = torch.zeros(
[
data["negative_edges_{}".format(p)].shape[1],
],
dtype=pred.dtype,
)
label = torch.cat((label_positive, label_negative)).to(device)
loss = loss_func(pred, label)
if get_auc:
auc = roc_auc_score(label.flatten().cpu().numpy(), torch.sigmoid(pred).flatten().data.cpu().numpy())
auc = roc_auc_score(
label.flatten().cpu().numpy(),
torch.sigmoid(pred).flatten().data.cpu().numpy(),
)
return loss, auc
else:
return loss
def train_model(data, model, loss_func, optimizer, device, g_data):
model.train()
out = model(g_data)
loss = get_loss('train', data, out, loss_func, device, get_auc=False)
loss = get_loss("train", data, out, loss_func, device, get_auc=False)
optimizer.zero_grad()
loss.backward()
......@@ -42,35 +69,41 @@ def train_model(data, model, loss_func, optimizer, device, g_data):
return g_data
def eval_model(data, g_data, model, loss_func, device):
model.eval()
out = model(g_data)
# train loss and auc
tmp_loss, auc_train = get_loss('train', data, out, loss_func, device)
tmp_loss, auc_train = get_loss("train", data, out, loss_func, device)
loss_train = tmp_loss.cpu().data.numpy()
# val loss and auc
_, auc_val = get_loss('val', data, out, loss_func, device)
_, auc_val = get_loss("val", data, out, loss_func, device)
# test loss and auc
_, auc_test = get_loss('test', data, out, loss_func, device)
_, auc_test = get_loss("test", data, out, loss_func, device)
return loss_train, auc_train, auc_val, auc_test
def main(args):
# The mean and standard deviation of the experiment results
# are stored in the 'results' folder
if not os.path.isdir('results'):
os.mkdir('results')
if not os.path.isdir("results"):
os.mkdir("results")
if torch.cuda.is_available():
device = 'cuda:0'
device = "cuda:0"
else:
device = 'cpu'
device = "cpu"
print('Learning Type: {}'.format(['Transductive', 'Inductive'][args.inductive]),
'Task: {}'.format(args.task))
print(
"Learning Type: {}".format(
["Transductive", "Inductive"][args.inductive]
),
"Task: {}".format(args.task),
)
results = []
......@@ -78,13 +111,20 @@ def main(args):
data = get_dataset(args)
# pre-sample anchor nodes and compute shortest distance values for all epochs
g_list, anchor_eid_list, dist_max_list, edge_weight_list = preselect_anchor(data, args)
(
g_list,
anchor_eid_list,
dist_max_list,
edge_weight_list,
) = preselect_anchor(data, args)
# model
model = PGNN(input_dim=data['feature'].shape[1]).to(device)
model = PGNN(input_dim=data["feature"].shape[1]).to(device)
# loss
optimizer = torch.optim.Adam(model.parameters(), lr=1e-2, weight_decay=5e-4)
optimizer = torch.optim.Adam(
model.parameters(), lr=1e-2, weight_decay=5e-4
)
loss_func = nn.BCEWithLogitsLoss()
best_auc_val = -1
......@@ -93,55 +133,79 @@ def main(args):
for epoch in range(args.epoch_num):
if epoch == 200:
for param_group in optimizer.param_groups:
param_group['lr'] /= 10
param_group["lr"] /= 10
g = dgl.graph(g_list[epoch])
g.ndata['feat'] = torch.FloatTensor(data['feature'])
g.edata['sp_dist'] = torch.FloatTensor(edge_weight_list[epoch])
g.ndata["feat"] = torch.FloatTensor(data["feature"])
g.edata["sp_dist"] = torch.FloatTensor(edge_weight_list[epoch])
g_data = {
'graph': g.to(device),
'anchor_eid': anchor_eid_list[epoch],
'dists_max': dist_max_list[epoch]
"graph": g.to(device),
"anchor_eid": anchor_eid_list[epoch],
"dists_max": dist_max_list[epoch],
}
train_model(data, model, loss_func, optimizer, device, g_data)
loss_train, auc_train, auc_val, auc_test = eval_model(
data, g_data, model, loss_func, device)
data, g_data, model, loss_func, device
)
if auc_val > best_auc_val:
best_auc_val = auc_val
best_auc_test = auc_test
if epoch % args.epoch_log == 0:
print(repeat, epoch, 'Loss {:.4f}'.format(loss_train), 'Train AUC: {:.4f}'.format(auc_train),
'Val AUC: {:.4f}'.format(auc_val), 'Test AUC: {:.4f}'.format(auc_test),
'Best Val AUC: {:.4f}'.format(best_auc_val), 'Best Test AUC: {:.4f}'.format(best_auc_test))
print(
repeat,
epoch,
"Loss {:.4f}".format(loss_train),
"Train AUC: {:.4f}".format(auc_train),
"Val AUC: {:.4f}".format(auc_val),
"Test AUC: {:.4f}".format(auc_test),
"Best Val AUC: {:.4f}".format(best_auc_val),
"Best Test AUC: {:.4f}".format(best_auc_test),
)
results.append(best_auc_test)
results = np.array(results)
results_mean = np.mean(results).round(6)
results_std = np.std(results).round(6)
print('-----------------Final-------------------')
print("-----------------Final-------------------")
print(results_mean, results_std)
with open('results/{}_{}_{}.txt'.format(['Transductive', 'Inductive'][args.inductive], args.task,
args.k_hop_dist), 'w') as f:
f.write('{}, {}\n'.format(results_mean, results_std))
with open(
"results/{}_{}_{}.txt".format(
["Transductive", "Inductive"][args.inductive],
args.task,
args.k_hop_dist,
),
"w",
) as f:
f.write("{}, {}\n".format(results_mean, results_std))
if __name__ == '__main__':
if __name__ == "__main__":
from argparse import ArgumentParser
parser = ArgumentParser()
parser.add_argument('--task', type=str, default='link', choices=['link', 'link_pair'])
parser.add_argument('--inductive', action='store_true',
help='Inductive learning or transductive learning')
parser.add_argument('--k_hop_dist', default=-1, type=int,
help='K-hop shortest path distance, -1 means exact shortest path.')
parser.add_argument('--epoch_num', type=int, default=2000)
parser.add_argument('--repeat_num', type=int, default=10)
parser.add_argument('--epoch_log', type=int, default=100)
parser.add_argument(
"--task", type=str, default="link", choices=["link", "link_pair"]
)
parser.add_argument(
"--inductive",
action="store_true",
help="Inductive learning or transductive learning",
)
parser.add_argument(
"--k_hop_dist",
default=-1,
type=int,
help="K-hop shortest path distance, -1 means exact shortest path.",
)
parser.add_argument("--epoch_num", type=int, default=2000)
parser.add_argument("--repeat_num", type=int, default=10)
parser.add_argument("--epoch_log", type=int, default=100)
args = parser.parse_args()
main(args)
import torch
import torch.nn as nn
import dgl.function as fn
import torch.nn.functional as F
import dgl.function as fn
class PGNN_layer(nn.Module):
def __init__(self, input_dim, output_dim):
super(PGNN_layer, self).__init__()
......@@ -17,23 +19,31 @@ class PGNN_layer(nn.Module):
with graph.local_scope():
u_feat = self.linear_hidden_u(feature)
v_feat = self.linear_hidden_v(feature)
graph.srcdata.update({'u_feat': u_feat})
graph.dstdata.update({'v_feat': v_feat})
graph.srcdata.update({"u_feat": u_feat})
graph.dstdata.update({"v_feat": v_feat})
graph.apply_edges(fn.u_mul_e('u_feat', 'sp_dist', 'u_message'))
graph.apply_edges(fn.v_add_e('v_feat', 'u_message', 'message'))
graph.apply_edges(fn.u_mul_e("u_feat", "sp_dist", "u_message"))
graph.apply_edges(fn.v_add_e("v_feat", "u_message", "message"))
messages = torch.index_select(graph.edata['message'], 0,
torch.LongTensor(anchor_eid).to(feature.device))
messages = messages.reshape(dists_max.shape[0], dists_max.shape[1], messages.shape[-1])
messages = torch.index_select(
graph.edata["message"],
0,
torch.LongTensor(anchor_eid).to(feature.device),
)
messages = messages.reshape(
dists_max.shape[0], dists_max.shape[1], messages.shape[-1]
)
messages = self.act(messages) # n*m*d
out_position = self.linear_out_position(messages).squeeze(-1) # n*m_out
out_position = self.linear_out_position(messages).squeeze(
-1
) # n*m_out
out_structure = torch.mean(messages, dim=1) # n*d
return out_position, out_structure
class PGNN(nn.Module):
def __init__(self, input_dim, feature_dim=32, dropout=0.5):
super(PGNN, self).__init__()
......@@ -44,12 +54,16 @@ class PGNN(nn.Module):
self.conv_out = PGNN_layer(feature_dim, feature_dim)
def forward(self, data):
x = data['graph'].ndata['feat']
graph = data['graph']
x = data["graph"].ndata["feat"]
graph = data["graph"]
x = self.linear_pre(x)
x_position, x = self.conv_first(graph, x, data['anchor_eid'], data['dists_max'])
x_position, x = self.conv_first(
graph, x, data["anchor_eid"], data["dists_max"]
)
x = self.dropout(x)
x_position, x = self.conv_out(graph, x, data['anchor_eid'], data['dists_max'])
x_position, x = self.conv_out(
graph, x, data["anchor_eid"], data["dists_max"]
)
x_position = F.normalize(x_position, p=2, dim=-1)
return x_position
import torch
import multiprocessing as mp
import random
import numpy as np
from multiprocessing import get_context
import networkx as nx
import numpy as np
import torch
from tqdm.auto import tqdm
import multiprocessing as mp
from multiprocessing import get_context
def get_communities(remove_feature):
community_size = 20
......@@ -45,14 +47,15 @@ def get_communities(remove_feature):
feature = np.identity(n)[:, rand_order]
data = {
'edge_index': edge_index,
'feature': feature,
'positive_edges': np.stack(np.nonzero(label)),
'num_nodes': feature.shape[0]
"edge_index": edge_index,
"feature": feature,
"positive_edges": np.stack(np.nonzero(label)),
"num_nodes": feature.shape[0],
}
return data
def to_single_directed(edges):
edges_new = np.zeros((2, edges.shape[1] // 2), dtype=int)
j = 0
......@@ -63,6 +66,7 @@ def to_single_directed(edges):
return edges_new
# each node at least remain in the new graph
def split_edges(p, edges, data, non_train_ratio=0.2):
e = edges.shape[1]
......@@ -70,15 +74,19 @@ def split_edges(p, edges, data, non_train_ratio=0.2):
split1 = int((1 - non_train_ratio) * e)
split2 = int((1 - non_train_ratio / 2) * e)
data.update({
'{}_edges_train'.format(p): edges[:, :split1], # 80%
'{}_edges_val'.format(p): edges[:, split1:split2], # 10%
'{}_edges_test'.format(p): edges[:, split2:] # 10%
})
data.update(
{
"{}_edges_train".format(p): edges[:, :split1], # 80%
"{}_edges_val".format(p): edges[:, split1:split2], # 10%
"{}_edges_test".format(p): edges[:, split2:], # 10%
}
)
def to_bidirected(edges):
return np.concatenate((edges, edges[::-1, :]), axis=-1)
def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
positive_edge_set = []
positive_edges = to_bidirected(positive_edges)
......@@ -86,54 +94,76 @@ def get_negative_edges(positive_edges, num_nodes, num_negative_edges):
positive_edge_set.append(tuple(positive_edges[:, i]))
positive_edge_set = set(positive_edge_set)
negative_edges = np.zeros((2, num_negative_edges), dtype=positive_edges.dtype)
negative_edges = np.zeros(
(2, num_negative_edges), dtype=positive_edges.dtype
)
for i in range(num_negative_edges):
while True:
mask_temp = tuple(np.random.choice(num_nodes, size=(2,), replace=False))
mask_temp = tuple(
np.random.choice(num_nodes, size=(2,), replace=False)
)
if mask_temp not in positive_edge_set:
negative_edges[:, i] = mask_temp
break
return negative_edges
def get_pos_neg_edges(data, infer_link_positive=True):
if infer_link_positive:
data['positive_edges'] = to_single_directed(data['edge_index'].numpy())
split_edges('positive', data['positive_edges'], data)
data["positive_edges"] = to_single_directed(data["edge_index"].numpy())
split_edges("positive", data["positive_edges"], data)
# resample edge mask link negative
negative_edges = get_negative_edges(data['positive_edges'], data['num_nodes'],
num_negative_edges=data['positive_edges'].shape[1])
split_edges('negative', negative_edges, data)
negative_edges = get_negative_edges(
data["positive_edges"],
data["num_nodes"],
num_negative_edges=data["positive_edges"].shape[1],
)
split_edges("negative", negative_edges, data)
return data
def shortest_path(graph, node_range, cutoff):
dists_dict = {}
for node in tqdm(node_range, leave=False):
dists_dict[node] = nx.single_source_shortest_path_length(graph, node, cutoff)
dists_dict[node] = nx.single_source_shortest_path_length(
graph, node, cutoff
)
return dists_dict
def merge_dicts(dicts):
result = {}
for dictionary in dicts:
result.update(dictionary)
return result
def all_pairs_shortest_path(graph, cutoff=None, num_workers=4):
nodes = list(graph.nodes)
random.shuffle(nodes)
pool = mp.Pool(processes=num_workers)
interval_size = len(nodes) / num_workers
results = [pool.apply_async(shortest_path, args=(
graph, nodes[int(interval_size * i): int(interval_size * (i + 1))], cutoff))
for i in range(num_workers)]
results = [
pool.apply_async(
shortest_path,
args=(
graph,
nodes[int(interval_size * i) : int(interval_size * (i + 1))],
cutoff,
),
)
for i in range(num_workers)
]
output = [p.get() for p in results]
dists_dict = merge_dicts(output)
pool.close()
pool.join()
return dists_dict
def precompute_dist_data(edge_index, num_nodes, approximate=0):
"""
Here dist is 1/real_dist, higher actually means closer, 0 means disconnected
......@@ -145,7 +175,9 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0):
n = num_nodes
dists_array = np.zeros((n, n))
dists_dict = all_pairs_shortest_path(graph, cutoff=approximate if approximate > 0 else None)
dists_dict = all_pairs_shortest_path(
graph, cutoff=approximate if approximate > 0 else None
)
node_list = graph.nodes()
for node_i in node_list:
shortest_dist = dists_dict[node_i]
......@@ -155,24 +187,36 @@ def precompute_dist_data(edge_index, num_nodes, approximate=0):
dists_array[node_i, node_j] = 1 / (dist + 1)
return dists_array
def get_dataset(args):
# Generate graph data
data_info = get_communities(args.inductive)
# Get positive and negative edges
data = get_pos_neg_edges(data_info, infer_link_positive=True if args.task == 'link' else False)
data = get_pos_neg_edges(
data_info, infer_link_positive=True if args.task == "link" else False
)
# Pre-compute shortest path length
if args.task == 'link':
dists_removed = precompute_dist_data(data['positive_edges_train'], data['num_nodes'],
approximate=args.k_hop_dist)
data['dists'] = torch.from_numpy(dists_removed).float()
data['edge_index'] = torch.from_numpy(to_bidirected(data['positive_edges_train'])).long()
if args.task == "link":
dists_removed = precompute_dist_data(
data["positive_edges_train"],
data["num_nodes"],
approximate=args.k_hop_dist,
)
data["dists"] = torch.from_numpy(dists_removed).float()
data["edge_index"] = torch.from_numpy(
to_bidirected(data["positive_edges_train"])
).long()
else:
dists = precompute_dist_data(data['edge_index'].numpy(), data['num_nodes'],
approximate=args.k_hop_dist)
data['dists'] = torch.from_numpy(dists).float()
dists = precompute_dist_data(
data["edge_index"].numpy(),
data["num_nodes"],
approximate=args.k_hop_dist,
)
data["dists"] = torch.from_numpy(dists).float()
return data
def get_anchors(n):
"""Get a list of NumPy arrays, each of them is an anchor node set"""
m = int(np.log2(n))
......@@ -180,9 +224,12 @@ def get_anchors(n):
for i in range(m):
anchor_size = int(n / np.exp2(i + 1))
for _ in range(m):
anchor_set_id.append(np.random.choice(n, size=anchor_size, replace=False))
anchor_set_id.append(
np.random.choice(n, size=anchor_size, replace=False)
)
return anchor_set_id
def get_dist_max(anchor_set_id, dist):
# N x K, N is number of nodes, K is the number of anchor sets
dist_max = torch.zeros((dist.shape[0], len(anchor_set_id)))
......@@ -198,6 +245,7 @@ def get_dist_max(anchor_set_id, dist):
dist_argmax[:, i] = torch.index_select(temp_id, 0, dist_argmax_temp)
return dist_max, dist_argmax
def get_a_graph(dists_max, dists_argmax):
src = []
dst = []
......@@ -207,7 +255,9 @@ def get_a_graph(dists_max, dists_argmax):
dists_max = dists_max.numpy()
for i in range(dists_max.shape[0]):
# Get unique closest anchor nodes for node i across all anchor sets
tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(dists_argmax[i, :], True)
tmp_dists_argmax, tmp_dists_argmax_idx = np.unique(
dists_argmax[i, :], True
)
src.extend([i] * tmp_dists_argmax.shape[0])
real_src.extend([i] * dists_argmax[i, :].shape[0])
real_dst.extend(list(dists_argmax[i, :].numpy()))
......@@ -218,13 +268,14 @@ def get_a_graph(dists_max, dists_argmax):
g = (dst, src)
return g, anchor_eid, edge_weight
def get_graphs(data, anchor_sets):
graphs = []
anchor_eids = []
dists_max_list = []
edge_weights = []
for anchor_set in tqdm(anchor_sets, leave=False):
dists_max, dists_argmax = get_dist_max(anchor_set, data['dists'])
dists_max, dists_argmax = get_dist_max(anchor_set, data["dists"])
g, anchor_eid, edge_weight = get_a_graph(dists_max, dists_argmax)
graphs.append(g)
anchor_eids.append(anchor_eid)
......@@ -233,6 +284,7 @@ def get_graphs(data, anchor_sets):
return graphs, anchor_eids, dists_max_list, edge_weights
def merge_result(outputs):
graphs = []
anchor_eids = []
......@@ -247,14 +299,26 @@ def merge_result(outputs):
return graphs, anchor_eids, dists_max_list, edge_weights
def preselect_anchor(data, args, num_workers=4):
pool = get_context("spawn").Pool(processes=num_workers)
# Pre-compute anchor sets, a collection of anchor sets per epoch
anchor_set_ids = [get_anchors(data['num_nodes']) for _ in range(args.epoch_num)]
anchor_set_ids = [
get_anchors(data["num_nodes"]) for _ in range(args.epoch_num)
]
interval_size = len(anchor_set_ids) / num_workers
results = [pool.apply_async(get_graphs, args=(
data, anchor_set_ids[int(interval_size * i):int(interval_size * (i + 1))],))
for i in range(num_workers)]
results = [
pool.apply_async(
get_graphs,
args=(
data,
anchor_set_ids[
int(interval_size * i) : int(interval_size * (i + 1))
],
),
)
for i in range(num_workers)
]
output = [p.get() for p in results]
graphs, anchor_eids, dists_max_list, edge_weights = merge_result(output)
......
''' Credit: https://github.com/fanyun-sun/InfoGraph '''
""" Credit: https://github.com/fanyun-sun/InfoGraph """
import math
import torch as th
import torch.nn.functional as F
import math
def get_positive_expectation(p_samples, average=True):
"""Computes the positive part of a JS Divergence.
......@@ -13,8 +14,8 @@ def get_positive_expectation(p_samples, average=True):
Returns:
th.Tensor
"""
log_2 = math.log(2.)
Ep = log_2 - F.softplus(- p_samples)
log_2 = math.log(2.0)
Ep = log_2 - F.softplus(-p_samples)
if average:
return Ep.mean()
......@@ -30,7 +31,7 @@ def get_negative_expectation(q_samples, average=True):
Returns:
th.Tensor
"""
log_2 = math.log(2.)
log_2 = math.log(2.0)
Eq = F.softplus(-q_samples) + q_samples - log_2
if average:
......@@ -51,8 +52,8 @@ def local_global_loss_(l_enc, g_enc, graph_id):
for nodeidx, graphidx in enumerate(graph_id):
pos_mask[nodeidx][graphidx] = 1.
neg_mask[nodeidx][graphidx] = 0.
pos_mask[nodeidx][graphidx] = 1.0
neg_mask[nodeidx][graphidx] = 0.0
res = th.mm(l_enc, g_enc.t())
......
......@@ -2,39 +2,45 @@
import argparse
import copy
import numpy as np
import torch
import torch.optim as optim
import torch.nn as nn
import numpy as np
from dgl.data import CoraGraphDataset, CiteseerGraphDataset
from tqdm import trange
from sklearn.model_selection import train_test_split
import torch.optim as optim
from model import JKNet
from sklearn.model_selection import train_test_split
from tqdm import trange
from dgl.data import CiteseerGraphDataset, CoraGraphDataset
def main(args):
# Step 1: Prepare graph data and retrieve train/validation/test index ============================= #
# Load from DGL dataset
if args.dataset == 'Cora':
if args.dataset == "Cora":
dataset = CoraGraphDataset()
elif args.dataset == 'Citeseer':
elif args.dataset == "Citeseer":
dataset = CiteseerGraphDataset()
else:
raise ValueError('Dataset {} is invalid.'.format(args.dataset))
raise ValueError("Dataset {} is invalid.".format(args.dataset))
graph = dataset[0]
# check cuda
device = f'cuda:{args.gpu}' if args.gpu >= 0 and torch.cuda.is_available() else 'cpu'
device = (
f"cuda:{args.gpu}"
if args.gpu >= 0 and torch.cuda.is_available()
else "cpu"
)
# retrieve the number of classes
n_classes = dataset.num_classes
# retrieve labels of ground truth
labels = graph.ndata.pop('label').to(device).long()
labels = graph.ndata.pop("label").to(device).long()
# Extract node features
feats = graph.ndata.pop('feat').to(device)
feats = graph.ndata.pop("feat").to(device)
n_features = feats.shape[-1]
# create masks for train / validation / test
......@@ -47,12 +53,14 @@ def main(args):
graph = graph.to(device)
# Step 2: Create model =================================================================== #
model = JKNet(in_dim=n_features,
model = JKNet(
in_dim=n_features,
hid_dim=args.hid_dim,
out_dim=n_classes,
num_layers=args.num_layers,
mode=args.mode,
dropout=args.dropout).to(device)
dropout=args.dropout,
).to(device)
best_model = copy.deepcopy(model)
......@@ -62,7 +70,7 @@ def main(args):
# Step 4: training epochs =============================================================== #
acc = 0
epochs = trange(args.epochs, desc='Accuracy & Loss')
epochs = trange(args.epochs, desc="Accuracy & Loss")
for _ in epochs:
# Training using a full graph
......@@ -72,7 +80,9 @@ def main(args):
# compute loss
train_loss = loss_fn(logits[train_idx], labels[train_idx])
train_acc = torch.sum(logits[train_idx].argmax(dim=1) == labels[train_idx]).item() / len(train_idx)
train_acc = torch.sum(
logits[train_idx].argmax(dim=1) == labels[train_idx]
).item() / len(train_idx)
# backward
opt.zero_grad()
......@@ -84,11 +94,16 @@ def main(args):
with torch.no_grad():
valid_loss = loss_fn(logits[val_idx], labels[val_idx])
valid_acc = torch.sum(logits[val_idx].argmax(dim=1) == labels[val_idx]).item() / len(val_idx)
valid_acc = torch.sum(
logits[val_idx].argmax(dim=1) == labels[val_idx]
).item() / len(val_idx)
# Print out performance
epochs.set_description('Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}'.format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()))
epochs.set_description(
"Train Acc {:.4f} | Train Loss {:.4f} | Val Acc {:.4f} | Val loss {:.4f}".format(
train_acc, train_loss.item(), valid_acc, valid_loss.item()
)
)
if valid_acc > acc:
acc = valid_acc
......@@ -96,31 +111,57 @@ def main(args):
best_model.eval()
logits = best_model(graph, feats)
test_acc = torch.sum(logits[test_idx].argmax(dim=1) == labels[test_idx]).item() / len(test_idx)
test_acc = torch.sum(
logits[test_idx].argmax(dim=1) == labels[test_idx]
).item() / len(test_idx)
print("Test Acc {:.4f}".format(test_acc))
return test_acc
if __name__ == "__main__":
"""
JKNet Hyperparameters
"""
parser = argparse.ArgumentParser(description='JKNet')
parser = argparse.ArgumentParser(description="JKNet")
# data source params
parser.add_argument('--dataset', type=str, default='Cora', help='Name of dataset.')
parser.add_argument(
"--dataset", type=str, default="Cora", help="Name of dataset."
)
# cuda params
parser.add_argument('--gpu', type=int, default=-1, help='GPU index. Default: -1, using CPU.')
parser.add_argument(
"--gpu", type=int, default=-1, help="GPU index. Default: -1, using CPU."
)
# training params
parser.add_argument('--run', type=int, default=10, help='Running times.')
parser.add_argument('--epochs', type=int, default=500, help='Training epochs.')
parser.add_argument('--lr', type=float, default=0.005, help='Learning rate.')
parser.add_argument('--lamb', type=float, default=0.0005, help='L2 reg.')
parser.add_argument("--run", type=int, default=10, help="Running times.")
parser.add_argument(
"--epochs", type=int, default=500, help="Training epochs."
)
parser.add_argument(
"--lr", type=float, default=0.005, help="Learning rate."
)
parser.add_argument("--lamb", type=float, default=0.0005, help="L2 reg.")
# model params
parser.add_argument("--hid-dim", type=int, default=32, help='Hidden layer dimensionalities.')
parser.add_argument("--num-layers", type=int, default=5, help='Number of GCN layers.')
parser.add_argument("--mode", type=str, default='cat', help="Type of aggregation.", choices=['cat', 'max', 'lstm'])
parser.add_argument("--dropout", type=float, default=0.5, help='Dropout applied at all layers.')
parser.add_argument(
"--hid-dim", type=int, default=32, help="Hidden layer dimensionalities."
)
parser.add_argument(
"--num-layers", type=int, default=5, help="Number of GCN layers."
)
parser.add_argument(
"--mode",
type=str,
default="cat",
help="Type of aggregation.",
choices=["cat", "max", "lstm"],
)
parser.add_argument(
"--dropout",
type=float,
default=0.5,
help="Dropout applied at all layers.",
)
args = parser.parse_args()
print(args)
......@@ -132,6 +173,6 @@ if __name__ == "__main__":
mean = np.around(np.mean(acc_lists, axis=0), decimals=3)
std = np.around(np.std(acc_lists, axis=0), decimals=3)
print('total acc: ', acc_lists)
print('mean', mean)
print('std', std)
print("total acc: ", acc_lists)
print("mean", mean)
print("std", std)
import torch
import torch.nn as nn
import torch.nn.functional as F
import dgl.function as fn
from dgl.nn import GraphConv, JumpingKnowledge
class JKNet(nn.Module):
def __init__(self,
in_dim,
hid_dim,
out_dim,
num_layers=1,
mode='cat',
dropout=0.):
def __init__(
self, in_dim, hid_dim, out_dim, num_layers=1, mode="cat", dropout=0.0
):
super(JKNet, self).__init__()
self.mode = mode
......@@ -21,12 +19,12 @@ class JKNet(nn.Module):
for _ in range(num_layers):
self.layers.append(GraphConv(hid_dim, hid_dim, activation=F.relu))
if self.mode == 'lstm':
if self.mode == "lstm":
self.jump = JumpingKnowledge(mode, hid_dim, num_layers)
else:
self.jump = JumpingKnowledge(mode)
if self.mode == 'cat':
if self.mode == "cat":
hid_dim = hid_dim * (num_layers + 1)
self.output = nn.Linear(hid_dim, out_dim)
......@@ -44,7 +42,7 @@ class JKNet(nn.Module):
feats = self.dropout(layer(g, feats))
feat_lst.append(feats)
g.ndata['h'] = self.jump(feat_lst)
g.update_all(fn.copy_u('h', 'm'), fn.sum('m', 'h'))
g.ndata["h"] = self.jump(feat_lst)
g.update_all(fn.copy_u("h", "m"), fn.sum("m", "h"))
return self.output(g.ndata['h'])
return self.output(g.ndata["h"])
from .mol_tree import Vocab
from .chemutils import decode_stereo
from .datautils import JTNNCollator, JTNNDataset
from .jtnn_vae import DGLJTNNVAE
from .mol_tree import Vocab
from .mpn import DGLMPN
from .nnutils import cuda
from .datautils import JTNNDataset, JTNNCollator
from .chemutils import decode_stereo
from collections import defaultdict
import rdkit.Chem as Chem
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
from scipy.sparse import csr_matrix
from scipy.sparse.csgraph import minimum_spanning_tree
from collections import defaultdict
from rdkit.Chem.EnumerateStereoisomers import EnumerateStereoisomers
MST_MAX_WEIGHT = 100
MAX_NCAND = 2000
def set_atommap(mol, num=0):
for atom in mol.GetAtoms():
atom.SetAtomMapNum(num)
def get_mol(smiles):
mol = Chem.MolFromSmiles(smiles)
if mol is None:
......@@ -18,26 +21,39 @@ def get_mol(smiles):
Chem.Kekulize(mol)
return mol
def get_smiles(mol):
return Chem.MolToSmiles(mol, kekuleSmiles=True)
def decode_stereo(smiles2D):
mol = Chem.MolFromSmiles(smiles2D)
dec_isomers = list(EnumerateStereoisomers(mol))
dec_isomers = [Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True)) for mol in dec_isomers]
smiles3D = [Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers]
chiralN = [atom.GetIdx() for atom in dec_isomers[0].GetAtoms()
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"]
dec_isomers = [
Chem.MolFromSmiles(Chem.MolToSmiles(mol, isomericSmiles=True))
for mol in dec_isomers
]
smiles3D = [
Chem.MolToSmiles(mol, isomericSmiles=True) for mol in dec_isomers
]
chiralN = [
atom.GetIdx()
for atom in dec_isomers[0].GetAtoms()
if int(atom.GetChiralTag()) > 0 and atom.GetSymbol() == "N"
]
if len(chiralN) > 0:
for mol in dec_isomers:
for idx in chiralN:
mol.GetAtomWithIdx(idx).SetChiralTag(Chem.rdchem.ChiralType.CHI_UNSPECIFIED)
mol.GetAtomWithIdx(idx).SetChiralTag(
Chem.rdchem.ChiralType.CHI_UNSPECIFIED
)
smiles3D.append(Chem.MolToSmiles(mol, isomericSmiles=True))
return smiles3D
def sanitize(mol):
try:
smiles = get_smiles(mol)
......@@ -46,14 +62,16 @@ def sanitize(mol):
return None
return mol
def copy_atom(atom):
new_atom = Chem.Atom(atom.GetSymbol())
new_atom.SetFormalCharge(atom.GetFormalCharge())
new_atom.SetAtomMapNum(atom.GetAtomMapNum())
return new_atom
def copy_edit_mol(mol):
new_mol = Chem.RWMol(Chem.MolFromSmiles(''))
new_mol = Chem.RWMol(Chem.MolFromSmiles(""))
for atom in mol.GetAtoms():
new_atom = copy_atom(atom)
new_mol.AddAtom(new_atom)
......@@ -64,13 +82,15 @@ def copy_edit_mol(mol):
new_mol.AddBond(a1, a2, bt)
return new_mol
def get_clique_mol(mol, atoms):
smiles = Chem.MolFragmentToSmiles(mol, atoms, kekuleSmiles=True)
new_mol = Chem.MolFromSmiles(smiles, sanitize=False)
new_mol = copy_edit_mol(new_mol).GetMol()
new_mol = sanitize(new_mol) #We assume this is not None
new_mol = sanitize(new_mol) # We assume this is not None
return new_mol
def tree_decomp(mol):
n_atoms = mol.GetNumAtoms()
if n_atoms == 1:
......@@ -81,7 +101,7 @@ def tree_decomp(mol):
a1 = bond.GetBeginAtom().GetIdx()
a2 = bond.GetEndAtom().GetIdx()
if not bond.IsInRing():
cliques.append([a1,a2])
cliques.append([a1, a2])
ssr = [list(x) for x in Chem.GetSymmSSSR(mol)]
cliques.extend(ssr)
......@@ -91,12 +111,14 @@ def tree_decomp(mol):
for atom in cliques[i]:
nei_list[atom].append(i)
#Merge Rings with intersection > 2 atoms
# Merge Rings with intersection > 2 atoms
for i in range(len(cliques)):
if len(cliques[i]) <= 2: continue
if len(cliques[i]) <= 2:
continue
for atom in cliques[i]:
for j in nei_list[atom]:
if i >= j or len(cliques[j]) <= 2: continue
if i >= j or len(cliques[j]) <= 2:
continue
inter = set(cliques[i]) & set(cliques[j])
if len(inter) > 2:
cliques[i].extend(cliques[j])
......@@ -109,7 +131,7 @@ def tree_decomp(mol):
for atom in cliques[i]:
nei_list[atom].append(i)
#Build edges and add singleton cliques
# Build edges and add singleton cliques
edges = defaultdict(int)
for atom in range(n_atoms):
if len(nei_list[atom]) <= 1:
......@@ -122,37 +144,44 @@ def tree_decomp(mol):
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1,c2)] = 1
elif len(rings) > 2: #Multiple (n>2) complex rings
edges[(c1, c2)] = 1
elif len(rings) > 2: # Multiple (n>2) complex rings
cliques.append([atom])
c2 = len(cliques) - 1
for c1 in cnei:
edges[(c1,c2)] = MST_MAX_WEIGHT - 1
edges[(c1, c2)] = MST_MAX_WEIGHT - 1
else:
for i in range(len(cnei)):
for j in range(i + 1, len(cnei)):
c1,c2 = cnei[i],cnei[j]
c1, c2 = cnei[i], cnei[j]
inter = set(cliques[c1]) & set(cliques[c2])
if edges[(c1,c2)] < len(inter):
edges[(c1,c2)] = len(inter) #cnei[i] < cnei[j] by construction
if edges[(c1, c2)] < len(inter):
edges[(c1, c2)] = len(
inter
) # cnei[i] < cnei[j] by construction
edges = [u + (MST_MAX_WEIGHT-v,) for u,v in edges.items()]
edges = [u + (MST_MAX_WEIGHT - v,) for u, v in edges.items()]
if len(edges) == 0:
return cliques, edges
#Compute Maximum Spanning Tree
row,col,data = list(zip(*edges))
# Compute Maximum Spanning Tree
row, col, data = list(zip(*edges))
n_clique = len(cliques)
clique_graph = csr_matrix( (data,(row,col)), shape=(n_clique,n_clique) )
clique_graph = csr_matrix((data, (row, col)), shape=(n_clique, n_clique))
junc_tree = minimum_spanning_tree(clique_graph)
row,col = junc_tree.nonzero()
edges = [(row[i],col[i]) for i in range(len(row))]
row, col = junc_tree.nonzero()
edges = [(row[i], col[i]) for i in range(len(row))]
return (cliques, edges)
def atom_equal(a1, a2):
return a1.GetSymbol() == a2.GetSymbol() and a1.GetFormalCharge() == a2.GetFormalCharge()
return (
a1.GetSymbol() == a2.GetSymbol()
and a1.GetFormalCharge() == a2.GetFormalCharge()
)
#Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
# Bond type not considered because all aromatic (so SINGLE matches DOUBLE)
def ring_bond_equal(b1, b2, reverse=False):
b1 = (b1.GetBeginAtom(), b1.GetEndAtom())
if reverse:
......@@ -161,10 +190,11 @@ def ring_bond_equal(b1, b2, reverse=False):
b2 = (b2.GetBeginAtom(), b2.GetEndAtom())
return atom_equal(b1[0], b2[0]) and atom_equal(b1[1], b2[1])
def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
prev_nids = [node['nid'] for node in prev_nodes]
prev_nids = [node["nid"] for node in prev_nodes]
for nei_node in prev_nodes + neighbors:
nei_id, nei_mol = nei_node['nid'], nei_node['mol']
nei_id, nei_mol = nei_node["nid"], nei_node["mol"]
amap = nei_amap[nei_id]
for atom in nei_mol.GetAtoms():
if atom.GetIdx() not in amap:
......@@ -181,82 +211,116 @@ def attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap):
a2 = amap[bond.GetEndAtom().GetIdx()]
if ctr_mol.GetBondBetweenAtoms(a1, a2) is None:
ctr_mol.AddBond(a1, a2, bond.GetBondType())
elif nei_id in prev_nids: #father node overrides
elif nei_id in prev_nids: # father node overrides
ctr_mol.RemoveBond(a1, a2)
ctr_mol.AddBond(a1, a2, bond.GetBondType())
return ctr_mol
def local_attach_nx(ctr_mol, neighbors, prev_nodes, amap_list):
ctr_mol = copy_edit_mol(ctr_mol)
nei_amap = {nei['nid']: {} for nei in prev_nodes + neighbors}
nei_amap = {nei["nid"]: {} for nei in prev_nodes + neighbors}
for nei_id,ctr_atom,nei_atom in amap_list:
for nei_id, ctr_atom, nei_atom in amap_list:
nei_amap[nei_id][nei_atom] = ctr_atom
ctr_mol = attach_mols_nx(ctr_mol, neighbors, prev_nodes, nei_amap)
return ctr_mol.GetMol()
#This version records idx mapping between ctr_mol and nei_mol
# This version records idx mapping between ctr_mol and nei_mol
def enum_attach_nx(ctr_mol, nei_node, amap, singletons):
nei_mol,nei_idx = nei_node['mol'], nei_node['nid']
nei_mol, nei_idx = nei_node["mol"], nei_node["nid"]
att_confs = []
black_list = [atom_idx for nei_id,atom_idx,_ in amap if nei_id in singletons]
ctr_atoms = [atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list]
black_list = [
atom_idx for nei_id, atom_idx, _ in amap if nei_id in singletons
]
ctr_atoms = [
atom for atom in ctr_mol.GetAtoms() if atom.GetIdx() not in black_list
]
ctr_bonds = [bond for bond in ctr_mol.GetBonds()]
if nei_mol.GetNumBonds() == 0: #neighbor singleton
if nei_mol.GetNumBonds() == 0: # neighbor singleton
nei_atom = nei_mol.GetAtomWithIdx(0)
used_list = [atom_idx for _,atom_idx,_ in amap]
used_list = [atom_idx for _, atom_idx, _ in amap]
for atom in ctr_atoms:
if atom_equal(atom, nei_atom) and atom.GetIdx() not in used_list:
new_amap = amap + [(nei_idx, atom.GetIdx(), 0)]
att_confs.append( new_amap )
att_confs.append(new_amap)
elif nei_mol.GetNumBonds() == 1: #neighbor is a bond
elif nei_mol.GetNumBonds() == 1: # neighbor is a bond
bond = nei_mol.GetBondWithIdx(0)
bond_val = int(bond.GetBondTypeAsDouble())
b1,b2 = bond.GetBeginAtom(), bond.GetEndAtom()
b1, b2 = bond.GetBeginAtom(), bond.GetEndAtom()
for atom in ctr_atoms:
#Optimize if atom is carbon (other atoms may change valence)
# Optimize if atom is carbon (other atoms may change valence)
if atom.GetAtomicNum() == 6 and atom.GetTotalNumHs() < bond_val:
continue
if atom_equal(atom, b1):
new_amap = amap + [(nei_idx, atom.GetIdx(), b1.GetIdx())]
att_confs.append( new_amap )
att_confs.append(new_amap)
elif atom_equal(atom, b2):
new_amap = amap + [(nei_idx, atom.GetIdx(), b2.GetIdx())]
att_confs.append( new_amap )
att_confs.append(new_amap)
else:
#intersection is an atom
# intersection is an atom
for a1 in ctr_atoms:
for a2 in nei_mol.GetAtoms():
if atom_equal(a1, a2):
#Optimize if atom is carbon (other atoms may change valence)
if a1.GetAtomicNum() == 6 and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4:
# Optimize if atom is carbon (other atoms may change valence)
if (
a1.GetAtomicNum() == 6
and a1.GetTotalNumHs() + a2.GetTotalNumHs() < 4
):
continue
new_amap = amap + [(nei_idx, a1.GetIdx(), a2.GetIdx())]
att_confs.append( new_amap )
att_confs.append(new_amap)
#intersection is an bond
# intersection is an bond
if ctr_mol.GetNumBonds() > 1:
for b1 in ctr_bonds:
for b2 in nei_mol.GetBonds():
if ring_bond_equal(b1, b2):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetBeginAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetEndAtom().GetIdx())]
att_confs.append( new_amap )
new_amap = amap + [
(
nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetBeginAtom().GetIdx(),
),
(
nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetEndAtom().GetIdx(),
),
]
att_confs.append(new_amap)
if ring_bond_equal(b1, b2, reverse=True):
new_amap = amap + [(nei_idx, b1.GetBeginAtom().GetIdx(), b2.GetEndAtom().GetIdx()),
(nei_idx, b1.GetEndAtom().GetIdx(), b2.GetBeginAtom().GetIdx())]
att_confs.append( new_amap )
new_amap = amap + [
(
nei_idx,
b1.GetBeginAtom().GetIdx(),
b2.GetEndAtom().GetIdx(),
),
(
nei_idx,
b1.GetEndAtom().GetIdx(),
b2.GetBeginAtom().GetIdx(),
),
]
att_confs.append(new_amap)
return att_confs
#Try rings first: Speed-Up
# Try rings first: Speed-Up
def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
all_attach_confs = []
singletons = [nei_node['nid'] for nei_node in neighbors + prev_nodes if nei_node['mol'].GetNumAtoms() == 1]
singletons = [
nei_node["nid"]
for nei_node in neighbors + prev_nodes
if nei_node["mol"].GetNumAtoms() == 1
]
def search(cur_amap, depth):
if len(all_attach_confs) > MAX_NCAND:
......@@ -266,11 +330,13 @@ def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
return
nei_node = neighbors[depth]
cand_amap = enum_attach_nx(node['mol'], nei_node, cur_amap, singletons)
cand_amap = enum_attach_nx(node["mol"], nei_node, cur_amap, singletons)
cand_smiles = set()
candidates = []
for amap in cand_amap:
cand_mol = local_attach_nx(node['mol'], neighbors[:depth+1], prev_nodes, amap)
cand_mol = local_attach_nx(
node["mol"], neighbors[: depth + 1], prev_nodes, amap
)
cand_mol = sanitize(cand_mol)
if cand_mol is None:
continue
......@@ -290,47 +356,69 @@ def enum_assemble_nx(node, neighbors, prev_nodes=[], prev_amap=[]):
cand_smiles = set()
candidates = []
for amap in all_attach_confs:
cand_mol = local_attach_nx(node['mol'], neighbors, prev_nodes, amap)
cand_mol = local_attach_nx(node["mol"], neighbors, prev_nodes, amap)
cand_mol = Chem.MolFromSmiles(Chem.MolToSmiles(cand_mol))
smiles = Chem.MolToSmiles(cand_mol)
if smiles in cand_smiles:
continue
cand_smiles.add(smiles)
Chem.Kekulize(cand_mol)
candidates.append( (smiles,cand_mol,amap) )
candidates.append((smiles, cand_mol, amap))
return candidates
#Only used for debugging purpose
def dfs_assemble_nx(graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id):
# Only used for debugging purpose
def dfs_assemble_nx(
graph, cur_mol, global_amap, fa_amap, cur_node_id, fa_node_id
):
cur_node = graph.nodes_dict[cur_node_id]
fa_node = graph.nodes_dict[fa_node_id] if fa_node_id is not None else None
fa_nid = fa_node['nid'] if fa_node is not None else -1
fa_nid = fa_node["nid"] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_id = [nei for nei in graph[cur_node_id] if graph.nodes_dict[nei]['nid'] != fa_nid]
children_id = [
nei
for nei in graph[cur_node_id]
if graph.nodes_dict[nei]["nid"] != fa_nid
]
children = [graph.nodes_dict[nei] for nei in children_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x:x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = [nei for nei in children if nei["mol"].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [nei for nei in children if nei["mol"].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid,a2,a1) for nid,a1,a2 in fa_amap if nid == cur_node['nid']]
cands = enum_assemble_nx(graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap)
cur_amap = [
(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node["nid"]
]
cands = enum_assemble_nx(
graph.nodes_dict[cur_node_id], neighbors, prev_nodes, cur_amap
)
if len(cands) == 0:
return
cand_smiles, _, cand_amap = zip(*cands)
label_idx = cand_smiles.index(cur_node['label'])
label_idx = cand_smiles.index(cur_node["label"])
label_amap = cand_amap[label_idx]
for nei_id,ctr_atom,nei_atom in label_amap:
for nei_id, ctr_atom, nei_atom in label_amap:
if nei_id == fa_nid:
continue
global_amap[nei_id][nei_atom] = global_amap[cur_node['nid']][ctr_atom]
global_amap[nei_id][nei_atom] = global_amap[cur_node["nid"]][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], global_amap) #father is already attached
cur_mol = attach_mols_nx(
cur_mol, children, [], global_amap
) # father is already attached
for nei_node_id, nei_node in zip(children_id, children):
if not nei_node['is_leaf']:
dfs_assemble_nx(graph, cur_mol, global_amap, label_amap, nei_node_id, cur_node_id)
if not nei_node["is_leaf"]:
dfs_assemble_nx(
graph,
cur_mol,
global_amap,
label_amap,
nei_node_id,
cur_node_id,
)
......@@ -2,41 +2,49 @@ import torch
from torch.utils.data import Dataset
import dgl
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
from .mol_tree_nx import DGLMolTree
from .mol_tree import Vocab
from dgl.data.utils import (
_get_dgl_url,
download,
extract_archive,
get_download_dir,
)
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .mol_tree import Vocab
from .mol_tree_nx import DGLMolTree
from .mpn import mol2dgl_single as mol2dgl_enc
def _unpack_field(examples, field):
return [e[field] for e in examples]
def _set_node_id(mol_tree, vocab):
wid = []
for i, node in enumerate(mol_tree.nodes_dict):
mol_tree.nodes_dict[node]['idx'] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))
mol_tree.nodes_dict[node]["idx"] = i
wid.append(vocab.get_index(mol_tree.nodes_dict[node]["smiles"]))
return wid
class JTNNDataset(Dataset):
def __init__(self, data, vocab, training=True):
self.dir = get_download_dir()
self.zip_file_path='{}/jtnn.zip'.format(self.dir)
self.zip_file_path = "{}/jtnn.zip".format(self.dir)
download(_get_dgl_url('dgllife/jtnn.zip'), path=self.zip_file_path)
extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
print('Loading data...')
data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
download(_get_dgl_url("dgllife/jtnn.zip"), path=self.zip_file_path)
extract_archive(self.zip_file_path, "{}/jtnn".format(self.dir))
print("Loading data...")
data_file = "{}/jtnn/{}.txt".format(self.dir, data)
with open(data_file) as f:
self.data = [line.strip("\r\n ").split()[0] for line in f]
self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
print('Loading finished.')
print('\tNum samples:', len(self.data))
print('\tVocab file:', self.vocab_file)
self.vocab_file = "{}/jtnn/{}.txt".format(self.dir, vocab)
print("Loading finished.")
print("\tNum samples:", len(self.data))
print("\tVocab file:", self.vocab_file)
self.training = training
self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
......@@ -55,11 +63,11 @@ class JTNNDataset(Dataset):
mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)
result = {
'mol_tree': mol_tree,
'mol_graph': mol_graph,
'atom_x_enc': atom_x_enc,
'bond_x_enc': bond_x_enc,
'wid': wid,
"mol_tree": mol_tree,
"mol_graph": mol_graph,
"atom_x_enc": atom_x_enc,
"bond_x_enc": bond_x_enc,
"wid": wid,
}
if not self.training:
......@@ -69,17 +77,24 @@ class JTNNDataset(Dataset):
cands = []
for node_id, node in mol_tree.nodes_dict.items():
# fill in ground truth
if node['label'] not in node['cands']:
node['cands'].append(node['label'])
node['cand_mols'].append(node['label_mol'])
if node["label"] not in node["cands"]:
node["cands"].append(node["label"])
node["cand_mols"].append(node["label_mol"])
if node['is_leaf'] or len(node['cands']) == 1:
if node["is_leaf"] or len(node["cands"]) == 1:
continue
cands.extend([(cand, mol_tree, node_id)
for cand in node['cand_mols']])
cands.extend(
[(cand, mol_tree, node_id) for cand in node["cand_mols"]]
)
if len(cands) > 0:
cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
(
cand_graphs,
atom_x_dec,
bond_x_dec,
tree_mess_src_e,
tree_mess_tgt_e,
tree_mess_tgt_n,
) = mol2dgl_dec(cands)
else:
cand_graphs = []
atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
......@@ -95,8 +110,9 @@ class JTNNDataset(Dataset):
cands.append(mol_tree.smiles3D)
stereo_graphs = [mol2dgl_enc(c) for c in cands]
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
zip(*stereo_graphs)
stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = zip(
*stereo_graphs
)
stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
......@@ -106,21 +122,24 @@ class JTNNDataset(Dataset):
stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
stereo_cand_label = []
result.update({
'cand_graphs': cand_graphs,
'atom_x_dec': atom_x_dec,
'bond_x_dec': bond_x_dec,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graphs': stereo_cand_graphs,
'stereo_atom_x_enc': stereo_atom_x_enc,
'stereo_bond_x_enc': stereo_bond_x_enc,
'stereo_cand_label': stereo_cand_label,
})
result.update(
{
"cand_graphs": cand_graphs,
"atom_x_dec": atom_x_dec,
"bond_x_dec": bond_x_dec,
"tree_mess_src_e": tree_mess_src_e,
"tree_mess_tgt_e": tree_mess_tgt_e,
"tree_mess_tgt_n": tree_mess_tgt_n,
"stereo_cand_graphs": stereo_cand_graphs,
"stereo_atom_x_enc": stereo_atom_x_enc,
"stereo_bond_x_enc": stereo_bond_x_enc,
"stereo_cand_label": stereo_cand_label,
}
)
return result
class JTNNCollator(object):
def __init__(self, vocab, training):
self.vocab = vocab
......@@ -131,43 +150,45 @@ class JTNNCollator(object):
if flatten:
graphs = [g for f in graphs for g in f]
graph_batch = dgl.batch(graphs)
graph_batch.ndata['x'] = atom_x
graph_batch.edata.update({
'x': bond_x,
'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
})
graph_batch.ndata["x"] = atom_x
graph_batch.edata.update(
{
"x": bond_x,
"src_x": atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
}
)
return graph_batch
def __call__(self, examples):
# get list of trees
mol_trees = _unpack_field(examples, 'mol_tree')
wid = _unpack_field(examples, 'wid')
mol_trees = _unpack_field(examples, "mol_tree")
wid = _unpack_field(examples, "wid")
for _wid, mol_tree in zip(wid, mol_trees):
mol_tree.graph.ndata['wid'] = torch.LongTensor(_wid)
mol_tree.graph.ndata["wid"] = torch.LongTensor(_wid)
# TODO: either support pickling or get around ctypes pointers using scipy
# batch molecule graphs
mol_graphs = _unpack_field(examples, 'mol_graph')
atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
mol_graphs = _unpack_field(examples, "mol_graph")
atom_x = torch.cat(_unpack_field(examples, "atom_x_enc"))
bond_x = torch.cat(_unpack_field(examples, "bond_x_enc"))
mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)
result = {
'mol_trees': mol_trees,
'mol_graph_batch': mol_graph_batch,
"mol_trees": mol_trees,
"mol_graph_batch": mol_graph_batch,
}
if not self.training:
return result
# batch candidate graphs
cand_graphs = _unpack_field(examples, 'cand_graphs')
cand_graphs = _unpack_field(examples, "cand_graphs")
cand_batch_idx = []
atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')
atom_x = torch.cat(_unpack_field(examples, "atom_x_dec"))
bond_x = torch.cat(_unpack_field(examples, "bond_x_dec"))
tree_mess_src_e = _unpack_field(examples, "tree_mess_src_e")
tree_mess_tgt_e = _unpack_field(examples, "tree_mess_tgt_e")
tree_mess_tgt_n = _unpack_field(examples, "tree_mess_tgt_n")
n_graph_nodes = 0
n_tree_nodes = 0
......@@ -182,12 +203,14 @@ class JTNNCollator(object):
tree_mess_src_e = torch.cat(tree_mess_src_e)
tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)
cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)
cand_graph_batch = self._batch_and_set(
cand_graphs, atom_x, bond_x, True
)
# batch stereoisomers
stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
stereo_cand_graphs = _unpack_field(examples, "stereo_cand_graphs")
atom_x = torch.cat(_unpack_field(examples, "stereo_atom_x_enc"))
bond_x = torch.cat(_unpack_field(examples, "stereo_bond_x_enc"))
stereo_cand_batch_idx = []
for i in range(len(stereo_cand_graphs)):
stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))
......@@ -195,28 +218,31 @@ class JTNNCollator(object):
if len(stereo_cand_batch_idx) > 0:
stereo_cand_labels = [
(label, length)
for ex in _unpack_field(examples, 'stereo_cand_label')
for ex in _unpack_field(examples, "stereo_cand_label")
for label, length in ex
]
stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
stereo_cand_graph_batch = self._batch_and_set(
stereo_cand_graphs, atom_x, bond_x, True)
stereo_cand_graphs, atom_x, bond_x, True
)
else:
stereo_cand_labels = []
stereo_cand_lengths = []
stereo_cand_graph_batch = None
stereo_cand_batch_idx = []
result.update({
'cand_graph_batch': cand_graph_batch,
'cand_batch_idx': cand_batch_idx,
'tree_mess_tgt_e': tree_mess_tgt_e,
'tree_mess_src_e': tree_mess_src_e,
'tree_mess_tgt_n': tree_mess_tgt_n,
'stereo_cand_graph_batch': stereo_cand_graph_batch,
'stereo_cand_batch_idx': stereo_cand_batch_idx,
'stereo_cand_labels': stereo_cand_labels,
'stereo_cand_lengths': stereo_cand_lengths,
})
result.update(
{
"cand_graph_batch": cand_graph_batch,
"cand_batch_idx": cand_batch_idx,
"tree_mess_tgt_e": tree_mess_tgt_e,
"tree_mess_src_e": tree_mess_src_e,
"tree_mess_tgt_n": tree_mess_tgt_n,
"stereo_cand_graph_batch": stereo_cand_graph_batch,
"stereo_cand_batch_idx": stereo_cand_batch_idx,
"stereo_cand_labels": stereo_cand_labels,
"stereo_cand_lengths": stereo_cand_lengths,
}
)
return result
import numpy as np
import torch
import torch.nn as nn
from .nnutils import GRUUpdate, cuda, tocpu
from dgl import batch, bfs_edges_generator, line_graph
import dgl.function as DGLF
import numpy as np
from dgl import batch, bfs_edges_generator, line_graph
from .nnutils import GRUUpdate, cuda, tocpu
MAX_NB = 8
def level_order(forest, roots):
forest = tocpu(forest)
edges = bfs_edges_generator(forest, roots)
......@@ -18,6 +21,7 @@ def level_order(forest, roots):
yield from reversed(edges_back)
yield from edges
class EncoderGatherUpdate(nn.Module):
def __init__(self, hidden_size):
nn.Module.__init__(self)
......@@ -26,10 +30,10 @@ class EncoderGatherUpdate(nn.Module):
self.W = nn.Linear(2 * hidden_size, hidden_size)
def forward(self, nodes):
x = nodes.data['x']
m = nodes.data['m']
x = nodes.data["x"]
m = nodes.data["m"]
return {
'h': torch.relu(self.W(torch.cat([x, m], 1))),
"h": torch.relu(self.W(torch.cat([x, m], 1))),
}
......@@ -52,42 +56,53 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch = batch(mol_trees)
# Build line graph to prepare for belief propagation
mol_tree_batch_lg = line_graph(mol_tree_batch, backtracking=False, shared=True)
mol_tree_batch_lg = line_graph(
mol_tree_batch, backtracking=False, shared=True
)
return self.run(mol_tree_batch, mol_tree_batch_lg)
def run(self, mol_tree_batch, mol_tree_batch_lg):
# Since tree roots are designated to 0. In the batched graph we can
# simply find the corresponding node ID by looking at node_offset
node_offset = np.cumsum(np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0))
node_offset = np.cumsum(
np.insert(mol_tree_batch.batch_num_nodes().cpu().numpy(), 0, 0)
)
root_ids = node_offset[:-1]
n_nodes = mol_tree_batch.number_of_nodes()
n_edges = mol_tree_batch.number_of_edges()
# Assign structure embeddings to tree nodes
mol_tree_batch.ndata.update({
'x': self.embedding(mol_tree_batch.ndata['wid']),
'm': cuda(torch.zeros(n_nodes, self.hidden_size)),
'h': cuda(torch.zeros(n_nodes, self.hidden_size)),
})
mol_tree_batch.ndata.update(
{
"x": self.embedding(mol_tree_batch.ndata["wid"]),
"m": cuda(torch.zeros(n_nodes, self.hidden_size)),
"h": cuda(torch.zeros(n_nodes, self.hidden_size)),
}
)
# Initialize the intermediate variables according to Eq (4)-(8).
# Also initialize the src_x and dst_x fields.
# TODO: context?
mol_tree_batch.edata.update({
's': cuda(torch.zeros(n_edges, self.hidden_size)),
'm': cuda(torch.zeros(n_edges, self.hidden_size)),
'r': cuda(torch.zeros(n_edges, self.hidden_size)),
'z': cuda(torch.zeros(n_edges, self.hidden_size)),
'src_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'dst_x': cuda(torch.zeros(n_edges, self.hidden_size)),
'rm': cuda(torch.zeros(n_edges, self.hidden_size)),
'accum_rm': cuda(torch.zeros(n_edges, self.hidden_size)),
})
mol_tree_batch.edata.update(
{
"s": cuda(torch.zeros(n_edges, self.hidden_size)),
"m": cuda(torch.zeros(n_edges, self.hidden_size)),
"r": cuda(torch.zeros(n_edges, self.hidden_size)),
"z": cuda(torch.zeros(n_edges, self.hidden_size)),
"src_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"dst_x": cuda(torch.zeros(n_edges, self.hidden_size)),
"rm": cuda(torch.zeros(n_edges, self.hidden_size)),
"accum_rm": cuda(torch.zeros(n_edges, self.hidden_size)),
}
)
# Send the source/destination node features to edges
mol_tree_batch.apply_edges(
func=lambda edges: {'src_x': edges.src['x'], 'dst_x': edges.dst['x']},
func=lambda edges: {
"src_x": edges.src["x"],
"dst_x": edges.dst["x"],
},
)
# Message passing
......@@ -98,15 +113,19 @@ class DGLJTNNEncoder(nn.Module):
mol_tree_batch_lg.ndata.update(mol_tree_batch.edata)
for eid in level_order(mol_tree_batch, root_ids):
eid = eid.to(mol_tree_batch_lg.device)
mol_tree_batch_lg.pull(eid, DGLF.copy_u('m', 'm'), DGLF.sum('m', 's'))
mol_tree_batch_lg.pull(eid, DGLF.copy_u('rm', 'rm'), DGLF.sum('rm', 'accum_rm'))
mol_tree_batch_lg.pull(
eid, DGLF.copy_u("m", "m"), DGLF.sum("m", "s")
)
mol_tree_batch_lg.pull(
eid, DGLF.copy_u("rm", "rm"), DGLF.sum("rm", "accum_rm")
)
mol_tree_batch_lg.apply_nodes(self.enc_tree_update, v=eid)
# Readout
mol_tree_batch.edata.update(mol_tree_batch_lg.ndata)
mol_tree_batch.update_all(DGLF.copy_e('m', 'm'), DGLF.sum('m', 'm'))
mol_tree_batch.update_all(DGLF.copy_e("m", "m"), DGLF.sum("m", "m"))
mol_tree_batch.apply_nodes(self.enc_tree_gather_update)
root_vecs = mol_tree_batch.nodes[root_ids].data['h']
root_vecs = mol_tree_batch.nodes[root_ids].data["h"]
return mol_tree_batch, root_vecs
import copy
import rdkit.Chem as Chem
import torch
import torch.nn as nn
import torch.nn.functional as F
from .nnutils import cuda
from .chemutils import set_atommap, copy_edit_mol, enum_assemble_nx, \
attach_mols_nx, decode_stereo
from .jtnn_enc import DGLJTNNEncoder
from dgl import batch, unbatch
from .chemutils import (
attach_mols_nx,
copy_edit_mol,
decode_stereo,
enum_assemble_nx,
set_atommap,
)
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
from .jtnn_dec import DGLJTNNDecoder
from .jtnn_enc import DGLJTNNEncoder
from .mpn import DGLMPN
from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import DGLJTMPN
from .jtmpn import mol2dgl_single as mol2dgl_dec
import rdkit.Chem as Chem
import copy
from .nnutils import cuda
from dgl import batch, unbatch
class DGLJTNNVAE(nn.Module):
def __init__(self, vocab, hidden_size, latent_size, depth):
super(DGLJTNNVAE, self).__init__()
self.vocab = vocab
......@@ -29,7 +35,8 @@ class DGLJTNNVAE(nn.Module):
self.mpn = DGLMPN(hidden_size, depth)
self.jtnn = DGLJTNNEncoder(vocab, hidden_size, self.embedding)
self.decoder = DGLJTNNDecoder(
vocab, hidden_size, latent_size // 2, self.embedding)
vocab, hidden_size, latent_size // 2, self.embedding
)
self.jtmpn = DGLJTMPN(hidden_size, depth)
self.T_mean = nn.Linear(hidden_size, latent_size // 2)
......@@ -44,24 +51,32 @@ class DGLJTNNVAE(nn.Module):
@staticmethod
def move_to_cuda(mol_batch):
for i in range(len(mol_batch['mol_trees'])):
mol_batch['mol_trees'][i].graph = cuda(mol_batch['mol_trees'][i].graph)
for i in range(len(mol_batch["mol_trees"])):
mol_batch["mol_trees"][i].graph = cuda(
mol_batch["mol_trees"][i].graph
)
mol_batch['mol_graph_batch'] = cuda(mol_batch['mol_graph_batch'])
if 'cand_graph_batch' in mol_batch:
mol_batch['cand_graph_batch'] = cuda(mol_batch['cand_graph_batch'])
if mol_batch.get('stereo_cand_graph_batch') is not None:
mol_batch['stereo_cand_graph_batch'] = cuda(mol_batch['stereo_cand_graph_batch'])
mol_batch["mol_graph_batch"] = cuda(mol_batch["mol_graph_batch"])
if "cand_graph_batch" in mol_batch:
mol_batch["cand_graph_batch"] = cuda(mol_batch["cand_graph_batch"])
if mol_batch.get("stereo_cand_graph_batch") is not None:
mol_batch["stereo_cand_graph_batch"] = cuda(
mol_batch["stereo_cand_graph_batch"]
)
def encode(self, mol_batch):
mol_graphs = mol_batch['mol_graph_batch']
mol_graphs = mol_batch["mol_graph_batch"]
mol_vec = self.mpn(mol_graphs)
mol_tree_batch, tree_vec = self.jtnn([t.graph for t in mol_batch['mol_trees']])
mol_tree_batch, tree_vec = self.jtnn(
[t.graph for t in mol_batch["mol_trees"]]
)
self.n_nodes_total += mol_graphs.number_of_nodes()
self.n_edges_total += mol_graphs.number_of_edges()
self.n_tree_nodes_total += sum(t.graph.number_of_nodes() for t in mol_batch['mol_trees'])
self.n_tree_nodes_total += sum(
t.graph.number_of_nodes() for t in mol_batch["mol_trees"]
)
self.n_passes += 1
return mol_tree_batch, tree_vec, mol_vec
......@@ -85,31 +100,45 @@ class DGLJTNNVAE(nn.Module):
def forward(self, mol_batch, beta=0, e1=None, e2=None):
self.move_to_cuda(mol_batch)
mol_trees = mol_batch['mol_trees']
mol_trees = mol_batch["mol_trees"]
batch_size = len(mol_trees)
mol_tree_batch, tree_vec, mol_vec = self.encode(mol_batch)
tree_vec, mol_vec, z_mean, z_log_var = self.sample(tree_vec, mol_vec, e1, e2)
kl_loss = -0.5 * torch.sum(1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)) / batch_size
tree_vec, mol_vec, z_mean, z_log_var = self.sample(
tree_vec, mol_vec, e1, e2
)
kl_loss = (
-0.5
* torch.sum(
1.0 + z_log_var - z_mean * z_mean - torch.exp(z_log_var)
)
/ batch_size
)
word_loss, topo_loss, word_acc, topo_acc = self.decoder([t.graph for t in mol_trees], tree_vec)
word_loss, topo_loss, word_acc, topo_acc = self.decoder(
[t.graph for t in mol_trees], tree_vec
)
assm_loss, assm_acc = self.assm(mol_batch, mol_tree_batch, mol_vec)
stereo_loss, stereo_acc = self.stereo(mol_batch, mol_vec)
loss = word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
loss = (
word_loss + topo_loss + assm_loss + 2 * stereo_loss + beta * kl_loss
)
return loss, kl_loss, word_acc, topo_acc, assm_acc, stereo_acc
def assm(self, mol_batch, mol_tree_batch, mol_vec):
cands = [mol_batch['cand_graph_batch'],
cuda(mol_batch['tree_mess_src_e']),
cuda(mol_batch['tree_mess_tgt_e']),
cuda(mol_batch['tree_mess_tgt_n'])]
cands = [
mol_batch["cand_graph_batch"],
cuda(mol_batch["tree_mess_src_e"]),
cuda(mol_batch["tree_mess_tgt_e"]),
cuda(mol_batch["tree_mess_tgt_n"]),
]
cand_vec = self.jtmpn(cands, mol_tree_batch)
cand_vec = self.G_mean(cand_vec)
batch_idx = cuda(torch.LongTensor(mol_batch['cand_batch_idx']))
batch_idx = cuda(torch.LongTensor(mol_batch["cand_batch_idx"]))
mol_vec = mol_vec[batch_idx]
mol_vec = mol_vec.view(-1, 1, self.latent_size // 2)
......@@ -118,16 +147,19 @@ class DGLJTNNVAE(nn.Module):
cnt, tot, acc = 0, 0, 0
all_loss = []
for i, mol_tree in enumerate(mol_batch['mol_trees']):
comp_nodes = [node_id for node_id, node in mol_tree.nodes_dict.items()
if len(node['cands']) > 1 and not node['is_leaf']]
for i, mol_tree in enumerate(mol_batch["mol_trees"]):
comp_nodes = [
node_id
for node_id, node in mol_tree.nodes_dict.items()
if len(node["cands"]) > 1 and not node["is_leaf"]
]
cnt += len(comp_nodes)
# segmented accuracy and cross entropy
for node_id in comp_nodes:
node = mol_tree.nodes_dict[node_id]
label = node['cands'].index(node['label'])
ncand = len(node['cands'])
cur_score = scores[tot:tot+ncand]
label = node["cands"].index(node["label"])
ncand = len(node["cands"])
cur_score = scores[tot : tot + ncand]
tot += ncand
if cur_score[label].item() >= cur_score.max().item():
......@@ -135,20 +167,23 @@ class DGLJTNNVAE(nn.Module):
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_score.view(1, -1), label, size_average=False))
F.cross_entropy(
cur_score.view(1, -1), label, size_average=False
)
)
all_loss = sum(all_loss) / len(mol_batch['mol_trees'])
all_loss = sum(all_loss) / len(mol_batch["mol_trees"])
return all_loss, acc / cnt
def stereo(self, mol_batch, mol_vec):
stereo_cands = mol_batch['stereo_cand_graph_batch']
batch_idx = mol_batch['stereo_cand_batch_idx']
labels = mol_batch['stereo_cand_labels']
lengths = mol_batch['stereo_cand_lengths']
stereo_cands = mol_batch["stereo_cand_graph_batch"]
batch_idx = mol_batch["stereo_cand_batch_idx"]
labels = mol_batch["stereo_cand_labels"]
lengths = mol_batch["stereo_cand_lengths"]
if len(labels) == 0:
# Only one stereoisomer exists; do nothing
return cuda(torch.tensor(0.)), 1.
return cuda(torch.tensor(0.0)), 1.0
batch_idx = cuda(torch.LongTensor(batch_idx))
stereo_cands = self.mpn(stereo_cands)
......@@ -159,12 +194,15 @@ class DGLJTNNVAE(nn.Module):
st, acc = 0, 0
all_loss = []
for label, le in zip(labels, lengths):
cur_scores = scores[st:st+le]
cur_scores = scores[st : st + le]
if cur_scores.data[label].item() >= cur_scores.max().item():
acc += 1
label = cuda(torch.LongTensor([label]))
all_loss.append(
F.cross_entropy(cur_scores.view(1, -1), label, size_average=False))
F.cross_entropy(
cur_scores.view(1, -1), label, size_average=False
)
)
st += le
all_loss = sum(all_loss) / len(labels)
......@@ -175,24 +213,32 @@ class DGLJTNNVAE(nn.Module):
effective_nodes_list = effective_nodes.tolist()
nodes_dict = [nodes_dict[v] for v in effective_nodes_list]
for i, (node_id, node) in enumerate(zip(effective_nodes_list, nodes_dict)):
node['idx'] = i
node['nid'] = i + 1
node['is_leaf'] = True
for i, (node_id, node) in enumerate(
zip(effective_nodes_list, nodes_dict)
):
node["idx"] = i
node["nid"] = i + 1
node["is_leaf"] = True
if mol_tree.graph.in_degrees(node_id) > 1:
node['is_leaf'] = False
set_atommap(node['mol'], node['nid'])
node["is_leaf"] = False
set_atommap(node["mol"], node["nid"])
mol_tree_sg = mol_tree.graph.subgraph(effective_nodes.to(tree_vec.device))
mol_tree_sg = mol_tree.graph.subgraph(
effective_nodes.to(tree_vec.device)
)
mol_tree_msg, _ = self.jtnn([mol_tree_sg])
mol_tree_msg = unbatch(mol_tree_msg)[0]
mol_tree_msg.nodes_dict = nodes_dict
cur_mol = copy_edit_mol(nodes_dict[0]['mol'])
cur_mol = copy_edit_mol(nodes_dict[0]["mol"])
global_amap = [{}] + [{} for node in nodes_dict]
global_amap[1] = {atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()}
global_amap[1] = {
atom.GetIdx(): atom.GetIdx() for atom in cur_mol.GetAtoms()
}
cur_mol = self.dfs_assemble(mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None)
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, global_amap, [], 0, None
)
if cur_mol is None:
return None
......@@ -207,56 +253,86 @@ class DGLJTNNVAE(nn.Module):
if len(stereo_cands) == 1:
return stereo_cands[0]
stereo_graphs = [mol2dgl_enc(c) for c in stereo_cands]
stereo_cand_graphs, atom_x, bond_x = \
zip(*stereo_graphs)
stereo_cand_graphs, atom_x, bond_x = zip(*stereo_graphs)
stereo_cand_graphs = cuda(batch(stereo_cand_graphs))
atom_x = cuda(torch.cat(atom_x))
bond_x = cuda(torch.cat(bond_x))
stereo_cand_graphs.ndata['x'] = atom_x
stereo_cand_graphs.edata['x'] = bond_x
stereo_cand_graphs.edata['src_x'] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]).zero_()
stereo_cand_graphs.ndata["x"] = atom_x
stereo_cand_graphs.edata["x"] = bond_x
stereo_cand_graphs.edata["src_x"] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]
).zero_()
stereo_vecs = self.mpn(stereo_cand_graphs)
stereo_vecs = self.G_mean(stereo_vecs)
scores = F.cosine_similarity(stereo_vecs, mol_vec)
_, max_id = scores.max(0)
return stereo_cands[max_id.item()]
def dfs_assemble(self, mol_tree_msg, mol_vec, cur_mol,
global_amap, fa_amap, cur_node_id, fa_node_id):
def dfs_assemble(
self,
mol_tree_msg,
mol_vec,
cur_mol,
global_amap,
fa_amap,
cur_node_id,
fa_node_id,
):
nodes_dict = mol_tree_msg.nodes_dict
fa_node = nodes_dict[fa_node_id] if fa_node_id is not None else None
cur_node = nodes_dict[cur_node_id]
fa_nid = fa_node['nid'] if fa_node is not None else -1
fa_nid = fa_node["nid"] if fa_node is not None else -1
prev_nodes = [fa_node] if fa_node is not None else []
children_node_id = [v for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]['nid'] != fa_nid]
children_node_id = [
v
for v in mol_tree_msg.successors(cur_node_id).tolist()
if nodes_dict[v]["nid"] != fa_nid
]
children = [nodes_dict[v] for v in children_node_id]
neighbors = [nei for nei in children if nei['mol'].GetNumAtoms() > 1]
neighbors = sorted(neighbors, key=lambda x: x['mol'].GetNumAtoms(), reverse=True)
singletons = [nei for nei in children if nei['mol'].GetNumAtoms() == 1]
neighbors = [nei for nei in children if nei["mol"].GetNumAtoms() > 1]
neighbors = sorted(
neighbors, key=lambda x: x["mol"].GetNumAtoms(), reverse=True
)
singletons = [nei for nei in children if nei["mol"].GetNumAtoms() == 1]
neighbors = singletons + neighbors
cur_amap = [(fa_nid, a2, a1) for nid, a1, a2 in fa_amap if nid == cur_node['nid']]
cur_amap = [
(fa_nid, a2, a1)
for nid, a1, a2 in fa_amap
if nid == cur_node["nid"]
]
cands = enum_assemble_nx(cur_node, neighbors, prev_nodes, cur_amap)
if len(cands) == 0:
return None
cand_smiles, cand_mols, cand_amap = list(zip(*cands))
cands = [(candmol, mol_tree_msg, cur_node_id) for candmol in cand_mols]
cand_graphs, atom_x, bond_x, tree_mess_src_edges, \
tree_mess_tgt_edges, tree_mess_tgt_nodes = mol2dgl_dec(cands)
(
cand_graphs,
atom_x,
bond_x,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
) = mol2dgl_dec(cands)
cand_graphs = batch([g.to(mol_vec.device) for g in cand_graphs])
atom_x = cuda(atom_x)
bond_x = cuda(bond_x)
cand_graphs.ndata['x'] = atom_x
cand_graphs.edata['x'] = bond_x
cand_graphs.edata['src_x'] = atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_()
cand_graphs.ndata["x"] = atom_x
cand_graphs.edata["x"] = bond_x
cand_graphs.edata["src_x"] = atom_x.new(
bond_x.shape[0], atom_x.shape[1]
).zero_()
cand_vecs = self.jtmpn(
(cand_graphs, tree_mess_src_edges, tree_mess_tgt_edges, tree_mess_tgt_nodes),
(
cand_graphs,
tree_mess_src_edges,
tree_mess_tgt_edges,
tree_mess_tgt_nodes,
),
mol_tree_msg,
)
cand_vecs = self.G_mean(cand_vecs)
......@@ -274,7 +350,9 @@ class DGLJTNNVAE(nn.Module):
for nei_id, ctr_atom, nei_atom in pred_amap:
if nei_id == fa_nid:
continue
new_global_amap[nei_id][nei_atom] = new_global_amap[cur_node['nid']][ctr_atom]
new_global_amap[nei_id][nei_atom] = new_global_amap[
cur_node["nid"]
][ctr_atom]
cur_mol = attach_mols_nx(cur_mol, children, [], new_global_amap)
new_mol = cur_mol.GetMol()
......@@ -285,11 +363,17 @@ class DGLJTNNVAE(nn.Module):
result = True
for nei_node_id, nei_node in zip(children_node_id, children):
if nei_node['is_leaf']:
if nei_node["is_leaf"]:
continue
cur_mol = self.dfs_assemble(
mol_tree_msg, mol_vec, cur_mol, new_global_amap, pred_amap,
nei_node_id, cur_node_id)
mol_tree_msg,
mol_vec,
cur_mol,
new_global_amap,
pred_amap,
nei_node_id,
cur_node_id,
)
if cur_mol is None:
result = False
break
......
'''
"""
line_profiler integration
'''
"""
import os
if os.getenv('PROFILE', 0):
import line_profiler
if os.getenv("PROFILE", 0):
import atexit
import line_profiler
profile = line_profiler.LineProfiler()
profile_output = os.getenv('PROFILE_OUTPUT', None)
profile_output = os.getenv("PROFILE_OUTPUT", None)
if profile_output:
from functools import partial
atexit.register(partial(profile.dump_stats, profile_output))
else:
atexit.register(profile.print_stats)
else:
def profile(f):
return f
import rdkit.Chem as Chem
import copy
import rdkit.Chem as Chem
def get_slots(smiles):
mol = Chem.MolFromSmiles(smiles)
return [(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs()) for atom in mol.GetAtoms()]
return [
(atom.GetSymbol(), atom.GetFormalCharge(), atom.GetTotalNumHs())
for atom in mol.GetAtoms()
]
class Vocab(object):
class Vocab(object):
def __init__(self, smiles_list):
self.vocab = smiles_list
self.vmap = {x:i for i,x in enumerate(self.vocab)}
self.vmap = {x: i for i, x in enumerate(self.vocab)}
self.slots = [get_slots(smiles) for smiles in self.vocab]
def get_index(self, smiles):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment