Commit 004f5a52 authored by Gao, Xiang's avatar Gao, Xiang Committed by Farhad Ramezanghorbani
Browse files

Use namedtuple to improve API while still maintaining backward compatibility (#380)

* Use namedtuple to improve API

* improve
parent 92c307dc
......@@ -16,7 +16,6 @@ calculator.
###############################################################################
# To begin with, let's first import the modules we will use:
from __future__ import print_function
from ase.lattice.cubic import Diamond
from ase.md.langevin import Langevin
from ase.optimize import BFGS
......
......@@ -9,7 +9,6 @@ TorchANI and can be used directly.
###############################################################################
# To begin with, let's first import the modules we will use:
from __future__ import print_function
import torch
import torchani
......@@ -43,7 +42,7 @@ species = model.species_to_tensor('CHHHH').to(device).unsqueeze(0)
###############################################################################
# Now let's compute energy and force:
_, energy = model((species, coordinates))
energy = model((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative
......
......@@ -46,9 +46,9 @@ species = model.species_to_tensor('CHHHH').unsqueeze(0)
###############################################################################
# And here is the result:
_, energies_ensemble = model((species, coordinates))
_, energies_single = model[0]((species, coordinates))
_, energies_ensemble_jit = loaded_compiled_model((species, coordinates))
_, energies_single_jit = loaded_compiled_model0((species, coordinates))
energies_ensemble = model((species, coordinates)).energies
energies_single = model[0]((species, coordinates)).energies
energies_ensemble_jit = loaded_compiled_model((species, coordinates)).energies
energies_single_jit = loaded_compiled_model0((species, coordinates)).energies
print('Ensemble energy, eager mode vs loaded jit:', energies_ensemble.item(), energies_ensemble_jit.item())
print('Single network energy, eager mode vs loaded jit:', energies_single.item(), energies_single_jit.item())
......@@ -75,7 +75,7 @@ methane = ase.Atoms('CHHHH', positions=coordinates.squeeze().detach().numpy())
###############################################################################
# Now let's compute energies using the ensemble directly:
_, energy = nnp1((species, coordinates))
energy = nnp1((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative
print('Energy:', energy.item())
......@@ -89,7 +89,7 @@ print('Force:', methane.get_forces() / ase.units.Hartree)
###############################################################################
# We can do the same thing with the single model:
_, energy = nnp2((species, coordinates))
energy = nnp2((species, coordinates)).energies
derivative = torch.autograd.grad(energy.sum(), coordinates)[0]
force = -derivative
print('Energy:', energy.item())
......
......@@ -286,7 +286,7 @@ def validate():
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
......@@ -343,7 +343,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
for chunk_species, chunk_coordinates in batch_x:
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms)
......
......@@ -231,7 +231,7 @@ def validate():
true_energies = batch_y['energies']
predicted_energies = []
for chunk_species, chunk_coordinates in batch_x:
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item()
......@@ -299,7 +299,7 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
# that we could compute force from it
chunk_coordinates.requires_grad_(True)
_, chunk_energies = model((chunk_species, chunk_coordinates))
chunk_energies = model((chunk_species, chunk_coordinates)).energies
# We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to
......
......@@ -54,7 +54,7 @@ masses = element_masses[species]
# To do vibration analysis, we first need to generate a graph that computes
# energies from species and coordinates. The code to generate a graph of energy
# is the same as the code to compute energy:
_, energies = model((species, coordinates))
energies = model((species, coordinates)).energies
###############################################################################
# We can now use the energy graph to compute analytical Hessian matrix:
......
import torch
from torch import Tensor
import math
from typing import Tuple, Optional
from typing import Tuple, Optional, NamedTuple
from torch.jit import Final
class SpeciesAEV(NamedTuple):
species: Tensor
aevs: Tensor
def cutoff_cosine(distances: Tensor, cutoff: float) -> Tensor:
# assuming all elements in distances are smaller than cutoff
return 0.5 * torch.cos(distances * (math.pi / cutoff)) + 0.5
......@@ -356,7 +361,7 @@ class AEVComputer(torch.nn.Module):
return self.Rcr, self.EtaR, self.ShfR, self.Rca, self.ShfZ, self.EtaA, self.Zeta, self.ShfA
def forward(self, input_: Tuple[Tensor, Tensor], cell: Optional[Tensor] = None,
pbc: Optional[Tensor] = None) -> Tuple[Tensor, Tensor]:
pbc: Optional[Tensor] = None) -> SpeciesAEV:
"""Compute AEVs
Arguments:
......@@ -384,7 +389,7 @@ class AEVComputer(torch.nn.Module):
for that direction.
Returns:
tuple: Species and AEVs. species are the species from the input
NamedTuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
......@@ -398,4 +403,5 @@ class AEVComputer(torch.nn.Module):
cutoff = max(self.Rcr, self.Rca)
shifts = compute_shifts(cell, pbc, cutoff)
return species, compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
aev = compute_aev(species, coordinates, cell, shifts, self.triu_index, self.constants(), self.sizes)
return SpeciesAEV(species, aev)
......@@ -93,11 +93,11 @@ class Calculator(ase.calculators.calculator.Calculator):
strain_y = self.strain(cell, displacement_y, 1)
strain_z = self.strain(cell, displacement_z, 2)
cell = cell + strain_x + strain_y + strain_z
_, aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc)
aev = self.aev_computer((species, coordinates), cell=cell, pbc=pbc).aevs
else:
_, aev = self.aev_computer((species, coordinates))
aev = self.aev_computer((species, coordinates)).aevs
_, energy = self.nn((species, aev))
energy = self.nn((species, aev)).energies
energy *= ase.units.Hartree
self.results['energy'] = energy.item()
self.results['free_energy'] = energy.item()
......
import torch
from torch import Tensor
from typing import Tuple
from typing import Tuple, NamedTuple
class SpeciesEnergies(NamedTuple):
species: Tensor
energies: Tensor
class ANIModel(torch.nn.Module):
......@@ -26,7 +31,7 @@ class ANIModel(torch.nn.Module):
def __getitem__(self, i):
return self.module_list[i]
def forward(self, species_aev: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_aev: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
species, aev = species_aev
species_ = species.flatten()
aev = aev.flatten(0, 1)
......@@ -40,7 +45,7 @@ class ANIModel(torch.nn.Module):
input_ = aev.index_select(0, midx)
output.masked_scatter_(mask, m(input_).flatten())
output = output.view_as(species)
return species, torch.sum(output, dim=1)
return SpeciesEnergies(species, torch.sum(output, dim=1))
class Ensemble(torch.nn.Module):
......@@ -51,12 +56,12 @@ class Ensemble(torch.nn.Module):
self.modules_list = torch.nn.ModuleList(modules)
self.size = len(self.modules_list)
def forward(self, species_input: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_input: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
sum_ = 0
for x in self.modules_list:
sum_ += x(species_input)[1]
species, _ = species_input
return species, sum_ / self.size
return SpeciesEnergies(species, sum_ / self.size)
def __getitem__(self, i):
return self.modules_list[i]
......@@ -69,7 +74,7 @@ class Sequential(torch.nn.Module):
super(Sequential, self).__init__()
self.modules_list = torch.nn.ModuleList(modules)
def forward(self, input_: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, input_: Tuple[Tensor, Tensor]):
for module in self.modules_list:
input_ = module(input_)
return input_
......
......@@ -4,7 +4,8 @@ import torch.utils.data
import math
import numpy as np
from collections import defaultdict
from typing import Tuple
from typing import Tuple, NamedTuple
from .nn import SpeciesEnergies
def pad(species):
......@@ -211,12 +212,12 @@ class EnergyShifter(torch.nn.Module):
properties['energies'] = energies
return atomic_properties, properties
def forward(self, species_energies: Tuple[Tensor, Tensor]) -> Tuple[Tensor, Tensor]:
def forward(self, species_energies: Tuple[Tensor, Tensor]) -> SpeciesEnergies:
"""(species, molecular energies)->(species, molecular energies + sae)
"""
species, energies = species_energies
sae = self.sae(species).to(energies.device)
return species, energies.to(sae.dtype) + sae
return SpeciesEnergies(species, energies.to(sae.dtype) + sae)
class ChemicalSymbolsToInts:
......@@ -269,6 +270,11 @@ def hessian(coordinates, energies=None, forces=None):
], dim=1)
class FreqsModes(NamedTuple):
freqs: Tensor
modes: Tensor
def vibrational_analysis(masses, hessian, unit='cm^-1'):
"""Computing the vibrational wavenumbers from hessian."""
if unit != 'cm^-1':
......@@ -292,7 +298,7 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
# converting from sqrt(hartree / (amu * angstrom^2)) to cm^-1
wavenumbers = frequencies * 17092
modes = (eigenvectors.t() * inv_sqrt_mass).reshape(frequencies.numel(), -1, 3)
return wavenumbers, modes
return FreqsModes(wavenumbers, modes)
__all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian',
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment