Unverified Commit 9cae6d3f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Docs improvements (#77)

parent 3cced1e6
# -*- coding: utf-8 -*-
"""Helpers for working with ignite."""
import torch import torch
from . import utils from . import utils
from torch.nn.modules.loss import _Loss from torch.nn.modules.loss import _Loss
...@@ -6,11 +9,26 @@ from ignite.metrics import RootMeanSquaredError ...@@ -6,11 +9,26 @@ from ignite.metrics import RootMeanSquaredError
class Container(torch.nn.ModuleDict): class Container(torch.nn.ModuleDict):
"""Each minibatch is splitted into chunks, as explained in the docstring of
:class:`torchani.data.BatchedANIDataset`, as a result, it is impossible to
use :class:`torchani.AEVComputer`, :class:`torchani.ANIModel` directly with
ignite. This class is designed to solve this issue.
Arguments:
modules (:class:`collections.abc.Mapping`): same as the argument in
:class:`torch.nn.ModuleDict`.
"""
def __init__(self, modules): def __init__(self, modules):
super(Container, self).__init__(modules) super(Container, self).__init__(modules)
def forward(self, species_coordinates): def forward(self, species_coordinates):
"""Takes sequence of species, coordinates pair as input, and returns
computed properties as a dictionary. Same property from different
chunks will be concatenated to form a single tensor for a batch. The
input, i.e. species and coordinates of chunks, will also be batched by
:func:`torchani.utils.pad_and_batch` and copied to output.
"""
results = {k: [] for k in self} results = {k: [] for k in self}
for sc in species_coordinates: for sc in species_coordinates:
for k in self: for k in self:
...@@ -24,6 +42,11 @@ class Container(torch.nn.ModuleDict): ...@@ -24,6 +42,11 @@ class Container(torch.nn.ModuleDict):
class DictLoss(_Loss): class DictLoss(_Loss):
"""Since :class:`Container` output dictionaries, losses defined in
:attr:`torch.nn` needs to be wrapped before used. This class wraps losses
that directly work on tensors with a key by calling the wrapped loss on the
associated value of that key.
"""
def __init__(self, key, loss): def __init__(self, key, loss):
super(DictLoss, self).__init__() super(DictLoss, self).__init__()
...@@ -34,7 +57,11 @@ class DictLoss(_Loss): ...@@ -34,7 +57,11 @@ class DictLoss(_Loss):
return self.loss(input[self.key], other[self.key]) return self.loss(input[self.key], other[self.key])
class _PerAtomDictLoss(DictLoss): class PerAtomDictLoss(DictLoss):
"""Similar to :class:`DictLoss`, but scale the loss values by the number of
atoms for each structure. The `loss` argument must be set to not to reduce
by the caller. Currently the only reduce operation supported is averaging.
"""
def forward(self, input, other): def forward(self, input, other):
loss = self.loss(input[self.key], other[self.key]) loss = self.loss(input[self.key], other[self.key])
...@@ -45,6 +72,7 @@ class _PerAtomDictLoss(DictLoss): ...@@ -45,6 +72,7 @@ class _PerAtomDictLoss(DictLoss):
class DictMetric(Metric): class DictMetric(Metric):
"""Similar to :class:`DictLoss`, but this is for metric, not loss."""
def __init__(self, key, metric): def __init__(self, key, metric):
self.key = key self.key = key
...@@ -63,13 +91,15 @@ class DictMetric(Metric): ...@@ -63,13 +91,15 @@ class DictMetric(Metric):
def MSELoss(key, per_atom=True): def MSELoss(key, per_atom=True):
"""Create MSE loss on the specified key."""
if per_atom: if per_atom:
return _PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none')) return PerAtomDictLoss(key, torch.nn.MSELoss(reduction='none'))
else: else:
return DictLoss(key, torch.nn.MSELoss()) return DictLoss(key, torch.nn.MSELoss())
class TransformedLoss(_Loss): class TransformedLoss(_Loss):
"""Do a transformation on loss values."""
def __init__(self, origin, transform): def __init__(self, origin, transform):
super(TransformedLoss, self).__init__() super(TransformedLoss, self).__init__()
...@@ -81,4 +111,5 @@ class TransformedLoss(_Loss): ...@@ -81,4 +111,5 @@ class TransformedLoss(_Loss):
def RMSEMetric(key): def RMSEMetric(key):
"""Create RMSE metric on key."""
return DictMetric(key, RootMeanSquaredError()) return DictMetric(key, RootMeanSquaredError())
# -*- coding: utf-8 -*-
"""Tools for loading NeuroChem input files."""
import pkg_resources import pkg_resources
import torch import torch
import os import os
...@@ -5,12 +8,15 @@ import bz2 ...@@ -5,12 +8,15 @@ import bz2
import lark import lark
import struct import struct
from collections.abc import Mapping from collections.abc import Mapping
from .models import ANIModel, Ensemble from .nn import ANIModel, Ensemble
from .utils import EnergyShifter from .utils import EnergyShifter
from .aev import AEVComputer from .aev import AEVComputer
class Constants(Mapping): class Constants(Mapping):
"""NeuroChem constants. Objects of this class can be used as arguments
to :class:`torchani.AEVComputer`, like ``torchani.AEVComputer(**consts)``.
"""
def __init__(self, filename): def __init__(self, filename):
self.filename = filename self.filename = filename
...@@ -57,12 +63,14 @@ class Constants(Mapping): ...@@ -57,12 +63,14 @@ class Constants(Mapping):
return getattr(self, item) return getattr(self, item)
def species_to_tensor(self, species): def species_to_tensor(self, species):
"""Convert species from squence of strings to 1D tensor"""
rev = [self.rev_species[s] for s in species] rev = [self.rev_species[s] for s in species]
return torch.tensor(rev, dtype=torch.long) return torch.tensor(rev, dtype=torch.long)
def load_sae(filename): def load_sae(filename):
"""Load self energies from NeuroChem sae file""" """Returns an object of :class:`EnergyShifter` with self energies from
NeuroChem sae file"""
self_energies = [] self_energies = []
with open(filename) as f: with open(filename) as f:
for i in f: for i in f:
...@@ -75,20 +83,8 @@ def load_sae(filename): ...@@ -75,20 +83,8 @@ def load_sae(filename):
def load_atomic_network(filename): def load_atomic_network(filename):
"""Load atomic network from NeuroChem's .nnf, .wparam and .bparam files """Returns an instance of :class:`torch.nn.Sequential` with hyperparameters
and parameters loaded NeuroChem's .nnf, .wparam and .bparam files."""
Parameters
----------
filename : string
The file name for the `.nnf` file that store network
hyperparameters. The `.bparam` and `.wparam` must be
in the same directory
Returns
-------
torch.nn.Sequential
The loaded atomic network
"""
def decompress_nnf(buffer): def decompress_nnf(buffer):
while buffer[0] != b'='[0]: while buffer[0] != b'='[0]:
...@@ -227,15 +223,33 @@ def load_atomic_network(filename): ...@@ -227,15 +223,33 @@ def load_atomic_network(filename):
return torch.nn.Sequential(*layers) return torch.nn.Sequential(*layers)
def load_model(species, from_): def load_model(species, dir):
"""Returns an instance of :class:`torchani.ANIModel` loaded from
NeuroChem's network directory.
Arguments:
species (:class:`collections.abc.Sequence`): Sequence of strings for
chemical symbols of each supported atom type in correct order.
dir (str): String for directory storing network configurations.
"""
models = [] models = []
for i in species: for i in species:
filename = os.path.join(from_, 'ANN-{}.nnf'.format(i)) filename = os.path.join(dir, 'ANN-{}.nnf'.format(i))
models.append(load_atomic_network(filename)) models.append(load_atomic_network(filename))
return ANIModel(models) return ANIModel(models)
def load_model_ensemble(species, prefix, count): def load_model_ensemble(species, prefix, count):
"""Returns an instance of :class:`torchani.Ensemble` loaded from
NeuroChem's network directories beginning with the given prefix.
Arguments:
species (:class:`collections.abc.Sequence`): Sequence of strings for
chemical symbols of each supported atom type in correct order.
prefix (str): Prefix of paths of directory that networks configurations
are stored.
count (int): Number of models in the ensemble.
"""
models = [] models = []
for i in range(count): for i in range(count):
network_dir = os.path.join('{}{}'.format(prefix, i), 'networks') network_dir = os.path.join('{}{}'.format(prefix, i), 'networks')
...@@ -244,6 +258,21 @@ def load_model_ensemble(species, prefix, count): ...@@ -244,6 +258,21 @@ def load_model_ensemble(species, prefix, count):
class Buildins: class Buildins:
"""Container for all builtin stuffs.
Attributes:
const_file (:class:`str`): Path to the builtin constant file.
consts (:class:`Constants`): Constants loaded from builtin constant
file.
aev_computer (:class:`torchani.AEVComputer`): AEV computer with builtin
constants.
sae_file (:class:`str`): Path to the builtin self atomic energy file.
energy_shifter (:class:`torchani.EnergyShifter`): AEV computer with
builtin constants.
ensemble_size (:class:`int`): Number of models in model ensemble.
ensemble_prefix (:class:`str`): Prefix of directories of models.
models (:class:`torchani.Ensemble`): Ensemble of models.
"""
def __init__(self): def __init__(self):
self.const_file = pkg_resources.resource_filename( self.const_file = pkg_resources.resource_filename(
...@@ -264,4 +293,6 @@ class Buildins: ...@@ -264,4 +293,6 @@ class Buildins:
self.ensemble_size) self.ensemble_size)
buildins = Buildins() class Trainer:
"""NeuroChem training configurations"""
pass
...@@ -3,47 +3,37 @@ from . import utils ...@@ -3,47 +3,37 @@ from . import utils
class ANIModel(torch.nn.ModuleList): class ANIModel(torch.nn.ModuleList):
"""ANI model that compute properties from species and AEVs.
Different atom types might have different modules, when computing
properties, for each atom, the module for its corresponding atom type will
be applied to its AEV, after that, outputs of modules will be reduced along
different atoms to obtain molecular properties.
Arguments:
modules (:class:`collections.abc.Sequence`): Modules for each atom
types. Atom types are distinguished by their order in
:attr:`modules`, which means, for example ``modules[i]`` must be
the module for atom type ``i``. Different atom types can share a
module by putting the same reference in :attr:`modules`.
reducer (:class:`collections.abc.Callable`): The callable that reduce
atomic outputs into molecular outputs. It must have signature
``(tensor, dim)->tensor``.
padding_fill (float): The value to fill output of padding atoms.
Padding values will participate in reducing, so this value should
be appropriately chosen so that it has no effect on the result. For
example, if the reducer is :func:`torch.sum`, then
:attr:`padding_fill` should be 0, and if the reducer is
:func:`torch.min`, then :attr:`padding_fill` should be
:obj:`math.inf`.
"""
def __init__(self, modules, reducer=torch.sum, padding_fill=0): def __init__(self, modules, reducer=torch.sum, padding_fill=0):
"""
Parameters
----------
modules : seq(torch.nn.Module)
Modules for all species.
reducer : function
Function of (input, dim)->output that reduce the input tensor along
the given dimension to get an output tensor. This function will be
called with the per atom output tensor with internal shape as input
, and desired reduction dimension as dim, and should reduce the
input into the tensor containing desired output.
padding_fill : float
Default value used to fill padding atoms
"""
super(ANIModel, self).__init__(modules) super(ANIModel, self).__init__(modules)
self.reducer = reducer self.reducer = reducer
self.padding_fill = padding_fill self.padding_fill = padding_fill
def forward(self, species_aev): def forward(self, species_aev):
"""Compute output from aev
Parameters
----------
(species, aev)
species : torch.Tensor
Tensor storing the species for each atom.
aev : torch.Tensor
Pytorch tensor of shape (conformations, atoms, aev_length) storing
the computed AEVs.
Returns
-------
(species, output)
species : torch.Tensor
Tensor storing the species for each atom.
output : torch.Tensor
Pytorch tensor of shape (conformations, output_length) for the
output of each conformation.
"""
species, aev = species_aev species, aev = species_aev
species_ = species.flatten() species_ = species.flatten()
present_species = utils.present_species(species) present_species = utils.present_species(species)
...@@ -60,8 +50,9 @@ class ANIModel(torch.nn.ModuleList): ...@@ -60,8 +50,9 @@ class ANIModel(torch.nn.ModuleList):
class Ensemble(torch.nn.ModuleList): class Ensemble(torch.nn.ModuleList):
"""Compute the average output of an ensemeble of modules."""
def forward(self, species_aev): def forward(self, species_input):
outputs = [x(species_aev)[1] for x in self] outputs = [x(species_input)[1] for x in self]
species, _ = species_aev species, _ = species_input
return species, sum(outputs) / len(outputs) return species, sum(outputs) / len(outputs)
...@@ -2,6 +2,22 @@ import torch ...@@ -2,6 +2,22 @@ import torch
def pad_and_batch(species_coordinates): def pad_and_batch(species_coordinates):
"""Put different species and coordinates together into single tensor.
If the species and coordinates are from molecules of different number of
total atoms, then ghost atoms with atom type -1 and coordinate (0, 0, 0)
will be added to make it fit into the same shape.
Arguments:
species_coordinates (:class:`collections.abc.Sequence`): sequence of
pairs of species and coordinates. Species must be of shape
``(N, A)`` and coordinates must be of shape ``(N, A, 3)``, where
``N`` is the number of 3D structures, ``A`` is the number of atoms.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): Species, and
coordinates batched together.
"""
max_atoms = max([c.shape[1] for _, c in species_coordinates]) max_atoms = max([c.shape[1] for _, c in species_coordinates])
species = [] species = []
coordinates = [] coordinates = []
...@@ -23,6 +39,14 @@ def pad_and_batch(species_coordinates): ...@@ -23,6 +39,14 @@ def pad_and_batch(species_coordinates):
def present_species(species): def present_species(species):
"""Given a vector of species of atoms, compute the unique species present.
Arguments:
species (:class:`torch.Tensor`): 1D vector of shape ``(atoms,)``
Returns:
:class:`torch.Tensor`: 1D vector storing present atom types sorted.
"""
present_species = species.flatten().unique(sorted=True) present_species = species.flatten().unique(sorted=True)
if present_species[0].item() == -1: if present_species[0].item() == -1:
present_species = present_species[1:] present_species = present_species[1:]
...@@ -30,6 +54,18 @@ def present_species(species): ...@@ -30,6 +54,18 @@ def present_species(species):
def strip_redundant_padding(species, coordinates): def strip_redundant_padding(species, coordinates):
"""Strip trailing padding atoms.
Arguments:
species (:class:`torch.Tensor`): Long tensor of shape
``(conformations, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape
``(conformations, atoms, 3)``.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
with redundant padding atoms stripped.
"""
non_padding = (species >= 0).any(dim=0).nonzero().squeeze() non_padding = (species >= 0).any(dim=0).nonzero().squeeze()
species = species.index_select(1, non_padding) species = species.index_select(1, non_padding)
coordinates = coordinates.index_select(1, non_padding) coordinates = coordinates.index_select(1, non_padding)
...@@ -37,6 +73,16 @@ def strip_redundant_padding(species, coordinates): ...@@ -37,6 +73,16 @@ def strip_redundant_padding(species, coordinates):
class EnergyShifter(torch.nn.Module): class EnergyShifter(torch.nn.Module):
"""Helper class for adding and subtracting self atomic energies
This is a subclass of :class:`torch.nn.Module`, so it can be used directly
in a pipeline as ``[input->AEVComputer->ANIModel->EnergyShifter->output]``.
Arguments:
self_energies (:class:`collections.abc.Sequence`): Sequence of floating
numbers for the self energy of each atom type. The numbers should
be in order, i.e. ``self_energies[i]`` should be atom type ``i``.
"""
def __init__(self, self_energies): def __init__(self, self_energies):
super(EnergyShifter, self).__init__() super(EnergyShifter, self).__init__()
...@@ -44,11 +90,26 @@ class EnergyShifter(torch.nn.Module): ...@@ -44,11 +90,26 @@ class EnergyShifter(torch.nn.Module):
self.register_buffer('self_energies', self_energies) self.register_buffer('self_energies', self_energies)
def sae(self, species): def sae(self, species):
"""Compute self energies for molecules.
Padding atoms will be automatically excluded.
Arguments:
species (:class:`torch.Tensor`): Long tensor in shape
``(conformations, atoms)``.
Returns:
:class:`torch.Tensor`: 1D vector in shape ``(conformations,)``
for molecular self energies.
"""
self_energies = self.self_energies[species] self_energies = self.self_energies[species]
self_energies[species == -1] = 0 self_energies[species == -1] = 0
return self_energies.sum(dim=1) return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties): def subtract_from_dataset(self, species, coordinates, properties):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that
subtract self energies.
"""
energies = properties['energies'] energies = properties['energies']
device = energies.device device = energies.device
energies = energies.to(torch.double) - self.sae(species).to(device) energies = energies.to(torch.double) - self.sae(species).to(device)
...@@ -56,6 +117,8 @@ class EnergyShifter(torch.nn.Module): ...@@ -56,6 +117,8 @@ class EnergyShifter(torch.nn.Module):
return species, coordinates, properties return species, coordinates, properties
def forward(self, species_energies): def forward(self, species_energies):
"""(species, molecular energies)->(species, molecular energies + sae)
"""
species, energies = species_energies species, energies = species_energies
sae = self.sae(species).to(energies.dtype).to(energies.device) sae = self.sae(species).to(energies.dtype).to(energies.device)
return species, energies + sae return species, energies + sae
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment