"examples/pytorch/vscode:/vscode.git/clone" did not exist on "f2c80b440e80226441dc6c11a95ade10defaaf11"
Commit bc4ab994 authored by Kevin Ryan's avatar Kevin Ryan Committed by Gao, Xiang
Browse files

Added finite differences gradient check for aev_computer. (#200)

parent 6e50a99e
...@@ -3,6 +3,8 @@ import torchani ...@@ -3,6 +3,8 @@ import torchani
import unittest import unittest
import os import os
import pickle import pickle
import random
import copy
import itertools import itertools
import ase import ase
import math import math
...@@ -19,8 +21,8 @@ class TestAEV(unittest.TestCase): ...@@ -19,8 +21,8 @@ class TestAEV(unittest.TestCase):
self.radial_length = self.aev_computer.radial_length self.radial_length = self.aev_computer.radial_length
self.tolerance = 1e-5 self.tolerance = 1e-5
def random_skip(self): def random_skip(self, prob=0):
return False return random.random() < prob
def transform(self, x): def transform(self, x):
return x return x
...@@ -94,6 +96,41 @@ class TestAEV(unittest.TestCase): ...@@ -94,6 +96,41 @@ class TestAEV(unittest.TestCase):
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(radial, angular, aev) self._assertAEVEqual(radial, angular, aev)
@unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU")
def testGradient(self):
"""Test validity of autodiff by comparing analytical and numerical
gradients.
"""
datafile = os.path.join(path, 'test_data/NIST/all')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create local copy of aev_computer to avoid interference with other
# tests.
aev_computer = copy.deepcopy(self.aev_computer).to(device).to(torch.float64)
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, _, _, _, _ in data:
coordinates = torch.from_numpy(coordinates).to(device).to(torch.float64)
coordinates.requires_grad_(True)
species = torch.from_numpy(species).to(device)
# PyTorch gradcheck expects to test a funtion with inputs and
# outputs of type torch.Tensor. The numerical estimation of
# the deriviate involves making small modifications to the
# input and observing how it affects the output. The species
# tensor needs to be removed from the input so that gradcheck
# does not attempt to estimate the gradient with respect to
# species and fail.
# Create simple function wrapper to handle this.
def aev_forward_wrapper(coords):
# Return only the aev portion of the output.
return aev_computer((species, coords))[1]
# Sanity Check: Forward wrapper returns aev without error.
aev_forward_wrapper(coordinates)
torch.autograd.gradcheck(
aev_forward_wrapper,
coordinates
)
class TestPBCSeeEachOther(unittest.TestCase): class TestPBCSeeEachOther(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment