Unverified Commit 35531421 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Completely rewrite AEVComputer (#197)

parent bc5f4312
...@@ -37,6 +37,7 @@ Utilities ...@@ -37,6 +37,7 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates .. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species .. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding .. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts .. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members: :members:
...@@ -61,8 +62,6 @@ ASE Interface ...@@ -61,8 +62,6 @@ ASE Interface
============= =============
.. automodule:: torchani.ase .. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator .. autoclass:: torchani.ase.Calculator
Ignite Helpers Ignite Helpers
......
...@@ -3,7 +3,9 @@ import torchani ...@@ -3,7 +3,9 @@ import torchani
import unittest import unittest
import os import os
import pickle import pickle
import random import itertools
import ase
import math
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
N = 97 N = 97
...@@ -93,19 +95,154 @@ class TestAEV(unittest.TestCase): ...@@ -93,19 +95,154 @@ class TestAEV(unittest.TestCase):
self._assertAEVEqual(radial, angular, aev) self._assertAEVEqual(radial, angular, aev)
class TestAEVASENeighborList(TestAEV): class TestPBCSeeEachOther(unittest.TestCase):
def setUp(self): def setUp(self):
super(TestAEVASENeighborList, self).setUp() self.builtin = torchani.neurochem.Builtins()
self.aev_computer.neighborlist = torchani.ase.NeighborList() self.aev_computer = self.builtin.aev_computer.to(torch.double)
def testTranslationalInvariancePBC(self):
coordinates = torch.tensor(
[[[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]]],
dtype=torch.double, requires_grad=True)
cell = torch.eye(3, dtype=torch.double) * 2
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.uint8)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
for _ in range(100):
translation = torch.randn(3, dtype=torch.double)
_, aev2 = self.aev_computer((species, coordinates + translation, cell, pbc))
self.assertTrue(torch.allclose(aev, aev2))
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
xyz2s = [
torch.tensor([9.9, 0.0, 0.0]),
torch.tensor([0.0, 9.9, 0.0]),
torch.tensor([0.0, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 0.0]),
torch.tensor([0.0, 9.9, 9.9]),
torch.tensor([9.9, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 9.9]),
]
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testPBCSurfaceSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
for i in range(3):
xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
xyz1[i] = 0.1
xyz2 = xyz1.clone()
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testPBCEdgesSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
for i, j in itertools.combinations(range(3), 2):
xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
xyz1[i] = 0.1
xyz1[j] = 0.1
for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
xyz2 = xyz1.clone()
xyz2[i] = new_i
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testNonRectangularPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
class TestAEVOnBoundary(unittest.TestCase):
def transform(self, x): def setUp(self):
"""To reduce the size of test cases for faster test speed""" self.eps = 1e-9
return x[:2, ...] cell = ase.geometry.cellpar_to_cell([100, 100, 100 * math.sqrt(2), 90, 45, 90])
self.cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
def random_skip(self): self.inv_cell = torch.inverse(self.cell)
"""To reduce the size of test cases for faster test speed""" self.coordinates = torch.tensor([[[0.0, 0.0, 0.0],
return random.random() < 0.95 [1.0, -0.1, -0.1],
[-0.1, 1.0, -0.1],
[-0.1, -0.1, 1.0],
[-1.0, -1.0, -1.0]]], dtype=torch.double)
self.species = torch.tensor([[1, 0, 0, 0]])
self.pbc = torch.ones(3, dtype=torch.uint8)
self.v1, self.v2, self.v3 = self.cell
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
builtin = torchani.neurochem.Builtins()
self.aev_computer = builtin.aev_computer.to(torch.double)
_, self.aev = self.aev_computer((self.species, self.center_coordinates, self.cell, self.pbc))
def assertInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
self.assertTrue(in_cell.all())
def assertNotInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
self.assertFalse(in_cell.all())
def testCornerSurfaceAndEdge(self):
for i, j, k in itertools.product([0, 0.5, 1], repeat=3):
if i == 0.5 and j == 0.5 and k == 0.5:
continue
coordinates = self.coordinates + i * self.v1 + j * self.v2 + k * self.v3
self.assertNotInCell(coordinates)
coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
self.assertInCell(coordinates)
_, aev = self.aev_computer((self.species, coordinates, self.cell, self.pbc))
self.assertGreater(aev.abs().max().item(), 0)
self.assertTrue(torch.allclose(aev, self.aev))
if __name__ == '__main__': if __name__ == '__main__':
......
from ase.lattice.cubic import Diamond from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase import units, Atoms from ase import units
from ase.calculators.test import numeric_force from ase.calculators.test import numeric_force
import torch import torch
import torchani import torchani
import unittest import unittest
import numpy
import itertools
import math
import os import os
import pickle
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
N = 97 N = 97
...@@ -26,8 +22,8 @@ def get_numeric_force(atoms, eps): ...@@ -26,8 +22,8 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase): class TestASE(unittest.TestCase):
def _testForce(self, pbc): def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=pbc) atoms = Diamond(symbol="C", pbc=True)
builtin = torchani.neurochem.Builtins() builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator( calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer, builtin.species, builtin.aev_computer,
...@@ -42,194 +38,6 @@ class TestASE(unittest.TestCase): ...@@ -42,194 +38,6 @@ class TestASE(unittest.TestCase):
if avgf > 0: if avgf > 0:
self.assertLess(df / avgf, 0.1) self.assertLess(df / avgf, 0.1)
def testForceWithPBCEnabled(self):
self._testForce(True)
def testForceWithPBCDisabled(self):
self._testForce(False)
def testANIDataset(self):
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, _default_neighborlist=True)
nnp = torch.nn.Sequential(
builtin.aev_computer,
builtin.models,
builtin.energy_shifter
)
for i in range(N):
datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _ = pickle.load(f)
coordinates = coordinates[0]
species = species[0]
species_str = [builtin.consts.species[i] for i in species]
atoms = Atoms(species_str, positions=coordinates)
atoms.set_calculator(calculator)
energy1 = atoms.get_potential_energy() / units.Hartree
forces1 = atoms.get_forces() / units.Hartree
atoms2 = Atoms(species_str, positions=coordinates)
atoms2.set_calculator(default_neighborlist_calculator)
energy2 = atoms2.get_potential_energy() / units.Hartree
forces2 = atoms2.get_forces() / units.Hartree
coordinates = torch.tensor(coordinates,
requires_grad=True).unsqueeze(0)
_, energy3 = nnp((torch.from_numpy(species).unsqueeze(0),
coordinates))
forces3 = -torch.autograd.grad(energy3.squeeze(),
coordinates)[0].numpy()
energy3 = energy3.item()
self.assertLess(abs(energy1 - energy2), tol)
self.assertLess(abs(energy1 - energy3), tol)
diff_f12 = torch.tensor(forces1 - forces2).abs().max().item()
self.assertLess(diff_f12, tol)
diff_f13 = torch.tensor(forces1 - forces3).abs().max().item()
self.assertLess(diff_f13, tol)
def testForceAgainstDefaultNeighborList(self):
atoms = Diamond(symbol="C", pbc=False)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, _default_neighborlist=True)
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 50 * units.kB, 0.002)
def test_energy(a=atoms):
a = a.copy()
a.set_calculator(calculator)
e1 = a.get_potential_energy()
a.set_calculator(default_neighborlist_calculator)
e2 = a.get_potential_energy()
self.assertLess(abs(e1 - e2), tol)
dyn.attach(test_energy, interval=1)
dyn.run(500)
def testTranslationalInvariancePBC(self):
atoms = Atoms('CH4', [[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]],
cell=[2, 2, 2], pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
atoms.set_calculator(calculator)
e = atoms.get_potential_energy()
for _ in range(100):
positions = atoms.get_positions()
translation = (numpy.random.rand(3) - 0.5) * 2
atoms.set_positions(positions + translation)
self.assertEqual(e, atoms.get_potential_energy())
def assertTensorEqual(self, a, b):
self.assertLess((a - b).abs().max().item(), 1e-6)
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
xyz2s = [
torch.tensor([9.9, 0.0, 0.0]),
torch.tensor([0.0, 9.9, 0.0]),
torch.tensor([0.0, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 0.0]),
torch.tensor([0.0, 9.9, 9.9]),
torch.tensor([9.9, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 9.9]),
]
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
mirror = xyz2
for i in range(3):
if mirror[i] > 5:
mirror[i] -= 10
self.assertTensorEqual(neighbor_coordinate, mirror)
def testPBCSurfaceSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
for i in range(3):
xyz1 = torch.tensor([5.0, 5.0, 5.0])
xyz1[i] = 0.1
xyz2 = xyz1.clone()
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
xyz2[i] = -0.1
self.assertTensorEqual(neighbor_coordinate, xyz2)
def testPBCEdgesSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
for i, j in itertools.combinations(range(3), 2):
xyz1 = torch.tensor([5.0, 5.0, 5.0])
xyz1[i] = 0.1
xyz1[j] = 0.1
for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
xyz2 = xyz1.clone()
xyz2[i] = new_i
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
if xyz2[i] > 5:
xyz2[i] = -0.1
if xyz2[j] > 5:
xyz2[j] = -0.1
self.assertTensorEqual(neighbor_coordinate, xyz2)
def testNonRectangularPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(
cell=[10, 10, 10 * math.sqrt(2), 90, 45, 90], pbc=True)
xyz1 = torch.tensor([0.1, 0.1, 0.05])
xyz2 = torch.tensor([10.0, 0.1, 0.1])
mirror = torch.tensor([0.0, 0.1, 0.1])
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
for i in range(3):
if mirror[i] > 5:
mirror[i] -= 10
self.assertTensorEqual(neighbor_coordinate, mirror)
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -21,7 +21,7 @@ class TestData(unittest.TestCase): ...@@ -21,7 +21,7 @@ class TestData(unittest.TestCase):
batch_size) batch_size)
def _assertTensorEqual(self, t1, t2): def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1 - t2).abs().max().item(), 0) self.assertLess((t1 - t2).abs().max().item(), 1e-6)
def testSplitBatch(self): def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long) species1 = torch.randint(4, (5, 4), dtype=torch.long)
......
...@@ -84,7 +84,6 @@ class TestEnergiesASEComputer(TestEnergies): ...@@ -84,7 +84,6 @@ class TestEnergiesASEComputer(TestEnergies):
def setUp(self): def setUp(self):
super(TestEnergiesASEComputer, self).setUp() super(TestEnergiesASEComputer, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x): def transform(self, x):
"""To reduce the size of test cases for faster test speed""" """To reduce the size of test cases for faster test speed"""
......
...@@ -90,7 +90,6 @@ class TestForceASEComputer(TestForce): ...@@ -90,7 +90,6 @@ class TestForceASEComputer(TestForce):
def setUp(self): def setUp(self):
super(TestForceASEComputer, self).setUp() super(TestForceASEComputer, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x): def transform(self, x):
"""To reduce the size of test cases for faster test speed""" """To reduce the size of test cases for faster test speed"""
......
...@@ -22,7 +22,8 @@ class TestIgnite(unittest.TestCase): ...@@ -22,7 +22,8 @@ class TestIgnite(unittest.TestCase):
shift_energy = builtins.energy_shifter shift_energy = builtins.energy_shifter
ds = torchani.data.BatchedANIDataset( ds = torchani.data.BatchedANIDataset(
path, builtins.consts.species_to_tensor, batchsize, path, builtins.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset]) transform=[shift_energy.subtract_from_dataset],
device=aev_computer.EtaR.device)
ds = torch.utils.data.Subset(ds, [0]) ds = torch.utils.data.Subset(ds, [0])
class Flatten(torch.nn.Module): class Flatten(torch.nn.Module):
......
...@@ -2,13 +2,12 @@ from __future__ import division ...@@ -2,13 +2,12 @@ from __future__ import division
import torch import torch
from . import _six # noqa:F401 from . import _six # noqa:F401
import math import math
from . import utils
from torch import Tensor from torch import Tensor
from typing import Tuple from typing import Tuple
@torch.jit.script # @torch.jit.script
def _cutoff_cosine(distances, cutoff): def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor # type: (Tensor, float) -> Tensor
return torch.where( return torch.where(
distances <= cutoff, distances <= cutoff,
...@@ -17,8 +16,8 @@ def _cutoff_cosine(distances, cutoff): ...@@ -17,8 +16,8 @@ def _cutoff_cosine(distances, cutoff):
) )
@torch.jit.script # @torch.jit.script
def _radial_subaev_terms(Rcr, EtaR, ShfR, distances): def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor # type: (float, Tensor, Tensor, Tensor) -> Tensor
"""Compute the radial subAEV terms of the center atom given neighbors """Compute the radial subAEV terms of the center atom given neighbors
...@@ -33,7 +32,7 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances): ...@@ -33,7 +32,7 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
distances = distances.unsqueeze(-1).unsqueeze(-1) distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, Rcr) fc = cutoff_cosine(distances, Rcr)
# Note that in the equation in the paper there is no 0.25 # Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient. # coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here. # We choose to be consistent with NeuroChem instead of the paper here.
...@@ -45,8 +44,8 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances): ...@@ -45,8 +44,8 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2) return ret.flatten(start_dim=-2)
@torch.jit.script # @torch.jit.script
def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor # type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs. """Compute the angular subAEV terms of the center atom given neighbor pairs.
...@@ -60,22 +59,18 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): ...@@ -60,22 +59,18 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
""" """
vectors1 = vectors1.unsqueeze( vectors1 = vectors1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) vectors2 = vectors2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5) distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5) distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from # 0.95 is multiplied to the cos values to prevent acos from
# returning NaN. # returning NaN.
cos_angles = 0.95 * \ cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors1, vectors2, dim=-5)
torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5)
angles = torch.acos(cos_angles) angles = torch.acos(cos_angles)
fcj1 = _cutoff_cosine(distances1, Rca) fcj1 = cutoff_cosine(distances1, Rca)
fcj2 = _cutoff_cosine(distances2, Rca) fcj2 = cutoff_cosine(distances2, Rca)
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA * ((distances1 + distances2) / 2 - ShfA) ** 2) factor2 = torch.exp(-EtaA * ((distances1 + distances2) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2 ret = 2 * factor1 * factor2 * fcj1 * fcj2
...@@ -86,168 +81,231 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2): ...@@ -86,168 +81,231 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=-4)
@torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
@torch.jit.script
def _terms_and_indices(Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA,
distances, vec):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
# type: (float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
radial_terms = _radial_subaev_terms(Rcr, EtaR,
ShfR, distances)
vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(Rca, ShfZ, EtaA,
Zeta, ShfA, *vec)
return radial_terms, angular_terms
# @torch.jit.script # @torch.jit.script
def default_neighborlist(species, coordinates, cutoff): def compute_shifts(cell, pbc, cutoff):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor] """Compute the shifts of unit cell along the given cell vectors to make it
"""Default neighborlist computer""" large enough to contain all pairs of neighbor atoms with PBC under
consideration
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
# vec has hape (conformations, atoms, atoms, 3) storing Rij vectors
distances = vec.norm(2, -1)
# distances has shape (conformations, atoms, atoms) storing Rij distances
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
in_cutoff = (min_distances <= cutoff).nonzero().flatten()[1:]
indices = indices.index_select(-1, in_cutoff)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
neighbor_species = species_.gather(-1, indices)
neighbor_distances = distances.index_select(-1, in_cutoff)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_ = indices.unsqueeze(-1).expand(-1, -1, -1, 3)
neighbor_coordinates = vec.gather(-2, indices_)
return neighbor_species, neighbor_distances, neighbor_coordinates
Arguments:
@torch.jit.script cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
def _compute_mask_r(species_r, num_species): vectors defining unit cell:
# type: (Tensor, int) -> Tensor tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
"""Get mask of radial terms for each supported species from indices""" cutoff (float): the cutoff inside which atoms are considered pairs
mask_r = (species_r.unsqueeze(-1) == torch.arange(num_species, dtype=torch.long, device=species_r.device)) pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
return mask_r if pbc is enabled for that direction.
Returns:
@torch.jit.script :class:`torch.Tensor`: long tensor of shifts. the center cell and
def _compute_mask_a(species_a, present_species): symmetric cells are not included.
"""Get mask of angular terms for each supported species from indices""" """
species_a1, species_a2 = _combinations(species_a, -1) # type: (Tensor, Tensor, float) -> Tensor
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1) reciprocal_cell = cell.inverse().t()
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species) inv_distances = reciprocal_cell.norm(2, -1)
mask = mask_a1 & mask_a2 num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
mask_rev = mask.permute(0, 1, 2, 4, 3) num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
mask_a = mask | mask_rev r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
return mask_a r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([
torch.cartesian_prod(r1, r2, r3),
torch.cartesian_prod(r1, r2, o),
torch.cartesian_prod(r1, r2, -r3),
torch.cartesian_prod(r1, o, r3),
torch.cartesian_prod(r1, o, o),
torch.cartesian_prod(r1, o, -r3),
torch.cartesian_prod(r1, -r2, r3),
torch.cartesian_prod(r1, -r2, o),
torch.cartesian_prod(r1, -r2, -r3),
torch.cartesian_prod(o, r2, r3),
torch.cartesian_prod(o, r2, o),
torch.cartesian_prod(o, r2, -r3),
torch.cartesian_prod(o, o, r3),
])
@torch.jit.script # @torch.jit.script
def _assemble(radial_terms, angular_terms, present_species, def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
mask_r, mask_a, num_species, angular_sublength): """Compute pairs of atoms that are neighbors
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Arguments: Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms, padding_mask (:class:`torch.Tensor`): boolean tensor of shape
neighbors, ``self.radial_sublength()``) (molecules, atoms) for padding mask. 1 == is padding.
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms, coordinates (:class:`torch.Tensor`): tensor of shape
pairs, ``self.angular_sublength()``) (molecules, atoms, 3) for atom coordinates.
present_species (:class:`torch.Tensor`): Long tensor for species cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
of atoms present in the molecules. defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
mask_r (:class:`torch.Tensor`): shape (conformations, atoms, cutoff (float): the cutoff inside which atoms are considered pairs
neighbors, supported species) shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
mask_a (:class:`torch.Tensor`): shape (conformations, atoms, """
pairs, present species, present species) # type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor]
coordinates = coordinates.detach()
cell = cell.detach()
num_atoms = padding_mask.shape[1]
all_atoms = torch.arange(num_atoms, device=cell.device)
# Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = shifts.new_zeros(p1_center.shape[0], 3)
# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
shifts_outide = shifts.index_select(0, shift_index)
# Step 4: combine results for all cells
shifts_all = torch.cat([shifts_center, shifts_outide])
p1_all = torch.cat([p1_center, p1])
p2_all = torch.cat([p2_center, p2])
shift_values = torch.mm(shifts_all.to(cell.dtype), cell)
# step 5, compute distances, and find all pairs within cutoff
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all) + shift_values).norm(2, -1)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all))
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
atom_index1 = p1_all[pair_index]
atom_index2 = p2_all[pair_index]
shifts = shifts_all.index_select(0, pair_index)
return molecule_index, atom_index1, atom_index2, shifts
# torch.jit.script
def triu_index(num_species):
species = torch.arange(num_species)
species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
pair_index = torch.arange(species1.shape[0])
ret = torch.zeros(num_species, num_species, dtype=torch.long)
ret[species1, species2] = pair_index
ret[species2, species1] = pair_index
return ret
# torch.jit.script
def convert_pair_index(index):
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
""" """
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, int, int) -> Tuple[Tensor, Tensor] # noqa: E501 n = (torch.sqrt(1.0 + 8.0 * index.to(torch.float)) - 1.0) / 2.0
n = torch.floor(n).to(torch.long)
conformations = radial_terms.shape[0] num_elems = n * (n + 1) / 2
atoms = radial_terms.shape[1] return index - num_elems, n + 1
# assemble radial subaev
present_radial_aevs = (radial_terms.unsqueeze(-2) * mask_r.unsqueeze(-1).to(radial_terms.dtype)).sum(-3) # torch.jit.script
# present_radial_aevs has shape def cumsum_from_zero(input_):
# (conformations, atoms, present species, radial_length) cumsum = torch.cumsum(input_, dim=0)
radial_aevs = present_radial_aevs.flatten(start_dim=2) cumsum = torch.cat([input_.new_tensor([0]), cumsum[:-1]])
return cumsum
# assemble angular subaev
rev_indices = torch.full((num_species,), -1, dtype=present_species.dtype,
device=present_species.device) # torch.jit.script
rev_indices[present_species] = torch.arange(present_species.numel(), def triple_by_molecule(molecule_index, atom_index1, atom_index2):
dtype=torch.long, """Input: indices for pairs of atoms that are close to each other.
device=radial_terms.device) each pair only appear once, i.e. only one of the pairs (1, 2) and
angular_aevs = [] (2, 1) exists.
zero_angular_subaev = torch.zeros(conformations, atoms, angular_sublength,
dtype=radial_terms.dtype, Output: indices for all central atoms and it pairs of neighbors. For
device=radial_terms.device) example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
for s1 in range(num_species): (1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
# TODO: make PyTorch support range(start, end) and central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
# range(start, end, step) and remove the workaround are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
# below. The inner for loop should be: """
# for s2 in range(s1, num_species): # convert representation from pair to central-other
for s2 in range(num_species - s1): n = molecule_index.shape[0]
s2 += s1 mi = molecule_index.repeat(2)
i1 = int(rev_indices[s1]) ai1 = torch.cat([atom_index1, atom_index2])
i2 = int(rev_indices[s2])
if i1 >= 0 and i2 >= 0: # sort and compute unique key
mask = mask_a[:, :, :, i1, i2].unsqueeze(-1) \ mi_ai1 = torch.stack([mi, ai1], dim=1)
.to(radial_terms.dtype) m_ac, rev_indices, counts = torch._unique_dim2_temporary_will_remove_soon(mi_ai1, dim=0, sorted=True, return_inverse=True, return_counts=True)
subaev = (angular_terms * mask).sum(-2) uniqued_molecule_index, uniqued_central_atom_index = m_ac.unbind(1)
else:
subaev = zero_angular_subaev # do local combinations within unique key, assuming sorted
angular_aevs.append(subaev) pair_sizes = counts * (counts - 1) // 2
total_size = pair_sizes.sum()
return radial_aevs, torch.cat(angular_aevs, dim=2) molecule_index = torch.repeat_interleave(uniqued_molecule_index, pair_sizes)
central_atom_index = torch.repeat_interleave(uniqued_central_atom_index, pair_sizes)
cumsum = cumsum_from_zero(pair_sizes)
@torch.jit.script cumsum = torch.repeat_interleave(cumsum, pair_sizes)
def _compute_aev(num_species, angular_sublength, Rcr, EtaR, ShfR, Rca, ShfZ, sorted_local_pair_index = torch.arange(total_size, device=molecule_index.device) - cumsum
EtaA, Zeta, ShfA, species, species_, distances, vec): sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
# type: (int, int, float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501 cumsum = cumsum_from_zero(counts)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
present_species = utils.present_species(species) sorted_local_index1 += cumsum
sorted_local_index2 += cumsum
radial_terms, angular_terms = _terms_and_indices(
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA, distances, vec) # unsort result from last part
mask_r = _compute_mask_r(species_, num_species) argsort = rev_indices.argsort()
mask_a = _compute_mask_a(species_, present_species) local_index1 = argsort[sorted_local_index1]
local_index2 = argsort[sorted_local_index2]
radial, angular = _assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a, # compute mapping between representation of central-other to pair
num_species, angular_sublength) sign1 = torch.where(local_index1 < n, torch.ones_like(local_index1), -torch.ones_like(local_index1))
fullaev = torch.cat([radial, angular], dim=2) sign2 = torch.where(local_index2 < n, torch.ones_like(local_index2), -torch.ones_like(local_index2))
return species, fullaev pair_index1 = torch.where(local_index1 < n, local_index1, local_index1 - n)
pair_index2 = torch.where(local_index2 < n, local_index2, local_index2 - n)
return molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2
# torch.jit.script
def compute_aev(species, coordinates, cell, pbc_switch, triu_index, constants, sizes):
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca)
shifts = compute_shifts(cell, pbc_switch, cutoff)
molecule_index, atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
species1 = species[molecule_index, atom_index1]
species2 = species[molecule_index, atom_index2]
shift_values = torch.mm(shifts.to(cell.dtype), cell)
vec = coordinates[molecule_index, atom_index1, :] - coordinates[molecule_index, atom_index2, :] + shift_values
distances = vec.norm(2, -1)
# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
index1 = (molecule_index * num_atoms + atom_index1) * num_species + species2
index2 = (molecule_index * num_atoms + atom_index2) * num_species + species1
radial_aev.scatter_add_(0, index1.unsqueeze(1).expand(-1, radial_sublength), radial_terms_)
radial_aev.scatter_add_(0, index2.unsqueeze(1).expand(-1, radial_sublength), radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# compute angular aev
molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(molecule_index, atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype)
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
index = (molecule_index * num_atoms + central_atom_index) * num_species_pairs + triu_index[species1_, species2_]
angular_aev.scatter_add_(0, index.unsqueeze(1).expand(-1, angular_sublength), angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1)
class AEVComputer(torch.nn.Module): class AEVComputer(torch.nn.Module):
...@@ -271,20 +329,6 @@ class AEVComputer(torch.nn.Module): ...@@ -271,20 +329,6 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_. equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types. num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper: .. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
...@@ -293,8 +337,7 @@ class AEVComputer(torch.nn.Module): ...@@ -293,8 +337,7 @@ class AEVComputer(torch.nn.Module):
'radial_length', 'angular_sublength', 'angular_length', 'radial_length', 'angular_sublength', 'angular_length',
'aev_length'] 'aev_length']
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
num_species, neighborlist_computer=default_neighborlist):
super(AEVComputer, self).__init__() super(AEVComputer, self).__init__()
self.Rcr = Rcr self.Rcr = Rcr
self.Rca = Rca self.Rca = Rca
...@@ -309,43 +352,60 @@ class AEVComputer(torch.nn.Module): ...@@ -309,43 +352,60 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1)) self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.num_species = num_species self.num_species = num_species
self.neighborlist = neighborlist_computer
# The length of radial subaev of a single species # The length of radial subaev of a single species
self.radial_sublength = self.EtaR.numel() * self.ShfR.numel() self.radial_sublength = self.EtaR.numel() * self.ShfR.numel()
# The length of full radial aev # The length of full radial aev
self.radial_length = self.num_species * self.radial_sublength self.radial_length = self.num_species * self.radial_sublength
# The length of angular subaev of a single species # The length of angular subaev of a single species
self.angular_sublength = self.EtaA.numel() * self.Zeta.numel() * \ self.angular_sublength = self.EtaA.numel() * self.Zeta.numel() * self.ShfA.numel() * self.ShfZ.numel()
self.ShfA.numel() * self.ShfZ.numel()
# The length of full angular aev # The length of full angular aev
self.angular_length = (self.num_species * (self.num_species + 1)) \ self.angular_length = (self.num_species * (self.num_species + 1)) // 2 * self.angular_sublength
// 2 * self.angular_sublength
# The length of full aev # The length of full aev
self.aev_length = self.radial_length + self.angular_length self.aev_length = self.radial_length + self.angular_length
self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length, self.aev_length
self.register_buffer('triu_index', triu_index(num_species))
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
# @torch.jit.script_method # @torch.jit.script_method
def forward(self, species_coordinates): def forward(self, input):
"""Compute AEVs """Compute AEVs
Arguments: Arguments:
species_coordinates (tuple): Two tensors: species and coordinates. input (tuple): Can be one of the following two cases:
If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species and coordinates.
species must have shape ``(C, A)`` and coordinates must have species must have shape ``(C, A)`` and coordinates must have
shape ``(C, A, 3)``, where ``C`` is the number of conformations shape ``(C, A, 3)``, where ``C`` is the number of molecules
in a chunk, and ``A`` is the number of atoms. in a chunk, and ``A`` is the number of atoms.
If you want to apply periodic boundary conditions, then the input
would be a tuple of four tensors: species, coordinates, cell, pbc
where species and coordinates are the same as described above, cell
is a tensor of shape (3, 3) of the three vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
and pbc is boolean vector of size 3 storing if pbc is enabled
for that direction.
Returns: Returns:
tuple: Species and AEVs. species are the species from the input tuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())`` ``(C, A, self.aev_length())``
""" """
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor] if len(input) == 2:
species, coordinates = input
species, coordinates = species_coordinates cell = torch.eye(3, dtype=self.EtaR.dtype, device=self.EtaR.device)
max_cutoff = max(self.Rcr, self.Rca) pbc = torch.zeros(3, dtype=torch.uint8, device=self.EtaR.device)
species_, distances, vec = self.neighborlist(species, coordinates, else:
max_cutoff) assert len(input) == 4
return _compute_aev( species, coordinates, cell, pbc = input
self.num_species, self.angular_sublength, self.Rcr, self.EtaR, return species, compute_aev(species, coordinates, cell, pbc, self.triu_index, self.constants(), self.sizes)
self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA,
species, species_, distances, vec)
...@@ -6,95 +6,13 @@ ...@@ -6,95 +6,13 @@
""" """
from __future__ import absolute_import from __future__ import absolute_import
import math
import torch import torch
import ase.neighborlist import ase.neighborlist
from . import utils from . import utils
import ase.calculators.calculator import ase.calculators.calculator
import ase.units import ase.units
import copy import copy
import numpy
class NeighborList(torch.nn.Module):
"""ASE neighborlist computer
Arguments:
cell: same as in :class:`ase.Atoms`
pbc: same as in :class:`ase.Atoms`
"""
def __init__(self, cell=None, pbc=None):
# wrap `cell` and `pbc` with `ase.Atoms`
super(NeighborList, self).__init__()
a = ase.Atoms('He', [[0, 0, 0]], cell=cell, pbc=pbc)
self.pbc = a.get_pbc()
self.cell = a.get_cell(complete=True)
def forward(self, species, coordinates, cutoff):
conformations = species.shape[0]
max_atoms = species.shape[1]
neighbor_species = []
neighbor_distances = []
neighbor_vecs = []
for i in range(conformations):
s = species[i].unsqueeze(0)
c = coordinates[i].unsqueeze(0)
s, c = utils.strip_redundant_padding(s, c)
s = s.squeeze()
c = c.squeeze()
atoms = s.shape[0]
atoms_object = ase.Atoms(
['He'] * atoms, # chemical symbols are not important here
positions=c.detach().numpy(),
pbc=self.pbc,
cell=self.cell)
idx1, idx2, shift = ase.neighborlist.neighbor_list(
'ijS', atoms_object, cutoff)
# NB: The absolute distance and distance vectors computed by
# `neighbor_list`can not be used since it does not preserve
# gradient information
idx1 = torch.tensor(idx1, device=coordinates.device,
dtype=torch.long)
idx2 = torch.tensor(idx2, device=coordinates.device,
dtype=torch.long)
D = c.index_select(0, idx2) - c.index_select(0, idx1)
shift = torch.tensor(shift, device=coordinates.device,
dtype=coordinates.dtype)
cell = torch.tensor(self.cell, device=coordinates.device,
dtype=coordinates.dtype)
D += torch.mm(shift, cell)
d = D.norm(2, -1)
neighbor_species1 = []
neighbor_distances1 = []
neighbor_vecs1 = []
for i in range(atoms):
this_atom_indices = (idx1 == i).nonzero().flatten()
neighbor_indices = idx2[this_atom_indices]
neighbor_species1.append(s[neighbor_indices])
neighbor_distances1.append(d[this_atom_indices])
neighbor_vecs1.append(D.index_select(0, this_atom_indices))
for i in range(max_atoms - atoms):
neighbor_species1.append(torch.full((1,), -1))
neighbor_distances1.append(torch.full((1,), math.inf))
neighbor_vecs1.append(torch.full((1, 3), 0))
neighbor_species1 = torch.nn.utils.rnn.pad_sequence(
neighbor_species1, padding_value=-1)
neighbor_distances1 = torch.nn.utils.rnn.pad_sequence(
neighbor_distances1, padding_value=math.inf)
neighbor_vecs1 = torch.nn.utils.rnn.pad_sequence(
neighbor_vecs1, padding_value=0)
neighbor_species.append(neighbor_species1)
neighbor_distances.append(neighbor_distances1)
neighbor_vecs.append(neighbor_vecs1)
neighbor_species = torch.nn.utils.rnn.pad_sequence(
neighbor_species, batch_first=True, padding_value=-1)
neighbor_distances = torch.nn.utils.rnn.pad_sequence(
neighbor_distances, batch_first=True, padding_value=math.inf)
neighbor_vecs = torch.nn.utils.rnn.pad_sequence(
neighbor_vecs, batch_first=True, padding_value=0)
return neighbor_species.permute(0, 2, 1), \
neighbor_distances.permute(0, 2, 1), \
neighbor_vecs.permute(0, 2, 1, 3)
class Calculator(ase.calculators.calculator.Calculator): class Calculator(ase.calculators.calculator.Calculator):
...@@ -109,16 +27,12 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -109,16 +27,12 @@ class Calculator(ase.calculators.calculator.Calculator):
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter. energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use, dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``. by dafault ``torch.float64``.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
""" """
implemented_properties = ['energy', 'forces'] implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter, def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64):
dtype=torch.float64, _default_neighborlist=False):
super(Calculator, self).__init__() super(Calculator, self).__init__()
self._default_neighborlist = _default_neighborlist
self.species_to_tensor = utils.ChemicalSymbolsToInts(species) self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to # aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object # make sure we do not change the original object
...@@ -138,16 +52,18 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -138,16 +52,18 @@ class Calculator(ase.calculators.calculator.Calculator):
def calculate(self, atoms=None, properties=['energy'], def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes): system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes) super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist: cell = torch.tensor(self.atoms.get_cell(complete=True),
self.aev_computer.neighborlist.pbc = self.atoms.get_pbc() requires_grad=True, dtype=self.dtype,
self.aev_computer.neighborlist.cell = \ device=self.device)
self.atoms.get_cell(complete=True) pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
device=self.device)
# print(cell, pbc)
species = self.species_to_tensor(self.atoms.get_chemical_symbols()) species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0) species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions()) coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \ coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties) .requires_grad_('forces' in properties)
_, energy = self.whole((species, coordinates)) _, energy = self.whole((species, coordinates, cell, pbc))
energy *= ase.units.Hartree energy *= ase.units.Hartree
self.results['energy'] = energy.item() self.results['energy'] = energy.item()
if 'forces' in properties: if 'forces' in properties:
......
...@@ -65,7 +65,7 @@ def pad_coordinates(species_coordinates): ...@@ -65,7 +65,7 @@ def pad_coordinates(species_coordinates):
return torch.cat(species), torch.cat(coordinates) return torch.cat(species), torch.cat(coordinates)
@torch.jit.script # @torch.jit.script
def present_species(species): def present_species(species):
"""Given a vector of species of atoms, compute the unique species present. """Given a vector of species of atoms, compute the unique species present.
...@@ -75,7 +75,8 @@ def present_species(species): ...@@ -75,7 +75,8 @@ def present_species(species):
Returns: Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted. :class:`torch.Tensor`: 1D vector storing present atom types sorted.
""" """
present_species, _ = species.flatten()._unique(sorted=True) # present_species, _ = species.flatten()._unique(sorted=True)
present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1: if present_species[0].item() == -1:
present_species = present_species[1:] present_species = present_species[1:]
return present_species return present_species
...@@ -86,9 +87,9 @@ def strip_redundant_padding(species, coordinates): ...@@ -86,9 +87,9 @@ def strip_redundant_padding(species, coordinates):
Arguments: Arguments:
species (:class:`torch.Tensor`): Long tensor of shape species (:class:`torch.Tensor`): Long tensor of shape
``(conformations, atoms)``. ``(molecules, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape coordinates (:class:`torch.Tensor`): Tensor of shape
``(conformations, atoms, 3)``. ``(molecules, atoms, 3)``.
Returns: Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates (:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
...@@ -100,6 +101,39 @@ def strip_redundant_padding(species, coordinates): ...@@ -100,6 +101,39 @@ def strip_redundant_padding(species, coordinates):
return species, coordinates return species, coordinates
def map2central(cell, coordinates, pbc):
"""Map atoms outside the unit cell into the cell using PBC.
Arguments:
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
if pbc is enabled for that direction.
Returns:
:class:`torch.Tensor`: coordinates of atoms mapped back to unit cell.
"""
# Step 1: convert coordinates from standard cartesian coordinate to unit
# cell coordinates
inv_cell = torch.inverse(cell)
coordinates_cell = torch.matmul(coordinates, inv_cell)
# Step 2: wrap cell coordinates into [0, 1)
coordinates_cell -= coordinates_cell.floor() * pbc.to(coordinates_cell.dtype)
# Step 3: convert from cell coordinates back to standard cartesian
# coordinate
return torch.matmul(coordinates_cell, cell)
class EnergyShifter(torch.nn.Module): class EnergyShifter(torch.nn.Module):
"""Helper class for adding and subtracting self atomic energies """Helper class for adding and subtracting self atomic energies
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment