Unverified Commit 9cae6d3f authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Docs improvements (#77)

parent 3cced1e6
......@@ -18,6 +18,8 @@ After installing the correct PyTorch, all you need is clone the repository and d
pip install .
```
After TorchANI has been installed, you can build the documents by running `sphinx-build docs build`.
# Paper
The original ANI-1 paper is:
......
......@@ -31,3 +31,8 @@ steps:
- python examples/training-benchmark.py ./dataset/ani_gdb_s01.h5 # run twice to test if checkpoint is working
- python examples/energy_force.py
- python examples/neurochem-test.py ./dataset/ani_gdb_s01.h5
Docs:
image: '${{BuildTorchANI}}'
commands:
- sphinx-build docs build
# -*- coding: utf-8 -*-
#
# Configuration file for the Sphinx documentation builder.
#
# This file does only contain a selection of the most common options. For a
# full list see the documentation:
# http://www.sphinx-doc.org/en/master/config
# -- Path setup --------------------------------------------------------------
# If extensions (or modules to document with autodoc) are in another directory,
# add these directories to sys.path here. If the directory is relative to the
# documentation root, use os.path.abspath to make it absolute, like shown here.
#
import os
import sys
sys.path.insert(0, os.path.abspath('.'))
# -- Project information -----------------------------------------------------
project = 'torchani'
copyright = '2018, Xiang Gao'
author = 'Xiang Gao'
# The short X.Y version
version = ''
# The full version, including alpha/beta/rc tags
release = ''
# -- General configuration ---------------------------------------------------
# If your documentation needs a minimal Sphinx version, state it here.
#
# needs_sphinx = '1.0'
# Add any Sphinx extension module names here, as strings. They can be
# extensions coming with Sphinx (named 'sphinx.ext.*') or your custom
# ones.
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.intersphinx',
'sphinx.ext.todo',
'sphinx.ext.coverage',
'sphinx.ext.mathjax',
'sphinx.ext.ifconfig',
'sphinx.ext.viewcode',
'sphinx.ext.githubpages',
'sphinx.ext.napoleon',
]
# Add any paths that contain templates here, relative to this directory.
templates_path = ['sphinx_templates']
# The suffix(es) of source filenames.
# You can specify multiple suffix as a list of string:
#
# source_suffix = ['.rst', '.md']
source_suffix = '.md'
# The master toctree document.
master_doc = 'index'
# The language for content autogenerated by Sphinx. Refer to documentation
# for a list of supported languages.
#
# This is also used if you do content translation via gettext catalogs.
# Usually you set "language" from the command line for these cases.
language = None
# List of patterns, relative to source directory, that match files and
# directories to ignore when looking for source files.
# This pattern also affects html_static_path and html_extra_path .
exclude_patterns = []
# The name of the Pygments (syntax highlighting) style to use.
pygments_style = 'sphinx'
# -- Options for HTML output -------------------------------------------------
# The theme to use for HTML and HTML Help pages. See the documentation for
# a list of builtin themes.
#
html_theme = 'sphinx_rtd_theme'
# Theme options are theme-specific and customize the look and feel of a theme
# further. For a list of options available for each theme, see the
# documentation.
#
# html_theme_options = {}
# Add any paths that contain custom static files (such as style sheets) here,
# relative to this directory. They are copied after the builtin static files,
# so a file named "default.css" will overwrite the builtin "default.css".
html_static_path = ['sphinx_static']
# Custom sidebar templates, must be a dictionary that maps document names
# to template names.
#
# The default sidebars (for documents that don't match any pattern) are
# defined by theme itself. Builtin themes are using these templates by
# default: ``['localtoc.html', 'relations.html', 'sourcelink.html',
# 'searchbox.html']``.
#
# html_sidebars = {}
# -- Options for HTMLHelp output ---------------------------------------------
# Output file base name for HTML help builder.
htmlhelp_basename = 'torchanidoc'
# -- Options for LaTeX output ------------------------------------------------
latex_elements = {
# The paper size ('letterpaper' or 'a4paper').
#
# 'papersize': 'letterpaper',
# The font size ('10pt', '11pt' or '12pt').
#
# 'pointsize': '10pt',
# Additional stuff for the LaTeX preamble.
#
# 'preamble': '',
# Latex figure (float) alignment
#
# 'figure_align': 'htbp',
}
# Grouping the document tree into LaTeX files. List of tuples
# (source start file, target name, title,
# author, documentclass [howto, manual, or own class]).
latex_documents = [
(master_doc, 'torchani.tex', 'torchani Documentation',
'Xiang Gao', 'manual'),
]
# -- Options for manual page output ------------------------------------------
# One entry per manual page. List of tuples
# (source start file, name, description, authors, manual section).
man_pages = [
(master_doc, 'torchani', 'torchani Documentation',
[author], 1)
]
# -- Options for Texinfo output ----------------------------------------------
# Grouping the document tree into Texinfo files. List of tuples
# (source start file, target name, title, author,
# dir menu entry, description, category)
texinfo_documents = [
(master_doc, 'torchani', 'torchani Documentation',
author, 'torchani', 'One line description of project.',
'Miscellaneous'),
]
# -- Extension configuration -------------------------------------------------
# -- Options for intersphinx extension ---------------------------------------
# Example configuration for intersphinx: refer to the Python standard library.
intersphinx_mapping = {'https://docs.python.org/': None}
# -- Options for todo extension ----------------------------------------------
# If true, `todo` and `todoList` produce output, else they produce nothing.
todo_include_todos = True
autoclass_content = 'both'
import torchani # noqa: F401
import sphinx_rtd_theme
project = 'TorchANI'
copyright = '2018, Roitberg Group'
author = 'Xiang Gao'
version = '0.1'
release = '0.1alpha'
extensions = [
'sphinx.ext.autodoc',
'sphinx.ext.napoleon',
'sphinx.ext.intersphinx',
'sphinx.ext.mathjax',
'sphinx.ext.viewcode',
]
templates_path = ['_templates']
html_static_path = ['_static']
source_suffix = '.rst'
master_doc = 'index'
pygments_style = 'sphinx'
html_theme = 'sphinx_rtd_theme'
html_theme_path = [sphinx_rtd_theme.get_html_theme_path()]
htmlhelp_basename = 'TorchANIdoc'
intersphinx_mapping = {
'python': ('https://docs.python.org/', None),
'numpy': ('http://docs.scipy.org/doc/numpy/', None),
'torch': ('https://pytorch.org/docs/master/', None),
'ignite': ('https://pytorch.org/ignite/', None),
}
latex_documents = [
(master_doc, 'TorchANI.tex', 'TorchANI Documentation',
'Xiang Gao', 'manual'),
]
man_pages = [
(master_doc, 'torchani', 'TorchANI Documentation',
[author], 1)
]
texinfo_documents = [
(master_doc, 'TorchANI', 'TorchANI Documentation',
author, 'TorchANI', 'One line description of project.',
'Miscellaneous'),
]
TorchANI
========
.. automodule:: torchani
.. autoclass:: torchani.AEVComputer
:members:
.. autoclass:: torchani.ANIModel
.. autoclass:: torchani.Ensemble
.. autoclass:: torchani.EnergyShifter
:members:
Datasets
========
.. automodule:: torchani.data
.. autoclass:: torchani.data.BatchedANIDataset
Utilities
=========
.. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad_and_batch
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
NeuroChem Importers
===================
.. automodule:: torchani.neurochem
.. autoclass:: torchani.neurochem.Constants
:members:
.. autofunction:: torchani.neurochem.load_sae
.. autofunction:: torchani.neurochem.load_atomic_network
.. autofunction:: torchani.neurochem.load_model
.. autofunction:: torchani.neurochem.load_model_ensemble
.. autoclass:: torchani.neurochem.Buildins
Ignite Helpers
==============
.. automodule:: torchani.ignite
.. autoclass:: torchani.ignite.Container
:members:
.. autoclass:: torchani.ignite.DictLoss
.. autoclass:: torchani.ignite.PerAtomDictLoss
.. autoclass:: torchani.ignite.TransformedLoss
.. autofunction:: torchani.ignite.MSELoss
.. autoclass:: torchani.ignite.DictMetric
.. autofunction:: torchani.ignite.RMSEMetric
......@@ -3,8 +3,10 @@ import torchani
import os
consts = torchani.buildins.consts
aev_computer = torchani.buildins.aev_computer
buildins = torchani.neurochem.Buildins()
consts = buildins.consts
aev_computer = buildins.aev_computer
shift_energy = buildins.energy_shifter
def atomic():
......
......@@ -5,6 +5,9 @@ import ignite
import pickle
import argparse
buildins = torchani.neurochem.Buildins()
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
......@@ -19,13 +22,13 @@ parser.add_argument('--batch_size',
default=1024, type=int)
parser.add_argument('--const_file',
help='File storing constants',
default=torchani.buildins.const_file)
default=buildins.const_file)
parser.add_argument('--sae_file',
help='File storing self atomic energies',
default=torchani.buildins.sae_file)
default=buildins.sae_file)
parser.add_argument('--network_dir',
help='Directory or prefix of directories storing networks',
default=torchani.buildins.ensemble_prefix + '0/networks')
default=buildins.ensemble_prefix + '0/networks')
parser = parser.parse_args()
# load modules and datasets
......
......@@ -52,15 +52,14 @@ writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer()
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
shift_energy = torchani.buildins.energy_shifter
training = torchani.data.BatchedANIDataset(
parser.training_path, model.consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
transform=[model.shift_energy.subtract_from_dataset])
validation = torchani.data.BatchedANIDataset(
parser.validation_path, model.consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
transform=[model.shift_energy.subtract_from_dataset])
container = torchani.ignite.Container({'energies': nnp})
parser.optim_args = json.loads(parser.optim_args)
......
......@@ -22,11 +22,10 @@ parser = parser.parse_args()
# set up benchmark
device = torch.device(parser.device)
nnp = model.get_or_create_model('/tmp/model.pt', device=device)
shift_energy = torchani.buildins.energy_shifter
dataset = torchani.data.BatchedANIDataset(
parser.dataset_path, model.consts.species_to_tensor,
parser.batch_size, device=device,
transform=[shift_energy.subtract_from_dataset])
transform=[model.shift_energy.subtract_from_dataset])
container = torchani.ignite.Container({'energies': nnp})
optimizer = torch.optim.Adam(nnp.parameters())
......@@ -66,16 +65,16 @@ def time_func(key, func):
# enable timers
nnp[0].radial_subaev_terms = time_func('radial terms',
nnp[0].radial_subaev_terms)
nnp[0].angular_subaev_terms = time_func('angular terms',
nnp[0].angular_subaev_terms)
nnp[0].terms_and_indices = time_func('terms and indices',
nnp[0].terms_and_indices)
nnp[0].combinations = time_func('combinations', nnp[0].combinations)
nnp[0].compute_mask_r = time_func('mask_r', nnp[0].compute_mask_r)
nnp[0].compute_mask_a = time_func('mask_a', nnp[0].compute_mask_a)
nnp[0].assemble = time_func('assemble', nnp[0].assemble)
nnp[0]._radial_subaev_terms = time_func('radial terms',
nnp[0]._radial_subaev_terms)
nnp[0]._angular_subaev_terms = time_func('angular terms',
nnp[0]._angular_subaev_terms)
nnp[0]._terms_and_indices = time_func('terms and indices',
nnp[0]._terms_and_indices)
nnp[0]._combinations = time_func('combinations', nnp[0]._combinations)
nnp[0]._compute_mask_r = time_func('mask_r', nnp[0]._compute_mask_r)
nnp[0]._compute_mask_a = time_func('mask_a', nnp[0]._compute_mask_a)
nnp[0]._assemble = time_func('assemble', nnp[0]._assemble)
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)
......
.. torchani documentation master file, created by
sphinx-quickstart on Tue May 1 00:00:05 2018.
You can adapt this file completely to your liking, but it should at least
contain the root `toctree` directive.
Welcome to torchani's documentation!
====================================
.. toctree::
:maxdepth: 2
:caption: Contents:
.. automodule:: torchani
:members:
:special-members: __init__, __call__
Indices and tables
==================
* :ref:`genindex`
* :ref:`modindex`
* :ref:`search`
......@@ -20,10 +20,4 @@ setup_attrs = {
'tests_require': ['nose'],
}
try:
from sphinx.setup_command import BuildDoc
setup_attrs['cmdclass'] = {'build_sphinx': BuildDoc}
except ModuleNotFoundError:
pass
setup(**setup_attrs)
......@@ -11,7 +11,8 @@ N = 97
class TestAEV(unittest.TestCase):
def setUp(self):
self.aev_computer = torchani.buildins.aev_computer
buildins = torchani.neurochem.Buildins()
self.aev_computer = buildins.aev_computer
self.radial_length = self.aev_computer.radial_length()
self.tolerance = 1e-5
......
......@@ -6,7 +6,8 @@ import unittest
path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset')
batch_size = 256
consts = torchani.buildins.consts
buildins = torchani.neurochem.Buildins()
consts = buildins.consts
class TestData(unittest.TestCase):
......
......@@ -13,9 +13,10 @@ class TestEnergies(unittest.TestCase):
def setUp(self):
self.tolerance = 5e-5
aev_computer = torchani.buildins.aev_computer
nnp = torchani.buildins.models[0]
shift_energy = torchani.buildins.energy_shifter
buildins = torchani.neurochem.Buildins()
aev_computer = buildins.aev_computer
nnp = buildins.models[0]
shift_energy = buildins.energy_shifter
self.model = torch.nn.Sequential(aev_computer, nnp, shift_energy)
def testIsomers(self):
......
......@@ -15,9 +15,10 @@ class TestEnsemble(unittest.TestCase):
self.conformations = 20
def _test_molecule(self, coordinates, species):
buildins = torchani.neurochem.Buildins()
coordinates = torch.tensor(coordinates, requires_grad=True)
aev = torchani.buildins.aev_computer
ensemble = torchani.buildins.models
aev = buildins.aev_computer
ensemble = buildins.models
models = [torch.nn.Sequential(aev, m) for m in ensemble]
ensemble = torch.nn.Sequential(aev, ensemble)
......
......@@ -12,8 +12,9 @@ class TestForce(unittest.TestCase):
def setUp(self):
self.tolerance = 1e-5
aev_computer = torchani.buildins.aev_computer
nnp = torchani.buildins.models[0]
buildins = torchani.neurochem.Buildins()
aev_computer = buildins.aev_computer
nnp = buildins.models[0]
self.model = torch.nn.Sequential(aev_computer, nnp)
def testIsomers(self):
......
......@@ -16,11 +16,12 @@ threshold = 1e-5
class TestIgnite(unittest.TestCase):
def testIgnite(self):
aev_computer = torchani.buildins.aev_computer
nnp = copy.deepcopy(torchani.buildins.models[0])
shift_energy = torchani.buildins.energy_shifter
buildins = torchani.neurochem.Buildins()
aev_computer = buildins.aev_computer
nnp = copy.deepcopy(buildins.models[0])
shift_energy = buildins.energy_shifter
ds = torchani.data.BatchedANIDataset(
path, torchani.buildins.consts.species_to_tensor, batchsize,
path, buildins.consts.species_to_tensor, batchsize,
transform=[shift_energy.subtract_from_dataset])
ds = torch.utils.data.Subset(ds, [0])
......
# -*- coding: utf-8 -*-
"""TorchANI is a PyTorch implementation of `ANI`_, created and maintained by
the `Roitberg group`_. TorchANI contains classes like
:class:`AEVComputer`, :class:`ANIModel`, and :class:`EnergyShifter` that can
be pipelined to compute molecular energies from the 3D coordinates of
molecules. It also include tools to: deal with ANI datasets(e.g. `ANI-1`_,
`ANI-1x`_, `ANI-1ccx`_, etc.) at :attr:`torchani.data`, import various file
formats of NeuroChem at :attr:`torchani.neurochem`, help working with ignite
at :attr:`torchani.ignite`, and more at :attr:`torchani.utils`.
.. _ANI:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
.. _Roitberg group:
https://roitberg.chem.ufl.edu/
.. _ANI-1:
https://www.nature.com/articles/sdata2017193
.. _ANI-1x:
https://aip.scitation.org/doi/abs/10.1063/1.5023802
.. _ANI-1ccx:
https://doi.org/10.26434/chemrxiv.6744440.v1
"""
from .utils import EnergyShifter
from .models import ANIModel, Ensemble
from .nn import ANIModel, Ensemble
from .aev import AEVComputer
from . import ignite
from . import utils
from . import neurochem
from . import data
from .neurochem import buildins
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble', 'buildins',
__all__ = ['AEVComputer', 'EnergyShifter', 'ANIModel', 'Ensemble',
'ignite', 'utils', 'neurochem', 'data']
......@@ -5,27 +5,6 @@ from . import utils
def _cutoff_cosine(distances, cutoff):
"""Compute the elementwise cutoff cosine function
The cutoff cosine function is define in
https://arxiv.org/pdf/1610.08935.pdf equation 2
Parameters
----------
distances : torch.Tensor
The pytorch tensor that stores Rij values. This tensor can
have any shape since the cutoff cosine function is computed
elementwise.
cutoff : float
The cutoff radius, i.e. the Rc in the equation. For any Rij > Rc,
the function value is defined to be zero.
Returns
-------
torch.Tensor
The tensor of the same shape as `distances` that stores the
computed function values.
"""
return torch.where(
distances <= cutoff,
0.5 * torch.cos(math.pi * distances / cutoff) + 0.5,
......@@ -34,16 +13,29 @@ def _cutoff_cosine(distances, cutoff):
class AEVComputer(torch.nn.Module):
"""AEV computer
Attributes
----------
filename : str
The name of the file that stores constant.
Rcr, Rca, EtaR, ShfR, Zeta, ShfZ, EtaA, ShfA : torch.Tensor
Tensor storing constants.
num_species : int
Number of supported atom types
r"""The AEV computer that takes coordinates as input and outputs aevs.
Arguments:
Rcr (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
equation (2) when used at equation (3) in the `ANI paper`_.
Rca (:class:`torch.Tensor`): The scalar tensor of :math:`R_C` in
equation (2) when used at equation (4) in the `ANI paper`_.
EtaR (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
equation (3) in the `ANI paper`_.
ShfR (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
equation (3) in the `ANI paper`_.
EtaA (:class:`torch.Tensor`): The 1D tensor of :math:`\eta` in
equation (4) in the `ANI paper`_.
Zeta (:class:`torch.Tensor`): The 1D tensor of :math:`\zeta` in
equation (4) in the `ANI paper`_.
ShfA (:class:`torch.Tensor`): The 1D tensor of :math:`R_s` in
equation (4) in the `ANI paper`_.
ShfZ (:class:`torch.Tensor`): The 1D tensor of :math:`\theta_s` in
equation (4) in the `ANI paper`_.
num_species (int): Number of supported atom types.
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
def __init__(self, Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ,
......@@ -85,26 +77,18 @@ class AEVComputer(torch.nn.Module):
"""Returns the length of full aev"""
return self.radial_length() + self.angular_length()
def radial_subaev_terms(self, distances):
def _radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors
The radial AEV is define in
https://arxiv.org/pdf/1610.08935.pdf equation 3.
The sum computed by this method is over all given neighbors,
so the caller of this method need to select neighbors if the
caller want a per species subAEV.
Parameters
----------
distances : torch.Tensor
Pytorch tensor of shape (..., neighbors) storing the |Rij|
length where i are the center atoms, and j are their neighbors.
Returns
-------
torch.Tensor
A tensor of shape (..., neighbors, `radial_sublength`) storing
the subAEVs.
This correspond to equation (3) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where ``N``
is the number of neighbor atoms within the cutoff radius and output
tensor should have shape
(conformations, atoms, ``self.radial_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
distances = distances.unsqueeze(-1).unsqueeze(-1)
fc = _cutoff_cosine(distances, self.Rcr)
......@@ -112,29 +96,24 @@ class AEVComputer(torch.nn.Module):
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-self.EtaR * (distances - self.ShfR)**2) * fc
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-2)
def angular_subaev_terms(self, vectors1, vectors2):
def _angular_subaev_terms(self, vectors1, vectors2):
"""Compute the angular subAEV terms of the center atom given neighbor pairs.
The angular AEV is define in
https://arxiv.org/pdf/1610.08935.pdf equation 4.
The sum computed by this method is over all given neighbor pairs,
so the caller of this method need to select neighbors if the caller
want a per species subAEV.
Parameters
----------
vectors1, vectors2: torch.Tensor
Tensor of shape (..., pairs, 3) storing the Rij vectors of pairs
of neighbors. The vectors1(..., j, :) and vectors2(..., j, :) are
the Rij vectors of the two atoms of pair j.
Returns
-------
torch.Tensor
Tensor of shape (..., pairs, `angular_sublength`) storing the
subAEVs.
This correspond to equation (4) in the `ANI paper`_. This function just
compute the terms. The sum in the equation is not computed.
The input tensor have shape (conformations, atoms, N), where N
is the number of neighbor atom pairs within the cutoff radius and
output tensor should have shape
(conformations, atoms, ``self.angular_sublength()``)
.. _ANI paper:
http://pubs.rsc.org/en/Content/ArticleLanding/2017/SC/C6SC05720A#!divAbstract
"""
vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
......@@ -156,45 +135,17 @@ class AEVComputer(torch.nn.Module):
factor2 = torch.exp(-self.EtaA *
((distances1 + distances2) / 2 - self.ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2
# ret now have shape (..., pairs, ?, ?, ?, ?) where ? depend on
# constants
# flat the last 4 dimensions to view the subAEV as one dimension vector
# At this point, ret now have shape
# (conformations, atoms, N, ?, ?, ?, ?) where ? depend on constants.
# We then should flat the last 4 dimensions to view the subAEV as one
# dimension vector
return ret.flatten(start_dim=-4)
def terms_and_indices(self, species, coordinates):
"""Compute radial and angular subAEV terms, and original indices.
Terms will be sorted according to their distances to central atoms,
and only these within cutoff radius are valid. The returned indices
contains what would their original indices be if they were unsorted.
Parameters
----------
species : torch.Tensor
The tensor that specifies the species of atoms in the molecule.
The tensor must have shape (conformations, atoms)
coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the
molecule. The tensor must have shape (conformations, atoms, 3)
Returns
-------
(radial_terms, angular_terms, indices_r, indices_a)
radial_terms : torch.Tensor
Tensor shaped (conformations, atoms, neighbors, `radial_sublength`)
for the (unsummed) radial subAEV terms.
angular_terms : torch.Tensor
Tensor of shape (conformations, atoms, pairs, `angular_sublength`)
for the (unsummed) angular subAEV terms.
indices_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors).
Let l = indices_r(i,j,k), then this means that
radial_terms(i,j,k,:) is in the subAEV term of conformation i
between atom j and atom l.
indices_a : torch.Tensor
Same as indices_r, except that the cutoff radius is Rca instead of
Rcr.
def _terms_and_indices(self, species, coordinates):
"""Returns radial and angular subAEV terms, these terms will be sorted
according to their distances to central atoms, and only these within
cutoff radius are valid. The returned indices stores the source of data
before sorting.
"""
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
......@@ -213,27 +164,34 @@ class AEVComputer(torch.nn.Module):
distances, indices = distances.sort(-1)
min_distances, _ = distances.flatten(end_dim=1).min(0)
inRcr = (min_distances <= self.Rcr).nonzero().flatten()[
1:] # TODO: can we use something like find_first?
inRcr = (min_distances <= self.Rcr).nonzero().flatten()[1:]
inRca = (min_distances <= self.Rca).nonzero().flatten()[1:]
distances = distances.index_select(-1, inRcr)
indices_r = indices.index_select(-1, inRcr)
radial_terms = self.radial_subaev_terms(distances)
radial_terms = self._radial_subaev_terms(distances)
indices_a = indices.index_select(-1, inRca)
new_shape = list(indices_a.shape) + [3]
# TODO: can we add something like expand_dim(dim=0, repeat=3)
# TODO: remove this workaround when gather support broadcasting
# https://github.com/pytorch/pytorch/pull/9532
_indices_a = indices_a.unsqueeze(-1).expand(*new_shape)
# TODO: can we make gather broadcast??
vec = vec.gather(-2, _indices_a)
# TODO: can we move combinations to ATen?
vec = self.combinations(vec, -2)
angular_terms = self.angular_subaev_terms(*vec)
vec = self._combinations(vec, -2)
angular_terms = self._angular_subaev_terms(*vec)
# Returned tensors has shape:
# (conformations, atoms, neighbors, ``self.radial_sublength()``)
# (conformations, atoms, pairs, ``self.angular_sublength()``)
# (conformations, atoms, neighbors)
# (conformations, atoms, pairs)
return radial_terms, angular_terms, indices_r, indices_a
def combinations(self, tensor, dim=0):
def _combinations(self, tensor, dim=0):
# TODO: remove this when combinations is merged into PyTorch
# https://github.com/pytorch/pytorch/pull/9393
n = tensor.shape[dim]
r = torch.arange(n, dtype=torch.long, device=tensor.device)
grid_x, grid_y = torch.meshgrid([r, r])
......@@ -246,49 +204,17 @@ class AEVComputer(torch.nn.Module):
return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
def compute_mask_r(self, species, indices_r):
"""Partition indices according to their species, radial part
Parameters
----------
indices_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors).
Let l = indices_r(i,j,k), then this means that
radial_terms(i,j,k,:) is in the subAEV term of conformation i
between atom j and atom l.
Returns
-------
torch.Tensor
Tensor of shape (conformations, atoms, neighbors, all species)
storing the mask for each species.
"""
def _compute_mask_r(self, species, indices_r):
"""Get mask of radial terms for each supported species from indices"""
species_r = species.gather(-1, indices_r)
"""Tensor of shape (conformations, atoms, neighbors) storing species
of neighbors."""
mask_r = (species_r.unsqueeze(-1) ==
torch.arange(self.num_species, device=self.EtaR.device))
return mask_r
def compute_mask_a(self, species, indices_a, present_species):
"""Partition indices according to their species, angular part
Parameters
----------
species_a : torch.Tensor
Tensor of shape (conformations, atoms, neighbors) storing the
species of neighbors.
present_species : torch.Tensor
Long tensor for the species, already uniqued.
Returns
-------
torch.Tensor
Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair.
"""
def _compute_mask_a(self, species, indices_a, present_species):
"""Get mask of angular terms for each supported species from indices"""
species_a = species.gather(-1, indices_a)
species_a1, species_a2 = self.combinations(species_a, -1)
species_a1, species_a2 = self._combinations(species_a, -1)
mask_a1 = (species_a1.unsqueeze(-1) == present_species).unsqueeze(-1)
mask_a2 = (species_a2.unsqueeze(-1).unsqueeze(-1) == present_species)
mask = mask_a1 * mask_a2
......@@ -296,35 +222,22 @@ class AEVComputer(torch.nn.Module):
mask_a = (mask + mask_rev) > 0
return mask_a
def assemble(self, radial_terms, angular_terms, present_species,
mask_r, mask_a):
"""Assemble radial and angular AEV from computed terms according
def _assemble(self, radial_terms, angular_terms, present_species,
mask_r, mask_a):
"""Returns radial and angular AEV computed from terms according
to the given partition information.
Parameters
----------
radial_terms : torch.Tensor
Tensor shaped (conformations, atoms, neighbors, `radial_sublength`)
for the (unsummed) radial subAEV terms.
angular_terms : torch.Tensor
Tensor of shape (conformations, atoms, pairs, `angular_sublength`)
for the (unsummed) angular subAEV terms.
present_species : torch.Tensor
Long tensor for species of atoms present in the molecules.
mask_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors, present species)
storing the mask for each species.
mask_a : torch.Tensor
Tensor of shape (conformations, atoms, pairs, present species,
present species) storing the mask for each pair.
Returns
-------
(torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor of
`dtype`. The radial AEV must be of shape (conformations, atoms,
radial_length) The angular AEV must be of shape (conformations,
atoms, angular_length)
Arguments:
radial_terms (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, ``self.radial_sublength()``)
angular_terms (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, ``self.angular_sublength()``)
present_species (:class:`torch.Tensor`): Long tensor for species
of atoms present in the molecules.
mask_r (:class:`torch.Tensor`): shape (conformations, atoms,
neighbors, supported species)
mask_a (:class:`torch.Tensor`): shape (conformations, atoms,
pairs, present species, present species)
"""
conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1]
......@@ -338,14 +251,10 @@ class AEVComputer(torch.nn.Module):
radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev
# TODO: can we use find_first?
rev_indices = {present_species[i].item(): i
for i in range(len(present_species))}
"""shape (conformations, atoms, present species,
present species, angular_length)"""
angular_aevs = []
zero_angular_subaev = torch.zeros(
# TODO: can we make stack and cat broadcast?
conformations, atoms, self.angular_sublength(),
dtype=self.EtaR.dtype, device=self.EtaR.device)
for s1, s2 in itertools.combinations_with_replacement(
......@@ -362,6 +271,19 @@ class AEVComputer(torch.nn.Module):
return radial_aevs, torch.cat(angular_aevs, dim=2)
def forward(self, species_coordinates):
"""Compute AEVs
Arguments:
species_coordinates (tuple): Two tensors: species and coordinates.
species must have shape ``(C, A)`` and coordinates must have
shape ``(C, A, 3)``, where ``C`` is the number of conformations
in a chunk, and ``A`` is the number of atoms.
Returns:
tuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape
``(C, A, self.aev_length())``
"""
species, coordinates = species_coordinates
present_species = utils.present_species(species)
......@@ -371,11 +293,11 @@ class AEVComputer(torch.nn.Module):
species_ = species.unsqueeze(1).expand(-1, atoms, -1)
radial_terms, angular_terms, indices_r, indices_a = \
self.terms_and_indices(species, coordinates)
mask_r = self.compute_mask_r(species_, indices_r)
mask_a = self.compute_mask_a(species_, indices_a, present_species)
self._terms_and_indices(species, coordinates)
mask_r = self._compute_mask_r(species_, indices_r)
mask_a = self._compute_mask_a(species_, indices_a, present_species)
radial, angular = self.assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a)
radial, angular = self._assemble(radial_terms, angular_terms,
present_species, mask_r, mask_a)
fullaev = torch.cat([radial, angular], dim=2)
return species, fullaev
# -*- coding: utf-8 -*-
"""Tools for loading, shuffling, and batching ANI datasets"""
from torch.utils.data import Dataset
from os.path import join, isfile, isdir
import os
......@@ -74,6 +77,56 @@ def split_batch(natoms, species, coordinates):
class BatchedANIDataset(Dataset):
"""Load data from hdf5 files, create minibatches, and convert to tensors.
This is already a dataset of batches, so when iterated, a batch rather
than a single data point will be yielded.
Since each batch might contain molecules of very different sizes, putting
the whole batch into a single tensor would require adding ghost atoms to
pad everything to the size of the largest molecule. As a result, huge
amount of computation would be wasted on ghost atoms. To avoid this issue,
the input of each batch, i.e. species and coordinates, are further divided
into chunks according to some heuristics, so that each chunk would only
have molecules of similar size, to minimize the padding required.
So, when iterating on this dataset, a tuple will be yeilded. The first
element of this tuple is a list of (species, coordinates) pairs. Each pair
is a chunk of molecules of similar size. The second element of this tuple
would be a dictonary, where the keys are those specified in the argument
:attr:`properties`, and values are a single tensor of the whole batch
(properties are not splitted into chunks).
Splitting batch into chunks leads to some inconvenience on training,
especially when using high level libraries like ``ignite``. To overcome
this inconvenience, :class:`torchani.ignite.Container` is created for
working with ignite.
Arguments:
path (str): Path to hdf5 files. If :attr:`path` is a file, then that
file would be loaded using `pyanitools.py`_. If :attr:`path` is
a directory, then all files with suffix `.h5` or `.hdf5` will be
loaded.
species_tensor_converter (:class:`collections.abc.Callable`): A
callable that convert species in the format of list of strings
to 1D tensor.
batch_size (int): Number of different 3D structures in a single
minibatch.
shuffle (bool): Whether to shuffle the whole dataset.
properties (list): List of keys in the dataset to be loaded.
``'species'`` and ``'coordinates'`` are always loaded and need not
to be specified here.
transform (list): List of :class:`collections.abc.Callable` that
transform the data. Callables must take species, coordinates,
and properties of the whole dataset as arguments, and return
the transformed species, coordinates, and properties.
dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating.
.. _pyanitools.py:
https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py
"""
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=['energies'], transform=(),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment