Unverified Commit 2cc64cb7 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add 3D structures downloaded from NIST to unit test set (#146)

parent 91cec854
*.h5 filter=lfs diff=lfs merge=lfs -text
tests/test_data/* filter=lfs diff=lfs merge=lfs -text
\ No newline at end of file
tests/test_data/ANI1_subset/* filter=lfs diff=lfs merge=lfs -text
tests/test_data/NIST/* filter=lfs diff=lfs merge=lfs -text
tools/diverse_test_set/result.json filter=lfs diff=lfs merge=lfs -text
queue:
name: Hosted Linux Preview
name: Hosted Ubuntu 1604
timeoutInMinutes: 300
variables:
......
queue:
name: Hosted Linux Preview
name: Hosted Ubuntu 1604
timeoutInMinutes: 10
variables:
......
queue:
name: Hosted Linux Preview
name: Hosted Ubuntu 1604
timeoutInMinutes: 300
variables:
......
queue:
name: Hosted Linux Preview
timeoutInMinutes: 300
name: Hosted Ubuntu 1604
timeoutInMinutes: 6000
variables:
python.version: '3.7'
......
queue:
name: Hosted Linux Preview
name: Hosted Ubuntu 1604
timeoutInMinutes: 30
variables:
......
......@@ -3,6 +3,7 @@ import torchani
import unittest
import os
import pickle
import random
path = os.path.dirname(os.path.realpath(__file__))
N = 97
......@@ -16,6 +17,12 @@ class TestAEV(unittest.TestCase):
self.radial_length = self.aev_computer.radial_length()
self.tolerance = 1e-5
def random_skip(self):
return False
def transform(self, x):
return x
def _assertAEVEqual(self, expected_radial, expected_angular, aev):
radial = aev[..., :self.radial_length]
angular = aev[..., self.radial_length:]
......@@ -28,10 +35,18 @@ class TestAEV(unittest.TestCase):
def testIsomers(self):
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, expected_radial, expected_angular, _, _ \
= pickle.load(f)
coordinates = torch.from_numpy(coordinates)
species = torch.from_numpy(species)
expected_radial = torch.from_numpy(expected_radial)
expected_angular = torch.from_numpy(expected_angular)
coordinates = self.transform(coordinates)
species = self.transform(species)
expected_radial = self.transform(expected_radial)
expected_angular = self.transform(expected_angular)
_, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(expected_radial, expected_angular, aev)
......@@ -39,9 +54,17 @@ class TestAEV(unittest.TestCase):
species_coordinates = []
radial_angular = []
for i in range(N):
datafile = os.path.join(path, 'test_data/{}'.format(i))
datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, radial, angular, _, _ = pickle.load(f)
coordinates = torch.from_numpy(coordinates)
species = torch.from_numpy(species)
radial = torch.from_numpy(radial)
angular = torch.from_numpy(angular)
coordinates = self.transform(coordinates)
species = self.transform(species)
radial = self.transform(radial)
angular = self.transform(angular)
species_coordinates.append((species, coordinates))
radial_angular.append((radial, angular))
species, coordinates = torchani.utils.pad_coordinates(
......@@ -55,6 +78,20 @@ class TestAEV(unittest.TestCase):
start += conformations
self._assertAEVEqual(expected_radial, expected_angular, aev_)
def testNIST(self):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, radial, angular, _, _ in data:
if self.random_skip():
continue
coordinates = torch.from_numpy(coordinates).to(torch.float)
species = torch.from_numpy(species)
radial = torch.from_numpy(radial).to(torch.float)
angular = torch.from_numpy(angular).to(torch.float)
_, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(radial, angular, aev)
class TestAEVASENeighborList(TestAEV):
......@@ -62,6 +99,14 @@ class TestAEVASENeighborList(TestAEV):
super(TestAEVASENeighborList, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x):
"""To reduce the size of test cases for faster test speed"""
return x[:2, ...]
def random_skip(self):
"""To reduce the size of test cases for faster test speed"""
return random.random() < 0.95
if __name__ == '__main__':
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment