Unverified Commit ef834586 authored by Jinze (Richard) Xue's avatar Jinze (Richard) Xue Committed by GitHub
Browse files

fix jit (#596)

parent a85e3305
...@@ -3,7 +3,7 @@ from torch import Tensor ...@@ -3,7 +3,7 @@ from torch import Tensor
import torch.utils.data import torch.utils.data
import math import math
from collections import defaultdict from collections import defaultdict
from typing import Tuple, NamedTuple, Optional from typing import Tuple, NamedTuple, Optional, Sequence, List, Dict
from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2fconst from torchani.units import sqrt_mhessian2invcm, sqrt_mhessian2milliev, mhessian2fconst
from .nn import SpeciesEnergies from .nn import SpeciesEnergies
...@@ -190,51 +190,45 @@ class EnergyShifter(torch.nn.Module): ...@@ -190,51 +190,45 @@ class EnergyShifter(torch.nn.Module):
return SpeciesEnergies(species, energies + sae) return SpeciesEnergies(species, energies + sae)
class ChemicalSymbolsToInts: class ChemicalSymbolsToInts(torch.nn.Module):
r"""Helper that can be called to convert chemical symbol string to integers r"""Helper that can be called to convert chemical symbol string to integers
On initialization the class should be supplied with a :class:`list` (or in On initialization the class should be supplied with a :class:`list` (or in
general :class:`collections.abc.Sequence`) of :class:`str`. The returned general :class:`collections.abc.Sequence`) of :class:`str`. The returned
instance is a callable object, which can be called with an arbitrary list instance is a callable object, which can be called with an arbitrary list
of the supported species that is converted into a tensor of dtype of the supported species that is converted into a tensor of dtype
:class:`torch.long`. Usage example: :class:`torch.long`. Usage example:
.. code-block:: python .. code-block:: python
from torchani.utils import ChemicalSymbolsToInts from torchani.utils import ChemicalSymbolsToInts
# We initialize ChemicalSymbolsToInts with the supported species # We initialize ChemicalSymbolsToInts with the supported species
species_to_tensor = ChemicalSymbolsToInts(['H', 'C', 'Fe', 'Cl']) species_to_tensor = ChemicalSymbolsToInts(['H', 'C', 'Fe', 'Cl'])
# We have a species list which we want to convert to an index tensor # We have a species list which we want to convert to an index tensor
index_tensor = species_to_tensor(['H', 'C', 'H', 'H', 'C', 'Cl', 'Fe']) index_tensor = species_to_tensor(['H', 'C', 'H', 'H', 'C', 'Cl', 'Fe'])
# index_tensor is now [0 1 0 0 1 3 2] # index_tensor is now [0 1 0 0 1 3 2]
.. warning:: .. warning::
If the input is a string python will iterate over If the input is a string python will iterate over
characters, this means that a string such as 'CHClFe' will be characters, this means that a string such as 'CHClFe' will be
intepreted as 'C' 'H' 'C' 'l' 'F' 'e'. It is recommended that you intepreted as 'C' 'H' 'C' 'l' 'F' 'e'. It is recommended that you
input either a :class:`list` or a :class:`numpy.ndarray` ['C', 'H', 'Cl', 'Fe'], input either a :class:`list` or a :class:`numpy.ndarray` ['C', 'H', 'Cl', 'Fe'],
and not a string. The output of a call does NOT correspond to a and not a string. The output of a call does NOT correspond to a
tensor of atomic numbers. tensor of atomic numbers.
Arguments: Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`): all_species (:class:`collections.abc.Sequence` of :class:`str`):
sequence of all supported species, in order (it is recommended to order sequence of all supported species, in order (it is recommended to order
according to atomic number). according to atomic number).
""" """
_dummy: Tensor
rev_species: Dict[str, int]
def __init__(self, all_species): def __init__(self, all_species: Sequence[str]):
super().__init__()
self.rev_species = {s: i for i, s in enumerate(all_species)} self.rev_species = {s: i for i, s in enumerate(all_species)}
# dummy tensor to hold output device
self.register_buffer('_dummy', torch.empty(0), persistent=False)
def __call__(self, species): def forward(self, species: List[str]) -> Tensor:
r"""Convert species from sequence of strings to 1D tensor""" r"""Convert species from sequence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species] rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long) return torch.tensor(rev, dtype=torch.long, device=self._dummy.device)
def __len__(self): def __len__(self):
return len(self.rev_species) return len(self.rev_species)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment