Commit 493731b5 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Support periodic table indexing in builtin models (#399)

* Support  periodic table Indexing in builtin models

* flake8

* more

* fix

* fix cuda
parent 1055f1f5
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.7] python-version: [3.6, 3.7]
test-filenames: [ test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py, test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_nn.py, test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py, test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py] test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
...@@ -14,7 +14,7 @@ import torchani ...@@ -14,7 +14,7 @@ import torchani
############################################################################### ###############################################################################
# Let's now manually specify the device we want TorchANI to run: # Let's now manually specify the device we want TorchANI to run:
device = torch.device('cpu') device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
############################################################################### ###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8 # Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
...@@ -22,7 +22,10 @@ device = torch.device('cpu') ...@@ -22,7 +22,10 @@ device = torch.device('cpu')
# using the average of the 8 models outperform using a single model, so it is # using the average of the 8 models outperform using a single model, so it is
# always recommended to use an ensemble, unless the speed of computation is an # always recommended to use an ensemble, unless the speed of computation is an
# issue in your application. # issue in your application.
model = torchani.models.ANI1ccx() #
# The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species.
model = torchani.models.ANI1ccx(periodic_table_index=True).to(device)
############################################################################### ###############################################################################
# Now let's define the coordinate and species. If you just want to compute the # Now let's define the coordinate and species. If you just want to compute the
...@@ -40,7 +43,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -40,7 +43,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881], [0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]], [0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True, device=device) requires_grad=True, device=device)
species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0) # In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]], device=device)
############################################################################### ###############################################################################
# Now let's compute energy and force: # Now let's compute energy and force:
......
...@@ -16,7 +16,7 @@ import torchani ...@@ -16,7 +16,7 @@ import torchani
############################################################################### ###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8 # Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization. # models trained with diffrent initialization.
model = torchani.models.ANI1ccx() model = torchani.models.ANI1ccx(periodic_table_index=True)
############################################################################### ###############################################################################
# It is very easy to compile and save the model using `torch.jit`. # It is very easy to compile and save the model using `torch.jit`.
...@@ -42,7 +42,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679], ...@@ -42,7 +42,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.66518241, -0.84461308, 0.20759389], [-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881], [0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]]) [0.66091919, -0.16799635, -0.91037834]]])
species = model.species_to_tensor('CHHHH').unsqueeze(0) # In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
############################################################################### ###############################################################################
# And here is the result: # And here is the result:
......
...@@ -29,5 +29,36 @@ class TestSpeciesConverterJIT(TestSpeciesConverter): ...@@ -29,5 +29,36 @@ class TestSpeciesConverterJIT(TestSpeciesConverter):
self.c = torch.jit.script(self.c) self.c = torch.jit.script(self.c)
class TestBuiltinNetPeriodicTableIndex(unittest.TestCase):
def setUp(self):
self.model1 = torchani.models.ANI1x()
self.model2 = torchani.models.ANI1x(periodic_table_index=True)
self.coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True)
self.species1 = self.model1.species_to_tensor('CHHHH').unsqueeze(0)
self.species2 = torch.tensor([[6, 1, 1, 1, 1]])
def testCH4Ensemble(self):
energy1 = self.model1((self.species1, self.coordinates)).energies
energy2 = self.model2((self.species2, self.coordinates)).energies
derivative1 = torch.autograd.grad(energy1.sum(), self.coordinates)[0]
derivative2 = torch.autograd.grad(energy2.sum(), self.coordinates)[0]
self.assertTrue(torch.allclose(energy1, energy2))
self.assertTrue(torch.allclose(derivative1, derivative2))
def testCH4Single(self):
energy1 = self.model1[0]((self.species1, self.coordinates)).energies
energy2 = self.model2[0]((self.species2, self.coordinates)).energies
derivative1 = torch.autograd.grad(energy1.sum(), self.coordinates)[0]
derivative2 = torch.autograd.grad(energy2.sum(), self.coordinates)[0]
self.assertTrue(torch.allclose(energy1, energy2))
self.assertTrue(torch.allclose(derivative1, derivative2))
if __name__ == '__main__': if __name__ == '__main__':
unittest.main() unittest.main()
...@@ -32,7 +32,7 @@ from torch import Tensor ...@@ -32,7 +32,7 @@ from torch import Tensor
from typing import Tuple, Optional from typing import Tuple, Optional
from pkg_resources import resource_filename from pkg_resources import resource_filename
from . import neurochem from . import neurochem
from .nn import Sequential from .nn import Sequential, SpeciesConverter
from .aev import AEVComputer from .aev import AEVComputer
...@@ -61,10 +61,15 @@ class BuiltinNet(torch.nn.Module): ...@@ -61,10 +61,15 @@ class BuiltinNet(torch.nn.Module):
aev_computer (:class:`torchani.AEVComputer`): AEV computer with aev_computer (:class:`torchani.AEVComputer`): AEV computer with
builtin constants builtin constants
neural_networks (:class:`torchani.Ensemble`): Ensemble of ANIModel networks neural_networks (:class:`torchani.Ensemble`): Ensemble of ANIModel networks
periodic_table_index (bool): Whether to use element number in periodic table
to index species. If set to `False`, then indices must be `0, 1, 2, ..., N - 1`
where `N` is the number of parametrized species.
""" """
def __init__(self, info_file): def __init__(self, info_file, periodic_table_index=False):
super(BuiltinNet, self).__init__() super(BuiltinNet, self).__init__()
self.periodic_table_index = periodic_table_index
package_name = '.'.join(__name__.split('.')[:-1]) package_name = '.'.join(__name__.split('.')[:-1])
info_file = 'resources/' + info_file info_file = 'resources/' + info_file
self.info_file = resource_filename(package_name, info_file) self.info_file = resource_filename(package_name, info_file)
...@@ -84,6 +89,7 @@ class BuiltinNet(torch.nn.Module): ...@@ -84,6 +89,7 @@ class BuiltinNet(torch.nn.Module):
self.consts = neurochem.Constants(self.const_file) self.consts = neurochem.Constants(self.const_file)
self.species = self.consts.species self.species = self.consts.species
self.species_converter = SpeciesConverter(self.species)
self.aev_computer = AEVComputer(**self.consts) self.aev_computer = AEVComputer(**self.consts)
self.energy_shifter = neurochem.load_sae(self.sae_file) self.energy_shifter = neurochem.load_sae(self.sae_file)
self.neural_networks = neurochem.load_model_ensemble( self.neural_networks = neurochem.load_model_ensemble(
...@@ -105,6 +111,8 @@ class BuiltinNet(torch.nn.Module): ...@@ -105,6 +111,8 @@ class BuiltinNet(torch.nn.Module):
.. note:: The coordinates, and cell are in Angstrom, and the energies .. note:: The coordinates, and cell are in Angstrom, and the energies
will be in Hartree. will be in Hartree.
""" """
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc) species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs) species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies) return self.energy_shifter(species_energies)
...@@ -124,11 +132,19 @@ class BuiltinNet(torch.nn.Module): ...@@ -124,11 +132,19 @@ class BuiltinNet(torch.nn.Module):
ret: (:class:`Sequential`): Sequential model ready for ret: (:class:`Sequential`): Sequential model ready for
calculations calculations
""" """
ret = Sequential( if self.periodic_table_index:
self.aev_computer, ret = Sequential(
self.neural_networks[index], self.species_converter,
self.energy_shifter self.aev_computer,
) self.neural_networks[index],
self.energy_shifter
)
else:
ret = Sequential(
self.aev_computer,
self.neural_networks[index],
self.energy_shifter
)
def ase(**kwargs): def ase(**kwargs):
"""Attach an ase calculator """ """Attach an ase calculator """
...@@ -189,8 +205,8 @@ class ANI1x(BuiltinNet): ...@@ -189,8 +205,8 @@ class ANI1x(BuiltinNet):
https://aip.scitation.org/doi/abs/10.1063/1.5023802 https://aip.scitation.org/doi/abs/10.1063/1.5023802
""" """
def __init__(self): def __init__(self, *args, **kwargs):
super().__init__('ani-1x_8x.info') super().__init__('ani-1x_8x.info', *args, **kwargs)
class ANI1ccx(BuiltinNet): class ANI1ccx(BuiltinNet):
...@@ -209,5 +225,5 @@ class ANI1ccx(BuiltinNet): ...@@ -209,5 +225,5 @@ class ANI1ccx(BuiltinNet):
https://doi.org/10.26434/chemrxiv.6744440.v1 https://doi.org/10.26434/chemrxiv.6744440.v1
""" """
def __init__(self): def __init__(self, *args, **kwargs):
super().__init__('ani-1ccx_8x.info') super().__init__('ani-1ccx_8x.info', *args, **kwargs)
...@@ -47,8 +47,6 @@ class ANIModel(torch.nn.ModuleDict): ...@@ -47,8 +47,6 @@ class ANIModel(torch.nn.ModuleDict):
def forward(self, species_aev: Tuple[Tensor, Tensor], def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
...@@ -75,8 +73,6 @@ class Ensemble(torch.nn.ModuleList): ...@@ -75,8 +73,6 @@ class Ensemble(torch.nn.ModuleList):
def forward(self, species_input: Tuple[Tensor, Tensor], def forward(self, species_input: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None, cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
sum_ = 0 sum_ = 0
for x in self: for x in self:
sum_ += x(species_input)[1] sum_ += x(species_input)[1]
...@@ -95,8 +91,6 @@ class Sequential(torch.nn.ModuleList): ...@@ -95,8 +91,6 @@ class Sequential(torch.nn.ModuleList):
pbc: Optional[Tensor] = None): pbc: Optional[Tensor] = None):
for module in self: for module in self:
input_ = module(input_, cell=cell, pbc=pbc) input_ = module(input_, cell=cell, pbc=pbc)
cell = None
pbc = None
return input_ return input_
...@@ -123,7 +117,7 @@ class SpeciesConverter(torch.nn.Module): ...@@ -123,7 +117,7 @@ class SpeciesConverter(torch.nn.Module):
super().__init__() super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)} rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
maxidx = max(rev_idx.values()) maxidx = max(rev_idx.values())
self.conv_tensor = torch.full((maxidx + 2,), -1, dtype=torch.long) self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long))
for i, s in enumerate(species): for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i self.conv_tensor[rev_idx[s]] = i
......
...@@ -217,8 +217,6 @@ class EnergyShifter(torch.nn.Module): ...@@ -217,8 +217,6 @@ class EnergyShifter(torch.nn.Module):
pbc: Optional[Tensor] = None) -> SpeciesEnergies: pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
""" """
assert cell is None
assert pbc is None
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.device) sae = self.sae(species).to(energies.device)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae) return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment