Unverified Commit 32e00bf3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Add support for analytical stress (#218)

* implement stress

* fix

* more

* Add analytical stress support
parent 19af71ea
from ase.lattice.cubic import Diamond from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase.md.nptberendsen import NPTBerendsen
from ase import units from ase import units
from ase.io import read
from ase.calculators.test import numeric_force from ase.calculators.test import numeric_force
import torch import torch
import torchani import torchani
...@@ -24,10 +26,7 @@ class TestASE(unittest.TestCase): ...@@ -24,10 +26,7 @@ class TestASE(unittest.TestCase):
def testWithNumericalForceWithPBCEnabled(self): def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True) atoms = Diamond(symbol="C", pbc=True)
builtin = torchani.neurochem.Builtins() calculator = torchani.models.ANI1x().ase()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
atoms.set_calculator(calculator) atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 30000000 * units.kB, 0.002) dyn = Langevin(atoms, 5 * units.fs, 30000000 * units.kB, 0.002)
dyn.run(100) dyn.run(100)
...@@ -38,6 +37,24 @@ class TestASE(unittest.TestCase): ...@@ -38,6 +37,24 @@ class TestASE(unittest.TestCase):
if avgf > 0: if avgf > 0:
self.assertLess(df / avgf, 0.1) self.assertLess(df / avgf, 0.1)
def testWithNumericalStressWithPBCEnabled(self):
filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
benzene = read(filename)
calculator = torchani.models.ANI1x().ase()
benzene.set_calculator(calculator)
dyn = NPTBerendsen(benzene, timestep=0.1 * units.fs,
temperature=300 * units.kB,
taut=0.1 * 1000 * units.fs, pressure=1.01325,
taup=1.0 * 1000 * units.fs, compressibility=4.57e-5)
def test_stress():
stress = benzene.get_stress()
numerical_stress = calculator.calculate_numerical_stress(benzene)
diff = torch.from_numpy(stress - numerical_stress).abs().max().item()
self.assertLess(diff, tol)
dyn.attach(test_stress, interval=30)
dyn.run(120)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -32,7 +32,7 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -32,7 +32,7 @@ class Calculator(ase.calculators.calculator.Calculator):
object with the wrapped positions. object with the wrapped positions.
""" """
implemented_properties = ['energy', 'forces'] implemented_properties = ['energy', 'forces', 'stress', 'free_energy']
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False): def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False):
super(Calculator, self).__init__() super(Calculator, self).__init__()
...@@ -53,12 +53,22 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -53,12 +53,22 @@ class Calculator(ase.calculators.calculator.Calculator):
self.energy_shifter self.energy_shifter
).to(dtype) ).to(dtype)
@staticmethod
def strain(tensor, displacement, surface_normal_axis):
rest_axes = {0, 1, 2} - set([surface_normal_axis])
displacement_normal = displacement[surface_normal_axis]
displacement_of_tensor = torch.zeros_like(tensor)
displacement_of_tensor[..., surface_normal_axis] = tensor[..., surface_normal_axis] * displacement_normal
for axis in rest_axes:
displacement_axis = displacement[axis]
displacement_of_tensor[..., axis] = tensor[..., surface_normal_axis] * displacement_axis
return displacement_of_tensor
def calculate(self, atoms=None, properties=['energy'], def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes): system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes) super(Calculator, self).calculate(atoms, properties, system_changes)
cell = torch.tensor(self.atoms.get_cell(complete=True), cell = torch.tensor(self.atoms.get_cell(complete=True),
requires_grad=True, dtype=self.dtype, dtype=self.dtype, device=self.device)
device=self.device)
pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8, pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
device=self.device) device=self.device)
pbc_enabled = bool(pbc.any().item()) pbc_enabled = bool(pbc.any().item())
...@@ -67,15 +77,43 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -67,15 +77,43 @@ class Calculator(ase.calculators.calculator.Calculator):
coordinates = torch.tensor(self.atoms.get_positions()) coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \ coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties) .requires_grad_('forces' in properties)
if 'stress' in properties:
displacement_x = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
displacement_y = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
displacement_z = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
strain_x = self.strain(coordinates, displacement_x, 0)
strain_y = self.strain(coordinates, displacement_y, 1)
strain_z = self.strain(coordinates, displacement_z, 2)
coordinates = coordinates + strain_x + strain_y + strain_z
if pbc_enabled: if pbc_enabled:
coordinates = utils.map2central(cell, coordinates, pbc) coordinates = utils.map2central(cell, coordinates, pbc)
if self.overwrite and atoms is not None: if self.overwrite and atoms is not None:
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy()) atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
if 'stress' in properties:
strain_x = self.strain(cell, displacement_x, 0)
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
_, energy = self.whole((species, coordinates, cell, pbc)) _, energy = self.whole((species, coordinates, cell, pbc))
else: else:
_, energy = self.whole((species, coordinates)) _, energy = self.whole((species, coordinates))
energy *= ase.units.Hartree energy *= ase.units.Hartree
self.results['energy'] = energy.item() self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
if 'forces' in properties: if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0] forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0]
self.results['forces'] = forces.squeeze().to('cpu').numpy() self.results['forces'] = forces.squeeze().to('cpu').numpy()
if 'stress' in properties:
volume = self.atoms.get_volume()
stress = torch.stack([
torch.autograd.grad(energy.squeeze(), displacement_x, retain_graph=True)[0],
torch.autograd.grad(energy.squeeze(), displacement_y, retain_graph=True)[0],
torch.autograd.grad(energy.squeeze(), displacement_z)[0],
]) / volume
self.results['stress'] = stress.cpu().numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment