test_energyshifter.py 1.04 KB
Newer Older
Gao, Xiang's avatar
Gao, Xiang committed
1
2
3
4
5
6
7
8
9
10
import torch
import torchani
import unittest
import random


class TestEnergyShifter(unittest.TestCase):

    def setUp(self):
        self.tol = 1e-5
11
        self.species = torchani.AEVComputer().species
Gao, Xiang's avatar
Gao, Xiang committed
12
13
14
15
        self.prepare = torchani.PrepareInput(self.species)
        self.shift_energy = torchani.EnergyShifter(self.species)

    def testSAEMatch(self):
16
17
        species_coordinates = []
        saes = []
Gao, Xiang's avatar
Gao, Xiang committed
18
19
20
        for _ in range(10):
            k = random.choice(range(5, 30))
            species = random.choices(self.species, k=k)
21
22
23
24
25
26
27
28
            coordinates = torch.empty(1, k, 3)
            species_coordinates.append(self.prepare((species, coordinates)))
            e = self.shift_energy.sae_from_list(species)
            saes.append(e)
        species, _ = torchani.padding.pad_and_batch(species_coordinates)
        saes_ = self.shift_energy.sae_from_tensor(species)
        saes = torch.tensor(saes, dtype=saes_.dtype, device=saes_.device)
        self.assertLess((saes - saes_).abs().max(), self.tol)
Gao, Xiang's avatar
Gao, Xiang committed
29
30
31
32


if __name__ == '__main__':
    unittest.main()