test_vibrational.py 1.99 KB
Newer Older
1
2
3
4
5
6
7
8
9
import os
import math
import unittest
import torch
import torchani
import ase
import ase.optimize
import ase.vibrations
import numpy
Jinze Xue's avatar
Jinze Xue committed
10
from torchani.testing import TestCase
11
12
13
14
15
16


path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/xyz_files/H2O.xyz')


Jinze Xue's avatar
Jinze Xue committed
17
class TestVibrational(TestCase):
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33

    def testVibrationalWavenumbers(self):
        model = torchani.models.ANI1x().double()
        d = 0.9575
        t = math.pi / 180 * 104.51
        molecule = ase.Atoms('H2O', positions=[
            (d, 0, 0),
            (d * math.cos(t), d * math.sin(t), 0),
            (0, 0, 0),
        ], calculator=model.ase())
        opt = ase.optimize.BFGS(molecule)
        opt.run(fmax=1e-6)
        masses = torch.tensor([1.008, 12.011, 14.007, 15.999], dtype=torch.double)
        # compute vibrational frequencies by ASE
        vib = ase.vibrations.Vibrations(molecule)
        vib.run()
34
35
36
37
        freq = torch.tensor([numpy.real(x) for x in vib.get_frequencies()[6:]])
        modes = []
        for j in range(6, 6 + len(freq)):
            modes.append(vib.get_mode(j))
Ignacio Pickering's avatar
Ignacio Pickering committed
38
        vib.clean()
39
        modes = torch.tensor(modes)
40
41
42
        # compute vibrational by torchani
        species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0)
        coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True)
43
44
        _, energies = model((species, coordinates))
        hessian = torchani.utils.hessian(coordinates, energies=energies)
45
        freq2, modes2, _, _ = torchani.utils.vibrational_analysis(masses[species], hessian)
46
47
        freq2 = freq2[6:].float()
        modes2 = modes2[6:]
48
        self.assertEqual(freq, freq2, atol=0, rtol=0.02, exact_dtype=False)
49

50
51
52
53
54
        diff1 = (modes - modes2).abs().max(dim=-1).values.max(dim=-1).values
        diff2 = (modes + modes2).abs().max(dim=-1).values.max(dim=-1).values
        diff = torch.where(diff1 < diff2, diff1, diff2)
        self.assertLess(diff.max(), 0.02)

55
56
57

if __name__ == '__main__':
    unittest.main()