Unverified Commit ba02a674 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add Benzene and Tripeptide MD trajectory to unit test (#213)

parent db36c1a0
......@@ -2,4 +2,7 @@
tests/test_data/ANI1_subset/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/NIST/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/NeuroChemOptimized/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/benzene-md/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/tripeptide-md/* filter=lfs diff=lfs merge=lfs -text
tools/generate-unit-test-expect/nist-dataset/result.json filter=lfs diff=lfs merge=lfs -text
......@@ -11,6 +11,7 @@ import math
path = os.path.dirname(os.path.realpath(__file__))
N = 97
tolerance = 1e-5
class TestAEV(unittest.TestCase):
......@@ -19,7 +20,7 @@ class TestAEV(unittest.TestCase):
builtins = torchani.neurochem.Builtins()
self.aev_computer = builtins.aev_computer
self.radial_length = self.aev_computer.radial_length
self.tolerance = 1e-5
self.debug = False
def random_skip(self, prob=0):
return random.random() < prob
......@@ -27,15 +28,18 @@ class TestAEV(unittest.TestCase):
def transform(self, x):
return x
def _assertAEVEqual(self, expected_radial, expected_angular, aev):
def assertAEVEqual(self, expected_radial, expected_angular, aev, tolerance=tolerance):
radial = aev[..., :self.radial_length]
angular = aev[..., self.radial_length:]
radial_diff = expected_radial - radial
if self.debug:
aid = 1
print(torch.stack([expected_radial[0, aid, :], radial[0, aid, :], radial_diff.abs()[0, aid, :]], dim=1))
radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular
angular_max_error = torch.max(torch.abs(angular_diff)).item()
self.assertLess(radial_max_error, self.tolerance)
self.assertLess(angular_max_error, self.tolerance)
self.assertLess(radial_max_error, tolerance)
self.assertLess(angular_max_error, tolerance)
def testIsomers(self):
for i in range(N):
......@@ -52,7 +56,46 @@ class TestAEV(unittest.TestCase):
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev)
self.assertAEVEqual(expected_radial, expected_angular, aev)
@unittest.skipIf(True, "WIP")
def testBenzeneMD(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/benzene-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _, cell, pbc \
= pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0)
species = torch.from_numpy(species).unsqueeze(0)
expected_radial = torch.from_numpy(expected_radial).float().unsqueeze(0)
expected_angular = torch.from_numpy(expected_angular).float().unsqueeze(0)
cell = torch.from_numpy(cell).float()
pbc = torch.from_numpy(pbc)
coordinates = torchani.utils.map2central(cell, coordinates, pbc)
coordinates = self.transform(coordinates)
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
self.assertAEVEqual(expected_radial, expected_angular, aev)
def testTripeptideMD(self):
tol = 5e-6
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _, _, _ \
= pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0)
species = torch.from_numpy(species).unsqueeze(0)
expected_radial = torch.from_numpy(expected_radial).float().unsqueeze(0)
expected_angular = torch.from_numpy(expected_angular).float().unsqueeze(0)
coordinates = self.transform(coordinates)
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates))
self.assertAEVEqual(expected_radial, expected_angular, aev, tol)
def testPadding(self):
species_coordinates = []
......@@ -80,7 +123,7 @@ class TestAEV(unittest.TestCase):
atoms = expected_radial.shape[1]
aev_ = aev[start:(start + conformations), 0:atoms]
start += conformations
self._assertAEVEqual(expected_radial, expected_angular, aev_)
self.assertAEVEqual(expected_radial, expected_angular, aev_)
def testNIST(self):
datafile = os.path.join(path, 'test_data/NIST/all')
......@@ -94,7 +137,7 @@ class TestAEV(unittest.TestCase):
radial = torch.from_numpy(radial).to(torch.float)
angular = torch.from_numpy(angular).to(torch.float)
_, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(radial, angular, aev)
self.assertAEVEqual(radial, angular, aev)
@unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU")
def testGradient(self):
......
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment