Commit 124f239e authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Add example to customize JIT models (#401)

* add example to customize JIT models

* cleanup

* fix list

* Update jit.py
parent 9833dd63
......@@ -12,8 +12,13 @@ models from a Python process and loaded in a process where there is no Python de
# To begin with, let's first import the modules we will use:
import torch
import torchani
from typing import Tuple, Optional
from torch import Tensor
###############################################################################
# Scripting builtin model directly
# --------------------------------
#
# Let's now load the built-in ANI-1ccx models. The builtin ANI-1ccx contains 8
# models trained with diffrent initialization.
model = torchani.models.ANI1ccx(periodic_table_index=True)
......@@ -53,3 +58,57 @@ energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())
###############################################################################
# Customize the model and script
# ------------------------------
#
# You could also customize the model you want to export. For example, let's do
# the following customization to the model:
#
# - uses double as dtype instead of float
# - don't care about periodic boundary condition
# - in addition to energies, allow returnsing optionally forces, and hessians
# - when indexing atom species, use its index in the periodic table instead of 0, 1, 2, 3, ...
#
# you could do the following:
class CustomModule(torch.nn.Module):
def __init__(self):
super().__init__()
self.model = torchani.models.ANI1x(periodic_table_index=True).double()
# self.model = torchani.models.ANI1x(periodic_table_index=True)[0].double()
# self.model = torchani.models.ANI1ccx(periodic_table_index=True).double()
def forward(self, species: Tensor, coordinates: Tensor, return_forces: bool = False,
return_hessians: bool = False) -> Tuple[Tensor, Optional[Tensor], Optional[Tensor]]:
if return_forces or return_hessians:
coordinates.requires_grad_(True)
energies = self.model((species, coordinates)).energies
forces: Optional[Tensor] = None # noqa: E701
hessians: Optional[Tensor] = None
if return_forces or return_hessians:
grad = torch.autograd.grad([energies.sum()], [coordinates], create_graph=return_hessians)[0]
assert grad is not None
forces = -grad
if return_hessians:
hessians = torchani.utils.hessian(coordinates, forces=forces)
return energies, forces, hessians
custom_model = CustomModule()
compiled_custom_model = torch.jit.script(custom_model)
torch.jit.save(compiled_custom_model, 'compiled_custom_model.pt')
loaded_compiled_custom_model = torch.jit.load('compiled_custom_model.pt')
energies, forces, hessians = custom_model(species, coordinates, True, True)
energies_jit, forces_jit, hessians_jit = loaded_compiled_custom_model(species, coordinates, True, True)
print('Energy, eager mode vs loaded jit:', energies.item(), energies_jit.item())
print()
print('Force, eager mode vs loaded jit:\n', forces.squeeze(0), '\n', forces_jit.squeeze(0))
print()
torch.set_printoptions(sci_mode=False, linewidth=1000)
print('Hessian, eager mode vs loaded jit:\n', hessians.squeeze(0), '\n', hessians_jit.squeeze(0))
......@@ -32,7 +32,7 @@ from torch import Tensor
from typing import Tuple, Optional
from pkg_resources import resource_filename
from . import neurochem
from .nn import Sequential, SpeciesConverter
from .nn import Sequential, SpeciesConverter, SpeciesEnergies
from .aev import AEVComputer
......@@ -97,7 +97,7 @@ class BuiltinNet(torch.nn.Module):
def forward(self, species_coordinates: Tuple[Tensor, Tensor],
cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
pbc: Optional[Tensor] = None) -> SpeciesEnergies:
"""Calculates predicted properties for minibatch of configurations
Args:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment