test_vibrational.py 1.97 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import os
import math
import unittest
import torch
import torchani
import ase
import ase.optimize
import ase.vibrations
import numpy


path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/xyz_files/H2O.xyz')


class TestVibrational(unittest.TestCase):

    def testVibrationalWavenumbers(self):
        model = torchani.models.ANI1x().double()
        d = 0.9575
        t = math.pi / 180 * 104.51
        molecule = ase.Atoms('H2O', positions=[
            (d, 0, 0),
            (d * math.cos(t), d * math.sin(t), 0),
            (0, 0, 0),
        ], calculator=model.ase())
        opt = ase.optimize.BFGS(molecule)
        opt.run(fmax=1e-6)
        masses = torch.tensor([1.008, 12.011, 14.007, 15.999], dtype=torch.double)
        # compute vibrational frequencies by ASE
        vib = ase.vibrations.Vibrations(molecule)
        vib.run()
33
34
35
36
        freq = torch.tensor([numpy.real(x) for x in vib.get_frequencies()[6:]])
        modes = []
        for j in range(6, 6 + len(freq)):
            modes.append(vib.get_mode(j))
Ignacio Pickering's avatar
Ignacio Pickering committed
37
        vib.clean()
38
        modes = torch.tensor(modes)
39
40
41
        # compute vibrational by torchani
        species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0)
        coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True)
42
43
        _, energies = model((species, coordinates))
        hessian = torchani.utils.hessian(coordinates, energies=energies)
44
        freq2, modes2, _, _ = torchani.utils.vibrational_analysis(masses[species], hessian)
45
46
        freq2 = freq2[6:].float()
        modes2 = modes2[6:]
47
48
49
        ratio = freq2 / freq
        self.assertLess((ratio - 1).abs().max(), 0.02)

50
51
52
53
54
        diff1 = (modes - modes2).abs().max(dim=-1).values.max(dim=-1).values
        diff2 = (modes + modes2).abs().max(dim=-1).values.max(dim=-1).values
        diff = torch.where(diff1 < diff2, diff1, diff2)
        self.assertLess(diff.max(), 0.02)

55
56
57

if __name__ == '__main__':
    unittest.main()