Unverified Commit db36c1a0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Write wrapped position to atoms (#215)

parent 3596f612
...@@ -27,11 +27,14 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -27,11 +27,14 @@ class Calculator(ase.calculators.calculator.Calculator):
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter. energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use, dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``. by dafault ``torch.float64``.
overwrite (bool): After wrapping atoms into central box, whether
to replace the original positions stored in :class:`ase.Atoms`
object with the wrapped positions.
""" """
implemented_properties = ['energy', 'forces'] implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64): def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False):
super(Calculator, self).__init__() super(Calculator, self).__init__()
self.species_to_tensor = utils.ChemicalSymbolsToInts(species) self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to # aev_computer.neighborlist will be changed later, so we need a copy to
...@@ -39,6 +42,7 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -39,6 +42,7 @@ class Calculator(ase.calculators.calculator.Calculator):
self.aev_computer = copy.deepcopy(aev_computer) self.aev_computer = copy.deepcopy(aev_computer)
self.model = copy.deepcopy(model) self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter) self.energy_shifter = copy.deepcopy(energy_shifter)
self.overwrite = overwrite
self.device = self.aev_computer.EtaR.device self.device = self.aev_computer.EtaR.device
self.dtype = dtype self.dtype = dtype
...@@ -65,6 +69,8 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -65,6 +69,8 @@ class Calculator(ase.calculators.calculator.Calculator):
.requires_grad_('forces' in properties) .requires_grad_('forces' in properties)
if pbc_enabled: if pbc_enabled:
coordinates = utils.map2central(cell, coordinates, pbc) coordinates = utils.map2central(cell, coordinates, pbc)
if self.overwrite and atoms is not None:
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
_, energy = self.whole((species, coordinates, cell, pbc)) _, energy = self.whole((species, coordinates, cell, pbc))
else: else:
_, energy = self.whole((species, coordinates)) _, energy = self.whole((species, coordinates))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment