Unverified Commit 23c9816c authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

CUAEV backward (#554)



* preparation

* radial preparation 30%

* radial backward kernel done

* reuse Gmr (exp part) result for gradient

* radial kernel every block run by column major, to avoid atomicAdd waiting

* apply code review

* static_cast

* implicit cast

* format

* angular preparation

* angular backward works, but slow, AtomicAdd should be avoided

* angular opti: use share memory to avoid AtomicAdd

* format

* equation optimization

* remove unnecessary shared mem for atomi

* remove a lot (warpsize * nbr) unnecessary shared mem for atomj

* format

* update

* clean

* fix

* fix

* test file

* fix
Co-authored-by: default avatarGao, Xiang <qasdfgtyuiop@gmail.com>
parent 7cf6823a
......@@ -65,24 +65,70 @@ class TestCUAEV(TestCase):
[-4.4978600, 0.8211300, 0.5604100],
[-4.4978700, -0.8000100, 0.4155600],
[0.00000000, -0.00000000, -0.00000000]]
], requires_grad=True, device=self.device)
], device=self.device)
species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)
_, aev = self.aev_computer((species, coordinates))
_, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev)
def testSimpleBackward(self):
coordinates = torch.tensor([
[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]],
[[-4.1862600, 0.0575700, -0.0381200],
[-3.1689400, 0.0523700, 0.0200000],
[-4.4978600, 0.8211300, 0.5604100],
[-4.4978700, -0.8000100, 0.4155600],
[0.00000000, -0.00000000, -0.00000000]]
], requires_grad=True, device=self.device)
species = torch.tensor([[1, 0, 0, 0, 0], [2, 0, 0, 0, -1]], device=self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev, f'cu_aev: {cu_aev}\n aev: {aev}')
self.assertEqual(cuaev_grad, aev_grad, f'\ncuaev_grad: {cuaev_grad}\n aev_grad: {aev_grad}')
def testTripeptideMD(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _, _, _ = pickle.load(f)
coordinates, species, *_ = pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
_, aev = self.aev_computer((species, coordinates))
_, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev)
def testTripeptideMDBackward(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
coordinates = torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)
def testNIST(self):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
......@@ -94,11 +140,33 @@ class TestCUAEV(TestCase):
_, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev)
def testNISTBackward(self):
datafile = os.path.join(path, 'test_data/NIST/all')
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, _, _, _, _ in data:
coordinates = torch.from_numpy(coordinates).to(torch.float).to(self.device).requires_grad_(True)
species = torch.from_numpy(species).to(self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-5, rtol=5e-5)
def testVeryDenseMolecule(self):
"""
Test very dense molecule for aev correctness, especially for angular part
"""
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _, _, _ = pickle.load(f)
coordinates, species, *_ = pickle.load(f)
# change angstrom coordinates to 10 times smaller
coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
......@@ -106,6 +174,28 @@ class TestCUAEV(TestCase):
_, cu_aev = self.cuaev_computer((species, coordinates))
self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)
def testVeryDenseMoleculeBackward(self):
for i in range(100):
datafile = os.path.join(path, 'test_data/tripeptide-md/{}.dat'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, *_ = pickle.load(f)
# change angstrom coordinates to 10 times smaller
coordinates = 0.1 * torch.from_numpy(coordinates).float().unsqueeze(0).to(self.device)
coordinates.requires_grad_(True)
species = torch.from_numpy(species).unsqueeze(0).to(self.device)
_, aev = self.aev_computer((species, coordinates))
aev.backward(torch.ones_like(aev))
aev_grad = coordinates.grad
coordinates = coordinates.clone().detach()
coordinates.requires_grad_()
_, cu_aev = self.cuaev_computer((species, coordinates))
cu_aev.backward(torch.ones_like(cu_aev))
cuaev_grad = coordinates.grad
self.assertEqual(cu_aev, aev, atol=5e-5, rtol=5e-5)
self.assertEqual(cuaev_grad, aev_grad, atol=5e-4, rtol=5e-4)
if __name__ == '__main__':
unittest.main()
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment