"vscode:/vscode.git/clone" did not exist on "591b76b53f075f1d9b01b42e2d3eb2751fbeb169"
vaetrain_dgl.py 5.82 KB
Newer Older
1
2
3
4
5
6
7
import math
import random
import sys
from collections import deque
from optparse import OptionParser

import rdkit
8
9
10
11
import torch
import torch.nn as nn
import torch.optim as optim
import torch.optim.lr_scheduler as lr_scheduler
12
import tqdm
13
from jtnn import *
14
15
16
from torch.utils.data import DataLoader

torch.multiprocessing.set_sharing_strategy("file_system")
17

18
19

def worker_init_fn(id_):
20
    lg = rdkit.RDLogger.logger()
21
    lg.setLevel(rdkit.RDLogger.CRITICAL)
22
23


24
worker_init_fn(None)
25
26

parser = OptionParser()
27
28
29
30
31
32
parser.add_option(
    "-t", "--train", dest="train", default="train", help="Training file name"
)
parser.add_option(
    "-v", "--vocab", dest="vocab", default="vocab", help="Vocab file name"
)
33
34
35
36
37
38
39
40
parser.add_option("-s", "--save_dir", dest="save_path")
parser.add_option("-m", "--model", dest="model_path", default=None)
parser.add_option("-b", "--batch", dest="batch_size", default=40)
parser.add_option("-w", "--hidden", dest="hidden_size", default=200)
parser.add_option("-l", "--latent", dest="latent_size", default=56)
parser.add_option("-d", "--depth", dest="depth", default=3)
parser.add_option("-z", "--beta", dest="beta", default=1.0)
parser.add_option("-q", "--lr", dest="lr", default=1e-3)
41
parser.add_option("-T", "--test", dest="test", action="store_true")
42
opts, args = parser.parse_args()
43

44
45
dataset = JTNNDataset(data=opts.train, vocab=opts.vocab, training=True)
vocab = dataset.vocab
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64

batch_size = int(opts.batch_size)
hidden_size = int(opts.hidden_size)
latent_size = int(opts.latent_size)
depth = int(opts.depth)
beta = float(opts.beta)
lr = float(opts.lr)

model = DGLJTNNVAE(vocab, hidden_size, latent_size, depth)

if opts.model_path is not None:
    model.load_state_dict(torch.load(opts.model_path))
else:
    for param in model.parameters():
        if param.dim() == 1:
            nn.init.constant(param, 0)
        else:
            nn.init.xavier_normal(param)

65
model = cuda(model)
66
67
68
69
print(
    "Model #Params: %dK"
    % (sum([x.nelement() for x in model.parameters()]) / 1000,)
)
70
71
72
73
74

optimizer = optim.Adam(model.parameters(), lr=lr)
scheduler = lr_scheduler.ExponentialLR(optimizer, 0.9)
scheduler.step()

75
MAX_EPOCH = 100
76
77
PRINT_ITER = 20

78

79
def train():
80
    dataset.training = True
81
    dataloader = DataLoader(
82
83
84
85
86
87
88
89
        dataset,
        batch_size=batch_size,
        shuffle=True,
        num_workers=4,
        collate_fn=JTNNCollator(vocab, True),
        drop_last=True,
        worker_init_fn=worker_init_fn,
    )
90
91

    for epoch in range(MAX_EPOCH):
92
        word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
93

94
        for it, batch in enumerate(tqdm.tqdm(dataloader)):
95
            model.zero_grad()
96
97
98
            try:
                loss, kl_div, wacc, tacc, sacc, dacc = model(batch, beta)
            except:
99
                print([t.smiles for t in batch["mol_trees"]])
100
                raise
101
102
103
104
105
106
107
108
109
110
111
112
113
114
            loss.backward()
            optimizer.step()

            word_acc += wacc
            topo_acc += tacc
            assm_acc += sacc
            steo_acc += dacc

            if (it + 1) % PRINT_ITER == 0:
                word_acc = word_acc / PRINT_ITER * 100
                topo_acc = topo_acc / PRINT_ITER * 100
                assm_acc = assm_acc / PRINT_ITER * 100
                steo_acc = steo_acc / PRINT_ITER * 100

115
116
117
118
119
120
121
122
123
124
125
126
                print(
                    "KL: %.1f, Word: %.2f, Topo: %.2f, Assm: %.2f, Steo: %.2f, Loss: %.6f"
                    % (
                        kl_div,
                        word_acc,
                        topo_acc,
                        assm_acc,
                        steo_acc,
                        loss.item(),
                    )
                )
                word_acc, topo_acc, assm_acc, steo_acc = 0, 0, 0, 0
127
128
                sys.stdout.flush()

129
            if (it + 1) % 1500 == 0:  # Fast annealing
130
131
                scheduler.step()
                print("learning rate: %.6f" % scheduler.get_lr()[0])
132
133
134
135
                torch.save(
                    model.state_dict(),
                    opts.save_path + "/model.iter-%d-%d" % (epoch, it + 1),
                )
136
137
138

        scheduler.step()
        print("learning rate: %.6f" % scheduler.get_lr()[0])
139
140
141
142
        torch.save(
            model.state_dict(), opts.save_path + "/model.iter-" + str(epoch)
        )

143

144
145
146
def test():
    dataset.training = False
    dataloader = DataLoader(
147
148
149
150
151
152
153
154
        dataset,
        batch_size=1,
        shuffle=False,
        num_workers=0,
        collate_fn=JTNNCollator(vocab, False),
        drop_last=True,
        worker_init_fn=worker_init_fn,
    )
155
156
157
158

    # Just an example of molecule decoding; in reality you may want to sample
    # tree and molecule vectors.
    for it, batch in enumerate(dataloader):
159
        gt_smiles = batch["mol_trees"][0].smiles
160
161
162
163
164
165
166
        print(gt_smiles)
        model.move_to_cuda(batch)
        _, tree_vec, mol_vec = model.encode(batch)
        tree_vec, mol_vec, _, _ = model.sample(tree_vec, mol_vec)
        smiles = model.decode(tree_vec, mol_vec)
        print(smiles)

167
168

if __name__ == "__main__":
169
170
171
172
    if opts.test:
        test()
    else:
        train()
173

174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
    print("# passes:", model.n_passes)
    print("Total # nodes processed:", model.n_nodes_total)
    print("Total # edges processed:", model.n_edges_total)
    print("Total # tree nodes processed:", model.n_tree_nodes_total)
    print("Graph decoder: # passes:", model.jtmpn.n_passes)
    print(
        "Graph decoder: Total # candidates processed:",
        model.jtmpn.n_samples_total,
    )
    print("Graph decoder: Total # nodes processed:", model.jtmpn.n_nodes_total)
    print("Graph decoder: Total # edges processed:", model.jtmpn.n_edges_total)
    print("Graph encoder: # passes:", model.mpn.n_passes)
    print(
        "Graph encoder: Total # candidates processed:",
        model.mpn.n_samples_total,
    )
    print("Graph encoder: Total # nodes processed:", model.mpn.n_nodes_total)
    print("Graph encoder: Total # edges processed:", model.mpn.n_edges_total)