Unverified Commit ee31b752 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

remove force from ANIModel (#39)

parent 2cb357bd
......@@ -19,10 +19,12 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
dtype=aev_computer.dtype,
device=aev_computer.device)
device=aev_computer.device,
requires_grad=True)
species = ['C', 'H', 'H', 'H', 'H']
energy, derivative = nn(coordinates, species)
energy = nn(coordinates, species)
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
energy = shift_energy.add_sae(energy, species)
force = -derivative
......
......@@ -100,10 +100,6 @@ class TestBenchmark(unittest.TestCase):
model = torchani.models.NeuroChemNNP(
aev_computer, benchmark=True)
self._testModule(model, ['forward>nn'])
model = torchani.models.NeuroChemNNP(
aev_computer, benchmark=True, derivative=True)
self._testModule(
model, ['forward>nn', 'forward>derivative'])
if __name__ == '__main__':
......
......@@ -15,21 +15,21 @@ class TestEnsemble(unittest.TestCase):
self.conformations = 20
def _test_molecule(self, coordinates, species):
coordinates = torch.tensor(coordinates, requires_grad=True)
n = torchani.buildin_ensemble
prefix = torchani.buildin_model_prefix
aev = torchani.SortedAEV(device=torch.device('cpu'))
ensemble = torchani.models.NeuroChemNNP(aev, derivative=True,
ensemble=True)
ensemble = torchani.models.NeuroChemNNP(aev, ensemble=True)
models = [torchani.models.
NeuroChemNNP(aev, derivative=True,
ensemble=False,
NeuroChemNNP(aev, ensemble=False,
from_=prefix + '{}/networks/'.format(i))
for i in range(n)]
energy1, force1 = ensemble(coordinates, species)
energy2, force2 = zip(*[m(coordinates, species) for m in models])
energy1 = ensemble(coordinates, species)
force1 = torch.autograd.grad(energy1.sum(), coordinates)[0]
energy2 = [m(coordinates, species) for m in models]
energy2 = sum(energy2) / n
force2 = sum(force2) / n
force2 = torch.autograd.grad(energy2.sum(), coordinates)[0]
energy_diff = (energy1 - energy2).abs().max().item()
force_diff = (force1 - force2).abs().max().item()
self.assertLess(energy_diff, self.tol)
......
......@@ -16,10 +16,12 @@ class TestForce(unittest.TestCase):
self.aev_computer = torchani.SortedAEV(
dtype=dtype, device=torch.device('cpu'))
self.nnp = torchani.models.NeuroChemNNP(
self.aev_computer, derivative=True)
self.aev_computer)
def _test_molecule(self, coordinates, species, forces):
_, derivative = self.nnp(coordinates, species)
coordinates = torch.tensor(coordinates, requires_grad=True)
energies = self.nnp(coordinates, species)
derivative = torch.autograd.grad(energies.sum(), coordinates)[0]
max_diff = (forces + derivative).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -43,7 +43,7 @@ class ANIModel(BenchmarkedModule):
"""
def __init__(self, aev_computer, suffixes, reducer, output_length, models,
derivative=False, derivative_graph=False, benchmark=False):
benchmark=False):
super(ANIModel, self).__init__(benchmark)
if not isinstance(aev_computer, AEVComputer):
raise TypeError(
......@@ -56,22 +56,9 @@ class ANIModel(BenchmarkedModule):
for i in models:
setattr(self, i, models[i])
self.derivative = derivative
if not derivative and derivative_graph:
raise ValueError(
'''BySpeciesModel: can not create graph for derivative if the
computation of derivative is turned off''')
self.derivative_graph = derivative_graph
if derivative and self.output_length != 1:
raise ValueError(
'derivative can only be computed for output length 1')
if benchmark:
self.aev_to_output = self._enable_benchmark(
self.aev_to_output, 'nn')
if derivative:
self.compute_derivative = self._enable_benchmark(
self.compute_derivative, 'derivative')
self.forward = self._enable_benchmark(self.forward, 'forward')
def aev_to_output(self, aev, species):
......@@ -116,15 +103,6 @@ class ANIModel(BenchmarkedModule):
molecule_output = self.reducer(per_species_outputs, dim=1)
return molecule_output
def compute_derivative(self, output, coordinates):
"""Compute the gradient d(output)/d(coordinates)"""
# Since different conformations are independent, computing
# the derivatives of all outputs w.r.t. its own coordinate is
# equivalent to compute the derivative of the sum of all outputs
# w.r.t. all coordinates.
return torch.autograd.grad(output.sum(), coordinates,
create_graph=self.derivative_graph)[0]
def forward(self, coordinates, species):
"""Feed forward
......@@ -138,26 +116,12 @@ class ANIModel(BenchmarkedModule):
Returns
-------
torch.Tensor or (torch.Tensor, torch.Tensor)
If derivative is turned off, then this function will return a
pytorch tensor of shape (conformations, output_length) for the
torch.Tensor
Tensor of shape (conformations, output_length) for the
output of each conformation.
If derivative is turned on, then this function will return a pair
of pytorch tensors where the first tensor is the output tensor as
when the derivative is off, and the second tensor is a tensor of
shape (conformation, atoms, 3) storing the d(output)/dR.
"""
species = self.aev_computer.species_to_tensor(species)
if not self.derivative:
coordinates = coordinates.detach()
else:
coordinates = torch.tensor(coordinates, requires_grad=True)
_species, _coordinates, = self.aev_computer.sort_by_species(
species, coordinates)
aev = self.aev_computer((_coordinates, _species))
output = self.aev_to_output(aev, _species)
if not self.derivative:
return output
else:
derivative = self.compute_derivative(output, coordinates)
return output, derivative
return self.aev_to_output(aev, _species)
......@@ -32,7 +32,6 @@ class CustomModel(ANIModel):
'''output length of each atomic neural network must
match''')
super(CustomModel, self).__init__(aev_computer, suffixes, reducer,
output_length, models, derivative,
derivative_graph, benchmark)
output_length, models, benchmark)
for i in per_species:
setattr(self, 'model_' + i, per_species[i])
......@@ -60,5 +60,4 @@ class NeuroChemNNP(ANIModel):
must match''')
models['model_' + i + suffix] = model_X
super(NeuroChemNNP, self).__init__(aev_computer, suffixes, reducer,
output_length, models, derivative,
derivative_graph, benchmark)
output_length, models, benchmark)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment