import torch import torch.nn as nn import torch.optim as optim import torch.optim.lr_scheduler as lr_scheduler from torch.utils.data import DataLoader import math, random, sys from optparse import OptionParser from collections import deque import rdkit import tqdm from jtnn import * torch.multiprocessing.set_sharing_strategy('file_system') def worker_init_fn(id_): lg = rdkit.RDLogger.logger() lg.setLevel(rdkit.RDLogger.CRITICAL) worker_init_fn(None) parser = OptionParser() parser.add_option("-t", "--train", dest="train", default='train', help='Training file name') parser.add_option("-v", "--vocab", dest="vocab", default='vocab', help='Vocab file name') parser.add_option("-s", "--save_dir", dest="save_path") parser.add_option("-m", "--model", dest="model_path", default=None) parser.add_option("-b", "--batch", dest="batch_size", default=40) parser.add_option("-w", "--hidden", dest="hidden_size", default=200) parser.add_option("-l", "--latent", dest="latent_size", default=56) parser.add_option("-d", "--depth", dest="depth", default=3) parser.add_option("-z", "--beta", dest="beta", default=1.0) parser.add_option("-q", "--lr", dest="lr", default=1e-3) parser.add_option("-T", "--test", dest="test", action="store_true") opts,args = parser.parse_args() dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True) vocab = dataset.vocab batch_size = int(opts.batch_size) hidden_size = int(opts.hidden_size) latent_size = int(opts.latent_size) depth = int(opts.depth) beta = float(opts.beta) lr = float(opts.lr) model = DGLJTNNVAE(vocab, hidden_size, latent_size, depth) if opts.model_path is not None: model.load_state_dict(torch.load(opts.model_path)) else: for param in model.parameters(): if param.dim() == 1: nn.init.constant(param, 0) else: nn.init.xavier_normal(param) model = cuda(model) print("Model #Params: %dK" % (sum([x.nelement() for x in model.parameters()]) / 1000,)) optimizer = optim.Adam(model.parameters(), lr=lr) scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9) scheduler.step() MAX_EPOCH = 100 PRINT_ITER = 20 def train(): dataset.training = True dataloader = DataLoader( dataset, batch_size=batch_size, shuffle=True, num_workers=4, collate_fn=JTNNCollator(vocab, True), drop_last=True, worker_init_fn=worker_init_fn) for epoch in range(MAX_EPOCH): word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 for it, batch in enumerate(tqdm.tqdm(dataloader)): model.zero_grad() try: loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta) except: print([t.smiles for t in batch['mol_trees']]) raise loss.backward() optimizer.step() word_acc += wacc topo_acc += tacc assm_acc += sacc steo_acc += dacc if (it + 1) % PRINT_ITER == 0: word_acc = word_acc / PRINT_ITER * 100 topo_acc = topo_acc / PRINT_ITER * 100 assm_acc = assm_acc / PRINT_ITER * 100 steo_acc = steo_acc / PRINT_ITER * 100 print("KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f" % ( kl_div, word_acc, topo_acc, assm_acc, steo_acc, loss.item())) word_acc,topo_acc,assm_acc,steo_acc = 0,0,0,0 sys.stdout.flush() if (it + 1) % 1500 == 0: #Fast annealing scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0]) torch.save(model.state_dict(), opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1)) scheduler.step() print("learning rate: %.6f" % scheduler.get_lr()[0]) torch.save(model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)) def test(): dataset.training = False dataloader = DataLoader( dataset, batch_size=1, shuffle=False, num_workers=0, collate_fn=JTNNCollator(vocab, False), drop_last=True, worker_init_fn=worker_init_fn) # Just an example of molecule decoding; in reality you may want to sample # tree and molecule vectors. for it, batch in enumerate(dataloader): gt_smiles = batch['mol_trees'][0].smiles print(gt_smiles) model.move_to_cuda(batch) _, tree_vec, mol_vec = model.encode(batch) tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec) smiles = model.decode(tree_vec, mol_vec) print(smiles) if __name__ == '__main__': if opts.test: test() else: train() print('# passes:', model.n_passes) print('Total # nodes processed:', model.n_nodes_total) print('Total # edges processed:', model.n_edges_total) print('Total # tree nodes processed:', model.n_tree_nodes_total) print('Graph decoder: # passes:', model.jtmpn.n_passes) print('Graph decoder: Total # candidates processed:', model.jtmpn.n_samples_total) print('Graph decoder: Total # nodes processed:', model.jtmpn.n_nodes_total) print('Graph decoder: Total # edges processed:', model.jtmpn.n_edges_total) print('Graph encoder: # passes:', model.mpn.n_passes) print('Graph encoder: Total # candidates processed:', model.mpn.n_samples_total) print('Graph encoder: Total # nodes processed:', model.mpn.n_nodes_total) print('Graph encoder: Total # edges processed:', model.mpn.n_edges_total)