Unverified Commit ef834586 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

fix jit (#596)

parent a85e3305
......@@ -3,7 +3,7 @@ from torch import Tensor
import torch.utils.data
import math
from collections import defaultdict
from typing import Tuple, NamedTuple, Optional
from typing import Tuple, NamedTuple, Optional, Sequence, List, Dict
from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2fconst
from .nn import SpeciesEnergies
......@@ -190,51 +190,45 @@ class EnergyShifter(torch.nn.Module):
return SpeciesEnergies(species, energies + sae)
class ChemicalSymbolsToInts:
class ChemicalSymbolsToInts(torch.nn.Module):
r"""Helper that can be called to convert chemical symbol string to integers
On initialization the class should be supplied with a :class:`list` (or in
general :class:`collections.abc.Sequence`) of :class:`str`. The returned
instance is a callable object, which can be called with an arbitrary list
of the supported species that is converted into a tensor of dtype
:class:`torch.long`. Usage example:
.. code-block:: python
from torchani.utils import ChemicalSymbolsToInts
# We initialize ChemicalSymbolsToInts with the supported species
species_to_tensor = ChemicalSymbolsToInts(['H', 'C', 'Fe', 'Cl'])
# We have a species list which we want to convert to an index tensor
index_tensor = species_to_tensor(['H', 'C', 'H', 'H', 'C', 'Cl', 'Fe'])
# index_tensor is now [0 1 0 0 1 3 2]
.. warning::
If the input is a string python will iterate over
characters, this means that a string such as 'CHClFe' will be
intepreted as 'C' 'H' 'C' 'l' 'F' 'e'. It is recommended that you
input either a :class:`list` or a :class:`numpy.ndarray` ['C', 'H', 'Cl', 'Fe'],
and not a string. The output of a call does NOT correspond to a
tensor of atomic numbers.
Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order (it is recommended to order
according to atomic number).
"""
_dummy: Tensor
rev_species: Dict[str, int]
def __init__(self, all_species):
def __init__(self, all_species: Sequence[str]):
super().__init__()
self.rev_species = {s: i for i, s in enumerate(all_species)}
# dummy tensor to hold output device
self.register_buffer('_dummy', torch.empty(0), persistent=False)
def __call__(self, species):
def forward(self, species: List[str]) -> Tensor:
r"""Convert species from sequence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
return torch.tensor(rev, dtype=torch.long, device=self._dummy.device)
def __len__(self):
return len(self.rev_species)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment