Commit 493731b5 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Support periodic table indexing in builtin models (#399)

* Support  periodic table Indexing in builtin models

* flake8

* more

* fix

* fix cuda
parent 1055f1f5
......@@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.7]
test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_nn.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
......@@ -14,7 +14,7 @@ import torchani
###############################################################################
# Let's now manually specify the device we want TorchANI to run:
device = torch.device('cpu')
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
......@@ -22,7 +22,10 @@ device = torch.device('cpu')
# using the average of the 8 models outperform using a single model, so it is
# always recommended to use an ensemble, unless the speed of computation is an
# issue in your application.
model = torchani.models.ANI1ccx()
#
# The ``periodic_table_index`` arguments tells TorchANI to use element index
# in periodic table to index species.
model = torchani.models.ANI1ccx(periodic_table_index=True).to(device)
###############################################################################
# Now let's define the coordinate and species. If you just want to compute the
......@@ -40,7 +43,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True, device=device)
species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]], device=device)
###############################################################################
# Now let's compute energy and force:
......
......@@ -16,7 +16,7 @@ import torchani
###############################################################################
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization.
model = torchani.models.ANI1ccx()
model = torchani.models.ANI1ccx(periodic_table_index=True)
###############################################################################
# It is very easy to compile and save the model using `torch.jit`.
......@@ -42,7 +42,8 @@ coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]])
species = model.species_to_tensor('CHHHH').unsqueeze(0)
# In periodic table, C = 6 and H = 1
species = torch.tensor([[6, 1, 1, 1, 1]])
###############################################################################
# And here is the result:
......
......@@ -29,5 +29,36 @@ class TestSpeciesConverterJIT(TestSpeciesConverter):
self.c = torch.jit.script(self.c)
class TestBuiltinNetPeriodicTableIndex(unittest.TestCase):
def setUp(self):
self.model1 = torchani.models.ANI1x()
self.model2 = torchani.models.ANI1x(periodic_table_index=True)
self.coordinates = torch.tensor([[[0.03192167, 0.00638559, 0.01301679],
[-0.83140486, 0.39370209, -0.26395324],
[-0.66518241, -0.84461308, 0.20759389],
[0.45554739, 0.54289633, 0.81170881],
[0.66091919, -0.16799635, -0.91037834]]],
requires_grad=True)
self.species1 = self.model1.species_to_tensor('CHHHH').unsqueeze(0)
self.species2 = torch.tensor([[6, 1, 1, 1, 1]])
def testCH4Ensemble(self):
energy1 = self.model1((self.species1, self.coordinates)).energies
energy2 = self.model2((self.species2, self.coordinates)).energies
derivative1 = torch.autograd.grad(energy1.sum(), self.coordinates)[0]
derivative2 = torch.autograd.grad(energy2.sum(), self.coordinates)[0]
self.assertTrue(torch.allclose(energy1, energy2))
self.assertTrue(torch.allclose(derivative1, derivative2))
def testCH4Single(self):
energy1 = self.model1[0]((self.species1, self.coordinates)).energies
energy2 = self.model2[0]((self.species2, self.coordinates)).energies
derivative1 = torch.autograd.grad(energy1.sum(), self.coordinates)[0]
derivative2 = torch.autograd.grad(energy2.sum(), self.coordinates)[0]
self.assertTrue(torch.allclose(energy1, energy2))
self.assertTrue(torch.allclose(derivative1, derivative2))
if __name__ == '__main__':
unittest.main()
......@@ -32,7 +32,7 @@ from torch import Tensor
from typing import Tuple, Optional
from pkg_resources import resource_filename
from . import neurochem
from .nn import Sequential
from .nn import Sequential, SpeciesConverter
from .aev import AEVComputer
......@@ -61,10 +61,15 @@ class BuiltinNet(torch.nn.Module):
aev_computer (:class:`torchani.AEVComputer`): AEV computer with
builtin constants
neural_networks (:class:`torchani.Ensemble`): Ensemble of ANIModel networks
periodic_table_index (bool): Whether to use element number in periodic table
to index species. If set to `False`, then indices must be `0, 1, 2, ..., N - 1`
where `N` is the number of parametrized species.
"""
def __init__(self, info_file):
def __init__(self, info_file, periodic_table_index=False):
super(BuiltinNet, self).__init__()
self.periodic_table_index = periodic_table_index
package_name = '.'.join(__name__.split('.')[:-1])
info_file = 'resources/' + info_file
self.info_file = resource_filename(package_name, info_file)
......@@ -84,6 +89,7 @@ class BuiltinNet(torch.nn.Module):
self.consts = neurochem.Constants(self.const_file)
self.species = self.consts.species
self.species_converter = SpeciesConverter(self.species)
self.aev_computer = AEVComputer(**self.consts)
self.energy_shifter = neurochem.load_sae(self.sae_file)
self.neural_networks = neurochem.load_model_ensemble(
......@@ -105,6 +111,8 @@ class BuiltinNet(torch.nn.Module):
.. note:: The coordinates, and cell are in Angstrom, and the energies
will be in Hartree.
"""
if self.periodic_table_index:
species_coordinates = self.species_converter(species_coordinates)
species_aevs = self.aev_computer(species_coordinates, cell=cell, pbc=pbc)
species_energies = self.neural_networks(species_aevs)
return self.energy_shifter(species_energies)
......@@ -124,6 +132,14 @@ class BuiltinNet(torch.nn.Module):
ret: (:class:`Sequential`): Sequential model ready for
calculations
"""
if self.periodic_table_index:
ret = Sequential(
self.species_converter,
self.aev_computer,
self.neural_networks[index],
self.energy_shifter
)
else:
ret = Sequential(
self.aev_computer,
self.neural_networks[index],
......@@ -189,8 +205,8 @@ class ANI1x(BuiltinNet):
https://aip.scitation.org/doi/abs/10.1063/1.5023802
"""
def __init__(self):
super().__init__('ani-1x_8x.info')
def __init__(self, *args, **kwargs):
super().__init__('ani-1x_8x.info', *args, **kwargs)
class ANI1ccx(BuiltinNet):
......@@ -209,5 +225,5 @@ class ANI1ccx(BuiltinNet):
https://doi.org/10.26434/chemrxiv.6744440.v1
"""
def __init__(self):
super().__init__('ani-1ccx_8x.info')
def __init__(self, *args, **kwargs):
super().__init__('ani-1ccx_8x.info', *args, **kwargs)
......@@ -47,8 +47,6 @@ class ANIModel(torch.nn.ModuleDict):
def forward(self, species_aev: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
species, aev = species_aev
species_ = species.flatten()
aev = aev.flatten(0, 1)
......@@ -75,8 +73,6 @@ class Ensemble(torch.nn.ModuleList):
def forward(self, species_input: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
assert cell is None
assert pbc is None
sum_ = 0
for x in self:
sum_ += x(species_input)[1]
......@@ -95,8 +91,6 @@ class Sequential(torch.nn.ModuleList):
pbc: Optional[Tensor] = None):
for module in self:
input_ = module(input_, cell=cell, pbc=pbc)
cell = None
pbc = None
return input_
......@@ -123,7 +117,7 @@ class SpeciesConverter(torch.nn.Module):
super().__init__()
rev_idx = {s: k for k, s in enumerate(self.periodic_table, 1)}
maxidx = max(rev_idx.values())
self.conv_tensor = torch.full((maxidx + 2,), -1, dtype=torch.long)
self.register_buffer('conv_tensor', torch.full((maxidx + 2,), -1, dtype=torch.long))
for i, s in enumerate(species):
self.conv_tensor[rev_idx[s]] = i
......
......@@ -217,8 +217,6 @@ class EnergyShifter(torch.nn.Module):
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae)
"""
assert cell is None
assert pbc is None
species, energies = species_energies
sae = self.sae(species).to(energies.device)
return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment