Unverified Commit 14d2ef70 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add support for analytical hessian and vibrational analysis (#222)

parent 0f006b3b
...@@ -24,3 +24,5 @@ benchmark_xyz ...@@ -24,3 +24,5 @@ benchmark_xyz
*_cache *_cache
datacache datacache
dist dist
*.pckl
...@@ -40,6 +40,8 @@ Utilities ...@@ -40,6 +40,8 @@ Utilities
.. autofunction:: torchani.utils.map2central .. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts .. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members: :members:
.. autofunction:: torchani.utils.hessian
.. autofunction:: torchani.utils.vibrational_analysis
NeuroChem NeuroChem
......
...@@ -16,6 +16,7 @@ Welcome to TorchANI's documentation! ...@@ -16,6 +16,7 @@ Welcome to TorchANI's documentation!
examples/energy_force examples/energy_force
examples/ase_interface examples/ase_interface
examples/vibration_analysis
examples/load_from_neurochem examples/load_from_neurochem
examples/nnp_training examples/nnp_training
examples/cache_aev examples/cache_aev
......
# -*- coding: utf-8 -*-
"""
Computing Vibrational Frequencies Using Analytical Hessian
==========================================================
TorchANI is able to use ASE interface to do structure optimization and
vibration analysis, but the Hessian in ASE's vibration analysis is computed
numerically, which is slow and less accurate.
TorchANI therefore provide an interface to compute the Hessian matrix and do
vibration analysis analytically, thanks to the super power of `torch.autograd`.
"""
import ase
import ase.optimize
import torch
import torchani
import math
###############################################################################
# Let's now manually specify the device we want TorchANI to run:
model = torchani.models.ANI1x().double()
###############################################################################
# Let's first construct a water molecule and do structure optimization:
d = 0.9575
t = math.pi / 180 * 104.51
molecule = ase.Atoms('H2O', positions=[
(d, 0, 0),
(d * math.cos(t), d * math.sin(t), 0),
(0, 0, 0),
], calculator=model.ase())
opt = ase.optimize.BFGS(molecule)
opt.run(fmax=1e-6)
###############################################################################
# Now let's extract coordinates and species from ASE to use it directly with
# TorchANI:
species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0)
coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True)
###############################################################################
# TorchANI needs to know the mass of each atom in amu in order to do vibration
# analysis:
element_masses = torch.tensor([
1.008, # H
12.011, # C
14.007, # N
15.999, # O
], dtype=torch.double)
masses = element_masses[species]
###############################################################################
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
_, energies = model((species, coordinates))
###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
hessian = torchani.utils.hessian(coordinates, energies=energies)
###############################################################################
# The Hessian matrix should have shape `(1, 9, 9)`, where 1 means there is only
# one molecule to compute, 9 means `3 atoms * 3D space = 9 degree of freedom`.
print(hessian.shape)
###############################################################################
# We are now ready to compute vibrational frequencies. The output has unit
# cm^-1. Since there are in total 9 degree of freedom, there are in total 9
# frequencies. Only the frequencies of the 3 vibrational modes are interesting.
freq = torchani.utils.vibrational_analysis(masses, hessian)[-3:]
print(freq)
import os
import math
import unittest
import torch
import torchani
import ase
import ase.optimize
import ase.vibrations
import numpy
path = os.path.dirname(os.path.realpath(__file__))
path = os.path.join(path, '../dataset/xyz_files/H2O.xyz')
class TestVibrational(unittest.TestCase):
def testVibrationalWavenumbers(self):
model = torchani.models.ANI1x().double()
d = 0.9575
t = math.pi / 180 * 104.51
molecule = ase.Atoms('H2O', positions=[
(d, 0, 0),
(d * math.cos(t), d * math.sin(t), 0),
(0, 0, 0),
], calculator=model.ase())
opt = ase.optimize.BFGS(molecule)
opt.run(fmax=1e-6)
masses = torch.tensor([1.008, 12.011, 14.007, 15.999], dtype=torch.double)
# compute vibrational frequencies by ASE
vib = ase.vibrations.Vibrations(molecule)
vib.run()
freq = torch.tensor([numpy.real(x) for x in vib.get_frequencies()[-3:]])
# compute vibrational by torchani
species = model.species_to_tensor(molecule.get_chemical_symbols()).unsqueeze(0)
coordinates = torch.from_numpy(molecule.get_positions()).unsqueeze(0).requires_grad_(True)
_, energies = model((species, coordinates))
hessian = torchani.utils.hessian(coordinates, energies=energies)
freq2 = torchani.utils.vibrational_analysis(masses[species], hessian)[-3:].float()
ratio = freq2 / freq
self.assertLess((ratio - 1).abs().max(), 0.02)
if __name__ == '__main__':
unittest.main()
import torch import torch
import math
def pad(species): def pad(species):
...@@ -205,5 +206,61 @@ class ChemicalSymbolsToInts: ...@@ -205,5 +206,61 @@ class ChemicalSymbolsToInts:
return torch.tensor(rev, dtype=torch.long) return torch.tensor(rev, dtype=torch.long)
__all__ = ['pad', 'pad_coordinates', 'present_species', def hessian(coordinates, energies=None, forces=None):
'strip_redundant_padding', 'ChemicalSymbolsToInts'] """Compute analytical hessian from the energy graph or force graph.
Arguments:
coordinates (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`
energies (:class:`torch.Tensor`): Tensor of shape `(molecules,)`, if specified,
then `forces` must be `None`. This energies must be computed from
`coordinates` in a graph.
forces (:class:`torch.Tensor`): Tensor of shape `(molecules, atoms, 3)`, if specified,
then `energies` must be `None`. This forces must be computed from
`coordinates` in a graph.
Returns:
:class:`torch.Tensor`: Tensor of shape `(molecules, 3A, 3A)` where A is the number of
atoms in each molecule
"""
if energies is None and forces is None:
raise ValueError('Energies or forces must be specified')
if energies is not None and forces is not None:
raise ValueError('Energies or forces can not be specified at the same time')
if forces is None:
forces = -torch.autograd.grad(energies.sum(), coordinates, create_graph=True)[0]
flattened_force = forces.flatten(start_dim=1)
force_components = flattened_force.unbind(dim=1)
return -torch.stack([
torch.autograd.grad(f.sum(), coordinates, retain_graph=True)[0].flatten(start_dim=1)
for f in force_components
], dim=1)
def vibrational_analysis(masses, hessian, unit='cm^-1'):
"""Computing the vibrational wavenumbers from hessian."""
if unit != 'cm^-1':
raise ValueError('Only cm^-1 are supported right now')
assert hessian.shape[0] == 1, 'Currently only supporting computing one molecule a time'
# Solving the eigenvalue problem: Hq = w^2 * T q
# where H is the Hessian matrix, q is the normal coordinates,
# T = diag(m1, m1, m1, m2, m2, m2, ....) is the mass
# We solve this eigenvalue problem through Lowdin diagnolization:
# Hq = w^2 * Tq ==> Hq = w^2 * T^(1/2) T^(1/2) q
# Letting q' = T^(1/2) q, we then have
# T^(-1/2) H T^(1/2) q' = w^2 * q'
inv_sqrt_mass = (1 / masses.sqrt()).repeat_interleave(3, dim=1) # shape (molecule, 3 * atoms)
mass_scaled_hessian = hessian * inv_sqrt_mass.unsqueeze(1) * inv_sqrt_mass.unsqueeze(2)
if mass_scaled_hessian.shape[0] != 1:
raise ValueError('The input should contain only one molecule')
mass_scaled_hessian = mass_scaled_hessian.squeeze(0)
eigenvalues = torch.symeig(mass_scaled_hessian).eigenvalues
angular_frequencies = eigenvalues.sqrt()
frequencies = angular_frequencies / (2 * math.pi)
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
wavenumbers = frequencies * 17092
return wavenumbers
__all__ = ['pad', 'pad_coordinates', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment