Unverified Commit ba02a674 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add Benzene and Tripeptide MD trajectory to unit test (#213)

parent db36c1a0
...@@ -2,4 +2,7 @@ ...@@ -2,4 +2,7 @@
tests/test_data/ANI1_subset/* filter=lfs diff=lfs merge=lfs -text tests/test_data/ANI1_subset/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/NIST/* filter=lfs diff=lfs merge=lfs -text tests/test_data/NIST/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/NeuroChemOptimized/* filter=lfs diff=lfs merge=lfs -text tests/test_data/NeuroChemOptimized/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/benzene-md/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/tripeptide-md/* filter=lfs diff=lfs merge=lfs -text
tools/generate-unit-test-expect/nist-dataset/result.json filter=lfs diff=lfs merge=lfs -text tools/generate-unit-test-expect/nist-dataset/result.json filter=lfs diff=lfs merge=lfs -text
...@@ -11,6 +11,7 @@ import math ...@@ -11,6 +11,7 @@ import math
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
N = 97 N = 97
tolerance = 1e-5
class TestAEV(unittest.TestCase): class TestAEV(unittest.TestCase):
...@@ -19,7 +20,7 @@ class TestAEV(unittest.TestCase): ...@@ -19,7 +20,7 @@ class TestAEV(unittest.TestCase):
builtins = torchani.neurochem.Builtins() builtins = torchani.neurochem.Builtins()
self.aev_computer = builtins.aev_computer self.aev_computer = builtins.aev_computer
self.radial_length = self.aev_computer.radial_length self.radial_length = self.aev_computer.radial_length
self.tolerance = 1e-5 self.debug = False
def random_skip(self, prob=0): def random_skip(self, prob=0):
return random.random() < prob return random.random() < prob
...@@ -27,15 +28,18 @@ class TestAEV(unittest.TestCase): ...@@ -27,15 +28,18 @@ class TestAEV(unittest.TestCase):
def transform(self, x): def transform(self, x):
return x return x
def _assertAEVEqual(self, expected_radial, expected_angular, aev): def assertAEVEqual(self, expected_radial, expected_angular, aev, tolerance=tolerance):
radial = aev[..., :self.radial_length] radial = aev[..., :self.radial_length]
angular = aev[..., self.radial_length:] angular = aev[..., self.radial_length:]
radial_diff = expected_radial - radial radial_diff = expected_radial - radial
if self.debug:
aid = 1
print(torch.stack([expected_radial[0, aid, :], radial[0, aid, :], radial_diff.abs()[0, aid, :]], dim=1))
radial_max_error = torch.max(torch.abs(radial_diff)).item() radial_max_error = torch.max(torch.abs(radial_diff)).item()
angular_diff = expected_angular - angular angular_diff = expected_angular - angular
angular_max_error = torch.max(torch.abs(angular_diff)).item() angular_max_error = torch.max(torch.abs(angular_diff)).item()
self.assertLess(radial_max_error, self.tolerance) self.assertLess(radial_max_error, tolerance)
self.assertLess(angular_max_error, self.tolerance) self.assertLess(angular_max_error, tolerance)
def testIsomers(self): def testIsomers(self):
for i in range(N): for i in range(N):
...@@ -52,7 +56,46 @@ class TestAEV(unittest.TestCase): ...@@ -52,7 +56,46 @@ class TestAEV(unittest.TestCase):
expected_radial = self.transform(expected_radial) expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular) expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev) self.assertAEVEqual(expected_radial, expected_angular, aev)
@unittest.skipIf(True, "WIP")
def testBenzeneMD(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/benzene-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _, cell, pbc \
= pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0)
species = torch.from_numpy(species).unsqueeze(0)
expected_radial = torch.from_numpy(expected_radial).float().unsqueeze(0)
expected_angular = torch.from_numpy(expected_angular).float().unsqueeze(0)
cell = torch.from_numpy(cell).float()
pbc = torch.from_numpy(pbc)
coordinates = torchani.utils.map2central(cell, coordinates, pbc)
coordinates = self.transform(coordinates)
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
self.assertAEVEqual(expected_radial, expected_angular, aev)
def testTripeptideMD(self):
tol = 5e-6
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _, _, _ \
= pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0)
species = torch.from_numpy(species).unsqueeze(0)
expected_radial = torch.from_numpy(expected_radial).float().unsqueeze(0)
expected_angular = torch.from_numpy(expected_angular).float().unsqueeze(0)
coordinates = self.transform(coordinates)
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates))
self.assertAEVEqual(expected_radial, expected_angular, aev, tol)
def testPadding(self): def testPadding(self):
species_coordinates = [] species_coordinates = []
...@@ -80,7 +123,7 @@ class TestAEV(unittest.TestCase): ...@@ -80,7 +123,7 @@ class TestAEV(unittest.TestCase):
atoms = expected_radial.shape[1] atoms = expected_radial.shape[1]
aev_ = aev[start:(start + conformations), 0:atoms] aev_ = aev[start:(start + conformations), 0:atoms]
start += conformations start += conformations
self._assertAEVEqual(expected_radial, expected_angular, aev_) self.assertAEVEqual(expected_radial, expected_angular, aev_)
def testNIST(self): def testNIST(self):
datafile = os.path.join(path, 'test_data/NIST/all') datafile = os.path.join(path, 'test_data/NIST/all')
...@@ -94,7 +137,7 @@ class TestAEV(unittest.TestCase): ...@@ -94,7 +137,7 @@ class TestAEV(unittest.TestCase):
radial = torch.from_numpy(radial).to(torch.float) radial = torch.from_numpy(radial).to(torch.float)
angular = torch.from_numpy(angular).to(torch.float) angular = torch.from_numpy(angular).to(torch.float)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(radial, angular, aev) self.assertAEVEqual(radial, angular, aev)
@unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU") @unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU")
def testGradient(self): def testGradient(self):
......
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
This source diff could not be displayed because it is stored in LFS. You can view the blob instead.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment