Commit 1a2b4504 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Allow passing pbc and cell to sequential, and simplify ASE interface (#386)

* Allow passing pbc and cell to sequential

* revert

* Simplify ase interface

* fix

* fix

* fix

* flake8

* Update ase.py

* fix
parent d400d8f0
......@@ -55,10 +55,8 @@ print(nnp2)
###############################################################################
# You can also create an ASE calculator using the ensemble or single model:
calculator1 = torchani.ase.Calculator(consts.species, aev_computer,
ensemble, energy_shifter)
calculator2 = torchani.ase.Calculator(consts.species, aev_computer,
model, energy_shifter)
calculator1 = torchani.ase.Calculator(consts.species, nnp1)
calculator2 = torchani.ase.Calculator(consts.species, nnp2)
print(calculator1)
print(calculator1)
......
......@@ -24,9 +24,12 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase):
def setUp(self):
self.model = torchani.models.ANI1x().double()[0]
def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True)
calculator = torchani.models.ANI1x().ase()
calculator = self.model.ase()
atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 30000000 * units.kB, 0.002)
dyn.run(100)
......@@ -40,7 +43,7 @@ class TestASE(unittest.TestCase):
def testWithNumericalStressWithPBCEnabled(self):
filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
benzene = read(filename)
calculator = torchani.models.ANI1x().ase()
calculator = self.model.ase()
benzene.set_calculator(calculator)
dyn = NPTBerendsen(benzene, timestep=0.1 * units.fs,
temperature=300 * units.kB,
......
......@@ -16,9 +16,7 @@ class TestStructureOptimization(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-6
self.ani1x = torchani.models.ANI1x()
self.calculator = torchani.ase.Calculator(
self.ani1x.species, self.ani1x.aev_computer,
self.ani1x.neural_networks[0], self.ani1x.energy_shifter)
self.calculator = self.ani1x[0].ase()
def testRMSE(self):
datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all')
......
......@@ -360,7 +360,8 @@ class AEVComputer(torch.nn.Module):
def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesAEV:
"""Compute AEVs
......
......@@ -6,12 +6,9 @@
"""
import torch
from .nn import Sequential
import ase.neighborlist
from . import utils
import ase.calculators.calculator
import ase.units
import copy
class Calculator(ase.calculators.calculator.Calculator):
......@@ -20,12 +17,8 @@ class Calculator(ase.calculators.calculator.Calculator):
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
aev_computer (:class:`torchani.AEVComputer`): AEV computer.
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``.
model (:class:`torch.nn.Module`): neural network potential model
that convert coordinates into energies.
overwrite (bool): After wrapping atoms into central box, whether
to replace the original positions stored in :class:`ase.Atoms`
object with the wrapped positions.
......@@ -33,24 +26,15 @@ class Calculator(ase.calculators.calculator.Calculator):
implemented_properties = ['energy', 'forces', 'stress', 'free_energy']
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False):
def __init__(self, species, model, overwrite=False):
super(Calculator, self).__init__()
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to
# make sure we do not change the original object
aev_computer = copy.deepcopy(aev_computer)
self.aev_computer = aev_computer.to(dtype)
self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter)
self.model = model
self.overwrite = overwrite
self.device = self.aev_computer.EtaR.device
self.dtype = dtype
self.nn = Sequential(
self.model,
self.energy_shifter
).to(dtype)
a_parameter = next(self.model.parameters())
self.device = a_parameter.device
self.dtype = a_parameter.dtype
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
......@@ -79,11 +63,10 @@ class Calculator(ase.calculators.calculator.Calculator):
if pbc_enabled:
if 'stress' in properties:
cell = cell @ scaling
aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs
energy = self.model((species, coordinates), cell=cell, pbc=pbc).energies
else:
aev = self.aev_computer((species, coordinates)).aevs
energy = self.model((species, coordinates)).energies
energy = self.nn((species, aev)).energies
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
......
......@@ -29,7 +29,7 @@ shouldn't be used anymore.
import torch
from torch import Tensor
from typing import Tuple
from typing import Tuple, Optional
from pkg_resources import resource_filename
from . import neurochem
from .nn import Sequential
......@@ -89,16 +89,20 @@ class BuiltinNet(torch.nn.Module):
self.neural_networks = neurochem.load_model_ensemble(
self.species, self.ensemble_prefix, self.ensemble_size)
def forward(self, species_coordinates: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Calculates predicted properties for minibatch of configurations
Args:
species_coordinates: minibatch of configurations
cell: the cell used in PBC computation, set to None if PBC is not enabled
pbc: the bool tensor indicating which direction PBC is enabled, set to None if PBC is not enabled
Returns:
species_energies: energies for the given configurations
"""
species_aevs = self.aev_computer(species_coordinates)
species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies)
......@@ -126,11 +130,7 @@ class BuiltinNet(torch.nn.Module):
def ase(**kwargs):
"""Attach an ase calculator """
from . import ase
return ase.Calculator(self.species,
self.aev_computer,
self.neural_networks[index],
self.energy_shifter,
**kwargs)
return ase.Calculator(self.species, ret, **kwargs)
ret.ase = ase
ret.species_to_tensor = self.consts.species_to_tensor
......@@ -154,9 +154,7 @@ class BuiltinNet(torch.nn.Module):
calculator (:class:`int`): A calculator to be used with ASE
"""
from . import ase
return ase.Calculator(self.species, self.aev_computer,
self.neural_networks, self.energy_shifter,
**kwargs)
return ase.Calculator(self.species, self, **kwargs)
def species_to_tensor(self, *args, **kwargs):
"""Convert species from strings to tensor.
......
import torch
from torch import Tensor
from typing import Tuple, NamedTuple
from typing import Tuple, NamedTuple, Optional
class SpeciesEnergies(NamedTuple):
......@@ -31,7 +31,11 @@ class ANIModel(torch.nn.Module):
def __getitem__(self, i):
return self.module_list[i]
def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
species, aev = species_aev
species_ = species.flatten()
aev = aev.flatten(0, 1)
......@@ -56,7 +60,11 @@ class Ensemble(torch.nn.Module):
self.modules_list = torch.nn.ModuleList(modules)
self.size = len(self.modules_list)
def forward(self, species_input: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
def forward(self, species_input: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
sum_ = 0
for x in self.modules_list:
sum_ += x(species_input)[1]
......@@ -74,9 +82,13 @@ class Sequential(torch.nn.Module):
super(Sequential, self).__init__()
self.modules_list = torch.nn.ModuleList(modules)
def forward(self, input_: Tuple[Tensor, Tensor]):
def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
for module in self.modules_list:
input_ = module(input_)
input_ = module(input_, cell=cell, pbc=pbc)
cell = None
pbc = None
return input_
......
......@@ -4,7 +4,7 @@ import torch.utils.data
import math
import numpy as np
from collections import defaultdict
from typing import Tuple, NamedTuple
from typing import Tuple, NamedTuple, Optional
from .nn import SpeciesEnergies
......@@ -212,9 +212,13 @@ class EnergyShifter(torch.nn.Module):
properties['energies'] = energies
return atomic_properties, properties
def forward(self, species_energies: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
def forward(self, species_energies: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae)
"""
assert cell is None
assert pbc is None
species, energies = species_energies
sae = self.sae(species).to(energies.device)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment