Unverified Commit 69168662 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

More test for ASE (#125)

parent 06cd86db
from ase.lattice.cubic import Diamond from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase import units from ase import units, Atoms
from ase.calculators.test import numeric_force from ase.calculators.test import numeric_force
import torch import torch
import torchani import torchani
import unittest import unittest
import numpy
def get_numeric_force(atoms, eps): def get_numeric_force(atoms, eps):
...@@ -17,8 +18,8 @@ def get_numeric_force(atoms, eps): ...@@ -17,8 +18,8 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase): class TestASE(unittest.TestCase):
def testForceWithPBCEnabled(self): def _testForce(self, pbc):
atoms = Diamond(symbol="C", pbc=True) atoms = Diamond(symbol="C", pbc=pbc)
builtin = torchani.neurochem.Builtins() builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator( calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer, builtin.species, builtin.aev_computer,
...@@ -30,8 +31,61 @@ class TestASE(unittest.TestCase): ...@@ -30,8 +31,61 @@ class TestASE(unittest.TestCase):
fn = get_numeric_force(atoms, 0.001) fn = get_numeric_force(atoms, 0.001)
df = (f - fn).abs().max() df = (f - fn).abs().max()
avgf = f.abs().mean() avgf = f.abs().mean()
if avgf > 0:
self.assertLess(df / avgf, 0.1) self.assertLess(df / avgf, 0.1)
def testForceWithPBCEnabled(self):
self._testForce(True)
def testForceWithPBCDisabled(self):
self._testForce(False)
def testForceAgainstDefaultNeighborList(self):
atoms = Diamond(symbol="C", pbc=False)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
default_neighborlist_calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter, True)
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 50 * units.kB, 0.002)
def test_energy(a=atoms):
a = a.copy()
a.set_calculator(calculator)
e1 = a.get_potential_energy()
a.set_calculator(default_neighborlist_calculator)
e2 = a.get_potential_energy()
self.assertEqual(e1, e2)
dyn.attach(test_energy, interval=1)
dyn.run(500)
def testTranslationalInvariancePBC(self):
atoms = Atoms('CH4', [[0, 0, 0],
[1, 0, 0],
[0, 1, 0],
[0, 0, 1],
[0, 1, 1]],
cell=[2, 2, 2], pbc=True)
builtin = torchani.neurochem.Builtins()
calculator = torchani.ase.Calculator(
builtin.species, builtin.aev_computer,
builtin.models, builtin.energy_shifter)
atoms.set_calculator(calculator)
e = atoms.get_potential_energy()
for _ in range(100):
positions = atoms.get_positions()
translation = (numpy.random.rand(3) - 0.5) * 2
atoms.set_positions(positions + translation)
self.assertEqual(e, atoms.get_potential_energy())
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -213,6 +213,8 @@ class AEVComputer(torch.nn.Module): ...@@ -213,6 +213,8 @@ class AEVComputer(torch.nn.Module):
# TODO: remove this when combinations is merged into PyTorch # TODO: remove this when combinations is merged into PyTorch
# https://github.com/pytorch/pytorch/pull/9393 # https://github.com/pytorch/pytorch/pull/9393
n = tensor.shape[dim] n = tensor.shape[dim]
if n == 0:
return tensor, tensor
r = torch.arange(n, dtype=torch.long, device=tensor.device) r = torch.arange(n, dtype=torch.long, device=tensor.device)
grid_x, grid_y = torch.meshgrid([r, r]) grid_x, grid_y = torch.meshgrid([r, r])
index1 = grid_y.masked_select( index1 = grid_y.masked_select(
......
...@@ -11,6 +11,7 @@ import ase.neighborlist ...@@ -11,6 +11,7 @@ import ase.neighborlist
from . import utils from . import utils
import ase.calculators.calculator import ase.calculators.calculator
import ase.units import ase.units
import copy
class NeighborList: class NeighborList:
...@@ -102,14 +103,20 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -102,14 +103,20 @@ class Calculator(ase.calculators.calculator.Calculator):
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`): model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models. neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter. energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
_default_neighborlist (bool): Whether to ignore pbc setting and always
use default neighborlist computer. This is for internal use only.
""" """
implemented_properties = ['energy', 'forces'] implemented_properties = ['energy', 'forces']
def __init__(self, species, aev_computer, model, energy_shifter): def __init__(self, species, aev_computer, model, energy_shifter,
_default_neighborlist=False):
super(Calculator, self).__init__() super(Calculator, self).__init__()
self._default_neighborlist = _default_neighborlist
self.species_to_tensor = utils.ChemicalSymbolsToInts(species) self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
self.aev_computer = aev_computer # aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
self.aev_computer = copy.copy(aev_computer)
self.model = model self.model = model
self.energy_shifter = energy_shifter self.energy_shifter = energy_shifter
...@@ -125,8 +132,10 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -125,8 +132,10 @@ class Calculator(ase.calculators.calculator.Calculator):
def calculate(self, atoms=None, properties=['energy'], def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes): system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes) super(Calculator, self).calculate(atoms, properties, system_changes)
if not self._default_neighborlist:
self.aev_computer.neighborlist = NeighborList( self.aev_computer.neighborlist = NeighborList(
cell=self.atoms.get_cell(complete=True), pbc=self.atoms.get_pbc()) cell=self.atoms.get_cell(complete=True),
pbc=self.atoms.get_pbc())
species = self.species_to_tensor(self.atoms.get_chemical_symbols()) species = self.species_to_tensor(self.atoms.get_chemical_symbols())
species = species.unsqueeze(0) species = species.unsqueeze(0)
coordinates = torch.tensor(self.atoms.get_positions(wrap=True)) coordinates = torch.tensor(self.atoms.get_positions(wrap=True))
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment