datautils.py 8.43 KB
Newer Older
1
import torch
2
3
4
from torch.utils.data import Dataset

import dgl
5
from dgl.data.utils import download, extract_archive, get_download_dir, _get_dgl_url
6
7
8
9
10
from .mol_tree_nx import DGLMolTree
from .mol_tree import Vocab

from .mpn import mol2dgl_single as mol2dgl_enc
from .jtmpn import mol2dgl_single as mol2dgl_dec
11
12
from .jtmpn import ATOM_FDIM as ATOM_FDIM_DEC
from .jtmpn import BOND_FDIM as BOND_FDIM_DEC
13

14
15
16
17
18
19
20
21
22
23
24
def _unpack_field(examples, field):
    return [e[field] for e in examples]

def _set_node_id(mol_tree, vocab):
    wid = []
    for i, node in enumerate(mol_tree.nodes_dict):
        mol_tree.nodes_dict[node]['idx'] = i
        wid.append(vocab.get_index(mol_tree.nodes_dict[node]['smiles']))

    return wid

25
class JTNNDataset(Dataset):
26
    def __init__(self, data, vocab, training=True):
27
28
        self.dir = get_download_dir()
        self.zip_file_path='{}/jtnn.zip'.format(self.dir)
29

30
        download(_get_dgl_url('dgllife/jtnn.zip'), path=self.zip_file_path)
31
32
33
34
35
36
37
38
39
        extract_archive(self.zip_file_path, '{}/jtnn'.format(self.dir))
        print('Loading data...')
        data_file = '{}/jtnn/{}.txt'.format(self.dir, data)
        with open(data_file) as f:
            self.data = [line.strip("\r\n ").split()[0] for line in f]
        self.vocab_file = '{}/jtnn/{}.txt'.format(self.dir, vocab)
        print('Loading finished.')
        print('\tNum samples:', len(self.data))
        print('\tVocab file:', self.vocab_file)
40
41
        self.training = training
        self.vocab = Vocab([x.strip("\r\n ") for x in open(self.vocab_file)])
42
43
44
45
46
47
48
49
50

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        smiles = self.data[idx]
        mol_tree = DGLMolTree(smiles)
        mol_tree.recover()
        mol_tree.assemble()
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84

        wid = _set_node_id(mol_tree, self.vocab)

        # prebuild the molecule graph
        mol_graph, atom_x_enc, bond_x_enc = mol2dgl_enc(mol_tree.smiles)

        result = {
                'mol_tree': mol_tree,
                'mol_graph': mol_graph,
                'atom_x_enc': atom_x_enc,
                'bond_x_enc': bond_x_enc,
                'wid': wid,
                }

        if not self.training:
            return result

        # prebuild the candidate graph list
        cands = []
        for node_id, node in mol_tree.nodes_dict.items():
            # fill in ground truth
            if node['label'] not in node['cands']:
                node['cands'].append(node['label'])
                node['cand_mols'].append(node['label_mol'])

            if node['is_leaf'] or len(node['cands']) == 1:
                continue
            cands.extend([(cand, mol_tree, node_id)
                         for cand in node['cand_mols']])
        if len(cands) > 0:
            cand_graphs, atom_x_dec, bond_x_dec, tree_mess_src_e, \
                    tree_mess_tgt_e, tree_mess_tgt_n = mol2dgl_dec(cands)
        else:
            cand_graphs = []
85
86
            atom_x_dec = torch.zeros(0, ATOM_FDIM_DEC)
            bond_x_dec = torch.zeros(0, BOND_FDIM_DEC)
87
88
89
            tree_mess_src_e = torch.zeros(0, 2).long()
            tree_mess_tgt_e = torch.zeros(0, 2).long()
            tree_mess_tgt_n = torch.zeros(0).long()
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145

        # prebuild the stereoisomers
        cands = mol_tree.stereo_cands
        if len(cands) > 1:
            if mol_tree.smiles3D not in cands:
                cands.append(mol_tree.smiles3D)

            stereo_graphs = [mol2dgl_enc(c) for c in cands]
            stereo_cand_graphs, stereo_atom_x_enc, stereo_bond_x_enc = \
                    zip(*stereo_graphs)
            stereo_atom_x_enc = torch.cat(stereo_atom_x_enc)
            stereo_bond_x_enc = torch.cat(stereo_bond_x_enc)
            stereo_cand_label = [(cands.index(mol_tree.smiles3D), len(cands))]
        else:
            stereo_cand_graphs = []
            stereo_atom_x_enc = torch.zeros(0, atom_x_enc.shape[1])
            stereo_bond_x_enc = torch.zeros(0, bond_x_enc.shape[1])
            stereo_cand_label = []

        result.update({
            'cand_graphs': cand_graphs,
            'atom_x_dec': atom_x_dec,
            'bond_x_dec': bond_x_dec,
            'tree_mess_src_e': tree_mess_src_e,
            'tree_mess_tgt_e': tree_mess_tgt_e,
            'tree_mess_tgt_n': tree_mess_tgt_n,
            'stereo_cand_graphs': stereo_cand_graphs,
            'stereo_atom_x_enc': stereo_atom_x_enc,
            'stereo_bond_x_enc': stereo_bond_x_enc,
            'stereo_cand_label': stereo_cand_label,
            })

        return result

class JTNNCollator(object):
    def __init__(self, vocab, training):
        self.vocab = vocab
        self.training = training

    @staticmethod
    def _batch_and_set(graphs, atom_x, bond_x, flatten):
        if flatten:
            graphs = [g for f in graphs for g in f]
        graph_batch = dgl.batch(graphs)
        graph_batch.ndata['x'] = atom_x
        graph_batch.edata.update({
            'x': bond_x,
            'src_x': atom_x.new(bond_x.shape[0], atom_x.shape[1]).zero_(),
            })
        return graph_batch

    def __call__(self, examples):
        # get list of trees
        mol_trees = _unpack_field(examples, 'mol_tree')
        wid = _unpack_field(examples, 'wid')
        for _wid, mol_tree in zip(wid, mol_trees):
146
            mol_tree.graph.ndata['wid'] = torch.LongTensor(_wid)
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178

        # TODO: either support pickling or get around ctypes pointers using scipy
        # batch molecule graphs
        mol_graphs = _unpack_field(examples, 'mol_graph')
        atom_x = torch.cat(_unpack_field(examples, 'atom_x_enc'))
        bond_x = torch.cat(_unpack_field(examples, 'bond_x_enc'))
        mol_graph_batch = self._batch_and_set(mol_graphs, atom_x, bond_x, False)

        result = {
                'mol_trees': mol_trees,
                'mol_graph_batch': mol_graph_batch,
                }

        if not self.training:
            return result

        # batch candidate graphs
        cand_graphs = _unpack_field(examples, 'cand_graphs')
        cand_batch_idx = []
        atom_x = torch.cat(_unpack_field(examples, 'atom_x_dec'))
        bond_x = torch.cat(_unpack_field(examples, 'bond_x_dec'))
        tree_mess_src_e = _unpack_field(examples, 'tree_mess_src_e')
        tree_mess_tgt_e = _unpack_field(examples, 'tree_mess_tgt_e')
        tree_mess_tgt_n = _unpack_field(examples, 'tree_mess_tgt_n')

        n_graph_nodes = 0
        n_tree_nodes = 0
        for i in range(len(cand_graphs)):
            tree_mess_tgt_e[i] += n_graph_nodes
            tree_mess_src_e[i] += n_tree_nodes
            tree_mess_tgt_n[i] += n_graph_nodes
            n_graph_nodes += sum(g.number_of_nodes() for g in cand_graphs[i])
179
            n_tree_nodes += mol_trees[i].graph.number_of_nodes()
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
            cand_batch_idx.extend([i] * len(cand_graphs[i]))
        tree_mess_tgt_e = torch.cat(tree_mess_tgt_e)
        tree_mess_src_e = torch.cat(tree_mess_src_e)
        tree_mess_tgt_n = torch.cat(tree_mess_tgt_n)

        cand_graph_batch = self._batch_and_set(cand_graphs, atom_x, bond_x, True)

        # batch stereoisomers
        stereo_cand_graphs = _unpack_field(examples, 'stereo_cand_graphs')
        atom_x = torch.cat(_unpack_field(examples, 'stereo_atom_x_enc'))
        bond_x = torch.cat(_unpack_field(examples, 'stereo_bond_x_enc'))
        stereo_cand_batch_idx = []
        for i in range(len(stereo_cand_graphs)):
            stereo_cand_batch_idx.extend([i] * len(stereo_cand_graphs[i]))

        if len(stereo_cand_batch_idx) > 0:
            stereo_cand_labels = [
                    (label, length)
                    for ex in _unpack_field(examples, 'stereo_cand_label')
                    for label, length in ex
                    ]
            stereo_cand_labels, stereo_cand_lengths = zip(*stereo_cand_labels)
            stereo_cand_graph_batch = self._batch_and_set(
                    stereo_cand_graphs, atom_x, bond_x, True)
        else:
            stereo_cand_labels = []
            stereo_cand_lengths = []
            stereo_cand_graph_batch = None
            stereo_cand_batch_idx = []

        result.update({
            'cand_graph_batch': cand_graph_batch,
            'cand_batch_idx': cand_batch_idx,
            'tree_mess_tgt_e': tree_mess_tgt_e,
            'tree_mess_src_e': tree_mess_src_e,
            'tree_mess_tgt_n': tree_mess_tgt_n,
            'stereo_cand_graph_batch': stereo_cand_graph_batch,
            'stereo_cand_batch_idx': stereo_cand_batch_idx,
            'stereo_cand_labels': stereo_cand_labels,
            'stereo_cand_lengths': stereo_cand_lengths,
            })

        return result