__init__.py 12.7 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
# -*- coding: utf-8 -*-
"""Tools for loading, shuffling, and batching ANI datasets"""

4
5
6
from torch.utils.data import Dataset
from os.path import join, isfile, isdir
import os
7
from ._pyanitools import anidataloader
8
import torch
Gao, Xiang's avatar
Gao, Xiang committed
9
from .. import utils, neurochem, aev
10
import pickle
11

Gao, Xiang's avatar
Gao, Xiang committed
12
13
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

14

15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
def chunk_counts(counts, split):
    split = [x + 1 for x in split] + [None]
    count_chunks = []
    start = 0
    for i in split:
        count_chunks.append(counts[start:i])
        start = i
    chunk_conformations = [sum([y[1] for y in x]) for x in count_chunks]
    chunk_maxatoms = [x[-1][0] for x in count_chunks]
    return chunk_conformations, chunk_maxatoms


def split_cost(counts, split):
    split_min_cost = 40000
    cost = 0
    chunk_conformations, chunk_maxatoms = chunk_counts(counts, split)
    for conformations, maxatoms in zip(chunk_conformations, chunk_maxatoms):
        cost += max(conformations * maxatoms ** 2, split_min_cost)
    return cost


def split_batch(natoms, species, coordinates):
    # count number of conformation by natoms
    natoms = natoms.tolist()
    counts = []
    for i in natoms:
        if len(counts) == 0:
            counts.append([i, 1])
            continue
        if i == counts[-1][0]:
            counts[-1][1] += 1
        else:
            counts.append([i, 1])
    # find best split using greedy strategy
    split = []
    cost = split_cost(counts, split)
    improved = True
    while improved:
        improved = False
        cycle_split = split
        cycle_cost = cost
        for i in range(len(counts)-1):
            if i not in split:
                s = sorted(split + [i])
                c = split_cost(counts, s)
                if c < cycle_cost:
                    improved = True
                    cycle_cost = c
                    cycle_split = s
        if improved:
            split = cycle_split
            cost = cycle_cost
    # do split
    start = 0
    species_coordinates = []
    chunk_conformations, _ = chunk_counts(counts, split)
    for i in chunk_conformations:
        s = species
        end = start + i
        s = species[start:end, ...]
        c = coordinates[start:end, ...]
Gao, Xiang's avatar
Gao, Xiang committed
76
        s, c = utils.strip_redundant_padding(s, c)
77
78
79
80
81
        species_coordinates.append((s, c))
        start = end
    return species_coordinates


82
class BatchedANIDataset(Dataset):
Gao, Xiang's avatar
Gao, Xiang committed
83
84
85
86
87
88
89
90
91
92
93
94
95
    """Load data from hdf5 files, create minibatches, and convert to tensors.

    This is already a dataset of batches, so when iterated, a batch rather
    than a single data point will be yielded.

    Since each batch might contain molecules of very different sizes, putting
    the whole batch into a single tensor would require adding ghost atoms to
    pad everything to the size of the largest molecule. As a result, huge
    amount of computation would be wasted on ghost atoms. To avoid this issue,
    the input of each batch, i.e. species and coordinates, are further divided
    into chunks according to some heuristics, so that each chunk would only
    have molecules of similar size, to minimize the padding required.

Gao, Xiang's avatar
Gao, Xiang committed
96
    So, when iterating on this dataset, a tuple will be yielded. The first
Gao, Xiang's avatar
Gao, Xiang committed
97
98
    element of this tuple is a list of (species, coordinates) pairs. Each pair
    is a chunk of molecules of similar size. The second element of this tuple
Gao, Xiang's avatar
Gao, Xiang committed
99
    would be a dictionary, where the keys are those specified in the argument
Gao, Xiang's avatar
Gao, Xiang committed
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
    :attr:`properties`, and values are a single tensor of the whole batch
    (properties are not splitted into chunks).

    Splitting batch into chunks leads to some inconvenience on training,
    especially when using high level libraries like ``ignite``. To overcome
    this inconvenience, :class:`torchani.ignite.Container` is created for
    working with ignite.

    Arguments:
        path (str): Path to hdf5 files. If :attr:`path` is a file, then that
            file would be loaded using `pyanitools.py`_. If :attr:`path` is
            a directory, then all files with suffix `.h5` or `.hdf5` will be
            loaded.
        species_tensor_converter (:class:`collections.abc.Callable`): A
            callable that convert species in the format of list of strings
            to 1D tensor.
        batch_size (int): Number of different 3D structures in a single
            minibatch.
        shuffle (bool): Whether to shuffle the whole dataset.
        properties (list): List of keys in the dataset to be loaded.
            ``'species'`` and ``'coordinates'`` are always loaded and need not
            to be specified here.
        transform (list): List of :class:`collections.abc.Callable` that
            transform the data. Callables must take species, coordinates,
            and properties of the whole dataset as arguments, and return
            the transformed species, coordinates, and properties.
        dtype (:class:`torch.dtype`): dtype of coordinates and properties to
            to convert the dataset to.
        device (:class:`torch.dtype`): device to put tensors when iterating.

    .. _pyanitools.py:
        https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py
    """
133

134
135
    def __init__(self, path, species_tensor_converter, batch_size,
                 shuffle=True, properties=['energies'], transform=(),
Gao, Xiang's avatar
Gao, Xiang committed
136
                 dtype=torch.get_default_dtype(), device=default_device):
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
        super(BatchedANIDataset, self).__init__()
        self.properties = properties
        self.device = device

        # get name of files storing data
        files = []
        if isdir(path):
            for f in os.listdir(path):
                f = join(path, f)
                if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
                    files.append(f)
        elif isfile(path):
            files = [path]
        else:
            raise ValueError('Bad path')

        # load full dataset
        species_coordinates = []
        properties = {k: [] for k in self.properties}
        for f in files:
            for m in anidataloader(f):
158
159
160
                s = species_tensor_converter(m['species'])
                c = torch.from_numpy(m['coordinates']).to(torch.double)
                species_coordinates.append((s, c))
161
                for i in properties:
162
163
                    p = torch.from_numpy(m[i]).to(torch.double)
                    properties[i].append(p)
164
        species, coordinates = utils.pad_coordinates(species_coordinates)
165
166
167
168
169
170
        for i in properties:
            properties[i] = torch.cat(properties[i])

        # shuffle if required
        conformations = coordinates.shape[0]
        if shuffle:
171
            indices = torch.randperm(conformations)
172
173
174
175
176
177
178
179
180
181
            species = species.index_select(0, indices)
            coordinates = coordinates.index_select(0, indices)
            for i in properties:
                properties[i] = properties[i].index_select(0, indices)

        # do transformations on data
        for t in transform:
            species, coordinates, properties = t(species, coordinates,
                                                 properties)

182
183
184
        # convert to desired dtype
        species = species
        coordinates = coordinates.to(dtype)
185
186
        for k in properties:
            properties[k] = properties[k].to(dtype)
187

188
        # split into minibatches, and strip redundant padding
189
        natoms = (species >= 0).to(torch.long).sum(1)
190
191
192
193
194
        batches = []
        num_batches = (conformations + batch_size - 1) // batch_size
        for i in range(num_batches):
            start = i * batch_size
            end = min((i + 1) * batch_size, conformations)
195
            natoms_batch = natoms[start:end]
196
            # sort batch by number of atoms to prepare for splitting
197
198
199
200
            natoms_batch, indices = natoms_batch.sort()
            species_batch = species[start:end, ...].index_select(0, indices)
            coordinates_batch = coordinates[start:end, ...] \
                .index_select(0, indices)
201
            properties_batch = {
202
203
                k: properties[k][start:end, ...].index_select(0, indices)
                for k in properties
204
            }
205
206
207
208
209
            # further split batch into chunks
            species_coordinates = split_batch(natoms_batch, species_batch,
                                              coordinates_batch)
            batch = species_coordinates, properties_batch
            batches.append(batch)
210
211
212
        self.batches = batches

    def __getitem__(self, idx):
213
214
215
        species_coordinates, properties = self.batches[idx]
        species_coordinates = [(s.to(self.device), c.to(self.device))
                               for s, c in species_coordinates]
216
217
218
        properties = {
            k: properties[k].to(self.device) for k in properties
        }
219
        return species_coordinates, properties
220
221
222

    def __len__(self):
        return len(self.batches)
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260


def _disk_cache_loader(index_queue, tensor_queue, disk_cache, device):
    """Get index and load from disk cache."""
    while True:
        index = index_queue.get()
        aev_path = os.path.join(disk_cache, str(index))
        with open(aev_path, 'rb') as f:
            tensor_queue.put(pickle.load(f))


class AEVCacheLoader:
    """Build a factory for AEV.

    The computation of AEV is the most time consuming part during training.
    Since during training, the AEV never changes, it is not hard to see that,
    If we have a fast enough storage (this is usually the case for good SSDs,
    but not for HDD), we could cache the computed AEVs into disk and load it
    rather than compute it from scratch everytime we use it.

    Arguments:
        disk_cache (str): Directory storing disk caches.
    """

    def __init__(self, disk_cache=None, in_memory_size=64):
        self.current = 0
        self.disk_cache = disk_cache

        # load dataset from disk cache
        dataset_path = os.path.join(disk_cache, 'dataset')
        with open(dataset_path, 'rb') as f:
            self.dataset = pickle.load(f)
        # initialize queues and processes
        self.tensor_queue = torch.multiprocessing.Queue()
        self.index_queue = torch.multiprocessing.Queue()
        self.in_memory_size = in_memory_size
        if len(self.dataset) < in_memory_size:
            self.in_memory_size = len(self.dataset)
Gao, Xiang's avatar
Gao, Xiang committed
261
        for i in range(self.in_memory_size):
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
            self.index_queue.put(i)
        self.loader = torch.multiprocessing.Process(
            target=_disk_cache_loader,
            args=(self.index_queue, self.tensor_queue, disk_cache,
                  self.dataset.device)
        )
        self.loader.start()

    def __iter__(self):
        if self.current != 0:
            raise ValueError('Only one iterator of AEVCacheLoader is allowed')
        else:
            return self

    def __next__(self):
        if self.current < len(self.dataset):
            new_idx = (self.current + self.in_memory_size) % len(self.dataset)
            self.index_queue.put(new_idx)
            species_aevs = self.tensor_queue.get()
            species_aevs = [(x.to(self.dataset.device),
                             y.to(self.dataset.device))
                            for x, y in species_aevs]
            _, output = self.dataset[self.current]
            self.current += 1
            return species_aevs, output
        else:
            self.current = 0
            raise StopIteration

    def __del__(self):
        self.loader.terminate()

Gao, Xiang's avatar
Gao, Xiang committed
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
    def __len__(self):
        return len(self.dataset)


builtin = neurochem.Builtins()


def cache_aev(output, dataset_path, batchsize, device=default_device,
              constfile=builtin.const_file, subtract_sae=False,
              sae_file=builtin.sae_file, enable_tqdm=True, **kwargs):
    # if output directory does not exist, then create it
    if not os.path.exists(output):
        os.makedirs(output)

    device = torch.device(device)
    consts = neurochem.Constants(constfile)
    aev_computer = aev.AEVComputer(**consts).to(device)

    if subtract_sae:
        energy_shifter = neurochem.load_sae(sae_file)
        transform = (energy_shifter.subtract_from_dataset,)
    else:
        transform = ()

    dataset = BatchedANIDataset(
        dataset_path, consts.species_to_tensor, batchsize,
        device=device, transform=transform, **kwargs
    )

    # dump out the dataset
    filename = os.path.join(output, 'dataset')
    with open(filename, 'wb') as f:
        pickle.dump(dataset, f)

    if enable_tqdm:
        import tqdm
        indices = tqdm.trange(len(dataset))
    else:
        indices = range(len(dataset))
    for i in indices:
        input_, _ = dataset[i]
        aevs = [aev_computer(j) for j in input_]
        aevs = [(x.cpu(), y.cpu()) for x, y in aevs]
        filename = os.path.join(output, '{}'.format(i))
        with open(filename, 'wb') as f:
            pickle.dump(aevs, f)


__all__ = ['BatchedANIDataset', 'AEVCacheLoader', 'cache_aev']