Unverified Commit 32e00bf3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add support for analytical stress (#218)

* implement stress

* fix

* more

* Add analytical stress support
parent 19af71ea
from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin
from ase.md.nptberendsen import NPTBerendsen
from ase import units
from ase.io import read
from ase.calculators.test import numeric_force
import torch
import torchani
......@@ -24,10 +26,7 @@ class TestASE(unittest.TestCase):
def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
calculator = torchani.models.ANI1x().ase()
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 30000000 * units.kB, 0.002)
dyn.run(100)
......@@ -38,6 +37,24 @@ class TestASE(unittest.TestCase):
if avgf > 0:
self.assertLess(df / avgf, 0.1)
def testWithNumericalStressWithPBCEnabled(self):
filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
benzene = read(filename)
calculator = torchani.models.ANI1x().ase()
benzene.set_calculator(calculator)
dyn = NPTBerendsen(benzene, timestep=0.1 * units.fs,
temperature=300 * units.kB,
taut=0.1 * 1000 * units.fs, pressure=1.01325,
taup=1.0 * 1000 * units.fs, compressibility=4.57e-5)
def test_stress():
stress = benzene.get_stress()
numerical_stress = calculator.calculate_numerical_stress(benzene)
diff = torch.from_numpy(stress - numerical_stress).abs().max().item()
self.assertLess(diff, tol)
dyn.attach(test_stress, interval=30)
dyn.run(120)
if __name__ == '__main__':
unittest.main()
......@@ -32,7 +32,7 @@ class Calculator(ase.calculators.calculator.Calculator):
object with the wrapped positions.
"""
implemented_properties = ['energy', 'forces']
implemented_properties = ['energy', 'forces', 'stress', 'free_energy']
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False):
super(Calculator, self).__init__()
......@@ -53,12 +53,22 @@ class Calculator(ase.calculators.calculator.Calculator):
self.energy_shifter
).to(dtype)
@staticmethod
def strain(tensor, displacement, surface_normal_axis):
rest_axes = {0, 1, 2} - set([surface_normal_axis])
displacement_normal = displacement[surface_normal_axis]
displacement_of_tensor = torch.zeros_like(tensor)
displacement_of_tensor[..., surface_normal_axis] = tensor[..., surface_normal_axis] * displacement_normal
for axis in rest_axes:
displacement_axis = displacement[axis]
displacement_of_tensor[..., axis] = tensor[..., surface_normal_axis] * displacement_axis
return displacement_of_tensor
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
cell = torch.tensor(self.atoms.get_cell(complete=True),
requires_grad=True, dtype=self.dtype,
device=self.device)
dtype=self.dtype, device=self.device)
pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
device=self.device)
pbc_enabled = bool(pbc.any().item())
......@@ -67,15 +77,43 @@ class Calculator(ase.calculators.calculator.Calculator):
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties)
if 'stress' in properties:
displacement_x = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
displacement_y = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
displacement_z = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
strain_x = self.strain(coordinates, displacement_x, 0)
strain_y = self.strain(coordinates, displacement_y, 1)
strain_z = self.strain(coordinates, displacement_z, 2)
coordinates = coordinates + strain_x + strain_y + strain_z
if pbc_enabled:
coordinates = utils.map2central(cell, coordinates, pbc)
if self.overwrite and atoms is not None:
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
if 'stress' in properties:
strain_x = self.strain(cell, displacement_x, 0)
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
_, energy = self.whole((species, coordinates, cell, pbc))
else:
_, energy = self.whole((species, coordinates))
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0]
self.results['forces'] = forces.squeeze().to('cpu').numpy()
if 'stress' in properties:
volume = self.atoms.get_volume()
stress = torch.stack([
torch.autograd.grad(energy.squeeze(), displacement_x, retain_graph=True)[0],
torch.autograd.grad(energy.squeeze(), displacement_y, retain_graph=True)[0],
torch.autograd.grad(energy.squeeze(), displacement_z)[0],
]) / volume
self.results['stress'] = stress.cpu().numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment