aev.py 18.5 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
2
import torch
import itertools
3
import math
4
from .env import buildin_const_file
5
from .benchmarked import BenchmarkedModule
6
from . import padding
7
8


9
class AEVComputerBase(BenchmarkedModule):
10
11
    __constants__ = ['Rcr', 'Rca', 'radial_sublength', 'radial_length',
                     'angular_sublength', 'angular_length', 'aev_length']
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36

    """Base class of various implementations of AEV computer

    Attributes
    ----------
    benchmark : boolean
        Whether to enable benchmark
    const_file : str
        The name of the original file that stores constant.
    Rcr, Rca : float
        Cutoff radius
    EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
        Tensor storing constants.
    radial_sublength : int
        The length of radial subaev of a single species
    radial_length : int
        The length of full radial aev
    angular_sublength : int
        The length of angular subaev of a single species
    angular_length : int
        The length of full angular aev
    aev_length : int
        The length of full aev
    """

37
    def __init__(self, benchmark=False, const_file=buildin_const_file):
38
        super(AEVComputerBase, self).__init__(benchmark)
39
40
41
        self.const_file = const_file

        # load constants from const file
42
        const = {}
43
44
45
46
47
48
49
50
51
52
53
54
        with open(const_file) as f:
            for i in f:
                try:
                    line = [x.strip() for x in i.split('=')]
                    name = line[0]
                    value = line[1]
                    if name == 'Rcr' or name == 'Rca':
                        setattr(self, name, float(value))
                    elif name in ['EtaR', 'ShfR', 'Zeta',
                                  'ShfZ', 'EtaA', 'ShfA']:
                        value = [float(x.strip()) for x in value.replace(
                            '[', '').replace(']', '').split(',')]
55
56
                        value = torch.tensor(value)
                        const[name] = value
57
58
59
60
61
62
63
64
                    elif name == 'Atyp':
                        value = [x.strip() for x in value.replace(
                            '[', '').replace(']', '').split(',')]
                        self.species = value
                except Exception:
                    raise ValueError('unable to parse const file')

        # Compute lengths
65
        self.radial_sublength = const['EtaR'].shape[0] * const['ShfR'].shape[0]
66
        self.radial_length = len(self.species) * self.radial_sublength
67
68
69
        self.angular_sublength = const['EtaA'].shape[0] * \
            const['Zeta'].shape[0] * const['ShfA'].shape[0] * \
            const['ShfZ'].shape[0]
70
71
72
73
74
75
76
        species = len(self.species)
        self.angular_length = int(
            (species * (species + 1)) / 2) * self.angular_sublength
        self.aev_length = self.radial_length + self.angular_length

        # convert constant tensors to a ready-to-broadcast shape
        # shape convension (..., EtaR, ShfR)
77
78
        const['EtaR'] = const['EtaR'].view(-1, 1)
        const['ShfR'] = const['ShfR'].view(1, -1)
79
        # shape convension (..., EtaA, Zeta, ShfA, ShfZ)
80
81
82
83
84
85
86
87
        const['EtaA'] = const['EtaA'].view(-1, 1, 1, 1)
        const['Zeta'] = const['Zeta'].view(1, -1, 1, 1)
        const['ShfA'] = const['ShfA'].view(1, 1, -1, 1)
        const['ShfZ'] = const['ShfZ'].view(1, 1, 1, -1)

        # register buffers
        for i in const:
            self.register_buffer(i, const[i])
88

89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
    def forward(self, coordinates_species):
        """Compute AEV from coordinates and species

        Parameters
        ----------
        (species, coordinates)
        species : torch.LongTensor
            Long tensor for the species, where a value k means the species is
            the same as self.species[k]
        coordinates : torch.Tensor
            The tensor that specifies the xyz coordinates of atoms in the
            molecule. The tensor must have shape (conformations, atoms, 3)

        Returns
        -------
        (torch.Tensor, torch.LongTensor)
            Returns full AEV and species
        """
        raise NotImplementedError('subclass must override this method')


Xiang Gao's avatar
Xiang Gao committed
110
111
112
def _cutoff_cosine(distances, cutoff):
    """Compute the elementwise cutoff cosine function

113
114
    The cutoff cosine function is define in
    https://arxiv.org/pdf/1610.08935.pdf equation 2
Xiang Gao's avatar
Xiang Gao committed
115
116
117
118

    Parameters
    ----------
    distances : torch.Tensor
119
120
121
        The pytorch tensor that stores Rij values. This tensor can
        have any shape since the cutoff cosine function is computed
        elementwise.
Xiang Gao's avatar
Xiang Gao committed
122
    cutoff : float
123
124
        The cutoff radius, i.e. the Rc in the equation. For any Rij > Rc,
        the function value is defined to be zero.
Xiang Gao's avatar
Xiang Gao committed
125
126
127
128

    Returns
    -------
    torch.Tensor
129
130
        The tensor of the same shape as `distances` that stores the
        computed function values.
Xiang Gao's avatar
Xiang Gao committed
131
    """
132
133
    return torch.where(
        distances <= cutoff,
134
        0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
135
136
        torch.zeros_like(distances)
    )
Xiang Gao's avatar
Xiang Gao committed
137
138


139
class AEVComputer(AEVComputerBase):
Xiang Gao's avatar
Xiang Gao committed
140
141
142
143
144
145
146
147
148
149
150
    """The AEV computer assuming input coordinates sorted by species

    Attributes
    ----------
    timers : dict
        Dictionary storing the the benchmark result. It has the following keys:
            radial_subaev : time spent on computing radial subaev
            angular_subaev : time spent on computing angular subaev
            total : total time for computing everything.
    """

151
    def __init__(self, benchmark=False, const_file=buildin_const_file):
152
        super(AEVComputer, self).__init__(benchmark, const_file)
Xiang Gao's avatar
Xiang Gao committed
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
        if benchmark:
            self.radial_subaev_terms = self._enable_benchmark(
                self.radial_subaev_terms, 'radial terms')
            self.angular_subaev_terms = self._enable_benchmark(
                self.angular_subaev_terms, 'angular terms')
            self.terms_and_indices = self._enable_benchmark(
                self.terms_and_indices, 'terms and indices')
            self.combinations = self._enable_benchmark(
                self.combinations, 'combinations')
            self.compute_mask_r = self._enable_benchmark(
                self.compute_mask_r, 'mask_r')
            self.compute_mask_a = self._enable_benchmark(
                self.compute_mask_a, 'mask_a')
            self.assemble = self._enable_benchmark(self.assemble, 'assemble')
            self.forward = self._enable_benchmark(self.forward, 'total')

    def radial_subaev_terms(self, distances):
        """Compute the radial subAEV terms of the center atom given neighbors

172
173
174
175
176
        The radial AEV is define in
        https://arxiv.org/pdf/1610.08935.pdf equation 3.
        The sum computed by this method is over all given neighbors,
        so the caller of this method need to select neighbors if the
        caller want a per species subAEV.
Xiang Gao's avatar
Xiang Gao committed
177
178
179
180

        Parameters
        ----------
        distances : torch.Tensor
181
182
            Pytorch tensor of shape (..., neighbors) storing the |Rij|
            length where i are the center atoms, and j are their neighbors.
Xiang Gao's avatar
Xiang Gao committed
183
184
185
186

        Returns
        -------
        torch.Tensor
187
188
            A tensor of shape (..., neighbors, `radial_sublength`) storing
            the subAEVs.
Xiang Gao's avatar
Xiang Gao committed
189
        """
190
        distances = distances.unsqueeze(-1).unsqueeze(-1)
Xiang Gao's avatar
Xiang Gao committed
191
        fc = _cutoff_cosine(distances, self.Rcr)
192
193
194
        # Note that in the equation in the paper there is no 0.25
        # coefficient, but in NeuroChem there is such a coefficient.
        # We choose to be consistent with NeuroChem instead of the paper here.
Xiang Gao's avatar
Xiang Gao committed
195
196
197
198
199
200
        ret = 0.25 * torch.exp(-self.EtaR * (distances - self.ShfR)**2) * fc
        return ret.flatten(start_dim=-2)

    def angular_subaev_terms(self, vectors1, vectors2):
        """Compute the angular subAEV terms of the center atom given neighbor pairs.

201
202
203
204
205
        The angular AEV is define in
        https://arxiv.org/pdf/1610.08935.pdf equation 4.
        The sum computed by this method is over all given neighbor pairs,
        so the caller of this method need to select neighbors if the caller
        want a per species subAEV.
Xiang Gao's avatar
Xiang Gao committed
206
207
208
209

        Parameters
        ----------
        vectors1, vectors2: torch.Tensor
210
211
212
            Tensor of shape (..., pairs, 3) storing the Rij vectors of pairs
            of neighbors. The vectors1(..., j, :) and vectors2(..., j, :) are
            the Rij vectors of the two atoms of pair j.
Xiang Gao's avatar
Xiang Gao committed
213
214
215
216

        Returns
        -------
        torch.Tensor
217
218
            Tensor of shape (..., pairs, `angular_sublength`) storing the
            subAEVs.
Xiang Gao's avatar
Xiang Gao committed
219
220
        """
        vectors1 = vectors1.unsqueeze(
221
            -1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
Xiang Gao's avatar
Xiang Gao committed
222
        vectors2 = vectors2.unsqueeze(
223
            -1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
Xiang Gao's avatar
Xiang Gao committed
224
225
226
        distances1 = vectors1.norm(2, dim=-5)
        distances2 = vectors2.norm(2, dim=-5)

227
228
        # 0.95 is multiplied to the cos values to prevent acos from
        # returning NaN.
Xiang Gao's avatar
Xiang Gao committed
229
230
231
232
233
234
235
236
237
238
239
        cos_angles = 0.95 * \
            torch.nn.functional.cosine_similarity(
                vectors1, vectors2, dim=-5)
        angles = torch.acos(cos_angles)

        fcj1 = _cutoff_cosine(distances1, self.Rca)
        fcj2 = _cutoff_cosine(distances2, self.Rca)
        factor1 = ((1 + torch.cos(angles - self.ShfZ)) / 2) ** self.Zeta
        factor2 = torch.exp(-self.EtaA *
                            ((distances1 + distances2) / 2 - self.ShfA) ** 2)
        ret = 2 * factor1 * factor2 * fcj1 * fcj2
240
241
242
        # ret now have shape (..., pairs, ?, ?, ?, ?) where ? depend on
        # constants

Xiang Gao's avatar
Xiang Gao committed
243
244
245
        # flat the last 4 dimensions to view the subAEV as one dimension vector
        return ret.flatten(start_dim=-4)

246
    def terms_and_indices(self, species, coordinates):
Xiang Gao's avatar
Xiang Gao committed
247
248
        """Compute radial and angular subAEV terms, and original indices.

249
250
251
        Terms will be sorted according to their distances to central atoms,
        and only these within cutoff radius are valid. The returned indices
        contains what would their original indices be if they were unsorted.
Xiang Gao's avatar
Xiang Gao committed
252
253
254

        Parameters
        ----------
255
256
257
        species : torch.Tensor
            The tensor that specifies the species of atoms in the molecule.
            The tensor must have shape (conformations, atoms)
Xiang Gao's avatar
Xiang Gao committed
258
        coordinates : torch.Tensor
259
260
            The tensor that specifies the xyz coordinates of atoms in the
            molecule. The tensor must have shape (conformations, atoms, 3)
Xiang Gao's avatar
Xiang Gao committed
261
262
263
264
265

        Returns
        -------
        (radial_terms, angular_terms, indices_r, indices_a)
        radial_terms : torch.Tensor
266
267
            Tensor shaped (conformations, atoms, neighbors, `radial_sublength`)
            for the (unsummed) radial subAEV terms.
Xiang Gao's avatar
Xiang Gao committed
268
        angular_terms : torch.Tensor
269
270
            Tensor of shape (conformations, atoms, pairs, `angular_sublength`)
            for the (unsummed) angular subAEV terms.
Xiang Gao's avatar
Xiang Gao committed
271
        indices_r : torch.Tensor
272
273
274
275
            Tensor of shape (conformations, atoms, neighbors).
            Let l = indices_r(i,j,k), then this means that
            radial_terms(i,j,k,:) is in the subAEV term of conformation i
            between atom j and atom l.
Xiang Gao's avatar
Xiang Gao committed
276
        indices_a : torch.Tensor
277
278
            Same as indices_r, except that the cutoff radius is Rca instead of
            Rcr.
Xiang Gao's avatar
Xiang Gao committed
279
280
281
282
283
284
285
286
        """

        vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
        """Shape (conformations, atoms, atoms, 3) storing Rij vectors"""

        distances = vec.norm(2, -1)
        """Shape (conformations, atoms, atoms) storing Rij distances"""

287
        padding_mask = (species == -1).unsqueeze(1)
288
289
290
291
292
        distances = torch.where(
            padding_mask,
            torch.tensor(math.inf, dtype=self.EtaR.dtype,
                         device=self.EtaR.device),
            distances)
293

Xiang Gao's avatar
Xiang Gao committed
294
295
296
        distances, indices = distances.sort(-1)

        min_distances, _ = distances.flatten(end_dim=1).min(0)
297
298
        inRcr = (min_distances <= self.Rcr).nonzero().flatten()[
            1:]  # TODO: can we use something like find_first?
Xiang Gao's avatar
Xiang Gao committed
299
300
301
302
303
304
305
306
        inRca = (min_distances <= self.Rca).nonzero().flatten()[1:]

        distances = distances.index_select(-1, inRcr)
        indices_r = indices.index_select(-1, inRcr)
        radial_terms = self.radial_subaev_terms(distances)

        indices_a = indices.index_select(-1, inRca)
        new_shape = list(indices_a.shape) + [3]
307
308
309
310
311
312
        # TODO: can we add something like expand_dim(dim=0, repeat=3)
        _indices_a = indices_a.unsqueeze(-1).expand(*new_shape)
        # TODO: can we make gather broadcast??
        vec = vec.gather(-2, _indices_a)
        # TODO: can we move combinations to ATen?
        vec = self.combinations(vec, -2)
313
        angular_terms = self.angular_subaev_terms(*vec)
Xiang Gao's avatar
Xiang Gao committed
314
315
316
317
318
319
320

        return radial_terms, angular_terms, indices_r, indices_a

    def combinations(self, tensor, dim=0):
        n = tensor.shape[dim]
        r = torch.arange(n).type(torch.long).to(tensor.device)
        grid_x, grid_y = torch.meshgrid([r, r])
321
322
323
324
325
326
        index1 = grid_y.masked_select(
            torch.triu(torch.ones(n, n, device=self.EtaR.device),
                       diagonal=1) == 1)
        index2 = grid_x.masked_select(
            torch.triu(torch.ones(n, n, device=self.EtaR.device),
                       diagonal=1) == 1)
327
328
        return tensor.index_select(dim, index1), \
            tensor.index_select(dim, index2)
Xiang Gao's avatar
Xiang Gao committed
329

330
    def compute_mask_r(self, species, indices_r):
Xiang Gao's avatar
Xiang Gao committed
331
332
333
334
        """Partition indices according to their species, radial part

        Parameters
        ----------
335
336
337
338
339
        indices_r : torch.Tensor
            Tensor of shape (conformations, atoms, neighbors).
            Let l = indices_r(i,j,k), then this means that
            radial_terms(i,j,k,:) is in the subAEV term of conformation i
            between atom j and atom l.
Xiang Gao's avatar
Xiang Gao committed
340
341
342
343

        Returns
        -------
        torch.Tensor
344
345
            Tensor of shape (conformations, atoms, neighbors, all species)
            storing the mask for each species.
Xiang Gao's avatar
Xiang Gao committed
346
        """
347
348
349
        species_r = species.gather(-1, indices_r)
        """Tensor of shape (conformations, atoms, neighbors) storing species
        of neighbors."""
Xiang Gao's avatar
Xiang Gao committed
350
        mask_r = (species_r.unsqueeze(-1) ==
351
                  torch.arange(len(self.species), device=self.EtaR.device))
Xiang Gao's avatar
Xiang Gao committed
352
353
        return mask_r

354
    def compute_mask_a(self, species, indices_a, present_species):
Xiang Gao's avatar
Xiang Gao committed
355
356
357
358
359
        """Partition indices according to their species, angular part

        Parameters
        ----------
        species_a : torch.Tensor
360
361
            Tensor of shape (conformations, atoms, neighbors) storing the
            species of neighbors.
Xiang Gao's avatar
Xiang Gao committed
362
363
364
365
366
367
        present_species : torch.Tensor
            Long tensor for the species, already uniqued.

        Returns
        -------
        torch.Tensor
368
369
            Tensor of shape (conformations, atoms, pairs, present species,
            present species) storing the mask for each pair.
Xiang Gao's avatar
Xiang Gao committed
370
        """
371
        species_a = species.gather(-1, indices_a)
Gao, Xiang's avatar
Gao, Xiang committed
372
373
374
        species_a1, species_a2 = self.combinations(species_a, -1)
        mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
        mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
375
376
377
378
        mask = mask_a1 * mask_a2
        mask_rev = mask.permute(0, 1, 2, 4, 3)
        mask_a = (mask + mask_rev) > 0
        return mask_a
Xiang Gao's avatar
Xiang Gao committed
379

380
381
382
383
    def assemble(self, radial_terms, angular_terms, present_species,
                 mask_r, mask_a):
        """Assemble radial and angular AEV from computed terms according
        to the given partition information.
Xiang Gao's avatar
Xiang Gao committed
384
385
386
387

        Parameters
        ----------
        radial_terms : torch.Tensor
388
389
            Tensor shaped (conformations, atoms, neighbors, `radial_sublength`)
            for the (unsummed) radial subAEV terms.
Xiang Gao's avatar
Xiang Gao committed
390
        angular_terms : torch.Tensor
391
392
            Tensor of shape (conformations, atoms, pairs, `angular_sublength`)
            for the (unsummed) angular subAEV terms.
Xiang Gao's avatar
Xiang Gao committed
393
394
395
        present_species : torch.Tensor
            Long tensor for species of atoms present in the molecules.
        mask_r : torch.Tensor
396
397
            Tensor of shape (conformations, atoms, neighbors, present species)
            storing the mask for each species.
Xiang Gao's avatar
Xiang Gao committed
398
        mask_a : torch.Tensor
399
400
            Tensor of shape (conformations, atoms, pairs, present species,
            present species) storing the mask for each pair.
Xiang Gao's avatar
Xiang Gao committed
401
402
403
404

        Returns
        -------
        (torch.Tensor, torch.Tensor)
405
406
407
408
            Returns (radial AEV, angular AEV), both are pytorch tensor of
            `dtype`. The radial AEV must be of shape (conformations, atoms,
            radial_length) The angular AEV must be of shape (conformations,
            atoms, angular_length)
Xiang Gao's avatar
Xiang Gao committed
409
410
411
412
413
        """
        conformations = radial_terms.shape[0]
        atoms = radial_terms.shape[1]

        # assemble radial subaev
414
415
416
417
        present_radial_aevs = (
            radial_terms.unsqueeze(-2) *
            mask_r.unsqueeze(-1).type(radial_terms.dtype)
        ).sum(-3)
418
        """shape (conformations, atoms, present species, radial_length)"""
Xiang Gao's avatar
Xiang Gao committed
419
420
421
        radial_aevs = present_radial_aevs.flatten(start_dim=2)

        # assemble angular subaev
422
423
        # TODO: can we use find_first?
        rev_indices = {present_species[i].item(): i
Xiang Gao's avatar
Xiang Gao committed
424
                       for i in range(len(present_species))}
425
426
        """shape (conformations, atoms, present species,
                  present species, angular_length)"""
Xiang Gao's avatar
Xiang Gao committed
427
        angular_aevs = []
428
429
430
        zero_angular_subaev = torch.zeros(
            # TODO: can we make stack and cat broadcast?
            conformations, atoms, self.angular_sublength,
431
            dtype=self.EtaR.dtype, device=self.EtaR.device)
432
433
        for s1, s2 in itertools.combinations_with_replacement(
                                        range(len(self.species)), 2):
434
            if s1 in rev_indices and s2 in rev_indices:
Xiang Gao's avatar
Xiang Gao committed
435
436
                i1 = rev_indices[s1]
                i2 = rev_indices[s2]
437
                mask = mask_a[..., i1, i2].unsqueeze(-1).type(self.EtaR.dtype)
Xiang Gao's avatar
Xiang Gao committed
438
439
440
441
442
443
444
                subaev = (angular_terms * mask).sum(-2)
            else:
                subaev = zero_angular_subaev
            angular_aevs.append(subaev)

        return radial_aevs, torch.cat(angular_aevs, dim=2)

445
446
    def forward(self, species_coordinates):
        species, coordinates = species_coordinates
Xiang Gao's avatar
Xiang Gao committed
447

448
        present_species = padding.present_species(species)
Xiang Gao's avatar
Xiang Gao committed
449

450
451
452
453
454
455
456
457
        # TODO: remove this workaround after gather support broadcasting
        atoms = coordinates.shape[1]
        species_ = species.unsqueeze(1).expand(-1, atoms, -1)

        radial_terms, angular_terms, indices_r, indices_a = \
            self.terms_and_indices(species, coordinates)
        mask_r = self.compute_mask_r(species_, indices_r)
        mask_a = self.compute_mask_a(species_, indices_a, present_species)
Xiang Gao's avatar
Xiang Gao committed
458

459
460
461
        radial, angular = self.assemble(radial_terms, angular_terms,
                                        present_species, mask_r, mask_a)
        fullaev = torch.cat([radial, angular], dim=2)
462
        return species, fullaev