Unverified Commit bd9d888a authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Use PyTorch autograd's hessian (#532)

* Use PyTorch autograd's hessian

* fix test

* save

* clean

* save

* save

* drop hessian from jit example
parent ea51fadb
......@@ -41,7 +41,6 @@ Utilities
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
.. autofunction:: torchani.utils.hessian
.. autofunction:: torchani.utils.vibrational_analysis
.. autofunction:: torchani.utils.get_atomic_masses
......
......@@ -69,7 +69,7 @@ print('Single network energy, eager mode vs loaded jit:', energies_single.item()
#
# - uses double as dtype instead of float
# - don't care about periodic boundary condition
# - in addition to energies, allow returnsing optionally forces, and hessians
# - in addition to energies, allow returning optionally forces
# - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ...
#
# you could do the following:
......@@ -81,34 +81,28 @@ class CustomModule(torch.nn.Module):
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()
def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False,
return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
if return_forces or return_hessians:
def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False) -> Tuple[Tensor, Optional[Tensor]]:
if return_forces:
coordinates.requires_grad_(True)
energies = self.model((species, coordinates)).energies
forces: Optional[Tensor] = None # noqa: E701
hessians: Optional[Tensor] = None
if return_forces or return_hessians:
grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0]
if return_forces:
grad = torch.autograd.grad([energies.sum()], [coordinates])[0]
assert grad is not None
forces = -grad
if return_hessians:
hessians = torchani.utils.hessian(coordinates, forces=forces)
return energies, forces, hessians
return energies, forces
custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt')
loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt')
energies, forces, hessians = custom_model(species, coordinates, True, True)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True)
energies, forces = custom_model(species, coordinates, True)
energies_jit, forces_jit = loaded_compiled_custom_model(species, coordinates, True)
print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item())
print()
print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0))
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0))
......@@ -47,18 +47,12 @@ coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_g
masses = torchani.utils.get_atomic_masses(species)
###############################################################################
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
energies = model((species, coordinates)).energies
# We can use :func:`torch.autograd.functional.hessian` to compute hessian:
hessian = torch.autograd.functional.hessian(lambda x: model((species, x)).energies, coordinates)
###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
hessian = torchani.utils.hessian(coordinates, energies=energies)
###############################################################################
# The Hessian matrix should have shape `(1, 9, 9)`, where 1 means there is only
# one molecule to compute, 9 means `3 atoms * 3D space = 9 degree of freedom`.
# The Hessian matrix should have shape `(1, 3, 3, 1, 3, 3)`, where 1 means there
# is only one molecule to compute, 3 means 3 atoms and 3D space.
print(hessian.shape)
###############################################################################
......
import unittest
import torch
import torchani
......@@ -10,9 +9,6 @@ class TestUtils(unittest.TestCase):
self.assertEqual(len(str2i), 6)
self.assertListEqual(str2i('BACCC').tolist(), [1, 0, 2, 2, 2])
def testHessianJIT(self):
torch.jit.script(torchani.utils.hessian)
if __name__ == '__main__':
unittest.main()
......@@ -39,8 +39,7 @@ class TestVibrational(unittest.TestCase):
# compute vibrational by torchani
species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0)
coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True)
_, energies = model((species, coordinates))
hessian = torchani.utils.hessian(coordinates, energies=energies)
hessian = torch.autograd.functional.hessian(lambda x: model((species, x)).energies, coordinates)
freq2, modes2, _, _ = torchani.utils.vibrational_analysis(masses[species], hessian)
freq2 = freq2[6:].float()
modes2 = modes2[6:]
......
......@@ -241,43 +241,6 @@ class ChemicalSymbolsToInts:
return len(self.rev_species)
def _get_derivatives_not_none(x: Tensor, y: Tensor, retain_graph: Optional[bool] = None, create_graph: bool = False) -> Tensor:
ret = torch.autograd.grad([y.sum()], [x], retain_graph=retain_graph, create_graph=create_graph)[0]
assert ret is not None
return ret
def hessian(coordinates: Tensor, energies: Optional[Tensor] = None, forces: Optional[Tensor] = None) -> Tensor:
"""Compute analytical hessian from the energy graph or force graph.
Arguments:
coordinates (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`
energies (:class:`torch.Tensor`): Tensor of shape `(molecules,)`, if specified,
then `forces` must be `None`. This energies must be computed from
`coordinates` in a graph.
forces (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`, if specified,
then `energies` must be `None`. This forces must be computed from
`coordinates` in a graph.
Returns:
:class:`torch.Tensor`: Tensor of shape `(molecules, 3A, 3A)` where A is the number of
atoms in each molecule
"""
if energies is None and forces is None:
raise ValueError('Energies or forces must be specified')
if energies is not None and forces is not None:
raise ValueError('Energies or forces can not be specified at the same time')
if forces is None:
assert energies is not None
forces = -_get_derivatives_not_none(coordinates, energies, create_graph=True)
flattened_force = forces.flatten(start_dim=1)
force_components = flattened_force.unbind(dim=1)
return -torch.stack([
_get_derivatives_not_none(coordinates, f, retain_graph=True).flatten(start_dim=1)
for f in force_components
], dim=1)
class FreqsModes(NamedTuple):
freqs: Tensor
modes: Tensor
......@@ -317,6 +280,8 @@ def vibrational_analysis(masses, hessian, mode_type='MDU', unit='cm^-1'):
raise ValueError('Only meV and cm^-1 are supported right now')
assert hessian.shape[0] == 1, 'Currently only supporting computing one molecule a time'
degree_of_freedom = hessian.shape[1] * hessian.shape[2]
hessian = hessian.reshape(1, degree_of_freedom, degree_of_freedom)
# Solving the eigenvalue problem: Hq = w^2 * T q
# where H is the Hessian matrix, q is the normal coordinates,
# T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass
......@@ -423,6 +388,5 @@ PERIODIC_TABLE = """
""".strip().split()
__all__ = ['pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts', 'get_atomic_masses']
__all__ = ['pad_atomic_properties', 'present_species', 'vibrational_analysis',
'strip_redundant_padding', 'ChemicalSymbolsToInts', 'get_atomic_masses']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment