Unverified Commit d6511699 authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

ChemicalSymbolsToInts is being used wrong (#425)

* Add warning to ChemicalSymbolsToInts

* Change ocurrences of ChemicalSymbolsToInts in the code to reflect safe usage

* flake8

* Add warning to call also

* Add clarification

* docs dont fetch magic functions so documentation is moved to the class docstring

* fix clarification

* flake8
parent 46f05aeb
......@@ -61,7 +61,7 @@ ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00]
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
###############################################################################
# Now let's setup datasets. These paths assumes the user run this script under
......
......@@ -38,7 +38,7 @@ ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00]
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
try:
......
......@@ -6,7 +6,7 @@ import torchani
class TestUtils(unittest.TestCase):
def testChemicalSymbolsToInts(self):
str2i = torchani.utils.ChemicalSymbolsToInts('ABCDEF')
str2i = torchani.utils.ChemicalSymbolsToInts(['A', 'B', 'C', 'D', 'E', 'F'])
self.assertEqual(len(str2i), 6)
self.assertListEqual(str2i('BACCC').tolist(), [1, 0, 2, 2, 2])
......
......@@ -104,7 +104,7 @@ if __name__ == "__main__":
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
......
......@@ -112,7 +112,7 @@ if __name__ == "__main__":
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
......
......@@ -197,7 +197,36 @@ class EnergyShifter(torch.nn.Module):
class ChemicalSymbolsToInts:
"""Helper that can be called to convert chemical symbol string to integers
r"""Helper that can be called to convert chemical symbol string to integers
On initialization the class should be supplied with a :class:`list` (or in
general :class:`collections.abc.Sequence`) of :class:`str`. The returned
instance is a callable object, which can be called with an arbitrary list
of the supported species that is converted into a tensor of dtype
:class:`torch.long`. Usage example:
.. code-block:: python
from torchani.utils import ChemicalSymbolsToInts
# We initialize ChemicalSymbolsToInts with the supported species
species_to_tensor = ChemicalSymbolsToInts(['H', 'C', 'Fe', 'Cl'])
# We have a species list which we want to convert to an index tensor
index_tensor = species_to_tensor(['H', 'C', 'H', 'H', 'C', 'Cl', 'Fe'])
# index_tensor is now [0 1 0 0 1 3 2]
.. warning::
If the input is a string python will iterate over
characters, this means that a string such as 'CHClFe' will be
intepreted as 'C' 'H' 'C' 'l' 'F' 'e'. It is recommended that you
input either a :class:`list` or a :class:`numpy.ndarray` ['C', 'H', 'Cl', 'Fe'],
and not a string. The output of a call does NOT correspond to a
tensor of atomic numbers.
Arguments:
all_species (:class:`collections.abc.Sequence` of :class:`str`):
......@@ -208,7 +237,7 @@ class ChemicalSymbolsToInts:
self.rev_species = {s: i for i, s in enumerate(all_species)}
def __call__(self, species):
"""Convert species from squence of strings to 1D tensor"""
r"""Convert species from sequence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment