common_aev_test.py 906 Bytes
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
import torch
import torchani
Gao, Xiang's avatar
Gao, Xiang committed
3
import os
Jinze Xue's avatar
Jinze Xue committed
4
from torchani.testing import TestCase
Gao, Xiang's avatar
Gao, Xiang committed
5
6


Jinze Xue's avatar
Jinze Xue committed
7
class _TestAEVBase(TestCase):
Gao, Xiang's avatar
Gao, Xiang committed
8
9

    def setUp(self):
Gao, Xiang's avatar
Gao, Xiang committed
10
11
12
13
        path = os.path.dirname(os.path.realpath(__file__))
        const_file = os.path.join(path, '../torchani/resources/ani-1x_8x/rHCNO-5.2R_16-3.5A_a4-8.params')  # noqa: E501
        consts = torchani.neurochem.Constants(const_file)
        self.aev_computer = torchani.AEVComputer(**consts)
Gao, Xiang's avatar
Gao, Xiang committed
14
15
16
        self.radial_length = self.aev_computer.radial_length
        self.debug = False

17
    def assertAEVEqual(self, expected_radial, expected_angular, aev):
Gao, Xiang's avatar
Gao, Xiang committed
18
19
20
21
        radial = aev[..., :self.radial_length]
        angular = aev[..., self.radial_length:]
        if self.debug:
            aid = 1
22
23
24
            print(torch.stack([expected_radial[0, aid, :], radial[0, aid, :]]))
        self.assertEqual(expected_radial, radial)
        self.assertEqual(expected_angular, angular)