Unverified Commit 35531421 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Completely rewrite AEVComputer (#197)

parent bc5f4312
......@@ -37,6 +37,7 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
......@@ -61,8 +62,6 @@ ASE Interface
=============
.. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator
Ignite Helpers
......
......@@ -3,7 +3,9 @@ import torchani
import unittest
import os
import pickle
import random
import itertools
import ase
import math
path = os.path.dirname(os.path.realpath(__file__))
N = 97
......@@ -93,19 +95,154 @@ class TestAEV(unittest.TestCase):
self._assertAEVEqual(radial, angular, aev)
class TestAEVASENeighborList(TestAEV):
class TestPBCSeeEachOther(unittest.TestCase):
def setUp(self):
super(TestAEVASENeighborList, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
self.builtin = torchani.neurochem.Builtins()
self.aev_computer = self.builtin.aev_computer.to(torch.double)
def testTranslationalInvariancePBC(self):
coordinates = torch.tensor(
[[[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]]],
dtype=torch.double, requires_grad=True)
cell = torch.eye(3, dtype=torch.double) * 2
species = torch.tensor([[1, 0, 0, 0, 0]], dtype=torch.long)
pbc = torch.ones(3, dtype=torch.uint8)
_, aev = self.aev_computer((species, coordinates, cell, pbc))
for _ in range(100):
translation = torch.randn(3, dtype=torch.double)
_, aev2 = self.aev_computer((species, coordinates + translation, cell, pbc))
self.assertTrue(torch.allclose(aev, aev2))
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
xyz2s = [
torch.tensor([9.9, 0.0, 0.0]),
torch.tensor([0.0, 9.9, 0.0]),
torch.tensor([0.0, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 0.0]),
torch.tensor([0.0, 9.9, 9.9]),
torch.tensor([9.9, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 9.9]),
]
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).to(torch.double).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testPBCSurfaceSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
for i in range(3):
xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
xyz1[i] = 0.1
xyz2 = xyz1.clone()
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testPBCEdgesSeeEachOther(self):
cell = torch.eye(3, dtype=torch.double) * 10
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
species = torch.tensor([[0, 0]])
for i, j in itertools.combinations(range(3), 2):
xyz1 = torch.tensor([5.0, 5.0, 5.0], dtype=torch.double)
xyz1[i] = 0.1
xyz1[j] = 0.1
for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
xyz2 = xyz1.clone()
xyz2[i] = new_i
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
def testNonRectangularPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
cell = ase.geometry.cellpar_to_cell([10, 10, 10 * math.sqrt(2), 90, 45, 90])
cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
pbc = torch.ones(3, dtype=torch.uint8)
allshifts = torchani.aev.compute_shifts(cell, pbc, 1)
xyz1 = torch.tensor([0.1, 0.1, 0.05], dtype=torch.double)
xyz2 = torch.tensor([10.0, 0.1, 0.1], dtype=torch.double)
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
molecule_index, atom_index1, atom_index2, _ = torchani.aev.neighbor_pairs(species == -1, coordinates, cell, allshifts, 1)
self.assertEqual(molecule_index.tolist(), [0])
self.assertEqual(atom_index1.tolist(), [0])
self.assertEqual(atom_index2.tolist(), [1])
class TestAEVOnBoundary(unittest.TestCase):
def transform(self, x):
"""To reduce the size of test cases for faster test speed"""
return x[:2, ...]
def random_skip(self):
"""To reduce the size of test cases for faster test speed"""
return random.random() < 0.95
def setUp(self):
self.eps = 1e-9
cell = ase.geometry.cellpar_to_cell([100, 100, 100 * math.sqrt(2), 90, 45, 90])
self.cell = torch.tensor(ase.geometry.complete_cell(cell), dtype=torch.double)
self.inv_cell = torch.inverse(self.cell)
self.coordinates = torch.tensor([[[0.0, 0.0, 0.0],
[1.0, -0.1, -0.1],
[-0.1, 1.0, -0.1],
[-0.1, -0.1, 1.0],
[-1.0, -1.0, -1.0]]], dtype=torch.double)
self.species = torch.tensor([[1, 0, 0, 0]])
self.pbc = torch.ones(3, dtype=torch.uint8)
self.v1, self.v2, self.v3 = self.cell
self.center_coordinates = self.coordinates + 0.5 * (self.v1 + self.v2 + self.v3)
builtin = torchani.neurochem.Builtins()
self.aev_computer = builtin.aev_computer.to(torch.double)
_, self.aev = self.aev_computer((self.species, self.center_coordinates, self.cell, self.pbc))
def assertInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
self.assertTrue(in_cell.all())
def assertNotInCell(self, coordinates):
coordinates_cell = coordinates @ self.inv_cell
self.assertTrue(torch.allclose(coordinates, coordinates_cell @ self.cell))
in_cell = (coordinates_cell >= -self.eps) & (coordinates_cell <= 1 + self.eps)
self.assertFalse(in_cell.all())
def testCornerSurfaceAndEdge(self):
for i, j, k in itertools.product([0, 0.5, 1], repeat=3):
if i == 0.5 and j == 0.5 and k == 0.5:
continue
coordinates = self.coordinates + i * self.v1 + j * self.v2 + k * self.v3
self.assertNotInCell(coordinates)
coordinates = torchani.utils.map2central(self.cell, coordinates, self.pbc)
self.assertInCell(coordinates)
_, aev = self.aev_computer((self.species, coordinates, self.cell, self.pbc))
self.assertGreater(aev.abs().max().item(), 0)
self.assertTrue(torch.allclose(aev, self.aev))
if __name__ == '__main__':
......
from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin
from ase import units, Atoms
from ase import units
from ase.calculators.test import numeric_force
import torch
import torchani
import unittest
import numpy
import itertools
import math
import os
import pickle
path = os.path.dirname(os.path.realpath(__file__))
N = 97
......@@ -26,8 +22,8 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase):
def _testForce(self, pbc):
atoms = Diamond(symbol="C", pbc=pbc)
def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
......@@ -42,194 +38,6 @@ class TestASE(unittest.TestCase):
if avgf > 0:
self.assertLess(df / avgf, 0.1)
def testForceWithPBCEnabled(self):
self._testForce(True)
def testForceWithPBCDisabled(self):
self._testForce(False)
def testANIDataset(self):
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, _default_neighborlist=True)
nnp = torch.nn.Sequential(
builtin.aev_computer,
builtin.models,
builtin.energy_shifter
)
for i in range(N):
datafile = os.path.join(path, 'test_data/ANI1_subset/{}'.format(i))
with open(datafile, 'rb') as f:
coordinates, species, _, _, _, _ = pickle.load(f)
coordinates = coordinates[0]
species = species[0]
species_str = [builtin.consts.species[i] for i in species]
atoms = Atoms(species_str, positions=coordinates)
atoms.set_calculator(calculator)
energy1 = atoms.get_potential_energy() / units.Hartree
forces1 = atoms.get_forces() / units.Hartree
atoms2 = Atoms(species_str, positions=coordinates)
atoms2.set_calculator(default_neighborlist_calculator)
energy2 = atoms2.get_potential_energy() / units.Hartree
forces2 = atoms2.get_forces() / units.Hartree
coordinates = torch.tensor(coordinates,
requires_grad=True).unsqueeze(0)
_, energy3 = nnp((torch.from_numpy(species).unsqueeze(0),
coordinates))
forces3 = -torch.autograd.grad(energy3.squeeze(),
coordinates)[0].numpy()
energy3 = energy3.item()
self.assertLess(abs(energy1 - energy2), tol)
self.assertLess(abs(energy1 - energy3), tol)
diff_f12 = torch.tensor(forces1 - forces2).abs().max().item()
self.assertLess(diff_f12, tol)
diff_f13 = torch.tensor(forces1 - forces3).abs().max().item()
self.assertLess(diff_f13, tol)
def testForceAgainstDefaultNeighborList(self):
atoms = Diamond(symbol="C", pbc=False)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, _default_neighborlist=True)
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 50 * units.kB, 0.002)
def test_energy(a=atoms):
a = a.copy()
a.set_calculator(calculator)
e1 = a.get_potential_energy()
a.set_calculator(default_neighborlist_calculator)
e2 = a.get_potential_energy()
self.assertLess(abs(e1 - e2), tol)
dyn.attach(test_energy, interval=1)
dyn.run(500)
def testTranslationalInvariancePBC(self):
atoms = Atoms('CH4', [[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]],
cell=[2, 2, 2], pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
atoms.set_calculator(calculator)
e = atoms.get_potential_energy()
for _ in range(100):
positions = atoms.get_positions()
translation = (numpy.random.rand(3) - 0.5) * 2
atoms.set_positions(positions + translation)
self.assertEqual(e, atoms.get_potential_energy())
def assertTensorEqual(self, a, b):
self.assertLess((a - b).abs().max().item(), 1e-6)
def testPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
xyz1 = torch.tensor([0.1, 0.1, 0.1])
xyz2s = [
torch.tensor([9.9, 0.0, 0.0]),
torch.tensor([0.0, 9.9, 0.0]),
torch.tensor([0.0, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 0.0]),
torch.tensor([0.0, 9.9, 9.9]),
torch.tensor([9.9, 0.0, 9.9]),
torch.tensor([9.9, 9.9, 9.9]),
]
for xyz2 in xyz2s:
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
mirror = xyz2
for i in range(3):
if mirror[i] > 5:
mirror[i] -= 10
self.assertTensorEqual(neighbor_coordinate, mirror)
def testPBCSurfaceSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
for i in range(3):
xyz1 = torch.tensor([5.0, 5.0, 5.0])
xyz1[i] = 0.1
xyz2 = xyz1.clone()
xyz2[i] = 9.9
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
xyz2[i] = -0.1
self.assertTensorEqual(neighbor_coordinate, xyz2)
def testPBCEdgesSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(cell=[10, 10, 10], pbc=True)
for i, j in itertools.combinations(range(3), 2):
xyz1 = torch.tensor([5.0, 5.0, 5.0])
xyz1[i] = 0.1
xyz1[j] = 0.1
for new_i, new_j in [[0.1, 9.9], [9.9, 0.1], [9.9, 9.9]]:
xyz2 = xyz1.clone()
xyz2[i] = new_i
xyz2[j] = new_i
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
if xyz2[i] > 5:
xyz2[i] = -0.1
if xyz2[j] > 5:
xyz2[j] = -0.1
self.assertTensorEqual(neighbor_coordinate, xyz2)
def testNonRectangularPBCConnersSeeEachOther(self):
species = torch.tensor([[0, 0]])
neighborlist = torchani.ase.NeighborList(
cell=[10, 10, 10 * math.sqrt(2), 90, 45, 90], pbc=True)
xyz1 = torch.tensor([0.1, 0.1, 0.05])
xyz2 = torch.tensor([10.0, 0.1, 0.1])
mirror = torch.tensor([0.0, 0.1, 0.1])
coordinates = torch.stack([xyz1, xyz2]).unsqueeze(0)
s, _, D = neighborlist(species, coordinates, 1)
self.assertListEqual(list(s.shape), [1, 2, 1])
neighbor_coordinate = D[0][0].squeeze() + xyz1
for i in range(3):
if mirror[i] > 5:
mirror[i] -= 10
self.assertTensorEqual(neighbor_coordinate, mirror)
if __name__ == '__main__':
unittest.main()
......@@ -21,7 +21,7 @@ class TestData(unittest.TestCase):
batch_size)
def _assertTensorEqual(self, t1, t2):
self.assertEqual((t1 - t2).abs().max().item(), 0)
self.assertLess((t1 - t2).abs().max().item(), 1e-6)
def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
......
......@@ -84,7 +84,6 @@ class TestEnergiesASEComputer(TestEnergies):
def setUp(self):
super(TestEnergiesASEComputer, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x):
"""To reduce the size of test cases for faster test speed"""
......
......@@ -90,7 +90,6 @@ class TestForceASEComputer(TestForce):
def setUp(self):
super(TestForceASEComputer, self).setUp()
self.aev_computer.neighborlist = torchani.ase.NeighborList()
def transform(self, x):
"""To reduce the size of test cases for faster test speed"""
......
......@@ -22,7 +22,8 @@ class TestIgnite(unittest.TestCase):
shift_energy = builtins.energy_shifter
ds = torchani.data.BatchedANIDataset(
path, builtins.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset])
transform=[shift_energy.subtract_from_dataset],
device=aev_computer.EtaR.device)
ds = torch.utils.data.Subset(ds, [0])
class Flatten(torch.nn.Module):
......
......@@ -2,13 +2,12 @@ from __future__ import division
import torch
from . import _six # noqa:F401
import math
from . import utils
from torch import Tensor
from typing import Tuple
@torch.jit.script
def _cutoff_cosine(distances, cutoff):
# @torch.jit.script
def cutoff_cosine(distances, cutoff):
# type: (Tensor, float) -> Tensor
return torch.where(
distances <= cutoff,
......@@ -17,8 +16,8 @@ def _cutoff_cosine(distances, cutoff):
)
@torch.jit.script
def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
# @torch.jit.script
def radial_terms(Rcr, EtaR, ShfR, distances):
# type: (float, Tensor, Tensor, Tensor) -> Tensor
"""Compute the radial subAEV terms of the center atom given neighbors
......@@ -33,7 +32,7 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, Rcr)
fc = cutoff_cosine(distances, Rcr)
# Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
......@@ -45,8 +44,8 @@ def _radial_subaev_terms(Rcr, EtaR, ShfR, distances):
return ret.flatten(start_dim=-2)
@torch.jit.script
def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# @torch.jit.script
def angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
# type: (float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tensor
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
......@@ -60,22 +59,18 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors1 = vectors1.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze(-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * \
torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5)
cos_angles = 0.95 * torch.nn.functional.cosine_similarity(vectors1, vectors2, dim=-5)
angles = torch.acos(cos_angles)
fcj1 = _cutoff_cosine(distances1, Rca)
fcj2 = _cutoff_cosine(distances2, Rca)
fcj1 = cutoff_cosine(distances1, Rca)
fcj2 = cutoff_cosine(distances2, Rca)
factor1 = ((1 + torch.cos(angles - ShfZ)) / 2) ** Zeta
factor2 = torch.exp(-EtaA * ((distances1 + distances2) / 2 - ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
......@@ -86,168 +81,231 @@ def _angular_subaev_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vectors1, vectors2):
return ret.flatten(start_dim=-4)
@torch.jit.script
def _combinations(tensor, dim=0):
# type: (Tensor, int) -> Tuple[Tensor, Tensor]
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device)
index1, index2 = torch.combinations(r).unbind(-1)
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
@torch.jit.script
def _terms_and_indices(Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA,
distances, vec):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
# type: (float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
radial_terms = _radial_subaev_terms(Rcr, EtaR,
ShfR, distances)
vec = _combinations(vec, -2)
angular_terms = _angular_subaev_terms(Rca, ShfZ, EtaA,
Zeta, ShfA, *vec)
return radial_terms, angular_terms
# @torch.jit.script
def default_neighborlist(species, coordinates, cutoff):
# type: (Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor]
"""Default neighborlist computer"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
# vec has hape (conformations, atoms, atoms, 3) storing Rij vectors
distances = vec.norm(2, -1)
# distances has shape (conformations, atoms, atoms) storing Rij distances
padding_mask = (species == -1).unsqueeze(1)
distances = distances.masked_fill(padding_mask, math.inf)
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
in_cutoff = (min_distances <= cutoff).nonzero().flatten()[1:]
indices = indices.index_select(-1, in_cutoff)
# TODO: remove this workaround after gather support broadcasting
atoms = coordinates.shape[1]
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
neighbor_species = species_.gather(-1, indices)
neighbor_distances = distances.index_select(-1, in_cutoff)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
indices_ = indices.unsqueeze(-1).expand(-1, -1, -1, 3)
neighbor_coordinates = vec.gather(-2, indices_)
return neighbor_species, neighbor_distances, neighbor_coordinates
def compute_shifts(cell, pbc, cutoff):
"""Compute the shifts of unit cell along the given cell vectors to make it
large enough to contain all pairs of neighbor atoms with PBC under
consideration
@torch.jit.script
def _compute_mask_r(species_r, num_species):
# type: (Tensor, int) -> Tensor
"""Get mask of radial terms for each supported species from indices"""
mask_r = (species_r.unsqueeze(-1) == torch.arange(num_species, dtype=torch.long, device=species_r.device))
return mask_r
@torch.jit.script
def _compute_mask_a(species_a, present_species):
"""Get mask of angular terms for each supported species from indices"""
species_a1, species_a2 = _combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
mask = mask_a1 & mask_a2
mask_rev = mask.permute(0, 1, 2, 4, 3)
mask_a = mask | mask_rev
return mask_a
Arguments:
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
vectors defining unit cell:
tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
cutoff (float): the cutoff inside which atoms are considered pairs
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
if pbc is enabled for that direction.
Returns:
:class:`torch.Tensor`: long tensor of shifts. the center cell and
symmetric cells are not included.
"""
# type: (Tensor, Tensor, float) -> Tensor
reciprocal_cell = cell.inverse().t()
inv_distances = reciprocal_cell.norm(2, -1)
num_repeats = torch.ceil(cutoff * inv_distances).to(torch.long)
num_repeats = torch.where(pbc, num_repeats, torch.zeros_like(num_repeats))
r1 = torch.arange(1, num_repeats[0] + 1, device=cell.device)
r2 = torch.arange(1, num_repeats[1] + 1, device=cell.device)
r3 = torch.arange(1, num_repeats[2] + 1, device=cell.device)
o = torch.zeros(1, dtype=torch.long, device=cell.device)
return torch.cat([
torch.cartesian_prod(r1, r2, r3),
torch.cartesian_prod(r1, r2, o),
torch.cartesian_prod(r1, r2, -r3),
torch.cartesian_prod(r1, o, r3),
torch.cartesian_prod(r1, o, o),
torch.cartesian_prod(r1, o, -r3),
torch.cartesian_prod(r1, -r2, r3),
torch.cartesian_prod(r1, -r2, o),
torch.cartesian_prod(r1, -r2, -r3),
torch.cartesian_prod(o, r2, r3),
torch.cartesian_prod(o, r2, o),
torch.cartesian_prod(o, r2, -r3),
torch.cartesian_prod(o, o, r3),
])
@torch.jit.script
def _assemble(radial_terms, angular_terms, present_species,
mask_r, mask_a, num_species, angular_sublength):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
# @torch.jit.script
def neighbor_pairs(padding_mask, coordinates, cell, shifts, cutoff):
"""Compute pairs of atoms that are neighbors
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
padding_mask (:class:`torch.Tensor`): boolean tensor of shape
(molecules, atoms) for padding mask. 1 == is padding.
coordinates (:class:`torch.Tensor`): tensor of shape
(molecules, atoms, 3) for atom coordinates.
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three vectors
defining unit cell: tensor([[x1, y1, z1], [x2, y2, z2], [x3, y3, z3]])
cutoff (float): the cutoff inside which atoms are considered pairs
shifts (:class:`torch.Tensor`): tensor of shape (?, 3) storing shifts
"""
# type: (Tensor, Tensor, Tensor, Tensor, float) -> Tuple[Tensor, Tensor, Tensor, Tensor]
coordinates = coordinates.detach()
cell = cell.detach()
num_atoms = padding_mask.shape[1]
all_atoms = torch.arange(num_atoms, device=cell.device)
# Step 2: center cell
p1_center, p2_center = torch.combinations(all_atoms).unbind(-1)
shifts_center = shifts.new_zeros(p1_center.shape[0], 3)
# Step 3: cells with shifts
# shape convention (shift index, molecule index, atom index, 3)
num_shifts = shifts.shape[0]
all_shifts = torch.arange(num_shifts, device=cell.device)
shift_index, p1, p2 = torch.cartesian_prod(all_shifts, all_atoms, all_atoms).unbind(-1)
shifts_outide = shifts.index_select(0, shift_index)
# Step 4: combine results for all cells
shifts_all = torch.cat([shifts_center, shifts_outide])
p1_all = torch.cat([p1_center, p1])
p2_all = torch.cat([p2_center, p2])
shift_values = torch.mm(shifts_all.to(cell.dtype), cell)
# step 5, compute distances, and find all pairs within cutoff
distances = (coordinates.index_select(1, p1_all) - coordinates.index_select(1, p2_all) + shift_values).norm(2, -1)
padding_mask = (padding_mask.index_select(1, p1_all)) | (padding_mask.index_select(1, p2_all))
distances.masked_fill_(padding_mask, math.inf)
in_cutoff = (distances <= cutoff).nonzero()
molecule_index, pair_index = in_cutoff.unbind(1)
atom_index1 = p1_all[pair_index]
atom_index2 = p2_all[pair_index]
shifts = shifts_all.index_select(0, pair_index)
return molecule_index, atom_index1, atom_index2, shifts
# torch.jit.script
def triu_index(num_species):
species = torch.arange(num_species)
species1, species2 = torch.combinations(species, r=2, with_replacement=True).unbind(-1)
pair_index = torch.arange(species1.shape[0])
ret = torch.zeros(num_species, num_species, dtype=torch.long)
ret[species1, species2] = pair_index
ret[species2, species1] = pair_index
return ret
# torch.jit.script
def convert_pair_index(index):
"""Let's say we have a pair:
index: 0 1 2 3 4 5 6 7 8 9 ...
elem1: 0 0 1 0 1 2 0 1 2 3 ...
elem2: 1 2 2 3 3 3 4 4 4 4 ...
This function convert index back to elem1 and elem2
To implement this, divide it into groups, the first group contains 1
elements, the second contains 2 elements, ..., the nth group contains
n elements.
Let's say we want to compute the elem1 and elem2 for index i. We first find
the number of complete groups contained in index 0, 1, ..., i - 1
(all inclusive, not including i), then i will be in the next group. Let's
say there are N complete groups, then these N groups contains
N * (N + 1) / 2 elements, solving for the largest N that satisfies
N * (N + 1) / 2 <= i, will get the N we want.
"""
# type: (Tensor, Tensor, Tensor, Tensor, Tensor, int, int) -> Tuple[Tensor, Tensor] # noqa: E501
conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1]
# assemble radial subaev
present_radial_aevs = (radial_terms.unsqueeze(-2) * mask_r.unsqueeze(-1).to(radial_terms.dtype)).sum(-3)
# present_radial_aevs has shape
# (conformations, atoms, present species, radial_length)
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
rev_indices = torch.full((num_species,), -1, dtype=present_species.dtype,
device=present_species.device)
rev_indices[present_species] = torch.arange(present_species.numel(),
dtype=torch.long,
device=radial_terms.device)
angular_aevs = []
zero_angular_subaev = torch.zeros(conformations, atoms, angular_sublength,
dtype=radial_terms.dtype,
device=radial_terms.device)
for s1 in range(num_species):
# TODO: make PyTorch support range(start, end) and
# range(start, end, step) and remove the workaround
# below. The inner for loop should be:
# for s2 in range(s1, num_species):
for s2 in range(num_species - s1):
s2 += s1
i1 = int(rev_indices[s1])
i2 = int(rev_indices[s2])
if i1 >= 0 and i2 >= 0:
mask = mask_a[:, :, :, i1, i2].unsqueeze(-1) \
.to(radial_terms.dtype)
subaev = (angular_terms * mask).sum(-2)
else:
subaev = zero_angular_subaev
angular_aevs.append(subaev)
return radial_aevs, torch.cat(angular_aevs, dim=2)
@torch.jit.script
def _compute_aev(num_species, angular_sublength, Rcr, EtaR, ShfR, Rca, ShfZ,
EtaA, Zeta, ShfA, species, species_, distances, vec):
# type: (int, int, float, Tensor, Tensor, float, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor) -> Tuple[Tensor, Tensor] # noqa: E501
present_species = utils.present_species(species)
radial_terms, angular_terms = _terms_and_indices(
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA, distances, vec)
mask_r = _compute_mask_r(species_, num_species)
mask_a = _compute_mask_a(species_, present_species)
radial, angular = _assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a,
num_species, angular_sublength)
fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev
n = (torch.sqrt(1.0 + 8.0 * index.to(torch.float)) - 1.0) / 2.0
n = torch.floor(n).to(torch.long)
num_elems = n * (n + 1) / 2
return index - num_elems, n + 1
# torch.jit.script
def cumsum_from_zero(input_):
cumsum = torch.cumsum(input_, dim=0)
cumsum = torch.cat([input_.new_tensor([0]), cumsum[:-1]])
return cumsum
# torch.jit.script
def triple_by_molecule(molecule_index, atom_index1, atom_index2):
"""Input: indices for pairs of atoms that are close to each other.
each pair only appear once, i.e. only one of the pairs (1, 2) and
(2, 1) exists.
Output: indices for all central atoms and it pairs of neighbors. For
example, if input has pair (0, 1), (0, 2), (0, 3), (0, 4), (1, 2),
(1, 3), (1, 4), (2, 3), (2, 4), (3, 4), then the output would have
central atom 0, 1, 2, 3, 4 and for cental atom 0, its pairs of neighbors
are (1, 2), (1, 3), (1, 4), (2, 3), (2, 4), (3, 4)
"""
# convert representation from pair to central-other
n = molecule_index.shape[0]
mi = molecule_index.repeat(2)
ai1 = torch.cat([atom_index1, atom_index2])
# sort and compute unique key
mi_ai1 = torch.stack([mi, ai1], dim=1)
m_ac, rev_indices, counts = torch._unique_dim2_temporary_will_remove_soon(mi_ai1, dim=0, sorted=True, return_inverse=True, return_counts=True)
uniqued_molecule_index, uniqued_central_atom_index = m_ac.unbind(1)
# do local combinations within unique key, assuming sorted
pair_sizes = counts * (counts - 1) // 2
total_size = pair_sizes.sum()
molecule_index = torch.repeat_interleave(uniqued_molecule_index, pair_sizes)
central_atom_index = torch.repeat_interleave(uniqued_central_atom_index, pair_sizes)
cumsum = cumsum_from_zero(pair_sizes)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
sorted_local_pair_index = torch.arange(total_size, device=molecule_index.device) - cumsum
sorted_local_index1, sorted_local_index2 = convert_pair_index(sorted_local_pair_index)
cumsum = cumsum_from_zero(counts)
cumsum = torch.repeat_interleave(cumsum, pair_sizes)
sorted_local_index1 += cumsum
sorted_local_index2 += cumsum
# unsort result from last part
argsort = rev_indices.argsort()
local_index1 = argsort[sorted_local_index1]
local_index2 = argsort[sorted_local_index2]
# compute mapping between representation of central-other to pair
sign1 = torch.where(local_index1 < n, torch.ones_like(local_index1), -torch.ones_like(local_index1))
sign2 = torch.where(local_index2 < n, torch.ones_like(local_index2), -torch.ones_like(local_index2))
pair_index1 = torch.where(local_index1 < n, local_index1, local_index1 - n)
pair_index2 = torch.where(local_index2 < n, local_index2, local_index2 - n)
return molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2
# torch.jit.script
def compute_aev(species, coordinates, cell, pbc_switch, triu_index, constants, sizes):
Rcr, EtaR, ShfR, Rca, ShfZ, EtaA, Zeta, ShfA = constants
num_species, radial_sublength, radial_length, angular_sublength, angular_length, aev_length = sizes
num_molecules = species.shape[0]
num_atoms = species.shape[1]
num_species_pairs = angular_length // angular_sublength
cutoff = max(Rcr, Rca)
shifts = compute_shifts(cell, pbc_switch, cutoff)
molecule_index, atom_index1, atom_index2, shifts = neighbor_pairs(species == -1, coordinates, cell, shifts, cutoff)
species1 = species[molecule_index, atom_index1]
species2 = species[molecule_index, atom_index2]
shift_values = torch.mm(shifts.to(cell.dtype), cell)
vec = coordinates[molecule_index, atom_index1, :] - coordinates[molecule_index, atom_index2, :] + shift_values
distances = vec.norm(2, -1)
# compute radial aev
radial_terms_ = radial_terms(Rcr, EtaR, ShfR, distances)
radial_aev = radial_terms_.new_zeros(num_molecules * num_atoms * num_species, radial_sublength)
index1 = (molecule_index * num_atoms + atom_index1) * num_species + species2
index2 = (molecule_index * num_atoms + atom_index2) * num_species + species1
radial_aev.scatter_add_(0, index1.unsqueeze(1).expand(-1, radial_sublength), radial_terms_)
radial_aev.scatter_add_(0, index2.unsqueeze(1).expand(-1, radial_sublength), radial_terms_)
radial_aev = radial_aev.reshape(num_molecules, num_atoms, radial_length)
# compute angular aev
molecule_index, central_atom_index, pair_index1, pair_index2, sign1, sign2 = triple_by_molecule(molecule_index, atom_index1, atom_index2)
vec1 = vec.index_select(0, pair_index1) * sign1.unsqueeze(1).to(vec.dtype)
vec2 = vec.index_select(0, pair_index2) * sign2.unsqueeze(1).to(vec.dtype)
species1_ = torch.where(sign1 == 1, species2[pair_index1], species1[pair_index1])
species2_ = torch.where(sign2 == 1, species2[pair_index2], species1[pair_index2])
angular_terms_ = angular_terms(Rca, ShfZ, EtaA, Zeta, ShfA, vec1, vec2)
angular_aev = angular_terms_.new_zeros(num_molecules * num_atoms * num_species_pairs, angular_sublength)
index = (molecule_index * num_atoms + central_atom_index) * num_species_pairs + triu_index[species1_, species2_]
angular_aev.scatter_add_(0, index.unsqueeze(1).expand(-1, angular_sublength), angular_terms_)
angular_aev = angular_aev.reshape(num_molecules, num_atoms, angular_length)
return torch.cat([radial_aev, angular_aev], dim=-1)
class AEVComputer(torch.nn.Module):
......@@ -271,20 +329,6 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
species and coordinates tensor have the same shape convention as
the input of :class:`AEVComputer`. The returned neighbor
species and coordinates tensor must have shape ``(C, A, N)`` and
``(C, A, N, 3)`` correspoindingly, where ``C`` is the number of
conformations in a chunk, ``A`` is the number of atoms, and ``N``
is the maximum number of neighbors that an atom could have.
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
......@@ -293,8 +337,7 @@ class AEVComputer(torch.nn.Module):
'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
num_species, neighborlist_computer=default_neighborlist):
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species):
super(AEVComputer, self).__init__()
self.Rcr = Rcr
self.Rca = Rca
......@@ -309,43 +352,60 @@ class AEVComputer(torch.nn.Module):
self.register_buffer('ShfZ', ShfZ.view(1, 1, 1, -1))
self.num_species = num_species
self.neighborlist = neighborlist_computer
# The length of radial subaev of a single species
self.radial_sublength = self.EtaR.numel() * self.ShfR.numel()
# The length of full radial aev
self.radial_length = self.num_species * self.radial_sublength
# The length of angular subaev of a single species
self.angular_sublength = self.EtaA.numel() * self.Zeta.numel() * \
self.ShfA.numel() * self.ShfZ.numel()
self.angular_sublength = self.EtaA.numel() * self.Zeta.numel() * self.ShfA.numel() * self.ShfZ.numel()
# The length of full angular aev
self.angular_length = (self.num_species * (self.num_species + 1)) \
// 2 * self.angular_sublength
self.angular_length = (self.num_species * (self.num_species + 1)) // 2 * self.angular_sublength
# The length of full aev
self.aev_length = self.radial_length + self.angular_length
self.sizes = self.num_species, self.radial_sublength, self.radial_length, self.angular_sublength, self.angular_length, self.aev_length
self.register_buffer('triu_index', triu_index(num_species))
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
# @torch.jit.script_method
def forward(self, species_coordinates):
def forward(self, input):
"""Compute AEVs
Arguments:
species_coordinates (tuple): Two tensors: species and coordinates.
input (tuple): Can be one of the following two cases:
If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species and coordinates.
species must have shape ``(C, A)`` and coordinates must have
shape ``(C, A, 3)``, where ``C`` is the number of conformations
shape ``(C, A, 3)``, where ``C`` is the number of molecules
in a chunk, and ``A`` is the number of atoms.
If you want to apply periodic boundary conditions, then the input
would be a tuple of four tensors: species, coordinates, cell, pbc
where species and coordinates are the same as described above, cell
is a tensor of shape (3, 3) of the three vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
and pbc is boolean vector of size 3 storing if pbc is enabled
for that direction.
Returns:
tuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
# type: (Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]
species, coordinates = species_coordinates
max_cutoff = max(self.Rcr, self.Rca)
species_, distances, vec = self.neighborlist(species, coordinates,
max_cutoff)
return _compute_aev(
self.num_species, self.angular_sublength, self.Rcr, self.EtaR,
self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA,
species, species_, distances, vec)
if len(input) == 2:
species, coordinates = input
cell = torch.eye(3, dtype=self.EtaR.dtype, device=self.EtaR.device)
pbc = torch.zeros(3, dtype=torch.uint8, device=self.EtaR.device)
else:
assert len(input) == 4
species, coordinates, cell, pbc = input
return species, compute_aev(species, coordinates, cell, pbc, self.triu_index, self.constants(), self.sizes)
......@@ -6,95 +6,13 @@
"""
from __future__ import absolute_import
import math
import torch
import ase.neighborlist
from . import utils
import ase.calculators.calculator
import ase.units
import copy
class NeighborList(torch.nn.Module):
"""ASE neighborlist computer
Arguments:
cell: same as in :class:`ase.Atoms`
pbc: same as in :class:`ase.Atoms`
"""
def __init__(self, cell=None, pbc=None):
# wrap `cell` and `pbc` with `ase.Atoms`
super(NeighborList, self).__init__()
a = ase.Atoms('He', [[0, 0, 0]], cell=cell, pbc=pbc)
self.pbc = a.get_pbc()
self.cell = a.get_cell(complete=True)
def forward(self, species, coordinates, cutoff):
conformations = species.shape[0]
max_atoms = species.shape[1]
neighbor_species = []
neighbor_distances = []
neighbor_vecs = []
for i in range(conformations):
s = species[i].unsqueeze(0)
c = coordinates[i].unsqueeze(0)
s, c = utils.strip_redundant_padding(s, c)
s = s.squeeze()
c = c.squeeze()
atoms = s.shape[0]
atoms_object = ase.Atoms(
['He'] * atoms, # chemical symbols are not important here
positions=c.detach().numpy(),
pbc=self.pbc,
cell=self.cell)
idx1, idx2, shift = ase.neighborlist.neighbor_list(
'ijS', atoms_object, cutoff)
# NB: The absolute distance and distance vectors computed by
# `neighbor_list`can not be used since it does not preserve
# gradient information
idx1 = torch.tensor(idx1, device=coordinates.device,
dtype=torch.long)
idx2 = torch.tensor(idx2, device=coordinates.device,
dtype=torch.long)
D = c.index_select(0, idx2) - c.index_select(0, idx1)
shift = torch.tensor(shift, device=coordinates.device,
dtype=coordinates.dtype)
cell = torch.tensor(self.cell, device=coordinates.device,
dtype=coordinates.dtype)
D += torch.mm(shift, cell)
d = D.norm(2, -1)
neighbor_species1 = []
neighbor_distances1 = []
neighbor_vecs1 = []
for i in range(atoms):
this_atom_indices = (idx1 == i).nonzero().flatten()
neighbor_indices = idx2[this_atom_indices]
neighbor_species1.append(s[neighbor_indices])
neighbor_distances1.append(d[this_atom_indices])
neighbor_vecs1.append(D.index_select(0, this_atom_indices))
for i in range(max_atoms - atoms):
neighbor_species1.append(torch.full((1,), -1))
neighbor_distances1.append(torch.full((1,), math.inf))
neighbor_vecs1.append(torch.full((1, 3), 0))
neighbor_species1 = torch.nn.utils.rnn.pad_sequence(
neighbor_species1, padding_value=-1)
neighbor_distances1 = torch.nn.utils.rnn.pad_sequence(
neighbor_distances1, padding_value=math.inf)
neighbor_vecs1 = torch.nn.utils.rnn.pad_sequence(
neighbor_vecs1, padding_value=0)
neighbor_species.append(neighbor_species1)
neighbor_distances.append(neighbor_distances1)
neighbor_vecs.append(neighbor_vecs1)
neighbor_species = torch.nn.utils.rnn.pad_sequence(
neighbor_species, batch_first=True, padding_value=-1)
neighbor_distances = torch.nn.utils.rnn.pad_sequence(
neighbor_distances, batch_first=True, padding_value=math.inf)
neighbor_vecs = torch.nn.utils.rnn.pad_sequence(
neighbor_vecs, batch_first=True, padding_value=0)
return neighbor_species.permute(0, 2, 1), \
neighbor_distances.permute(0, 2, 1), \
neighbor_vecs.permute(0, 2, 1, 3)
import numpy
class Calculator(ase.calculators.calculator.Calculator):
......@@ -109,16 +27,12 @@ class Calculator(ase.calculators.calculator.Calculator):
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
"""
implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter,
dtype=torch.float64, _default_neighborlist=False):
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64):
super(Calculator, self).__init__()
self._default_neighborlist = _default_neighborlist
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
......@@ -138,16 +52,18 @@ class Calculator(ase.calculators.calculator.Calculator):
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist:
self.aev_computer.neighborlist.pbc = self.atoms.get_pbc()
self.aev_computer.neighborlist.cell = \
self.atoms.get_cell(complete=True)
cell = torch.tensor(self.atoms.get_cell(complete=True),
requires_grad=True, dtype=self.dtype,
device=self.device)
pbc = torch.tensor(self.atoms.get_pbc().astype(numpy.uint8), dtype=torch.uint8,
device=self.device)
# print(cell, pbc)
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions())
coordinates = coordinates.unsqueeze(0).to(self.device).to(self.dtype) \
.requires_grad_('forces' in properties)
_, energy = self.whole((species, coordinates))
_, energy = self.whole((species, coordinates, cell, pbc))
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
if 'forces' in properties:
......
......@@ -65,7 +65,7 @@ def pad_coordinates(species_coordinates):
return torch.cat(species), torch.cat(coordinates)
@torch.jit.script
# @torch.jit.script
def present_species(species):
"""Given a vector of species of atoms, compute the unique species present.
......@@ -75,7 +75,8 @@ def present_species(species):
Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted.
"""
present_species, _ = species.flatten()._unique(sorted=True)
# present_species, _ = species.flatten()._unique(sorted=True)
present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1:
present_species = present_species[1:]
return present_species
......@@ -86,9 +87,9 @@ def strip_redundant_padding(species, coordinates):
Arguments:
species (:class:`torch.Tensor`): Long tensor of shape
``(conformations, atoms)``.
``(molecules, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape
``(conformations, atoms, 3)``.
``(molecules, atoms, 3)``.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
......@@ -100,6 +101,39 @@ def strip_redundant_padding(species, coordinates):
return species, coordinates
def map2central(cell, coordinates, pbc):
"""Map atoms outside the unit cell into the cell using PBC.
Arguments:
cell (:class:`torch.Tensor`): tensor of shape (3, 3) of the three
vectors defining unit cell:
.. code-block:: python
tensor([[x1, y1, z1],
[x2, y2, z2],
[x3, y3, z3]])
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
pbc (:class:`torch.Tensor`): boolean vector of size 3 storing
if pbc is enabled for that direction.
Returns:
:class:`torch.Tensor`: coordinates of atoms mapped back to unit cell.
"""
# Step 1: convert coordinates from standard cartesian coordinate to unit
# cell coordinates
inv_cell = torch.inverse(cell)
coordinates_cell = torch.matmul(coordinates, inv_cell)
# Step 2: wrap cell coordinates into [0, 1)
coordinates_cell -= coordinates_cell.floor() * pbc.to(coordinates_cell.dtype)
# Step 3: convert from cell coordinates back to standard cartesian
# coordinate
return torch.matmul(coordinates_cell, cell)
class EnergyShifter(torch.nn.Module):
"""Helper class for adding and subtracting self atomic energies
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment