Commit 66c3743c authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Add helper module to convert species from element id in periodic table to 0,...

Add helper module to convert species from element id in periodic table to 0, 1, 2, 3, ... format (#396)

* Init

* fix test

* flake8

* try fix

* Fix stupidity of len(self)
parent eb89457c
......@@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.7]
test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_nn.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
import unittest
import torch
import torchani
class TestSpeciesConverter(unittest.TestCase):
def setUp(self):
self.c = torchani.SpeciesConverter(['H', 'C', 'N', 'O'])
def testSpeciesConverter(self):
input_ = torch.tensor([
[1, 6, 7, 8, -1],
[1, 1, -1, 8, 1],
], dtype=torch.long)
expect = torch.tensor([
[0, 1, 2, 3, -1],
[0, 0, -1, 3, 0],
], dtype=torch.long)
dummy_coordinates = torch.empty(2, 5, 3)
output = self.c((input_, dummy_coordinates)).species
self.assertTrue(torch.allclose(output, expect))
class TestSpeciesConverterJIT(TestSpeciesConverter):
def setUp(self):
super().setUp()
self.c = torch.jit.script(self.c)
if __name__ == '__main__':
unittest.main()
......@@ -42,7 +42,7 @@ def by_batch(species, coordinates, model):
energies = []
forces = []
for s, c in zip(species, coordinates):
_, e = model((s, c))
e = model((s, c)).energies
f, = torch.autograd.grad(e.sum(), c)
energies.append(e)
forces.append(f)
......
......@@ -24,14 +24,13 @@ formats of NeuroChem at :attr:`torchani.neurochem`, and more at :attr:`torchani.
"""
from .utils import EnergyShifter
from .nn import ANIModel, Ensemble
from .nn import ANIModel, Ensemble, SpeciesConverter
from .aev import AEVComputer
from . import utils
from . import neurochem
from . import models
from . import optim
from pkg_resources import get_distribution, DistributionNotFound
import sys
try:
__version__ = get_distribution(__name__).version
......@@ -39,7 +38,7 @@ except DistributionNotFound:
# package is not installed
pass
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'SpeciesConverter',
'utils', 'neurochem', 'models', 'optim']
try:
......@@ -48,10 +47,8 @@ try:
except ImportError:
pass
if sys.version_info[0] > 2:
try:
from . import data # noqa: F401
__all__.append('data')
except ImportError:
pass
try:
from . import data # noqa: F401
__all__.append('data')
except ImportError:
pass
......@@ -190,7 +190,7 @@ class ANI1x(BuiltinNet):
"""
def __init__(self):
super(ANI1x, self).__init__('ani-1x_8x.info')
super().__init__('ani-1x_8x.info')
class ANI1ccx(BuiltinNet):
......@@ -210,4 +210,4 @@ class ANI1ccx(BuiltinNet):
"""
def __init__(self):
super(ANI1ccx, self).__init__('ani-1ccx_8x.info')
super().__init__('ani-1ccx_8x.info')
......@@ -8,6 +8,11 @@ class SpeciesEnergies(NamedTuple):
energies: Tensor
class SpeciesCoordinates(NamedTuple):
species: Tensor
coordinates: Tensor
class ANIModel(torch.nn.ModuleList):
"""ANI model that compute energies from species and AEVs.
......@@ -26,9 +31,6 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`.
"""
def __init__(self, modules):
super(ANIModel, self).__init__(modules)
def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
......@@ -54,7 +56,7 @@ class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules."""
def __init__(self, modules):
super(Ensemble, self).__init__(modules)
super().__init__(modules)
self.size = len(modules)
def forward(self, species_input: Tuple[Tensor, Tensor],
......@@ -89,3 +91,31 @@ class Gaussian(torch.nn.Module):
"""Gaussian activation"""
def forward(self, x: Tensor) -> Tensor:
return torch.exp(- x * x)
class SpeciesConverter(torch.nn.Module):
"""Convert from element index in the periodic table to 0, 1, 2, 3, ..."""
periodic_table = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()
def __init__(self, species):
super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
maxidx = max(rev_idx.values())
self.conv_tensor = torch.full((maxidx + 2,), -1, dtype=torch.long)
for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i
def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
species, coordinates = input_
return SpeciesCoordinates(self.conv_tensor[species], coordinates)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment