"git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "196835695ed6fa3ec53b888088d9d5581e8f8e94"
Unverified Commit 0f006b3b authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

improve analytical stress (#220)

parent 32e00bf3
...@@ -78,12 +78,9 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -78,12 +78,9 @@ class Calculator(ase.calculators.calculator.Calculator):
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \ coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties) .requires_grad_('forces' in properties)
if 'stress' in properties: if 'stress' in properties:
displacement_x = torch.zeros(3, requires_grad=True, displacements = torch.zeros(3, 3, requires_grad=True,
dtype=self.dtype, device=self.device) dtype=self.dtype, device=self.device)
displacement_y = torch.zeros(3, requires_grad=True, displacement_x, displacement_y, displacement_z = displacements
dtype=self.dtype, device=self.device)
displacement_z = torch.zeros(3, requires_grad=True,
dtype=self.dtype, device=self.device)
strain_x = self.strain(coordinates, displacement_x, 0) strain_x = self.strain(coordinates, displacement_x, 0)
strain_y = self.strain(coordinates, displacement_y, 1) strain_y = self.strain(coordinates, displacement_y, 1)
strain_z = self.strain(coordinates, displacement_z, 2) strain_z = self.strain(coordinates, displacement_z, 2)
...@@ -111,9 +108,5 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -111,9 +108,5 @@ class Calculator(ase.calculators.calculator.Calculator):
if 'stress' in properties: if 'stress' in properties:
volume = self.atoms.get_volume() volume = self.atoms.get_volume()
stress = torch.stack([ stress = torch.autograd.grad(energy.squeeze(), displacements)[0] / volume
torch.autograd.grad(energy.squeeze(), displacement_x, retain_graph=True)[0],
torch.autograd.grad(energy.squeeze(), displacement_y, retain_graph=True)[0],
torch.autograd.grad(energy.squeeze(), displacement_z)[0],
]) / volume
self.results['stress'] = stress.cpu().numpy() self.results['stress'] = stress.cpu().numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment