__init__.py 11.2 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
# -*- coding: utf-8 -*-
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
"""Tools for loading, shuffling, and batching ANI datasets

The `torchani.data.load(path)` creates an iterable of raw data,
where species are strings, and coordinates are numpy ndarrays.

You can transform these iterable by using transformations.
To do transformation, just do `it.transformation_name()`.

Available transformations are listed below:

- `species_to_indices` converts species from strings to numbers.
- `subtract_self_energies` subtracts self energies, you can pass.
    a dict of self energies, or an `EnergyShifter` to let it infer
    self energy from dataset and store the result to the given shifter.
- `remove_outliers`
- `shuffle`
- `cache` cache the result of previous transformations.
- `collate` pad the dataset, convert it to tensor, and stack them
    together to get a batch.
- `pin_memory` copy the tensor to pinned memory so that later transfer
    to cuda could be faster.

24
25
26
27
28
29
30
By default `species_to_indices` and `subtract_self_energies` order atoms by
atomic number. A special ordering can be used if requested, by calling
`species_to_indices(species_order)` or `subtract_self_energies(energy_shifter,
species_order)` however, this is definitely NOT recommended, it is best to
always order according to atomic number.

you can also use `split` to split the iterable to pieces. use `split` as:
31
32
33

.. code-block:: python

34
    it.split(ratio1, ratio2, None)
35
36
37
38
39
40
41
42

where the None in the end indicate that we want to use all of the the rest

Example:

.. code-block:: python

    energy_shifter = torchani.utils.EnergyShifter(None)
43
    training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(int(0.8 * size), None)
44
45
    training = training.collate(batch_size).cache()
    validation = validation.collate(batch_size).cache()
46
47
48
49
50
51
52
53
54

If the above approach takes too much memory for you, you can then use dataloader
with multiprocessing to achieve comparable performance with less memory usage:

.. code-block:: python

    training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle().split(0.8, None)
    training = torch.utils.data.DataLoader(list(training), batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
    validation = torch.utils.data.DataLoader(list(validation), batch_size=batch_size, collate_fn=torchani.data.collate_fn, num_workers=64)
55
"""
Gao, Xiang's avatar
Gao, Xiang committed
56

57
58
from os.path import join, isfile, isdir
import os
59
from ._pyanitools import anidataloader
Gao, Xiang's avatar
Gao, Xiang committed
60
from .. import utils
61
62
63
64
65
66
import importlib
import functools
import math
import random
from collections import Counter
import numpy
67
import gc
68
69
70
71
72
73
74
75

PKBAR_INSTALLED = importlib.util.find_spec('pkbar') is not None  # type: ignore
if PKBAR_INSTALLED:
    import pkbar

verbose = True


76
PROPERTIES = ('energies',)
77
78
79
80
81
82
83
84
PADDING = {
    'species': -1,
    'coordinates': 0.0,
    'forces': 0.0,
    'energies': 0.0
}


85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
def collate_fn(samples):
    return utils.stack_with_padding(samples, PADDING)


class IterableAdapter:
    """https://stackoverflow.com/a/39564774"""
    def __init__(self, iterable_factory, length=None):
        self.iterable_factory = iterable_factory
        self.length = length

    def __iter__(self):
        return iter(self.iterable_factory())


class IterableAdapterWithLength(IterableAdapter):

    def __init__(self, iterable_factory, length):
        super().__init__(iterable_factory)
        self.length = length

    def __len__(self):
        return self.length


109
class Transformations:
110
    """Convert one reenterable iterable to another reenterable iterable"""
111
112

    @staticmethod
113
    def species_to_indices(reenterable_iterable, species_order=('H', 'C', 'N', 'O', 'F', 'S', 'Cl')):
114
115
116
        if species_order == 'periodic_table':
            species_order = utils.PERIODIC_TABLE
        idx = {k: i for i, k in enumerate(species_order)}
117
118
119
120
121
122
123
124
125

        def reenterable_iterable_factory():
            for d in reenterable_iterable:
                d['species'] = numpy.array([idx[s] for s in d['species']])
                yield d
        try:
            return IterableAdapterWithLength(reenterable_iterable_factory, len(reenterable_iterable))
        except TypeError:
            return IterableAdapter(reenterable_iterable_factory)
126
127

    @staticmethod
128
    def subtract_self_energies(reenterable_iterable, self_energies=None, species_order=None):
129
        intercept = 0.0
130
        shape_inference = False
131
        if isinstance(self_energies, utils.EnergyShifter):
132
            shape_inference = True
133
134
135
136
            shifter = self_energies
            self_energies = {}
            counts = {}
            Y = []
137
            for n, d in enumerate(reenterable_iterable):
138
139
140
141
142
143
144
145
146
147
148
149
                species = d['species']
                count = Counter()
                for s in species:
                    count[s] += 1
                for s, c in count.items():
                    if s not in counts:
                        counts[s] = [0] * n
                    counts[s].append(c)
                for s in counts:
                    if len(counts[s]) != n + 1:
                        counts[s].append(0)
                Y.append(d['energies'])
150

151
152
153
154
155
            # sort based on the order in periodic table by default
            if species_order is None:
                species_order = utils.PERIODIC_TABLE

            species = sorted(list(counts.keys()), key=lambda x: species_order.index(x))
156

157
158
159
160
161
162
163
164
165
166
167
168
169
            X = [counts[s] for s in species]
            if shifter.fit_intercept:
                X.append([1] * n)
            X = numpy.array(X).transpose()
            Y = numpy.array(Y)
            sae, _, _, _ = numpy.linalg.lstsq(X, Y, rcond=None)
            sae_ = sae
            if shifter.fit_intercept:
                intercept = sae[-1]
                sae_ = sae[:-1]
            for s, e in zip(species, sae_):
                self_energies[s] = e
            shifter.__init__(sae, shifter.fit_intercept)
170
171
172
173
174
175
176
177
178
179
180
181
        gc.collect()

        def reenterable_iterable_factory():
            for d in reenterable_iterable:
                e = intercept
                for s in d['species']:
                    e += self_energies[s]
                d['energies'] -= e
                yield d
        if shape_inference:
            return IterableAdapterWithLength(reenterable_iterable_factory, n)
        return IterableAdapter(reenterable_iterable_factory)
182
183

    @staticmethod
184
    def remove_outliers(reenterable_iterable, threshold1=15.0, threshold2=8.0):
185
186
187
188
189
190
        assert 'subtract_self_energies', "Transformation remove_outliers can only run after subtract_self_energies"

        # pass 1: remove everything that has per-atom energy > threshold1
        def scaled_energy(x):
            num_atoms = len(x['species'])
            return abs(x['energies']) / math.sqrt(num_atoms)
191
        filtered = IterableAdapter(lambda: (x for x in reenterable_iterable if scaled_energy(x) < threshold1))
192
193
194
195
196
197
198
199
200
201
202
203

        # pass 2: compute those that are outside the mean by threshold2 * std
        n = 0
        mean = 0
        std = 0
        for m in filtered:
            n += 1
            mean += m['energies']
            std += m['energies'] ** 2
        mean /= n
        std = math.sqrt(std / n - mean ** 2)

204
        return IterableAdapter(lambda: filter(lambda x: abs(x['energies'] - mean) < threshold2 * std, filtered))
205
206

    @staticmethod
207
208
209
210
    def shuffle(reenterable_iterable):
        list_ = list(reenterable_iterable)
        del reenterable_iterable
        gc.collect()
211
212
213
214
        random.shuffle(list_)
        return list_

    @staticmethod
215
216
217
218
219
    def cache(reenterable_iterable):
        ret = list(reenterable_iterable)
        del reenterable_iterable
        gc.collect()
        return ret
220
221

    @staticmethod
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
    def collate(reenterable_iterable, batch_size):
        def reenterable_iterable_factory():
            batch = []
            i = 0
            for d in reenterable_iterable:
                batch.append(d)
                i += 1
                if i == batch_size:
                    i = 0
                    yield collate_fn(batch)
                    batch = []
            if len(batch) > 0:
                yield collate_fn(batch)
        try:
            length = (len(reenterable_iterable) + batch_size - 1) // batch_size
            return IterableAdapterWithLength(reenterable_iterable_factory, length)
        except TypeError:
            return IterableAdapter(reenterable_iterable_factory)
240
241

    @staticmethod
242
243
244
245
246
247
248
249
    def pin_memory(reenterable_iterable):
        def reenterable_iterable_factory():
            for d in reenterable_iterable:
                yield {k: d[k].pin_memory() for k in d}
        try:
            return IterableAdapterWithLength(reenterable_iterable_factory, len(reenterable_iterable))
        except TypeError:
            return IterableAdapter(reenterable_iterable_factory)
250
251
252


class TransformableIterable:
253
254
    def __init__(self, wrapped_iterable, transformations=()):
        self.wrapped_iterable = wrapped_iterable
255
256
257
        self.transformations = transformations

    def __iter__(self):
258
        return iter(self.wrapped_iterable)
259
260
261
262
263
264
265

    def __getattr__(self, name):
        transformation = getattr(Transformations, name)

        @functools.wraps(transformation)
        def f(*args, **kwargs):
            return TransformableIterable(
266
                transformation(self.wrapped_iterable, *args, **kwargs),
267
268
269
270
271
                self.transformations + (name,))

        return f

    def split(self, *nums):
272
        length = len(self)
273
274
275
276
277
        iters = []
        self_iter = iter(self)
        for n in nums:
            list_ = []
            if n is not None:
278
                for _ in range(int(n * length)):
279
280
281
282
283
                    list_.append(next(self_iter))
            else:
                for i in self_iter:
                    list_.append(i)
            iters.append(TransformableIterable(list_, self.transformations + ('split',)))
284
285
        del self_iter
        gc.collect()
286
        return iters
287
288

    def __len__(self):
289
        return len(self.wrapped_iterable)
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306


def load(path, additional_properties=()):
    properties = PROPERTIES + additional_properties

    def h5_files(path):
        """yield file name of all h5 files in a path"""
        if isdir(path):
            for f in os.listdir(path):
                f = join(path, f)
                yield from h5_files(f)
        elif isfile(path) and (path.endswith('.h5') or path.endswith('.hdf5')):
            yield path

    def molecules():
        for f in h5_files(path):
            anidata = anidataloader(f)
Gao, Xiang's avatar
Gao, Xiang committed
307
            anidata_size = anidata.group_size()
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
            use_pbar = PKBAR_INSTALLED and verbose
            if use_pbar:
                pbar = pkbar.Pbar('=> loading {}, total molecules: {}'.format(f, anidata_size), anidata_size)
            for i, m in enumerate(anidata):
                yield m
                if use_pbar:
                    pbar.update(i)

    def conformations():
        for m in molecules():
            species = m['species']
            coordinates = m['coordinates']
            for i in range(coordinates.shape[0]):
                ret = {'species': species, 'coordinates': coordinates[i]}
                for k in properties:
                    if k in m:
                        ret[k] = m[k][i]
                yield ret

Gao, Xiang's avatar
Gao, Xiang committed
327
    return TransformableIterable(IterableAdapter(lambda: conformations()))
328
329


330
__all__ = ['load', 'collate_fn']