"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "c0f058265161178f2a88849e92b37ffdc81f1dcc"
Commit 66c3743c authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Add helper module to convert species from element id in periodic table to 0,...

Add helper module to convert species from element id in periodic table to 0, 1, 2, 3, ... format (#396)

* Init

* fix test

* flake8

* try fix

* Fix stupidity of len(self)
parent eb89457c
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.7] python-version: [3.6, 3.7]
test-filenames: [ test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py, test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_nn.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py, test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py] test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
import unittest
import torch
import torchani
class TestSpeciesConverter(unittest.TestCase):
def setUp(self):
self.c = torchani.SpeciesConverter(['H', 'C', 'N', 'O'])
def testSpeciesConverter(self):
input_ = torch.tensor([
[1, 6, 7, 8, -1],
[1, 1, -1, 8, 1],
], dtype=torch.long)
expect = torch.tensor([
[0, 1, 2, 3, -1],
[0, 0, -1, 3, 0],
], dtype=torch.long)
dummy_coordinates = torch.empty(2, 5, 3)
output = self.c((input_, dummy_coordinates)).species
self.assertTrue(torch.allclose(output, expect))
class TestSpeciesConverterJIT(TestSpeciesConverter):
def setUp(self):
super().setUp()
self.c = torch.jit.script(self.c)
if __name__ == '__main__':
unittest.main()
...@@ -42,7 +42,7 @@ def by_batch(species, coordinates, model): ...@@ -42,7 +42,7 @@ def by_batch(species, coordinates, model):
energies = [] energies = []
forces = [] forces = []
for s, c in zip(species, coordinates): for s, c in zip(species, coordinates):
_, e = model((s, c)) e = model((s, c)).energies
f, = torch.autograd.grad(e.sum(), c) f, = torch.autograd.grad(e.sum(), c)
energies.append(e) energies.append(e)
forces.append(f) forces.append(f)
......
...@@ -24,14 +24,13 @@ formats of NeuroChem at :attr:`torchani.neurochem`, and more at :attr:`torchani. ...@@ -24,14 +24,13 @@ formats of NeuroChem at :attr:`torchani.neurochem`, and more at :attr:`torchani.
""" """
from .utils import EnergyShifter from .utils import EnergyShifter
from .nn import ANIModel, Ensemble from .nn import ANIModel, Ensemble, SpeciesConverter
from .aev import AEVComputer from .aev import AEVComputer
from . import utils from . import utils
from . import neurochem from . import neurochem
from . import models from . import models
from . import optim from . import optim
from pkg_resources import get_distribution, DistributionNotFound from pkg_resources import get_distribution, DistributionNotFound
import sys
try: try:
__version__ = get_distribution(__name__).version __version__ = get_distribution(__name__).version
...@@ -39,7 +38,7 @@ except DistributionNotFound: ...@@ -39,7 +38,7 @@ except DistributionNotFound:
# package is not installed # package is not installed
pass pass
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', __all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'SpeciesConverter',
'utils', 'neurochem', 'models', 'optim'] 'utils', 'neurochem', 'models', 'optim']
try: try:
...@@ -48,10 +47,8 @@ try: ...@@ -48,10 +47,8 @@ try:
except ImportError: except ImportError:
pass pass
try:
if sys.version_info[0] > 2: from . import data # noqa: F401
try: __all__.append('data')
from . import data # noqa: F401 except ImportError:
__all__.append('data') pass
except ImportError:
pass
...@@ -190,7 +190,7 @@ class ANI1x(BuiltinNet): ...@@ -190,7 +190,7 @@ class ANI1x(BuiltinNet):
""" """
def __init__(self): def __init__(self):
super(ANI1x, self).__init__('ani-1x_8x.info') super().__init__('ani-1x_8x.info')
class ANI1ccx(BuiltinNet): class ANI1ccx(BuiltinNet):
...@@ -210,4 +210,4 @@ class ANI1ccx(BuiltinNet): ...@@ -210,4 +210,4 @@ class ANI1ccx(BuiltinNet):
""" """
def __init__(self): def __init__(self):
super(ANI1ccx, self).__init__('ani-1ccx_8x.info') super().__init__('ani-1ccx_8x.info')
...@@ -8,6 +8,11 @@ class SpeciesEnergies(NamedTuple): ...@@ -8,6 +8,11 @@ class SpeciesEnergies(NamedTuple):
energies: Tensor energies: Tensor
class SpeciesCoordinates(NamedTuple):
species: Tensor
coordinates: Tensor
class ANIModel(torch.nn.ModuleList): class ANIModel(torch.nn.ModuleList):
"""ANI model that compute energies from species and AEVs. """ANI model that compute energies from species and AEVs.
...@@ -26,9 +31,6 @@ class ANIModel(torch.nn.ModuleList): ...@@ -26,9 +31,6 @@ class ANIModel(torch.nn.ModuleList):
module by putting the same reference in :attr:`modules`. module by putting the same reference in :attr:`modules`.
""" """
def __init__(self, modules):
super(ANIModel, self).__init__(modules)
def forward(self, species_aev: Tuple[Tensor, Tensor], def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
...@@ -54,7 +56,7 @@ class Ensemble(torch.nn.ModuleList): ...@@ -54,7 +56,7 @@ class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemble of modules.""" """Compute the average output of an ensemble of modules."""
def __init__(self, modules): def __init__(self, modules):
super(Ensemble, self).__init__(modules) super().__init__(modules)
self.size = len(modules) self.size = len(modules)
def forward(self, species_input: Tuple[Tensor, Tensor], def forward(self, species_input: Tuple[Tensor, Tensor],
...@@ -89,3 +91,31 @@ class Gaussian(torch.nn.Module): ...@@ -89,3 +91,31 @@ class Gaussian(torch.nn.Module):
"""Gaussian activation""" """Gaussian activation"""
def forward(self, x: Tensor) -> Tensor: def forward(self, x: Tensor) -> Tensor:
return torch.exp(- x * x) return torch.exp(- x * x)
class SpeciesConverter(torch.nn.Module):
"""Convert from element index in the periodic table to 0, 1, 2, 3, ..."""
periodic_table = """
H He
Li Be B C N O F Ne
Na Mg Al Si P S Cl Ar
K Ca Sc Ti V Cr Mn Fe Co Ni Cu Zn Ga Ge As Se Br Kr
Rb Sr Y Zr Nb Mo Tc Ru Rh Pd Ag Cd In Sn Sb Te I Xe
Cs Ba La Ce Pr Nd Pm Sm Eu Gd Tb Dy Ho Er Tm Yb Lu Hf Ta W Re Os Ir Pt Au Hg Tl Pb Bi Po At Rn
Fr Ra Ac Th Pa U Np Pu Am Cm Bk Cf Es Fm Md No Lr Rf Db Sg Bh Hs Mt Ds Rg Cn Nh Fl Mc Lv Ts Og
""".strip().split()
def __init__(self, species):
super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
maxidx = max(rev_idx.values())
self.conv_tensor = torch.full((maxidx + 2,), -1, dtype=torch.long)
for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i
def forward(self, input_: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None):
species, coordinates = input_
return SpeciesCoordinates(self.conv_tensor[species], coordinates)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment