Unverified Commit 35531421 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Completely rewrite AEVComputer (#197)

parent bc5f4312
...@@ -37,6 +37,7 @@ Utilities ...@@ -37,6 +37,7 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates .. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species .. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding .. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts .. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members: :members:
...@@ -61,8 +62,6 @@ ASE Interface ...@@ -61,8 +62,6 @@ ASE Interface
============= =============
.. automodule:: torchani.ase .. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator .. autoclass:: torchani.ase.Calculator
Ignite Helpers Ignite Helpers
......
...@@ -3,7 +3,9 @@ import torchani ...@@ -3,7 +3,9 @@ import torchani
import unittest import unittest
import os import os
import pickle import pickle
import random import itertools
import ase
import math
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
N = 97 N = 97
...@@ -93,19 +95,154 @@ class TestAEV(unittest.TestCase): ...@@ -93,19 +95,154 @@ class TestAEV(unittest.TestCase):
self._assertAEVEqual(radial, angular, aev) self._assertAEVEqual(radial, angular, aev)
class TestAEVASENeighborList(TestAEV): class TestPBCSeeEachOther(unittest.TestCase):
def setUp(self): def setUp(self):
super(TestAEVASENeighborList, self).setUp() self.builtin = torchani.neurochem.Builtins()
self.aev_computer.neighborlist = torchani.ase.NeighborList() self.aev_computer = self.builtin.aev_computer.to(torch.double)
def testTranslationalInvariancePBC(self):
coordinates = torch.tensor(
[[[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]]],
dtype=torch.double, requires_grad=True)
cell = torch.eye(3, dtype=torch.double) * 2
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.uint8)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
for _ in range(100):
translation = torch.randn(3, dtype=torch.double)
_, aev2 = self.aev_computer((species, coordinates + translation, cell, pbc))
self.assertTrue(torch.allclose(aev, aev2))
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
xyz2s = [
torch.tensor([9.9, 0.0, 0.0]),
torch.tensor([0.0, 9.9, 0.0]),
torch.tensor([0.0, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 0.0]),
torch.tensor([0.0, 9.9, 9.9]),
torch.tensor([9.9, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 9.9]),
]
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testPBCSurfaceSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
for i in range(3):
xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
xyz1[i] = 0.1
xyz2 = xyz1.clone()
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testPBCEdgesSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
for i, j in itertools.combinations(range(3), 2):
xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
xyz1[i] = 0.1
xyz1[j] = 0.1
for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
xyz2 = xyz1.clone()
xyz2[i] = new_i
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testNonRectangularPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
class TestAEVOnBoundary(unittest.TestCase):
def transform(self, x): def setUp(self):
"""To reduce the size of test cases for faster test speed""" self.eps = 1e-9
return x[:2, ...] cell = ase.geometry.cellpar_to_cell([100, 100, 100 * math.sqrt(2), 90, 45, 90])
self.cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
def random_skip(self): self.inv_cell = torch.inverse(self.cell)
"""To reduce the size of test cases for faster test speed""" self.coordinates = torch.tensor([[[0.0, 0.0, 0.0],
return random.random() < 0.95 [1.0, -0.1, -0.1],
[-0.1, 1.0, -0.1],
[-0.1, -0.1, 1.0],
[-1.0, -1.0, -1.0]]], dtype=torch.double)
self.species = torch.tensor([[1, 0, 0, 0]])
self.pbc = torch.ones(3, dtype=torch.uint8)
self.v1, self.v2, self.v3 = self.cell
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
builtin = torchani.neurochem.Builtins()
self.aev_computer = builtin.aev_computer.to(torch.double)
_, self.aev = self.aev_computer((self.species, self.center_coordinates, self.cell, self.pbc))
def assertInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
self.assertTrue(in_cell.all())
def assertNotInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
self.assertFalse(in_cell.all())
def testCornerSurfaceAndEdge(self):
for i, j, k in itertools.product([0, 0.5, 1], repeat=3):
if i == 0.5 and j == 0.5 and k == 0.5:
continue
coordinates = self.coordinates + i * self.v1 + j * self.v2 + k * self.v3
self.assertNotInCell(coordinates)
coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
self.assertInCell(coordinates)
_, aev = self.aev_computer((self.species, coordinates, self.cell, self.pbc))
self.assertGreater(aev.abs().max().item(), 0)
self.assertTrue(torch.allclose(aev, self.aev))
if __name__ == '__main__': if __name__ == '__main__':
......
from ase.lattice.cubic import Diamond from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase import units, Atoms from ase import units
from ase.calculators.test import numeric_force from ase.calculators.test import numeric_force
import torch import torch
import torchani import torchani
import unittest import unittest
import numpy
import itertools
import math
import os import os
import pickle
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
N = 97 N = 97
...@@ -26,8 +22,8 @@ def get_numeric_force(atoms, eps): ...@@ -26,8 +22,8 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase): class TestASE(unittest.TestCase):
def _testForce(self, pbc): def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=pbc) atoms = Diamond(symbol="C", pbc=True)
builtin = torchani.neurochem.Builtins() builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator( calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer, builtin.species, builtin.aev_computer,
...@@ -42,194 +38,6 @@ class TestASE(unittest.TestCase): ...@@ -42,194 +38,6 @@ class TestASE(unittest.TestCase):
if avgf > 0: if avgf > 0:
self.assertLess(df / avgf, 0.1) self.assertLess(df / avgf, 0.1)
def testForceWithPBCEnabled(self):
self._testForce(True)
def testForceWithPBCDisabled(self):
self._testForce(False)
def testANIDataset(self):
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, _default_neighborlist=True)
nnp = torch.nn.Sequential(
builtin.aev_computer,
builtin.models,
builtin.energy_shifter
)
for i in range(N):
datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _ = pickle.load(f)
coordinates = coordinates[0]
species = species[0]
species_str = [builtin.consts.species[i] for i in species]
atoms = Atoms(species_str, positions=coordinates)
atoms.set_calculator(calculator)
energy1 = atoms.get_potential_energy() / units.Hartree
forces1 = atoms.get_forces() / units.Hartree
atoms2 = Atoms(species_str, positions=coordinates)
atoms2.set_calculator(default_neighborlist_calculator)
energy2 = atoms2.get_potential_energy() / units.Hartree
forces2 = atoms2.get_forces() / units.Hartree
coordinates = torch.tensor(coordinates,
requires_grad=True).unsqueeze(0)
_, energy3 = nnp((torch.from_numpy(species).unsqueeze(0),
coordinates))
forces3 = -torch.autograd.grad(energy3.squeeze(),
coordinates)[0].numpy()
energy3 = energy3.item()
self.assertLess(abs(energy1 - energy2), tol)
self.assertLess(abs(energy1 - energy3), tol)
diff_f12 = torch.tensor(forces1 - forces2).abs().max().item()
self.assertLess(diff_f12, tol)
diff_f13 = torch.tensor(forces1 - forces3).abs().max().item()
self.assertLess(diff_f13, tol)
def testForceAgainstDefaultNeighborList(self):
atoms = Diamond(symbol="C", pbc=False)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, _default_neighborlist=True)
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 50 * units.kB, 0.002)
def test_energy(a=atoms):
a = a.copy()
a.set_calculator(calculator)
e1 = a.get_potential_energy()
a.set_calculator(default_neighborlist_calculator)
e2 = a.get_potential_energy()
self.assertLess(abs(e1 - e2), tol)
dyn.attach(test_energy, interval=1)
dyn.run(500)
def testTranslationalInvariancePBC(self):
atoms = Atoms('CH4', [[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]],
cell=[2, 2, 2], pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
atoms.set_calculator(calculator)
e = atoms.get_potential_energy()
for _ in range(100):
positions = atoms.get_positions()
translation = (numpy.random.rand(3) - 0.5) * 2
atoms.set_positions(positions + translation)
self.assertEqual(e, atoms.get_potential_energy())
def assertTensorEqual(self, a, b):
self.assertLess((a - b).abs().max().item(), 1e-6)
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
xyz2s = [
torch.tensor([9.9, 0.0, 0.0]),
torch.tensor([0.0, 9.9, 0.0]),
torch.tensor([0.0, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 0.0]),
torch.tensor([0.0, 9.9, 9.9]),
torch.tensor([9.9, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 9.9]),
]
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
mirror = xyz2
for i in range(3):
if mirror[i] > 5:
mirror[i] -= 10
self.assertTensorEqual(neighbor_coordinate, mirror)
def testPBCSurfaceSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
for i in range(3):
xyz1 = torch.tensor([5.0, 5.0, 5.0])
xyz1[i] = 0.1
xyz2 = xyz1.clone()
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
xyz2[i] = -0.1
self.assertTensorEqual(neighbor_coordinate, xyz2)
def testPBCEdgesSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
for i, j in itertools.combinations(range(3), 2):
xyz1 = torch.tensor([5.0, 5.0, 5.0])
xyz1[i] = 0.1
xyz1[j] = 0.1
for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
xyz2 = xyz1.clone()
xyz2[i] = new_i
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
if xyz2[i] > 5:
xyz2[i] = -0.1
if xyz2[j] > 5:
xyz2[j] = -0.1
self.assertTensorEqual(neighbor_coordinate, xyz2)
def testNonRectangularPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(
cell=[10, 10, 10 * math.sqrt(2), 90, 45, 90], pbc=True)
xyz1 = torch.tensor([0.1, 0.1, 0.05])
xyz2 = torch.tensor([10.0, 0.1, 0.1])
mirror = torch.tensor([0.0, 0.1, 0.1])
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
for i in range(3):
if mirror[i] > 5:
mirror[i] -= 10
self.assertTensorEqual(neighbor_coordinate, mirror)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,7 +21,7 @@ class TestData(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestData(unittest.TestCase):
batch_size) batch_size)
def _assertTensorEqual(self, t1, t2): def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1 - t2).abs().max().item(), 0) self.assertLess((t1 - t2).abs().max().item(), 1e-6)
def testSplitBatch(self): def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long) species1 = torch.randint(4, (5, 4), dtype=torch.long)
......
...@@ -84,7 +84,6 @@ class TestEnergiesASEComputer(TestEnergies): ...@@ -84,7 +84,6 @@ class TestEnergiesASEComputer(TestEnergies):
def setUp(self): def setUp(self):
super(TestEnergiesASEComputer, self).setUp() super(TestEnergiesASEComputer, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x): def transform(self, x):
"""To reduce the size of test cases for faster test speed""" """To reduce the size of test cases for faster test speed"""
......
...@@ -90,7 +90,6 @@ class TestForceASEComputer(TestForce): ...@@ -90,7 +90,6 @@ class TestForceASEComputer(TestForce):
def setUp(self): def setUp(self):
super(TestForceASEComputer, self).setUp() super(TestForceASEComputer, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x): def transform(self, x):
"""To reduce the size of test cases for faster test speed""" """To reduce the size of test cases for faster test speed"""
......
...@@ -22,7 +22,8 @@ class TestIgnite(unittest.TestCase): ...@@ -22,7 +22,8 @@ class TestIgnite(unittest.TestCase):
shift_energy = builtins.energy_shifter shift_energy = builtins.energy_shifter
ds = torchani.data.BatchedANIDataset( ds = torchani.data.BatchedANIDataset(
path, builtins.consts.species_to_tensor, batchsize, path, builtins.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset],
device=aev_computer.EtaR.device)
ds = torch.utils.data.Subset(ds, [0]) ds = torch.utils.data.Subset(ds, [0])
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
......
This diff is collapsed.
...@@ -6,95 +6,13 @@ ...@@ -6,95 +6,13 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
import math
import torch import torch
import ase.neighborlist import ase.neighborlist
from . import utils from . import utils
import ase.calculators.calculator import ase.calculators.calculator
import ase.units import ase.units
import copy import copy
import numpy
class NeighborList(torch.nn.Module):
"""ASE neighborlist computer
Arguments:
cell: same as in :class:`ase.Atoms`
pbc: same as in :class:`ase.Atoms`
"""
def __init__(self, cell=None, pbc=None):
# wrap `cell` and `pbc` with `ase.Atoms`
super(NeighborList, self).__init__()
a = ase.Atoms('He', [[0, 0, 0]], cell=cell, pbc=pbc)
self.pbc = a.get_pbc()
self.cell = a.get_cell(complete=True)
def forward(self, species, coordinates, cutoff):
conformations = species.shape[0]
max_atoms = species.shape[1]
neighbor_species = []
neighbor_distances = []
neighbor_vecs = []
for i in range(conformations):
s = species[i].unsqueeze(0)
c = coordinates[i].unsqueeze(0)
s, c = utils.strip_redundant_padding(s, c)
s = s.squeeze()
c = c.squeeze()
atoms = s.shape[0]
atoms_object = ase.Atoms(
['He'] * atoms, # chemical symbols are not important here
positions=c.detach().numpy(),
pbc=self.pbc,
cell=self.cell)
idx1, idx2, shift = ase.neighborlist.neighbor_list(
'ijS', atoms_object, cutoff)
# NB: The absolute distance and distance vectors computed by
# `neighbor_list`can not be used since it does not preserve
# gradient information
idx1 = torch.tensor(idx1, device=coordinates.device,
dtype=torch.long)
idx2 = torch.tensor(idx2, device=coordinates.device,
dtype=torch.long)
D = c.index_select(0, idx2) - c.index_select(0, idx1)
shift = torch.tensor(shift, device=coordinates.device,
dtype=coordinates.dtype)
cell = torch.tensor(self.cell, device=coordinates.device,
dtype=coordinates.dtype)
D += torch.mm(shift, cell)
d = D.norm(2, -1)
neighbor_species1 = []
neighbor_distances1 = []
neighbor_vecs1 = []
for i in range(atoms):
this_atom_indices = (idx1 == i).nonzero().flatten()
neighbor_indices = idx2[this_atom_indices]
neighbor_species1.append(s[neighbor_indices])
neighbor_distances1.append(d[this_atom_indices])
neighbor_vecs1.append(D.index_select(0, this_atom_indices))
for i in range(max_atoms - atoms):
neighbor_species1.append(torch.full((1,), -1))
neighbor_distances1.append(torch.full((1,), math.inf))
neighbor_vecs1.append(torch.full((1, 3), 0))
neighbor_species1 = torch.nn.utils.rnn.pad_sequence(
neighbor_species1, padding_value=-1)
neighbor_distances1 = torch.nn.utils.rnn.pad_sequence(
neighbor_distances1, padding_value=math.inf)
neighbor_vecs1 = torch.nn.utils.rnn.pad_sequence(
neighbor_vecs1, padding_value=0)
neighbor_species.append(neighbor_species1)
neighbor_distances.append(neighbor_distances1)
neighbor_vecs.append(neighbor_vecs1)
neighbor_species = torch.nn.utils.rnn.pad_sequence(
neighbor_species, batch_first=True, padding_value=-1)
neighbor_distances = torch.nn.utils.rnn.pad_sequence(
neighbor_distances, batch_first=True, padding_value=math.inf)
neighbor_vecs = torch.nn.utils.rnn.pad_sequence(
neighbor_vecs, batch_first=True, padding_value=0)
return neighbor_species.permute(0, 2, 1), \
neighbor_distances.permute(0, 2, 1), \
neighbor_vecs.permute(0, 2, 1, 3)
class Calculator(ase.calculators.calculator.Calculator): class Calculator(ase.calculators.calculator.Calculator):
...@@ -109,16 +27,12 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -109,16 +27,12 @@ class Calculator(ase.calculators.calculator.Calculator):
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter. energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use, dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``. by dafault ``torch.float64``.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
""" """
implemented_properties = ['energy', 'forces'] implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter, def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64):
dtype=torch.float64, _default_neighborlist=False):
super(Calculator, self).__init__() super(Calculator, self).__init__()
self._default_neighborlist = _default_neighborlist
self.species_to_tensor = utils.ChemicalSymbolsToInts(species) self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to # aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object # make sure we do not change the original object
...@@ -138,16 +52,18 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -138,16 +52,18 @@ class Calculator(ase.calculators.calculator.Calculator):
def calculate(self, atoms=None, properties=['energy'], def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes): system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes) super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist: cell = torch.tensor(self.atoms.get_cell(complete=True),
self.aev_computer.neighborlist.pbc = self.atoms.get_pbc() requires_grad=True, dtype=self.dtype,
self.aev_computer.neighborlist.cell = \ device=self.device)
self.atoms.get_cell(complete=True) pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
device=self.device)
# print(cell, pbc)
species = self.species_to_tensor(self.atoms.get_chemical_symbols()) species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0) species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions()) coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \ coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties) .requires_grad_('forces' in properties)
_, energy = self.whole((species, coordinates)) _, energy = self.whole((species, coordinates, cell, pbc))
energy *= ase.units.Hartree energy *= ase.units.Hartree
self.results['energy'] = energy.item() self.results['energy'] = energy.item()
if 'forces' in properties: if 'forces' in properties:
......
...@@ -65,7 +65,7 @@ def pad_coordinates(species_coordinates): ...@@ -65,7 +65,7 @@ def pad_coordinates(species_coordinates):
return torch.cat(species), torch.cat(coordinates) return torch.cat(species), torch.cat(coordinates)
@torch.jit.script # @torch.jit.script
def present_species(species): def present_species(species):
"""Given a vector of species of atoms, compute the unique species present. """Given a vector of species of atoms, compute the unique species present.
...@@ -75,7 +75,8 @@ def present_species(species): ...@@ -75,7 +75,8 @@ def present_species(species):
Returns: Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted. :class:`torch.Tensor`: 1D vector storing present atom types sorted.
""" """
present_species, _ = species.flatten()._unique(sorted=True) # present_species, _ = species.flatten()._unique(sorted=True)
present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1: if present_species[0].item() == -1:
present_species = present_species[1:] present_species = present_species[1:]
return present_species return present_species
...@@ -86,9 +87,9 @@ def strip_redundant_padding(species, coordinates): ...@@ -86,9 +87,9 @@ def strip_redundant_padding(species, coordinates):
Arguments: Arguments:
species (:class:`torch.Tensor`): Long tensor of shape species (:class:`torch.Tensor`): Long tensor of shape
``(conformations, atoms)``. ``(molecules, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape coordinates (:class:`torch.Tensor`): Tensor of shape
``(conformations, atoms, 3)``. ``(molecules, atoms, 3)``.
Returns: Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates (:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
...@@ -100,6 +101,39 @@ def strip_redundant_padding(species, coordinates): ...@@ -100,6 +101,39 @@ def strip_redundant_padding(species, coordinates):
return species, coordinates return species, coordinates
def map2central(cell, coordinates, pbc):
"""Map atoms outside the unit cell into the cell using PBC.
Arguments:
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
if pbc is enabled for that direction.
Returns:
:class:`torch.Tensor`: coordinates of atoms mapped back to unit cell.
"""
# Step 1: convert coordinates from standard cartesian coordinate to unit
# cell coordinates
inv_cell = torch.inverse(cell)
coordinates_cell = torch.matmul(coordinates, inv_cell)
# Step 2: wrap cell coordinates into [0, 1)
coordinates_cell -= coordinates_cell.floor() * pbc.to(coordinates_cell.dtype)
# Step 3: convert from cell coordinates back to standard cartesian
# coordinate
return torch.matmul(coordinates_cell, cell)
class EnergyShifter(torch.nn.Module): class EnergyShifter(torch.nn.Module):
"""Helper class for adding and subtracting self atomic energies """Helper class for adding and subtracting self atomic energies
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment