Commit d400d8f0 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

improve analytical stress calculation (#387)

parent e784666f
......@@ -52,13 +52,6 @@ class Calculator(ase.calculators.calculator.Calculator):
self.energy_shifter
).to(dtype)
@staticmethod
def strain(tensor, displacement, surface_normal_axis):
displacement_of_tensor = torch.zeros_like(tensor)
for axis in range(3):
displacement_of_tensor[..., axis] = tensor[..., surface_normal_axis] * displacement[axis]
return displacement_of_tensor
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
......@@ -70,7 +63,7 @@ class Calculator(ase.calculators.calculator.Calculator):
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
coordinates = coordinates.to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties)
if pbc_enabled:
......@@ -79,20 +72,13 @@ class Calculator(ase.calculators.calculator.Calculator):
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
if 'stress' in properties:
displacements = torch.zeros(3, 3, requires_grad=True,
dtype=self.dtype, device=self.device)
displacement_x, displacement_y, displacement_z = displacements
strain_x = self.strain(coordinates, displacement_x, 0)
strain_y = self.strain(coordinates, displacement_y, 1)
strain_z = self.strain(coordinates, displacement_z, 2)
coordinates = coordinates + strain_x + strain_y + strain_z
scaling = torch.eye(3, requires_grad=True, dtype=self.dtype, device=self.device)
coordinates = coordinates @ scaling
coordinates = coordinates.unsqueeze(0)
if pbc_enabled:
if 'stress' in properties:
strain_x = self.strain(cell, displacement_x, 0)
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
cell = cell @ scaling
aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs
else:
aev = self.aev_computer((species, coordinates)).aevs
......@@ -108,5 +94,5 @@ class Calculator(ase.calculators.calculator.Calculator):
if 'stress' in properties:
volume = self.atoms.get_volume()
stress = torch.autograd.grad(energy.squeeze(), displacements)[0] / volume
stress = torch.autograd.grad(energy.squeeze(), scaling)[0] / volume
self.results['stress'] = stress.cpu().numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment