Unverified Commit 69168662 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

More test for ASE (#125)

parent 06cd86db
from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin
from ase import units
from ase import units, Atoms
from ase.calculators.test import numeric_force
import torch
import torchani
import unittest
import numpy
def get_numeric_force(atoms, eps):
......@@ -17,8 +18,8 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase):
def testForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True)
def _testForce(self, pbc):
atoms = Diamond(symbol="C", pbc=pbc)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
......@@ -30,8 +31,61 @@ class TestASE(unittest.TestCase):
fn = get_numeric_force(atoms, 0.001)
df = (f - fn).abs().max()
avgf = f.abs().mean()
if avgf > 0:
self.assertLess(df / avgf, 0.1)
def testForceWithPBCEnabled(self):
self._testForce(True)
def testForceWithPBCDisabled(self):
self._testForce(False)
def testForceAgainstDefaultNeighborList(self):
atoms = Diamond(symbol="C", pbc=False)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, True)
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 50 * units.kB, 0.002)
def test_energy(a=atoms):
a = a.copy()
a.set_calculator(calculator)
e1 = a.get_potential_energy()
a.set_calculator(default_neighborlist_calculator)
e2 = a.get_potential_energy()
self.assertEqual(e1, e2)
dyn.attach(test_energy, interval=1)
dyn.run(500)
def testTranslationalInvariancePBC(self):
atoms = Atoms('CH4', [[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]],
cell=[2, 2, 2], pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
atoms.set_calculator(calculator)
e = atoms.get_potential_energy()
for _ in range(100):
positions = atoms.get_positions()
translation = (numpy.random.rand(3) - 0.5) * 2
atoms.set_positions(positions + translation)
self.assertEqual(e, atoms.get_potential_energy())
if __name__ == '__main__':
unittest.main()
......@@ -213,6 +213,8 @@ class AEVComputer(torch.nn.Module):
# TODO: remove this when combinations is merged into PyTorch
# https://github.com/pytorch/pytorch/pull/9393
n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device)
grid_x, grid_y = torch.meshgrid([r, r])
index1 = grid_y.masked_select(
......
......@@ -11,6 +11,7 @@ import ase.neighborlist
from . import utils
import ase.calculators.calculator
import ase.units
import copy
class NeighborList:
......@@ -102,14 +103,20 @@ class Calculator(ase.calculators.calculator.Calculator):
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
"""
implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter):
def __init__(self, species, aev_computer, model, energy_shifter,
_default_neighborlist=False):
super(Calculator, self).__init__()
self._default_neighborlist = _default_neighborlist
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
self.aev_computer = aev_computer
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
self.aev_computer = copy.copy(aev_computer)
self.model = model
self.energy_shifter = energy_shifter
......@@ -125,8 +132,10 @@ class Calculator(ase.calculators.calculator.Calculator):
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist:
self.aev_computer.neighborlist = NeighborList(
cell=self.atoms.get_cell(complete=True), pbc=self.atoms.get_pbc())
cell=self.atoms.get_cell(complete=True),
pbc=self.atoms.get_pbc())
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions(wrap=True))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment