Unverified Commit e30b98a8 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by GitHub
Browse files

compute atomic self energies for a given data set prior to training (#255)

* using least-squares to compute atomic self energies from the dataset

* self atomic energy calculation in the example training file
parent cecaf992
......@@ -30,7 +30,10 @@ device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
###############################################################################
# Now let's setup constants and construct an AEV computer. These numbers could
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`_ and `sae_linfit.dat`_
# be found in `rHCNO-5.2R_16-3.5A_a4-8.params`
# The atomic self energies given in `sae_linfit.dat`_ are computed from ANI-1x
# dataset. These constants can be calculated for any given dataset if ``None``
# is provided as an argument to the object of :class:`EnergyShifter` class.
#
# .. note::
#
......@@ -52,12 +55,7 @@ EtaA = torch.tensor([8.0000000e+00], device=device)
ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=device)
num_species = 4
aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species)
energy_shifter = torchani.utils.EnergyShifter([
-0.600952980000, # H
-38.08316124000, # C
-54.70775770000, # N
-75.19446356000, # O
])
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts('HCNO')
......@@ -85,6 +83,7 @@ training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, device=device,
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('H,C,N,O self energies: ', energy_shifter.self_energies)
###############################################################################
# When iterating the dataset, we will get pairs of input and output
# ``(species_coordinates, properties)``, where ``species_coordinates`` is the
......
import torch
import torch.utils.data
import math
import numpy as np
from collections import defaultdict
......@@ -142,9 +143,28 @@ class EnergyShifter(torch.nn.Module):
def __init__(self, self_energies):
super(EnergyShifter, self).__init__()
self_energies = torch.tensor(self_energies, dtype=torch.double)
if self_energies is not None:
self_energies = torch.tensor(self_energies, dtype=torch.double)
self.register_buffer('self_energies', self_energies)
@staticmethod
def sae_from_dataset(atomic_properties, properties):
"""Compute atomic self energies from dataset.
Least-squares solution to a linear equation is calculated to output
``self_energies`` when ``self_energies = None`` is passed to
:class:`torchani.EnergyShifter`
"""
species = atomic_properties['species']
energies = properties['energies']
present_species_ = present_species(species)
X = (species.unsqueeze(-1) == present_species_).sum(dim=1).to(torch.double)
y = energies.unsqueeze(dim=-1)
coeff_, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
return coeff_.squeeze()
def sae(self, species):
"""Compute self energies for molecules.
......@@ -166,6 +186,10 @@ class EnergyShifter(torch.nn.Module):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that
subtract self energies.
"""
if self.self_energies is None:
self_energies = self.sae_from_dataset(atomic_properties, properties)
self.self_energies = torch.tensor(self_energies, dtype=torch.double)
species = atomic_properties['species']
energies = properties['energies']
device = energies.device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment