Commit 004f5a52 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Use namedtuple to improve API while still maintaining backward compatibility (#380)

* Use namedtuple to improve API

* improve
parent 92c307dc
...@@ -16,7 +16,6 @@ calculator. ...@@ -16,7 +16,6 @@ calculator.
############################################################################### ###############################################################################
# To begin with, let's first import the modules we will use: # To begin with, let's first import the modules we will use:
from __future__ import print_function
from ase.lattice.cubic import Diamond from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin from ase.md.langevin import Langevin
from ase.optimize import BFGS from ase.optimize import BFGS
......
...@@ -9,7 +9,6 @@ TorchANI and can be used directly. ...@@ -9,7 +9,6 @@ TorchANI and can be used directly.
############################################################################### ###############################################################################
# To begin with, let's first import the modules we will use: # To begin with, let's first import the modules we will use:
from __future__ import print_function
import torch import torch
import torchani import torchani
...@@ -43,7 +42,7 @@ species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0) ...@@ -43,7 +42,7 @@ species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
############################################################################### ###############################################################################
# Now let's compute energy and force: # Now let's compute energy and force:
_, energy = model((species, coordinates)) energy = model((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0] derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative force = -derivative
......
...@@ -46,9 +46,9 @@ species = model.species_to_tensor('CHHHH').unsqueeze(0) ...@@ -46,9 +46,9 @@ species = model.species_to_tensor('CHHHH').unsqueeze(0)
############################################################################### ###############################################################################
# And here is the result: # And here is the result:
_, energies_ensemble = model((species, coordinates)) energies_ensemble = model((species, coordinates)).energies
_, energies_single = model[0]((species, coordinates)) energies_single = model[0]((species, coordinates)).energies
_, energies_ensemble_jit = loaded_compiled_model((species, coordinates)) energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
_, energies_single_jit = loaded_compiled_model0((species, coordinates)) energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item()) print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item()) print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())
...@@ -75,7 +75,7 @@ methane = ase.Atoms('CHHHH', positions=coordinates.squeeze().detach().numpy()) ...@@ -75,7 +75,7 @@ methane = ase.Atoms('CHHHH', positions=coordinates.squeeze().detach().numpy())
############################################################################### ###############################################################################
# Now let's compute energies using the ensemble directly: # Now let's compute energies using the ensemble directly:
_, energy = nnp1((species, coordinates)) energy = nnp1((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0] derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative force = -derivative
print('Energy:', energy.item()) print('Energy:', energy.item())
...@@ -89,7 +89,7 @@ print('Force:', methane.get_forces() / ase.units.Hartree) ...@@ -89,7 +89,7 @@ print('Force:', methane.get_forces() / ase.units.Hartree)
############################################################################### ###############################################################################
# We can do the same thing with the single model: # We can do the same thing with the single model:
_, energy = nnp2((species, coordinates)) energy = nnp2((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0] derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative force = -derivative
print('Energy:', energy.item()) print('Energy:', energy.item())
......
...@@ -286,7 +286,7 @@ def validate(): ...@@ -286,7 +286,7 @@ def validate():
true_energies = batch_y['energies'] true_energies = batch_y['energies']
predicted_energies = [] predicted_energies = []
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates)) chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies) predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies) predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
...@@ -343,7 +343,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -343,7 +343,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1)) num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates)) chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies) predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms) num_atoms = torch.cat(num_atoms)
......
...@@ -231,7 +231,7 @@ def validate(): ...@@ -231,7 +231,7 @@ def validate():
true_energies = batch_y['energies'] true_energies = batch_y['energies']
predicted_energies = [] predicted_energies = []
for chunk_species, chunk_coordinates in batch_x: for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates)) chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies) predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies) predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
...@@ -299,7 +299,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -299,7 +299,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
# that we could compute force from it # that we could compute force from it
chunk_coordinates.requires_grad_(True) chunk_coordinates.requires_grad_(True)
_, chunk_energies = model((chunk_species, chunk_coordinates)) chunk_energies = model((chunk_species, chunk_coordinates)).energies
# We can use torch.autograd.grad to compute force. Remember to # We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to # create graph so that the loss of the force can contribute to
......
...@@ -54,7 +54,7 @@ masses = element_masses[species] ...@@ -54,7 +54,7 @@ masses = element_masses[species]
# To do vibration analysis, we first need to generate a graph that computes # To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy # energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy: # is the same as the code to compute energy:
_, energies = model((species, coordinates)) energies = model((species, coordinates)).energies
############################################################################### ###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix: # We can now use the energy graph to compute analytical Hessian matrix:
......
import torch import torch
from torch import Tensor from torch import Tensor
import math import math
from typing import Tuple, Optional from typing import Tuple, Optional, NamedTuple
from torch.jit import Final from torch.jit import Final
class SpeciesAEV(NamedTuple):
species: Tensor
aevs: Tensor
def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor: def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
# assuming all elements in distances are smaller than cutoff # assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5 return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
...@@ -356,7 +361,7 @@ class AEVComputer(torch.nn.Module): ...@@ -356,7 +361,7 @@ class AEVComputer(torch.nn.Module):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None, def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]: pbc: Optional[Tensor] = None) -> SpeciesAEV:
"""Compute AEVs """Compute AEVs
Arguments: Arguments:
...@@ -384,7 +389,7 @@ class AEVComputer(torch.nn.Module): ...@@ -384,7 +389,7 @@ class AEVComputer(torch.nn.Module):
for that direction. for that direction.
Returns: Returns:
tuple: Species and AEVs. species are the species from the input NamedTuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())`` ``(C, A, self.aev_length())``
""" """
...@@ -398,4 +403,5 @@ class AEVComputer(torch.nn.Module): ...@@ -398,4 +403,5 @@ class AEVComputer(torch.nn.Module):
cutoff = max(self.Rcr, self.Rca) cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff) shifts = compute_shifts(cell, pbc, cutoff)
return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes) aev = compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
return SpeciesAEV(species, aev)
...@@ -93,11 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator): ...@@ -93,11 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
strain_y = self.strain(cell, displacement_y, 1) strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2) strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z cell = cell + strain_x + strain_y + strain_z
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc) aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs
else: else:
_, aev = self.aev_computer((species, coordinates)) aev = self.aev_computer((species, coordinates)).aevs
_, energy = self.nn((species, aev)) energy = self.nn((species, aev)).energies
energy *= ase.units.Hartree energy *= ase.units.Hartree
self.results['energy'] = energy.item() self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item() self.results['free_energy'] = energy.item()
......
import torch import torch
from torch import Tensor from torch import Tensor
from typing import Tuple from typing import Tuple, NamedTuple
class SpeciesEnergies(NamedTuple):
species: Tensor
energies: Tensor
class ANIModel(torch.nn.Module): class ANIModel(torch.nn.Module):
...@@ -26,7 +31,7 @@ class ANIModel(torch.nn.Module): ...@@ -26,7 +31,7 @@ class ANIModel(torch.nn.Module):
def __getitem__(self, i): def __getitem__(self, i):
return self.module_list[i] return self.module_list[i]
def forward(self, species_aev: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
aev = aev.flatten(0, 1) aev = aev.flatten(0, 1)
...@@ -40,7 +45,7 @@ class ANIModel(torch.nn.Module): ...@@ -40,7 +45,7 @@ class ANIModel(torch.nn.Module):
input_ = aev.index_select(0, midx) input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten()) output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species) output = output.view_as(species)
return species, torch.sum(output, dim=1) return SpeciesEnergies(species, torch.sum(output, dim=1))
class Ensemble(torch.nn.Module): class Ensemble(torch.nn.Module):
...@@ -51,12 +56,12 @@ class Ensemble(torch.nn.Module): ...@@ -51,12 +56,12 @@ class Ensemble(torch.nn.Module):
self.modules_list = torch.nn.ModuleList(modules) self.modules_list = torch.nn.ModuleList(modules)
self.size = len(self.modules_list) self.size = len(self.modules_list)
def forward(self, species_input: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: def forward(self, species_input: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
sum_ = 0 sum_ = 0
for x in self.modules_list: for x in self.modules_list:
sum_ += x(species_input)[1] sum_ += x(species_input)[1]
species, _ = species_input species, _ = species_input
return species, sum_ / self.size return SpeciesEnergies(species, sum_ / self.size)
def __getitem__(self, i): def __getitem__(self, i):
return self.modules_list[i] return self.modules_list[i]
...@@ -69,7 +74,7 @@ class Sequential(torch.nn.Module): ...@@ -69,7 +74,7 @@ class Sequential(torch.nn.Module):
super(Sequential, self).__init__() super(Sequential, self).__init__()
self.modules_list = torch.nn.ModuleList(modules) self.modules_list = torch.nn.ModuleList(modules)
def forward(self, input_: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: def forward(self, input_: Tuple[Tensor, Tensor]):
for module in self.modules_list: for module in self.modules_list:
input_ = module(input_) input_ = module(input_)
return input_ return input_
......
...@@ -4,7 +4,8 @@ import torch.utils.data ...@@ -4,7 +4,8 @@ import torch.utils.data
import math import math
import numpy as np import numpy as np
from collections import defaultdict from collections import defaultdict
from typing import Tuple from typing import Tuple, NamedTuple
from .nn import SpeciesEnergies
def pad(species): def pad(species):
...@@ -211,12 +212,12 @@ class EnergyShifter(torch.nn.Module): ...@@ -211,12 +212,12 @@ class EnergyShifter(torch.nn.Module):
properties['energies'] = energies properties['energies'] = energies
return atomic_properties, properties return atomic_properties, properties
def forward(self, species_energies: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]: def forward(self, species_energies: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
""" """
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.device) sae = self.sae(species).to(energies.device)
return species, energies.to(sae.dtype) + sae return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
class ChemicalSymbolsToInts: class ChemicalSymbolsToInts:
...@@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None): ...@@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None):
], dim=1) ], dim=1)
class FreqsModes(NamedTuple):
freqs: Tensor
modes: Tensor
def vibrational_analysis(masses, hessian, unit='cm^-1'): def vibrational_analysis(masses, hessian, unit='cm^-1'):
"""Computing the vibrational wavenumbers from hessian.""" """Computing the vibrational wavenumbers from hessian."""
if unit != 'cm^-1': if unit != 'cm^-1':
...@@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'): ...@@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1 # converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
wavenumbers = frequencies * 17092 wavenumbers = frequencies * 17092
modes = (eigenvectors.t() * inv_sqrt_mass).reshape(frequencies.numel(), -1, 3) modes = (eigenvectors.t() * inv_sqrt_mass).reshape(frequencies.numel(), -1, 3)
return wavenumbers, modes return FreqsModes(wavenumbers, modes)
__all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian', __all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment