Unverified Commit f9db30c3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

ASE calculator (#121)

parent 84fc8d80
...@@ -27,6 +27,8 @@ Utilities ...@@ -27,6 +27,8 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates .. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species .. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding .. autofunction:: torchani.utils.strip_redundant_padding
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
NeuroChem NeuroChem
...@@ -51,6 +53,8 @@ ASE Interface ...@@ -51,6 +53,8 @@ ASE Interface
.. automodule:: torchani.ase .. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList .. autoclass:: torchani.ase.NeighborList
:members: :members:
.. autoclass:: torchani.ase.Calculator
:members:
Ignite Helpers Ignite Helpers
============== ==============
......
...@@ -65,7 +65,11 @@ class AEVComputer(torch.nn.Module): ...@@ -65,7 +65,11 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_. equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types. num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float) (species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species, -> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input distances and relative coordinates of neighbor atoms. The input
......
...@@ -9,6 +9,8 @@ import math ...@@ -9,6 +9,8 @@ import math
import torch import torch
import ase.neighborlist import ase.neighborlist
from . import utils from . import utils
import ase.calculators.calculator
import ase.units
class NeighborList: class NeighborList:
...@@ -80,3 +82,48 @@ class NeighborList: ...@@ -80,3 +82,48 @@ class NeighborList:
return neighbor_species.permute(0, 2, 1), \ return neighbor_species.permute(0, 2, 1), \
neighbor_distances.permute(0, 2, 1), \ neighbor_distances.permute(0, 2, 1), \
neighbor_vecs.permute(0, 2, 1, 3) neighbor_vecs.permute(0, 2, 1, 3)
class Calculator(ase.calculators.calculator.Calculator):
"""TorchANI calculator for ASE
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
aev_computer (:class:`torchani.AEVComputer`): AEV computer.
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
"""
def __init__(self, species, aev_computer, model, energy_shifter):
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
self.aev_computer = aev_computer
self.model = model
self.energy_shifter = energy_shifter
self.device = self.aev_computer.EtaR.device
self.dtype = self.aev_computer.EtaR.dtype
self.whole = torch.nn.Sequential(
self.aev_computer,
self.model,
self.energy_shifter
)
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
self.aev_computer.neighbor_list = NeighborList(
cell=self.atoms.get_cell(), pbc=self.atoms.get_pbc())
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
coordinates = self.atoms.get_positions(wrap=True).unsqueeze(0)
coordinates = torch.tensor(coordinates,
device=self.device,
dtype=self.dtype,
requires_grad=('forces' in properties))
_, energy = self.whole((species, coordinates)) * ase.units.Hartree
self.results['energy'] = energy.item()
if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0]
self.results['forces'] = forces.item()
...@@ -13,7 +13,7 @@ import math ...@@ -13,7 +13,7 @@ import math
import timeit import timeit
from collections.abc import Mapping from collections.abc import Mapping
from ..nn import ANIModel, Ensemble, Gaussian from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer from ..aev import AEVComputer
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
...@@ -21,6 +21,10 @@ from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric ...@@ -21,6 +21,10 @@ from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
class Constants(Mapping): class Constants(Mapping):
"""NeuroChem constants. Objects of this class can be used as arguments """NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``. to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
Attributes:
species_to_tensor (:class:`ChemicalSymbolsToInts`): call to convert
string chemical symbols to 1d long tensor.
""" """
def __init__(self, filename): def __init__(self, filename):
...@@ -45,10 +49,7 @@ class Constants(Mapping): ...@@ -45,10 +49,7 @@ class Constants(Mapping):
except Exception: except Exception:
raise ValueError('unable to parse const file') raise ValueError('unable to parse const file')
self.num_species = len(self.species) self.num_species = len(self.species)
self.rev_species = {} self.species_to_tensor = ChemicalSymbolsToInts(self.species)
for i in range(len(self.species)):
s = self.species[i]
self.rev_species[s] = i
def __iter__(self): def __iter__(self):
yield 'Rcr' yield 'Rcr'
...@@ -67,11 +68,6 @@ class Constants(Mapping): ...@@ -67,11 +68,6 @@ class Constants(Mapping):
def __getitem__(self, item): def __getitem__(self, item):
return getattr(self, item) return getattr(self, item)
def species_to_tensor(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
def load_sae(filename): def load_sae(filename):
"""Returns an object of :class:`EnergyShifter` with self energies from """Returns an object of :class:`EnergyShifter` with self energies from
......
...@@ -151,5 +151,25 @@ class EnergyShifter(torch.nn.Module): ...@@ -151,5 +151,25 @@ class EnergyShifter(torch.nn.Module):
return species, energies + sae return species, energies + sae
class ChemicalSymbolsToInts:
"""Helper that can be called to convert chemical symbol string to integers
Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
"""
def __init__(self, all_species):
self.rev_species = {}
for i in range(len(all_species)):
s = all_species[i]
self.rev_species[s] = i
def __call__(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
__all__ = ['pad', 'pad_coordinates', 'present_species', __all__ = ['pad', 'pad_coordinates', 'present_species',
'strip_redundant_padding'] 'strip_redundant_padding', 'ChemicalSymbolsToInts']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment