common_aev_test.py 1.1 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import random
import unittest
import torch
import torchani

tolerance = 1e-5


class _TestAEVBase(unittest.TestCase):

    def setUp(self):
        ani1x = torchani.models.ANI1x()
        self.aev_computer = ani1x.aev_computer
        self.radial_length = self.aev_computer.radial_length
        self.debug = False

    def transform(self, x):
        return x

    def random_skip(self, prob=0):
        return random.random() < prob

    def assertAEVEqual(self, expected_radial, expected_angular, aev, tolerance=tolerance):
        radial = aev[..., :self.radial_length]
        angular = aev[..., self.radial_length:]
        radial_diff = expected_radial - radial
        if self.debug:
            aid = 1
            print(torch.stack([expected_radial[0, aid, :], radial[0, aid, :], radial_diff.abs()[0, aid, :]], dim=1))
        radial_max_error = torch.max(torch.abs(radial_diff)).item()
        angular_diff = expected_angular - angular
        angular_max_error = torch.max(torch.abs(angular_diff)).item()
        self.assertLess(radial_max_error, tolerance)
        self.assertLess(angular_max_error, tolerance)