Unverified Commit 95c3206b authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Fix ase interface on GPU (#212)

parent cbfbab3f
......@@ -58,7 +58,7 @@ class Calculator(ase.calculators.calculator.Calculator):
pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
device=self.device)
# print(cell, pbc)
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
......@@ -68,4 +68,4 @@ class Calculator(ase.calculators.calculator.Calculator):
self.results['energy'] = energy.item()
if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0]
self.results['forces'] = forces.squeeze().numpy()
self.results['forces'] = forces.squeeze().to('cpu').numpy()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment