__init__.py 20.2 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
# -*- coding: utf-8 -*-
"""Tools for loading, shuffling, and batching ANI datasets"""

4
5
6
from torch.utils.data import Dataset
from os.path import join, isfile, isdir
import os
7
from ._pyanitools import anidataloader
8
import torch
9
from .. import utils, neurochem, aev, models
10
import pickle
11
12
import numpy as np
from scipy.sparse import bsr_matrix
13
import warnings
14
from .new import CachedDataset, ShuffledDataset, find_threshold
15

Gao, Xiang's avatar
Gao, Xiang committed
16
17
default_device = 'cuda' if torch.cuda.is_available() else 'cpu'

18

19
20
21
22
23
24
25
def chunk_counts(counts, split):
    split = [x + 1 for x in split] + [None]
    count_chunks = []
    start = 0
    for i in split:
        count_chunks.append(counts[start:i])
        start = i
26
    chunk_molecules = [sum([y[1] for y in x]) for x in count_chunks]
27
    chunk_maxatoms = [x[-1][0] for x in count_chunks]
28
    return chunk_molecules, chunk_maxatoms
29
30
31
32
33


def split_cost(counts, split):
    split_min_cost = 40000
    cost = 0
34
35
36
    chunk_molecules, chunk_maxatoms = chunk_counts(counts, split)
    for molecules, maxatoms in zip(chunk_molecules, chunk_maxatoms):
        cost += max(molecules * maxatoms ** 2, split_min_cost)
37
38
39
    return cost


40
41
def split_batch(natoms, atomic_properties):

42
43
44
45
    # count number of conformation by natoms
    natoms = natoms.tolist()
    counts = []
    for i in natoms:
46
        if not counts:
47
48
49
50
51
52
            counts.append([i, 1])
            continue
        if i == counts[-1][0]:
            counts[-1][1] += 1
        else:
            counts.append([i, 1])
53

54
55
56
57
58
59
60
61
    # find best split using greedy strategy
    split = []
    cost = split_cost(counts, split)
    improved = True
    while improved:
        improved = False
        cycle_split = split
        cycle_cost = cost
62
        for i in range(len(counts) - 1):
63
64
65
66
67
68
69
70
71
72
            if i not in split:
                s = sorted(split + [i])
                c = split_cost(counts, s)
                if c < cycle_cost:
                    improved = True
                    cycle_cost = c
                    cycle_split = s
        if improved:
            split = cycle_split
            cost = cycle_cost
73

74
    # do split
75
76
77
78
79
80
81
82
83
84
85
86
87
    chunk_molecules, _ = chunk_counts(counts, split)
    num_chunks = None
    for k in atomic_properties:
        atomic_properties[k] = atomic_properties[k].split(chunk_molecules)
        if num_chunks is None:
            num_chunks = len(atomic_properties[k])
        else:
            assert num_chunks == len(atomic_properties[k])
    chunks = []
    for i in range(num_chunks):
        chunk = {k: atomic_properties[k][i] for k in atomic_properties}
        chunks.append(utils.strip_redundant_padding(chunk))
    return chunks
88
89


90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
def load_and_pad_whole_dataset(path, species_tensor_converter, shuffle=True,
                               properties=('energies',), atomic_properties=()):
    # get name of files storing data
    files = []
    if isdir(path):
        for f in os.listdir(path):
            f = join(path, f)
            if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')):
                files.append(f)
    elif isfile(path):
        files = [path]
    else:
        raise ValueError('Bad path')

    # load full dataset
    atomic_properties_ = []
    properties = {k: [] for k in properties}
    for f in files:
        for m in anidataloader(f):
            atomic_properties_.append(dict(
                species=species_tensor_converter(m['species']).unsqueeze(0),
                **{
                    k: torch.from_numpy(m[k]).to(torch.double)
                    for k in ['coordinates'] + list(atomic_properties)
                }
            ))
            for i in properties:
                p = torch.from_numpy(m[i]).to(torch.double)
                properties[i].append(p)
    atomic_properties = utils.pad_atomic_properties(atomic_properties_)
    for i in properties:
        properties[i] = torch.cat(properties[i])

    # shuffle if required
    molecules = atomic_properties['species'].shape[0]
    if shuffle:
        indices = torch.randperm(molecules)
        for i in properties:
            properties[i] = properties[i].index_select(0, indices)
        for i in atomic_properties:
            atomic_properties[i] = atomic_properties[i].index_select(0, indices)
    return atomic_properties, properties


def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size):
    molecules = atomic_properties['species'].shape[0]
    # split into minibatches
    for k in properties:
        properties[k] = properties[k].split(batch_size)
    for k in atomic_properties:
        atomic_properties[k] = atomic_properties[k].split(batch_size)

    # further split batch into chunks and strip redundant padding
    batches = []
    num_batches = (molecules + batch_size - 1) // batch_size
    for i in range(num_batches):
        batch_properties = {k: v[i] for k, v in properties.items()}
        batch_atomic_properties = {k: v[i] for k, v in atomic_properties.items()}
        species = batch_atomic_properties['species']
        natoms = (species >= 0).to(torch.long).sum(1)

        # sort batch by number of atoms to prepare for splitting
        natoms, indices = natoms.sort()
        for k in batch_properties:
            batch_properties[k] = batch_properties[k].index_select(0, indices)
        for k in batch_atomic_properties:
            batch_atomic_properties[k] = batch_atomic_properties[k].index_select(0, indices)

        batch_atomic_properties = split_batch(natoms, batch_atomic_properties)
        batches.append((batch_atomic_properties, batch_properties))

    return batches


164
class PaddedBatchChunkDataset(Dataset):
Gao, Xiang's avatar
Gao, Xiang committed
165

166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
    def __init__(self, atomic_properties, properties, batch_size,
                 dtype=torch.get_default_dtype(), device=default_device):
        super().__init__()
        self.device = device
        self.dtype = dtype

        # convert to desired dtype
        for k in properties:
            properties[k] = properties[k].to(dtype)
        for k in atomic_properties:
            if k == 'species':
                continue
            atomic_properties[k] = atomic_properties[k].to(dtype)

        self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size)

    def __getitem__(self, idx):
        atomic_properties, properties = self.batches[idx]
        atomic_properties, properties = atomic_properties.copy(), properties.copy()
        species_coordinates = []
        for chunk in atomic_properties:
            for k in chunk:
                chunk[k] = chunk[k].to(self.device)
            species_coordinates.append((chunk['species'], chunk['coordinates']))
        for k in properties:
            properties[k] = properties[k].to(self.device)
        properties['atomic'] = atomic_properties
        return species_coordinates, properties

    def __len__(self):
        return len(self.batches)


class BatchedANIDataset(PaddedBatchChunkDataset):
    """Same as :func:`torchani.data.load_ani_dataset`. This API has been deprecated."""

    def __init__(self, path, species_tensor_converter, batch_size,
                 shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
                 dtype=torch.get_default_dtype(), device=default_device):
        self.properties = properties
        self.atomic_properties = atomic_properties
        warnings.warn("BatchedANIDataset is deprecated; use load_ani_dataset()", DeprecationWarning)

        atomic_properties, properties = load_and_pad_whole_dataset(
            path, species_tensor_converter, shuffle, properties, atomic_properties)

        # do transformations on data
        for t in transform:
            atomic_properties, properties = t(atomic_properties, properties)

        super().__init__(atomic_properties, properties, batch_size, dtype, device)


def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
220
221
                     rm_outlier=False, properties=('energies',), atomic_properties=(),
                     transform=(), dtype=torch.get_default_dtype(), device=default_device,
222
223
224
225
226
                     split=(None,)):
    """Load ANI dataset from hdf5 files, and split into subsets.

    The return datasets are already a dataset of batches, so when iterated, a
    batch rather than a single data point will be yielded.
Gao, Xiang's avatar
Gao, Xiang committed
227
228
229
230
231
232
233
234
235

    Since each batch might contain molecules of very different sizes, putting
    the whole batch into a single tensor would require adding ghost atoms to
    pad everything to the size of the largest molecule. As a result, huge
    amount of computation would be wasted on ghost atoms. To avoid this issue,
    the input of each batch, i.e. species and coordinates, are further divided
    into chunks according to some heuristics, so that each chunk would only
    have molecules of similar size, to minimize the padding required.

Gao, Xiang's avatar
Gao, Xiang committed
236
    So, when iterating on this dataset, a tuple will be yielded. The first
Gao, Xiang's avatar
Gao, Xiang committed
237
238
    element of this tuple is a list of (species, coordinates) pairs. Each pair
    is a chunk of molecules of similar size. The second element of this tuple
Gao, Xiang's avatar
Gao, Xiang committed
239
    would be a dictionary, where the keys are those specified in the argument
Gao, Xiang's avatar
Gao, Xiang committed
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
    :attr:`properties`, and values are a single tensor of the whole batch
    (properties are not splitted into chunks).

    Splitting batch into chunks leads to some inconvenience on training,
    especially when using high level libraries like ``ignite``. To overcome
    this inconvenience, :class:`torchani.ignite.Container` is created for
    working with ignite.

    Arguments:
        path (str): Path to hdf5 files. If :attr:`path` is a file, then that
            file would be loaded using `pyanitools.py`_. If :attr:`path` is
            a directory, then all files with suffix `.h5` or `.hdf5` will be
            loaded.
        species_tensor_converter (:class:`collections.abc.Callable`): A
            callable that convert species in the format of list of strings
            to 1D tensor.
        batch_size (int): Number of different 3D structures in a single
            minibatch.
        shuffle (bool): Whether to shuffle the whole dataset.
259
260
        rm_outlier (bool): Whether to discard the outlier energy conformers
            from a given dataset.
261
262
263
264
265
266
267
268
269
270
271
272
273
274
        properties (list): List of keys of `molecular` properties in the
            dataset to be loaded. Here `molecular` means, no matter the number
            of atoms that property always have fixed size, i.e. the tensor
            shape of molecular properties should be (molecule, ...). An example
            of molecular property is the molecular energies. ``'species'`` and
            ``'coordinates'`` are always loaded and need not to be specified
            anywhere.
        atomic_properties (list): List of keys of `atomic` properties in the
            dataset to be loaded. Here `atomic` means, the size of property
            is proportional to the number of atoms in the molecule, i.e. the
            tensor shape of atomic properties should be (molecule, atoms, ...).
            An example of atomic property is the forces. ``'species'`` and
            ``'coordinates'`` are always loaded and need not to be specified
            anywhere.
Gao, Xiang's avatar
Gao, Xiang committed
275
        transform (list): List of :class:`collections.abc.Callable` that
276
277
278
            transform the data. Callables must take atomic properties,
            properties as arguments, and return the transformed atomic
            properties and properties.
Gao, Xiang's avatar
Gao, Xiang committed
279
280
281
        dtype (:class:`torch.dtype`): dtype of coordinates and properties to
            to convert the dataset to.
        device (:class:`torch.dtype`): device to put tensors when iterating.
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
        split (list): as sequence of integers or floats or ``None``. Integers
            are interpreted as number of elements, floats are interpreted as
            percentage, and ``None`` are interpreted as the rest of the dataset
            and can only appear as the last element of :class:`split`. For
            example, if the whole dataset has 10000 entry, and split is
            ``(5000, 0.1, None)``, then this function will create 3 datasets,
            where the first dataset contains 5000 elements, the second dataset
            contains ``int(0.1 * 10000)``, which is 1000, and the third dataset
            will contains ``10000 - 5000 - 1000`` elements. By default this
            creates only a single dataset.

    Returns:
        An instance of :class:`torchani.data.PaddedBatchChunkDataset` if there is
        only one element in :attr:`split`, otherwise returns a tuple of the same
        classes according to :attr:`split`.
Gao, Xiang's avatar
Gao, Xiang committed
297
298
299
300

    .. _pyanitools.py:
        https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py
    """
301
302
303
    atomic_properties_, properties_ = load_and_pad_whole_dataset(
        path, species_tensor_converter, shuffle, properties, atomic_properties)

304
305
306
307
    molecules = atomic_properties_['species'].shape[0]
    atomic_keys = ['species', 'coordinates', *atomic_properties]
    keys = properties

308
309
310
311
    # do transformations on data
    for t in transform:
        atomic_properties_, properties_ = t(atomic_properties_, properties_)

312
313
    if rm_outlier:
        transformed_energies = properties_['energies']
314
        num_atoms = (atomic_properties_['species'] >= 0).to(transformed_energies.dtype).sum(dim=1)
315
316
        scaled_diff = transformed_energies / num_atoms.sqrt()

317
        mean = scaled_diff[torch.abs(scaled_diff) < 15.0].mean()
318
        std = scaled_diff[torch.abs(scaled_diff) < 15.0].std()
319

320
321
322
        # -8 * std + mean < scaled_diff < +8 * std + mean
        tol = 8.0 * std + mean
        low_idx = (torch.abs(scaled_diff) < tol).nonzero().squeeze()
323
        outlier_count = molecules - low_idx.numel()
324

325
326
        # discard outlier energy conformers if exist
        if outlier_count > 0:
327
            print("Note: {} outlier energy conformers have been discarded from dataset".format(outlier_count))
328
329
330
331
            for key, val in atomic_properties_.items():
                atomic_properties_[key] = val[low_idx]
            for key, val in properties_.items():
                properties_[key] = val[low_idx]
332
            molecules = low_idx.numel()
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364

    # compute size of each subset
    split_ = []
    total = 0
    for index, size in enumerate(split):
        if isinstance(size, float):
            size = int(size * molecules)
        if size is None:
            assert index == len(split) - 1
            size = molecules - total
        split_.append(size)
        total += size

    # split
    start = 0
    splitted = []
    for size in split_:
        ap = {k: atomic_properties_[k][start:start + size] for k in atomic_keys}
        p = {k: properties_[k][start:start + size] for k in keys}
        start += size
        splitted.append((ap, p))

    # consturct batched dataset
    ret = []
    for ap, p in splitted:
        ds = PaddedBatchChunkDataset(ap, p, batch_size, dtype, device)
        ds.properties = properties
        ds.atomic_properties = atomic_properties
        ret.append(ds)
    if len(ret) == 1:
        return ret[0]
    return tuple(ret)
365
366


367
class AEVCacheLoader(Dataset):
368
369
370
371
372
373
374
375
376
377
378
379
    """Build a factory for AEV.

    The computation of AEV is the most time consuming part during training.
    Since during training, the AEV never changes, it is not hard to see that,
    If we have a fast enough storage (this is usually the case for good SSDs,
    but not for HDD), we could cache the computed AEVs into disk and load it
    rather than compute it from scratch everytime we use it.

    Arguments:
        disk_cache (str): Directory storing disk caches.
    """

Gao, Xiang's avatar
Gao, Xiang committed
380
    def __init__(self, disk_cache=None):
381
        super(AEVCacheLoader, self).__init__()
382
383
384
385
386
387
388
        self.disk_cache = disk_cache

        # load dataset from disk cache
        dataset_path = os.path.join(disk_cache, 'dataset')
        with open(dataset_path, 'rb') as f:
            self.dataset = pickle.load(f)

Gao, Xiang's avatar
Gao, Xiang committed
389
    def __getitem__(self, index):
390
        _, output = self.dataset.batches[index]
Gao, Xiang's avatar
Gao, Xiang committed
391
392
393
        aev_path = os.path.join(self.disk_cache, str(index))
        with open(aev_path, 'rb') as f:
            species_aevs = pickle.load(f)
394
395
396
397
398
399
            for i, sa in enumerate(species_aevs):
                species, aevs = self.decode_aev(*sa)
                species_aevs[i] = (
                    species.to(self.dataset.device),
                    aevs.to(self.dataset.device)
                )
Gao, Xiang's avatar
Gao, Xiang committed
400
        return species_aevs, output
401

Gao, Xiang's avatar
Gao, Xiang committed
402
403
404
    def __len__(self):
        return len(self.dataset)

405
406
407
    @staticmethod
    def decode_aev(encoded_species, encoded_aev):
        return encoded_species, encoded_aev
Gao, Xiang's avatar
Gao, Xiang committed
408

409
410
411
412
413
414
    @staticmethod
    def encode_aev(species, aev):
        return species, aev


class SparseAEVCacheLoader(AEVCacheLoader):
415
416
417
418
419
420
421
422
423
424
425
426
427
    """Build a factory for AEV.

    The computation of AEV is the most time-consuming part of the training.
    AEV never changes during training and contains a large number of zeros.
    Therefore, we can store the computed AEVs as sparse representation and
    load it during the training rather than compute it from scratch. The
    storage requirement for ```'cache_sparse_aev'``` is considerably less
    than ```'cache_aev'```.

    Arguments:
        disk_cache (str): Directory storing disk caches.
    """

428
429
430
431
432
433
    @staticmethod
    def decode_aev(encoded_species, encoded_aev):
        species = torch.from_numpy(encoded_species.todense())
        aevs_np = np.stack([np.array(i.todense()) for i in encoded_aev], axis=0)
        aevs = torch.from_numpy(aevs_np)
        return species, aevs
434

435
436
437
438
439
    @staticmethod
    def encode_aev(species, aev):
        encoded_species = bsr_matrix(species.cpu().numpy())
        encoded_aev = [bsr_matrix(i.cpu().numpy()) for i in aev]
        return encoded_species, encoded_aev
440
441


442
ani1x = models.ANI1x()
Gao, Xiang's avatar
Gao, Xiang committed
443
444


445
446
447
448
449
450
451
452
453
454
455
def create_aev_cache(dataset, aev_computer, output, progress_bar=True, encoder=lambda *x: x):
    """Cache AEV for the given dataset.

    Arguments:
        dataset (:class:`torchani.data.PaddedBatchChunkDataset`): the dataset to be cached
        aev_computer (:class:`torchani.AEVComputer`): the AEV computer used to compute aev
        output (str): path to the directory where cache will be stored
        progress_bar (bool): whether to show progress bar
        encoder (:class:`collections.abc.Callable`): The callable
            (species, aev) -> (encoded_species, encoded_aev) that encode species and aev
    """
Gao, Xiang's avatar
Gao, Xiang committed
456
457
458
459
460
    # dump out the dataset
    filename = os.path.join(output, 'dataset')
    with open(filename, 'wb') as f:
        pickle.dump(dataset, f)

461
    if progress_bar:
Gao, Xiang's avatar
Gao, Xiang committed
462
463
464
465
466
467
        import tqdm
        indices = tqdm.trange(len(dataset))
    else:
        indices = range(len(dataset))
    for i in indices:
        input_, _ = dataset[i]
468
        aevs = [encoder(*aev_computer(j)) for j in input_]
Gao, Xiang's avatar
Gao, Xiang committed
469
470
471
472
473
        filename = os.path.join(output, '{}'.format(i))
        with open(filename, 'wb') as f:
            pickle.dump(aevs, f)


474
475
def _cache_aev(output, dataset_path, batchsize, device, constfile,
               subtract_sae, sae_file, enable_tqdm, encoder, **kwargs):
476
477
478
479
480
481
482
483
484
485
486
487
488
489
    # if output directory does not exist, then create it
    if not os.path.exists(output):
        os.makedirs(output)

    device = torch.device(device)
    consts = neurochem.Constants(constfile)
    aev_computer = aev.AEVComputer(**consts).to(device)

    if subtract_sae:
        energy_shifter = neurochem.load_sae(sae_file)
        transform = (energy_shifter.subtract_from_dataset,)
    else:
        transform = ()

490
    dataset = load_ani_dataset(
491
492
493
494
        dataset_path, consts.species_to_tensor, batchsize,
        device=device, transform=transform, **kwargs
    )

495
    create_aev_cache(dataset, aev_computer, output, enable_tqdm, encoder)
496

497
498

def cache_aev(output, dataset_path, batchsize, device=default_device,
499
500
              constfile=ani1x.const_file, subtract_sae=False,
              sae_file=ani1x.sae_file, enable_tqdm=True, **kwargs):
501
502
503
504
505
506
    _cache_aev(output, dataset_path, batchsize, device, constfile,
               subtract_sae, sae_file, enable_tqdm, AEVCacheLoader.encode_aev,
               **kwargs)


def cache_sparse_aev(output, dataset_path, batchsize, device=default_device,
507
508
                     constfile=ani1x.const_file, subtract_sae=False,
                     sae_file=ani1x.sae_file, enable_tqdm=True, **kwargs):
509
510
511
    _cache_aev(output, dataset_path, batchsize, device, constfile,
               subtract_sae, sae_file, enable_tqdm,
               SparseAEVCacheLoader.encode_aev, **kwargs)
512
513


514
515
516
__all__ = ['load_ani_dataset', 'BatchedANIDataset', 'AEVCacheLoader',
           'SparseAEVCacheLoader', 'cache_aev', 'cache_sparse_aev',
           'CachedDataset', 'ShuffledDataset', 'find_threshold']