Commit 1888d734 authored by Farhad Ramezanghorbani's avatar Farhad Ramezanghorbani Committed by Gao, Xiang
Browse files

calculate intercept when fitting, discard outliers from dataset (#263)

parent ac736c33
...@@ -216,8 +216,8 @@ class BatchedANIDataset(PaddedBatchChunkDataset): ...@@ -216,8 +216,8 @@ class BatchedANIDataset(PaddedBatchChunkDataset):
def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True, def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
properties=('energies',), atomic_properties=(), transform=(), rm_outlier=False, properties=('energies',), atomic_properties=(),
dtype=torch.get_default_dtype(), device=default_device, transform=(), dtype=torch.get_default_dtype(), device=default_device,
split=(None,)): split=(None,)):
"""Load ANI dataset from hdf5 files, and split into subsets. """Load ANI dataset from hdf5 files, and split into subsets.
...@@ -255,6 +255,8 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True, ...@@ -255,6 +255,8 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
batch_size (int): Number of different 3D structures in a single batch_size (int): Number of different 3D structures in a single
minibatch. minibatch.
shuffle (bool): Whether to shuffle the whole dataset. shuffle (bool): Whether to shuffle the whole dataset.
rm_outlier (bool): Whether to discard the outlier energy conformers
from a given dataset.
properties (list): List of keys of `molecular` properties in the properties (list): List of keys of `molecular` properties in the
dataset to be loaded. Here `molecular` means, no matter the number dataset to be loaded. Here `molecular` means, no matter the number
of atoms that property always have fixed size, i.e. the tensor of atoms that property always have fixed size, i.e. the tensor
...@@ -298,13 +300,32 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True, ...@@ -298,13 +300,32 @@ def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True,
atomic_properties_, properties_ = load_and_pad_whole_dataset( atomic_properties_, properties_ = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties) path, species_tensor_converter, shuffle, properties, atomic_properties)
molecules = atomic_properties_['species'].shape[0]
atomic_keys = ['species', 'coordinates', *atomic_properties]
keys = properties
# do transformations on data # do transformations on data
for t in transform: for t in transform:
atomic_properties_, properties_ = t(atomic_properties_, properties_) atomic_properties_, properties_ = t(atomic_properties_, properties_)
molecules = atomic_properties_['species'].shape[0] if rm_outlier:
atomic_keys = ['species', 'coordinates', *atomic_properties] transformed_energies = properties_['energies']
keys = properties num_atoms = (atomic_properties_['species'] >= 0).sum(dim=1).to(transformed_energies.dtype)
scaled_diff = transformed_energies / num_atoms.sqrt()
mean = transformed_energies.mean()
std = transformed_energies.std()
tol = 15.0 * std + mean
low_idx = (torch.abs(scaled_diff) < tol).nonzero().squeeze()
outlier_count = molecules - low_idx.numel()
# discard outlier energy conformers if exist
if outlier_count > 0:
print(f'Note: {outlier_count} outlier energy conformers have been discarded from dataset')
for key, val in atomic_properties_.items():
atomic_properties_[key] = val[low_idx]
for key, val in properties_.items():
properties_[key] = val[low_idx]
# compute size of each subset # compute size of each subset
split_ = [] split_ = []
......
...@@ -142,18 +142,20 @@ class EnergyShifter(torch.nn.Module): ...@@ -142,18 +142,20 @@ class EnergyShifter(torch.nn.Module):
self_energies (:class:`collections.abc.Sequence`): Sequence of floating self_energies (:class:`collections.abc.Sequence`): Sequence of floating
numbers for the self energy of each atom type. The numbers should numbers for the self energy of each atom type. The numbers should
be in order, i.e. ``self_energies[i]`` should be atom type ``i``. be in order, i.e. ``self_energies[i]`` should be atom type ``i``.
fit_intercept (bool): Whether to calculate the intercept during the LSTSQ
fit. The intercept will also be taken into account to shift energies.
""" """
def __init__(self, self_energies): def __init__(self, self_energies, fit_intercept=False):
super(EnergyShifter, self).__init__() super(EnergyShifter, self).__init__()
self.fit_intercept = fit_intercept
if self_energies is not None: if self_energies is not None:
self_energies = torch.tensor(self_energies, dtype=torch.double) self_energies = torch.tensor(self_energies, dtype=torch.double)
self.register_buffer('self_energies', self_energies) self.register_buffer('self_energies', self_energies)
@staticmethod def sae_from_dataset(self, atomic_properties, properties):
def sae_from_dataset(atomic_properties, properties):
"""Compute atomic self energies from dataset. """Compute atomic self energies from dataset.
Least-squares solution to a linear equation is calculated to output Least-squares solution to a linear equation is calculated to output
...@@ -164,6 +166,9 @@ class EnergyShifter(torch.nn.Module): ...@@ -164,6 +166,9 @@ class EnergyShifter(torch.nn.Module):
energies = properties['energies'] energies = properties['energies']
present_species_ = present_species(species) present_species_ = present_species(species)
X = (species.unsqueeze(-1) == present_species_).sum(dim=1).to(torch.double) X = (species.unsqueeze(-1) == present_species_).sum(dim=1).to(torch.double)
# Concatenate a vector of ones to find fit intercept
if self.fit_intercept:
X = torch.cat((X, torch.ones(X.shape[0], 1).to(torch.double)), dim=-1)
y = energies.unsqueeze(dim=-1) y = energies.unsqueeze(dim=-1)
coeff_, _, _, _ = np.linalg.lstsq(X, y, rcond=None) coeff_, _, _, _ = np.linalg.lstsq(X, y, rcond=None)
return coeff_.squeeze() return coeff_.squeeze()
...@@ -181,9 +186,13 @@ class EnergyShifter(torch.nn.Module): ...@@ -181,9 +186,13 @@ class EnergyShifter(torch.nn.Module):
:class:`torch.Tensor`: 1D vector in shape ``(conformations,)`` :class:`torch.Tensor`: 1D vector in shape ``(conformations,)``
for molecular self energies. for molecular self energies.
""" """
intercept = 0.0
if self.fit_intercept:
intercept = self.self_energies[-1]
self_energies = self.self_energies[species] self_energies = self.self_energies[species]
self_energies[species == -1] = 0 self_energies[species == -1] = 0
return self_energies.sum(dim=1) return self_energies.sum(dim=1) + intercept
def subtract_from_dataset(self, atomic_properties, properties): def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that """Transformer for :class:`torchani.data.BatchedANIDataset` that
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment