Commit 0b3e26ee authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Support len in ChemicalSymbolsToInts, len of it returns number of supported elements (#286)

* Support len in ChemicalSymbolsToInts

* fix

* flake8
parent b2da3199
import unittest
import torchani
class TestUtils(unittest.TestCase):
def testChemicalSymbolsToInts(self):
str2i = torchani.utils.ChemicalSymbolsToInts('ABCDEF')
self.assertEqual(len(str2i), 6)
self.assertListEqual(str2i('BACCC').tolist(), [1, 0, 2, 2, 2])
if __name__ == '__main__':
unittest.main()
......@@ -226,15 +226,16 @@ class ChemicalSymbolsToInts:
"""
def __init__(self, all_species):
self.rev_species = {}
for i, s in enumerate(all_species):
self.rev_species[s] = i
self.rev_species = {s: i for i, s in enumerate(all_species)}
def __call__(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long)
def __len__(self):
return len(self.rev_species)
def hessian(coordinates, energies=None, forces=None):
"""Compute analytical hessian from the energy graph or force graph.
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment