Commit bc4ab994 authored by Kevin Ryan's avatar Kevin Ryan Committed by Gao, Xiang
Browse files

Added finite differences gradient check for aev_computer. (#200)

parent 6e50a99e
......@@ -3,6 +3,8 @@ import torchani
import unittest
import os
import pickle
import random
import copy
import itertools
import ase
import math
......@@ -19,8 +21,8 @@ class TestAEV(unittest.TestCase):
self.radial_length = self.aev_computer.radial_length
self.tolerance = 1e-5
def random_skip(self):
return False
def random_skip(self, prob=0):
return random.random() < prob
def transform(self, x):
return x
......@@ -94,6 +96,41 @@ class TestAEV(unittest.TestCase):
_, aev = self.aev_computer((species, coordinates))
self._assertAEVEqual(radial, angular, aev)
@unittest.skipIf(not torch.cuda.is_available(), "Too slow on CPU")
def testGradient(self):
"""Test validity of autodiff by comparing analytical and numerical
gradients.
"""
datafile = os.path.join(path, 'test_data/NIST/all')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# Create local copy of aev_computer to avoid interference with other
# tests.
aev_computer = copy.deepcopy(self.aev_computer).to(device).to(torch.float64)
with open(datafile, 'rb') as f:
data = pickle.load(f)
for coordinates, species, _, _, _, _ in data:
coordinates = torch.from_numpy(coordinates).to(device).to(torch.float64)
coordinates.requires_grad_(True)
species = torch.from_numpy(species).to(device)
# PyTorch gradcheck expects to test a funtion with inputs and
# outputs of type torch.Tensor. The numerical estimation of
# the deriviate involves making small modifications to the
# input and observing how it affects the output. The species
# tensor needs to be removed from the input so that gradcheck
# does not attempt to estimate the gradient with respect to
# species and fail.
# Create simple function wrapper to handle this.
def aev_forward_wrapper(coords):
# Return only the aev portion of the output.
return aev_computer((species, coords))[1]
# Sanity Check: Forward wrapper returns aev without error.
aev_forward_wrapper(coordinates)
torch.autograd.gradcheck(
aev_forward_wrapper,
coordinates
)
class TestPBCSeeEachOther(unittest.TestCase):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment