".github/vscode:/vscode.git/clone" did not exist on "4011273a86e8eaab435d3e0965df79512b3813ca"
Commit 1a2b4504 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Allow passing pbc and cell to sequential, and simplify ASE interface (#386)

* Allow passing pbc and cell to sequential

* revert

* Simplify ase interface

* fix

* fix

* fix

* flake8

* Update ase.py

* fix
parent d400d8f0
...@@ -55,10 +55,8 @@ print(nnp2) ...@@ -55,10 +55,8 @@ print(nnp2)
############################################################################### ###############################################################################
# You can also create an ASE calculator using the ensemble or single model: # You can also create an ASE calculator using the ensemble or single model:
calculator1 = torchani.ase.Calculator(consts.species, aev_computer, calculator1 = torchani.ase.Calculator(consts.species, nnp1)
ensemble, energy_shifter) calculator2 = torchani.ase.Calculator(consts.species, nnp2)
calculator2 = torchani.ase.Calculator(consts.species, aev_computer,
model, energy_shifter)
print(calculator1) print(calculator1)
print(calculator1) print(calculator1)
......
...@@ -24,9 +24,12 @@ def get_numeric_force(atoms, eps): ...@@ -24,9 +24,12 @@ def get_numeric_force(atoms, eps):
class TestASE(unittest.TestCase): class TestASE(unittest.TestCase):
def setUp(self):
self.model = torchani.models.ANI1x().double()[0]
def testWithNumericalForceWithPBCEnabled(self): def testWithNumericalForceWithPBCEnabled(self):
atoms = Diamond(symbol="C", pbc=True) atoms = Diamond(symbol="C", pbc=True)
calculator = torchani.models.ANI1x().ase() calculator = self.model.ase()
atoms.set_calculator(calculator) atoms.set_calculator(calculator)
dyn = Langevin(atoms, 5 * units.fs, 30000000 * units.kB, 0.002) dyn = Langevin(atoms, 5 * units.fs, 30000000 * units.kB, 0.002)
dyn.run(100) dyn.run(100)
...@@ -40,7 +43,7 @@ class TestASE(unittest.TestCase): ...@@ -40,7 +43,7 @@ class TestASE(unittest.TestCase):
def testWithNumericalStressWithPBCEnabled(self): def testWithNumericalStressWithPBCEnabled(self):
filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif') filename = os.path.join(path, '../tools/generate-unit-test-expect/others/Benzene.cif')
benzene = read(filename) benzene = read(filename)
calculator = torchani.models.ANI1x().ase() calculator = self.model.ase()
benzene.set_calculator(calculator) benzene.set_calculator(calculator)
dyn = NPTBerendsen(benzene, timestep=0.1 * units.fs, dyn = NPTBerendsen(benzene, timestep=0.1 * units.fs,
temperature=300 * units.kB, temperature=300 * units.kB,
......
...@@ -16,9 +16,7 @@ class TestStructureOptimization(unittest.TestCase): ...@@ -16,9 +16,7 @@ class TestStructureOptimization(unittest.TestCase):
def setUp(self): def setUp(self):
self.tolerance = 1e-6 self.tolerance = 1e-6
self.ani1x = torchani.models.ANI1x() self.ani1x = torchani.models.ANI1x()
self.calculator = torchani.ase.Calculator( self.calculator = self.ani1x[0].ase()
self.ani1x.species, self.ani1x.aev_computer,
self.ani1x.neural_networks[0], self.ani1x.energy_shifter)
def testRMSE(self): def testRMSE(self):
datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all') datafile = os.path.join(path, 'test_data/NeuroChemOptimized/all')
......
...@@ -360,7 +360,8 @@ class AEVComputer(torch.nn.Module): ...@@ -360,7 +360,8 @@ class AEVComputer(torch.nn.Module):
def constants(self): def constants(self):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None, def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesAEV: pbc: Optional[Tensor] = None) -> SpeciesAEV:
"""Compute AEVs """Compute AEVs
......
...@@ -6,12 +6,9 @@ ...@@ -6,12 +6,9 @@
""" """
import torch import torch
from .nn import Sequential
import ase.neighborlist
from . import utils from . import utils
import ase.calculators.calculator import ase.calculators.calculator
import ase.units import ase.units
import copy
class Calculator(ase.calculators.calculator.Calculator): class Calculator(ase.calculators.calculator.Calculator):
...@@ -20,12 +17,8 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -20,12 +17,8 @@ class Calculator(ase.calculators.calculator.Calculator):
Arguments: Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`): species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order. sequence of all supported species, in order.
aev_computer (:class:`torchani.AEVComputer`): AEV computer. model (:class:`torch.nn.Module`): neural network potential model
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`): that convert coordinates into energies.
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
dtype (:class:`torchani.EnergyShifter`): data type to use,
by dafault ``torch.float64``.
overwrite (bool): After wrapping atoms into central box, whether overwrite (bool): After wrapping atoms into central box, whether
to replace the original positions stored in :class:`ase.Atoms` to replace the original positions stored in :class:`ase.Atoms`
object with the wrapped positions. object with the wrapped positions.
...@@ -33,24 +26,15 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -33,24 +26,15 @@ class Calculator(ase.calculators.calculator.Calculator):
implemented_properties = ['energy', 'forces', 'stress', 'free_energy'] implemented_properties = ['energy', 'forces', 'stress', 'free_energy']
def __init__(self, species, aev_computer, model, energy_shifter, dtype=torch.float64, overwrite=False): def __init__(self, species, model, overwrite=False):
super(Calculator, self).__init__() super(Calculator, self).__init__()
self.species_to_tensor = utils.ChemicalSymbolsToInts(species) self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
# aev_computer.neighborlist will be changed later, so we need a copy to self.model = model
# make sure we do not change the original object
aev_computer = copy.deepcopy(aev_computer)
self.aev_computer = aev_computer.to(dtype)
self.model = copy.deepcopy(model)
self.energy_shifter = copy.deepcopy(energy_shifter)
self.overwrite = overwrite self.overwrite = overwrite
self.device = self.aev_computer.EtaR.device a_parameter = next(self.model.parameters())
self.dtype = dtype self.device = a_parameter.device
self.dtype = a_parameter.dtype
self.nn = Sequential(
self.model,
self.energy_shifter
).to(dtype)
def calculate(self, atoms=None, properties=['energy'], def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes): system_changes=ase.calculators.calculator.all_changes):
...@@ -79,11 +63,10 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -79,11 +63,10 @@ class Calculator(ase.calculators.calculator.Calculator):
if pbc_enabled: if pbc_enabled:
if 'stress' in properties: if 'stress' in properties:
cell = cell @ scaling cell = cell @ scaling
aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs energy = self.model((species, coordinates), cell=cell, pbc=pbc).energies
else: else:
aev = self.aev_computer((species, coordinates)).aevs energy = self.model((species, coordinates)).energies
energy = self.nn((species, aev)).energies
energy *= ase.units.Hartree energy *= ase.units.Hartree
self.results['energy'] = energy.item() self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item() self.results['free_energy'] = energy.item()
......
...@@ -29,7 +29,7 @@ shouldn't be used anymore. ...@@ -29,7 +29,7 @@ shouldn't be used anymore.
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple from typing import Tuple, Optional
from pkg_resources import resource_filename from pkg_resources import resource_filename
from . import neurochem from . import neurochem
from .nn import Sequential from .nn import Sequential
...@@ -89,16 +89,20 @@ class BuiltinNet(torch.nn.Module): ...@@ -89,16 +89,20 @@ class BuiltinNet(torch.nn.Module):
self.neural_networks = neurochem.load_model_ensemble( self.neural_networks = neurochem.load_model_ensemble(
self.species, self.ensemble_prefix, self.ensemble_size) self.species, self.ensemble_prefix, self.ensemble_size)
def forward(self, species_coordinates: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
"""Calculates predicted properties for minibatch of configurations """Calculates predicted properties for minibatch of configurations
Args: Args:
species_coordinates: minibatch of configurations species_coordinates: minibatch of configurations
cell: the cell used in PBC computation, set to None if PBC is not enabled
pbc: the bool tensor indicating which direction PBC is enabled, set to None if PBC is not enabled
Returns: Returns:
species_energies: energies for the given configurations species_energies: energies for the given configurations
""" """
species_aevs = self.aev_computer(species_coordinates) species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs) species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies) return self.energy_shifter(species_energies)
...@@ -126,11 +130,7 @@ class BuiltinNet(torch.nn.Module): ...@@ -126,11 +130,7 @@ class BuiltinNet(torch.nn.Module):
def ase(**kwargs): def ase(**kwargs):
"""Attach an ase calculator """ """Attach an ase calculator """
from . import ase from . import ase
return ase.Calculator(self.species, return ase.Calculator(self.species, ret, **kwargs)
self.aev_computer,
self.neural_networks[index],
self.energy_shifter,
**kwargs)
ret.ase = ase ret.ase = ase
ret.species_to_tensor = self.consts.species_to_tensor ret.species_to_tensor = self.consts.species_to_tensor
...@@ -154,9 +154,7 @@ class BuiltinNet(torch.nn.Module): ...@@ -154,9 +154,7 @@ class BuiltinNet(torch.nn.Module):
calculator (:class:`int`): A calculator to be used with ASE calculator (:class:`int`): A calculator to be used with ASE
""" """
from . import ase from . import ase
return ase.Calculator(self.species, self.aev_computer, return ase.Calculator(self.species, self, **kwargs)
self.neural_networks, self.energy_shifter,
**kwargs)
def species_to_tensor(self, *args, **kwargs): def species_to_tensor(self, *args, **kwargs):
"""Convert species from strings to tensor. """Convert species from strings to tensor.
......
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple, NamedTuple from typing import Tuple, NamedTuple, Optional
class SpeciesEnergies(NamedTuple): class SpeciesEnergies(NamedTuple):
...@@ -31,7 +31,11 @@ class ANIModel(torch.nn.Module): ...@@ -31,7 +31,11 @@ class ANIModel(torch.nn.Module):
def __getitem__(self, i): def __getitem__(self, i):
return self.module_list[i] return self.module_list[i]
def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies: def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
...@@ -56,7 +60,11 @@ class Ensemble(torch.nn.Module): ...@@ -56,7 +60,11 @@ class Ensemble(torch.nn.Module):
self.modules_list = torch.nn.ModuleList(modules) self.modules_list = torch.nn.ModuleList(modules)
self.size = len(self.modules_list) self.size = len(self.modules_list)
def forward(self, species_input: Tuple[Tensor, Tensor]) -> SpeciesEnergies: def forward(self, species_input: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
sum_ = 0 sum_ = 0
for x in self.modules_list: for x in self.modules_list:
sum_ += x(species_input)[1] sum_ += x(species_input)[1]
...@@ -74,9 +82,13 @@ class Sequential(torch.nn.Module): ...@@ -74,9 +82,13 @@ class Sequential(torch.nn.Module):
super(Sequential, self).__init__() super(Sequential, self).__init__()
self.modules_list = torch.nn.ModuleList(modules) self.modules_list = torch.nn.ModuleList(modules)
def forward(self, input_: Tuple[Tensor, Tensor]): def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
for module in self.modules_list: for module in self.modules_list:
input_ = module(input_) input_ = module(input_, cell=cell, pbc=pbc)
cell = None
pbc = None
return input_ return input_
......
...@@ -4,7 +4,7 @@ import torch.utils.data ...@@ -4,7 +4,7 @@ import torch.utils.data
import math import math
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Tuple, NamedTuple from typing import Tuple, NamedTuple, Optional
from .nn import SpeciesEnergies from .nn import SpeciesEnergies
...@@ -212,9 +212,13 @@ class EnergyShifter(torch.nn.Module): ...@@ -212,9 +212,13 @@ class EnergyShifter(torch.nn.Module):
properties['energies'] = energies properties['energies'] = energies
return atomic_properties, properties return atomic_properties, properties
def forward(self, species_energies: Tuple[Tensor, Tensor]) -> SpeciesEnergies: def forward(self, species_energies: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
""" """
assert cell is None
assert pbc is None
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.device) sae = self.sae(species).to(energies.device)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae) return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment