Unverified Commit 8292fa97 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Fix periodic table index ase (#426)

* Fix bug that prevented using ASE together with periodic_table_index

* ASE has to know if the model has PTI or not

* Allow for models that don't have a periodic_table_index attribute
parent d6511699
......@@ -59,5 +59,30 @@ class TestASE(unittest.TestCase):
dyn.run(120)
class TestASEWithPTI(unittest.TestCase):
def setUp(self):
self.model_pti = torchani.models.ANI1x(periodic_table_index=True).double()
self.model = torchani.models.ANI1x().double()
def testEqualEnsemblePTI(self):
calculator_pti = self.model_pti.ase()
calculator = self.model.ase()
atoms = Diamond(symbol="C", pbc=True)
atoms_pti = Diamond(symbol="C", pbc=True)
atoms.set_calculator(calculator)
atoms_pti.set_calculator(calculator_pti)
self.assertEqual(atoms.get_potential_energy(), atoms_pti.get_potential_energy())
def testEqualOneModelPTI(self):
calculator_pti = self.model_pti[0].ase()
calculator = self.model[0].ase()
atoms = Diamond(symbol="C", pbc=True)
atoms_pti = Diamond(symbol="C", pbc=True)
atoms.set_calculator(calculator)
atoms_pti.set_calculator(calculator_pti)
self.assertEqual(atoms.get_potential_energy(), atoms_pti.get_potential_energy())
if __name__ == '__main__':
unittest.main()
......@@ -35,6 +35,14 @@ class Calculator(ase.calculators.calculator.Calculator):
a_parameter = next(self.model.parameters())
self.device = a_parameter.device
self.dtype = a_parameter.dtype
try:
# We assume that the model has a "periodic_table_index" attribute
# if it doesn't we set the calculator's attribute to false and we
# assume that species will be correctly transformed by
# species_to_tensor
self.periodic_table_index = model.periodic_table_index
except AttributeError:
self.periodic_table_index = False
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
......@@ -44,7 +52,12 @@ class Calculator(ase.calculators.calculator.Calculator):
pbc = torch.tensor(self.atoms.get_pbc(), dtype=torch.bool,
device=self.device)
pbc_enabled = pbc.any().item()
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
if self.periodic_table_index:
species = torch.tensor(self.atoms.get_atomic_numbers(), dtype=torch.long, device=self.device)
else:
species = self.species_to_tensor(self.atoms.get_chemical_symbols()).to(self.device)
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.to(self.device).to(self.dtype) \
......
......@@ -153,6 +153,7 @@ class BuiltinNet(torch.nn.Module):
ret.ase = ase
ret.species_to_tensor = self.consts.species_to_tensor
ret.periodic_table_index = self.periodic_table_index
return ret
def __len__(self):
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment