Unverified Commit db36c1a0 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Write wrapped position to atoms (#215)

parent 3596f612
......@@ -27,11 +27,14 @@ class Calculator(ase.calculators.calculator.Calculator):
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``.
overwrite (bool): After wrapping atoms into central box, whether
to replace the original positions stored in :class:`ase.Atoms`
object with the wrapped positions.
"""
implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64):
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False):
super(Calculator, self).__init__()
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to
......@@ -39,6 +42,7 @@ class Calculator(ase.calculators.calculator.Calculator):
self.aev_computer = copy.deepcopy(aev_computer)
self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter)
self.overwrite = overwrite
self.device = self.aev_computer.EtaR.device
self.dtype = dtype
......@@ -65,6 +69,8 @@ class Calculator(ase.calculators.calculator.Calculator):
.requires_grad_('forces' in properties)
if pbc_enabled:
coordinates = utils.map2central(cell, coordinates, pbc)
if self.overwrite and atoms is not None:
atoms.set_positions(coordinates.detach().cpu().reshape(-1, 3).numpy())
_, energy = self.whole((species, coordinates, cell, pbc))
else:
_, energy = self.whole((species, coordinates))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment