Commit 1888d734 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

calculate intercept when fitting, discard outliers from dataset (#263)

parent ac736c33
......@@ -216,8 +216,8 @@ class BatchedANIDataset(PaddedBatchChunkDataset):
def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device,
rm_outlier=False, properties=('energies',), atomic_properties=(),
transform=(), dtype=torch.get_default_dtype(), device=default_device,
split=(None,)):
"""Load ANI dataset from hdf5 files, and split into subsets.
......@@ -255,6 +255,8 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
batch_size (int): Number of different 3D structures in a single
minibatch.
shuffle (bool): Whether to shuffle the whole dataset.
rm_outlier (bool): Whether to discard the outlier energy conformers
from a given dataset.
properties (list): List of keys of `molecular` properties in the
dataset to be loaded. Here `molecular` means, no matter the number
of atoms that property always have fixed size, i.e. the tensor
......@@ -298,13 +300,32 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
atomic_properties_, properties_ = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties)
molecules = atomic_properties_['species'].shape[0]
atomic_keys = ['species', 'coordinates', *atomic_properties]
keys = properties
# do transformations on data
for t in transform:
atomic_properties_, properties_ = t(atomic_properties_, properties_)
molecules = atomic_properties_['species'].shape[0]
atomic_keys = ['species', 'coordinates', *atomic_properties]
keys = properties
if rm_outlier:
transformed_energies = properties_['energies']
num_atoms = (atomic_properties_['species'] >= 0).sum(dim=1).to(transformed_energies.dtype)
scaled_diff = transformed_energies / num_atoms.sqrt()
mean = transformed_energies.mean()
std = transformed_energies.std()
tol = 15.0 * std + mean
low_idx = (torch.abs(scaled_diff) < tol).nonzero().squeeze()
outlier_count = molecules - low_idx.numel()
# discard outlier energy conformers if exist
if outlier_count > 0:
print(f'Note: {outlier_count} outlier energy conformers have been discarded from dataset')
for key, val in atomic_properties_.items():
atomic_properties_[key] = val[low_idx]
for key, val in properties_.items():
properties_[key] = val[low_idx]
# compute size of each subset
split_ = []
......
......@@ -142,18 +142,20 @@ class EnergyShifter(torch.nn.Module):
self_energies (:class:`collections.abc.Sequence`): Sequence of floating
numbers for the self energy of each atom type. The numbers should
be in order, i.e. ``self_energies[i]`` should be atom type ``i``.
fit_intercept (bool): Whether to calculate the intercept during the LSTSQ
fit. The intercept will also be taken into account to shift energies.
"""
def __init__(self, self_energies):
def __init__(self, self_energies, fit_intercept=False):
super(EnergyShifter, self).__init__()
self.fit_intercept = fit_intercept
if self_energies is not None:
self_energies = torch.tensor(self_energies, dtype=torch.double)
self.register_buffer('self_energies', self_energies)
@staticmethod
def sae_from_dataset(atomic_properties, properties):
def sae_from_dataset(self, atomic_properties, properties):
"""Compute atomic self energies from dataset.
Least-squares solution to a linear equation is calculated to output
......@@ -164,6 +166,9 @@ class EnergyShifter(torch.nn.Module):
energies = properties['energies']
present_species_ = present_species(species)
X = (species.unsqueeze(-1) == present_species_).sum(dim=1).to(torch.double)
# Concatenate a vector of ones to find fit intercept
if self.fit_intercept:
X = torch.cat((X, torch.ones(X.shape[0], 1).to(torch.double)), dim=-1)
y = energies.unsqueeze(dim=-1)
coeff_, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
return coeff_.squeeze()
......@@ -181,9 +186,13 @@ class EnergyShifter(torch.nn.Module):
:class:`torch.Tensor`: 1D vector in shape ``(conformations,)``
for molecular self energies.
"""
intercept = 0.0
if self.fit_intercept:
intercept = self.self_energies[-1]
self_energies = self.self_energies[species]
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment