Unverified Commit f9db30c3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

ASE calculator (#121)

parent 84fc8d80
......@@ -27,6 +27,8 @@ Utilities
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autoclass:: torchani.utils.ChemicalSymbolsToInts
:members:
NeuroChem
......@@ -51,6 +53,8 @@ ASE Interface
.. automodule:: torchani.ase
.. autoclass:: torchani.ase.NeighborList
:members:
.. autoclass:: torchani.ase.Calculator
:members:
Ignite Helpers
==============
......
......@@ -65,7 +65,11 @@ class AEVComputer(torch.nn.Module):
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
neighborlist_computer (:class:`collections.abc.Callable`): The callable
neighborlist_computer (:class:`collections.abc.Callable`): initial
value of :attr:`neighborlist`
Attributes:
neighborlist (:class:`collections.abc.Callable`): The callable
(species:Tensor, coordinates:Tensor, cutoff:float)
-> Tuple[Tensor, Tensor, Tensor] that returns the species,
distances and relative coordinates of neighbor atoms. The input
......
......@@ -9,6 +9,8 @@ import math
import torch
import ase.neighborlist
from . import utils
import ase.calculators.calculator
import ase.units
class NeighborList:
......@@ -80,3 +82,48 @@ class NeighborList:
return neighbor_species.permute(0, 2, 1), \
neighbor_distances.permute(0, 2, 1), \
neighbor_vecs.permute(0, 2, 1, 3)
class Calculator(ase.calculators.calculator.Calculator):
"""TorchANI calculator for ASE
Arguments:
species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
aev_computer (:class:`torchani.AEVComputer`): AEV computer.
model (:class:`torchani.ANIModel` or :class:`torchani.Ensemble`):
neural network potential models.
energy_shifter (:class:`torchani.EnergyShifter`): Energy shifter.
"""
def __init__(self, species, aev_computer, model, energy_shifter):
self.species_to_tensor = utils.ChemicalSymbolsToInts(species)
self.aev_computer = aev_computer
self.model = model
self.energy_shifter = energy_shifter
self.device = self.aev_computer.EtaR.device
self.dtype = self.aev_computer.EtaR.dtype
self.whole = torch.nn.Sequential(
self.aev_computer,
self.model,
self.energy_shifter
)
def calculate(self, atoms=None, properties=['energy'],
system_changes=ase.calculators.calculator.all_changes):
super(Calculator, self).calculate(atoms, properties, system_changes)
self.aev_computer.neighbor_list = NeighborList(
cell=self.atoms.get_cell(), pbc=self.atoms.get_pbc())
species = self.species_to_tensor(self.atoms.get_chemical_symbols())
coordinates = self.atoms.get_positions(wrap=True).unsqueeze(0)
coordinates = torch.tensor(coordinates,
device=self.device,
dtype=self.dtype,
requires_grad=('forces' in properties))
_, energy = self.whole((species, coordinates)) * ase.units.Hartree
self.results['energy'] = energy.item()
if 'forces' in properties:
forces = -torch.autograd.grad(energy.squeeze(), coordinates)[0]
self.results['forces'] = forces.item()
......@@ -13,7 +13,7 @@ import math
import timeit
from collections.abc import Mapping
from ..nn import ANIModel, Ensemble, Gaussian
from ..utils import EnergyShifter
from ..utils import EnergyShifter, ChemicalSymbolsToInts
from ..aev import AEVComputer
from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
......@@ -21,6 +21,10 @@ from ..ignite import Container, MSELoss, TransformedLoss, RMSEMetric, MAEMetric
class Constants(Mapping):
"""NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
Attributes:
species_to_tensor (:class:`ChemicalSymbolsToInts`): call to convert
string chemical symbols to 1d long tensor.
"""
def __init__(self, filename):
......@@ -45,10 +49,7 @@ class Constants(Mapping):
except Exception:
raise ValueError('unable to parse const file')
self.num_species = len(self.species)
self.rev_species = {}
for i in range(len(self.species)):
s = self.species[i]
self.rev_species[s] = i
self.species_to_tensor = ChemicalSymbolsToInts(self.species)
def __iter__(self):
yield 'Rcr'
......@@ -67,11 +68,6 @@ class Constants(Mapping):
def __getitem__(self, item):
return getattr(self, item)
def species_to_tensor(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
def load_sae(filename):
"""Returns an object of :class:`EnergyShifter` with self energies from
......
......@@ -151,5 +151,25 @@ class EnergyShifter(torch.nn.Module):
return species, energies + sae
class ChemicalSymbolsToInts:
"""Helper that can be called to convert chemical symbol string to integers
Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order.
"""
def __init__(self, all_species):
self.rev_species = {}
for i in range(len(all_species)):
s = all_species[i]
self.rev_species[s] = i
def __call__(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
__all__ = ['pad', 'pad_coordinates', 'present_species',
'strip_redundant_padding']
'strip_redundant_padding', 'ChemicalSymbolsToInts']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment