Unverified Commit 87c4184a authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add style check to unit test (#12)

parent a61a1b3e
...@@ -6,7 +6,8 @@ import ase_interface ...@@ -6,7 +6,8 @@ import ase_interface
import numpy import numpy
import torchani import torchani
import pickle import pickle
from torchani import buildin_const_file, buildin_sae_file, buildin_network_dir, default_dtype, default_device from torchani import buildin_const_file, buildin_sae_file, \
buildin_network_dir, default_dtype, default_device
import torchani.pyanitools import torchani.pyanitools
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
...@@ -15,7 +16,9 @@ conv_au_ev = 27.21138505 ...@@ -15,7 +16,9 @@ conv_au_ev = 27.21138505
class NeuroChem (torchani.aev_base.AEVComputer): class NeuroChem (torchani.aev_base.AEVComputer):
def __init__(self, dtype=default_dtype, device=default_device, const_file=buildin_const_file, sae_file=buildin_sae_file, network_dir=buildin_network_dir): def __init__(self, dtype=default_dtype, device=default_device,
const_file=buildin_const_file, sae_file=buildin_sae_file,
network_dir=buildin_network_dir):
super(NeuroChem, self).__init__(False, dtype, device, const_file) super(NeuroChem, self).__init__(False, dtype, device, const_file)
self.sae_file = sae_file self.sae_file = sae_file
self.network_dir = network_dir self.network_dir = network_dir
...@@ -52,7 +55,9 @@ class NeuroChem (torchani.aev_base.AEVComputer): ...@@ -52,7 +55,9 @@ class NeuroChem (torchani.aev_base.AEVComputer):
self.dtype).to(self.device) self.dtype).to(self.device)
forces = torch.from_numpy(numpy.stack(forces)).type( forces = torch.from_numpy(numpy.stack(forces)).type(
self.dtype).to(self.device) self.dtype).to(self.device)
return self._get_radial_part(aevs), self._get_angular_part(aevs), energies, forces return self._get_radial_part(aevs), \
self._get_angular_part(aevs), \
energies, forces
aev = torchani.SortedAEV(device=torch.device('cpu')) aev = torchani.SortedAEV(device=torch.device('cpu'))
......
import pkg_resources
import torch
buildin_const_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
buildin_sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
buildin_network_dir = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train0/networks/')
buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
default_dtype = torch.float32
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
from .energyshifter import EnergyShifter from .energyshifter import EnergyShifter
from .nn import ModelOnAEV, PerSpeciesFromNeuroChem from .nn import ModelOnAEV, PerSpeciesFromNeuroChem
from .aev import SortedAEV from .aev import SortedAEV
import logging from .env import buildin_const_file, buildin_sae_file, buildin_network_dir, \
buildin_model_prefix, default_dtype, default_device
__all__ = ['SortedAEV', 'EnergyShifter', 'ModelOnAEV', 'PerSpeciesFromNeuroChem', 'data', __all__ = ['SortedAEV', 'EnergyShifter', 'ModelOnAEV',
'buildin_const_file', 'buildin_sae_file', 'buildin_network_dir', 'buildin_dataset_dir', 'PerSpeciesFromNeuroChem', 'data', 'buildin_const_file',
'default_dtype', 'default_device'] 'buildin_sae_file', 'buildin_network_dir', 'buildin_dataset_dir',
'buildin_model_prefix', 'default_dtype', 'default_device']
...@@ -2,28 +2,36 @@ import torch ...@@ -2,28 +2,36 @@ import torch
import itertools import itertools
import numpy import numpy
from .aev_base import AEVComputer from .aev_base import AEVComputer
from . import buildin_const_file, default_dtype, default_device from .env import buildin_const_file, default_dtype, default_device
def _cutoff_cosine(distances, cutoff): def _cutoff_cosine(distances, cutoff):
"""Compute the elementwise cutoff cosine function """Compute the elementwise cutoff cosine function
The cutoff cosine function is define in https://arxiv.org/pdf/1610.08935.pdf equation 2 The cutoff cosine function is define in
https://arxiv.org/pdf/1610.08935.pdf equation 2
Parameters Parameters
---------- ----------
distances : torch.Tensor distances : torch.Tensor
The pytorch tensor that stores Rij values. This tensor can have any shape since the cutoff The pytorch tensor that stores Rij values. This tensor can
cosine function is computed elementwise. have any shape since the cutoff cosine function is computed
elementwise.
cutoff : float cutoff : float
The cutoff radius, i.e. the Rc in the equation. For any Rij > Rc, the function value is defined to be zero. The cutoff radius, i.e. the Rc in the equation. For any Rij > Rc,
the function value is defined to be zero.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The tensor of the same shape as `distances` that stores the computed function values. The tensor of the same shape as `distances` that stores the
computed function values.
""" """
return torch.where(distances <= cutoff, 0.5 * torch.cos(numpy.pi * distances / cutoff) + 0.5, torch.zeros_like(distances)) return torch.where(
distances <= cutoff,
0.5 * torch.cos(numpy.pi * distances / cutoff) + 0.5,
torch.zeros_like(distances)
)
class SortedAEV(AEVComputer): class SortedAEV(AEVComputer):
...@@ -38,7 +46,8 @@ class SortedAEV(AEVComputer): ...@@ -38,7 +46,8 @@ class SortedAEV(AEVComputer):
total : total time for computing everything. total : total time for computing everything.
""" """
def __init__(self, benchmark=False, device=default_device, dtype=default_dtype, const_file=buildin_const_file): def __init__(self, benchmark=False, device=default_device,
dtype=default_dtype, const_file=buildin_const_file):
super(SortedAEV, self).__init__(benchmark, dtype, device, const_file) super(SortedAEV, self).__init__(benchmark, dtype, device, const_file)
if benchmark: if benchmark:
self.radial_subaev_terms = self._enable_benchmark( self.radial_subaev_terms = self._enable_benchmark(
...@@ -77,55 +86,63 @@ class SortedAEV(AEVComputer): ...@@ -77,55 +86,63 @@ class SortedAEV(AEVComputer):
def radial_subaev_terms(self, distances): def radial_subaev_terms(self, distances):
"""Compute the radial subAEV terms of the center atom given neighbors """Compute the radial subAEV terms of the center atom given neighbors
The radial AEV is define in https://arxiv.org/pdf/1610.08935.pdf equation 3. The radial AEV is define in
The sum computed by this method is over all given neighbors, so the caller https://arxiv.org/pdf/1610.08935.pdf equation 3.
of this method need to select neighbors if the caller want a per species subAEV. The sum computed by this method is over all given neighbors,
so the caller of this method need to select neighbors if the
caller want a per species subAEV.
Parameters Parameters
---------- ----------
distances : torch.Tensor distances : torch.Tensor
Pytorch tensor of shape (..., neighbors) storing the |Rij| length where i are the Pytorch tensor of shape (..., neighbors) storing the |Rij|
center atoms, and j are their neighbors. length where i are the center atoms, and j are their neighbors.
Returns Returns
------- -------
torch.Tensor torch.Tensor
A tensor of shape (..., neighbors, `radial_sublength`) storing the subAEVs. A tensor of shape (..., neighbors, `radial_sublength`) storing
the subAEVs.
""" """
distances = distances.unsqueeze( distances = distances.unsqueeze(-1).unsqueeze(-1)
-1).unsqueeze(-1) # TODO: allow unsqueeze to insert multiple dimensions
fc = _cutoff_cosine(distances, self.Rcr) fc = _cutoff_cosine(distances, self.Rcr)
# Note that in the equation in the paper there is no 0.25 coefficient, but in NeuroChem there is such a coefficient. We choose to be consistent with NeuroChem instead of the paper here. # Note that in the equation in the paper there is no 0.25
# coefficient, but in NeuroChem there is such a coefficient.
# We choose to be consistent with NeuroChem instead of the paper here.
ret = 0.25 * torch.exp(-self.EtaR * (distances - self.ShfR)**2) * fc ret = 0.25 * torch.exp(-self.EtaR * (distances - self.ShfR)**2) * fc
return ret.flatten(start_dim=-2) return ret.flatten(start_dim=-2)
def angular_subaev_terms(self, vectors1, vectors2): def angular_subaev_terms(self, vectors1, vectors2):
"""Compute the angular subAEV terms of the center atom given neighbor pairs. """Compute the angular subAEV terms of the center atom given neighbor pairs.
The angular AEV is define in https://arxiv.org/pdf/1610.08935.pdf equation 4. The angular AEV is define in
The sum computed by this method is over all given neighbor pairs, so the caller https://arxiv.org/pdf/1610.08935.pdf equation 4.
of this method need to select neighbors if the caller want a per species subAEV. The sum computed by this method is over all given neighbor pairs,
so the caller of this method need to select neighbors if the caller
want a per species subAEV.
Parameters Parameters
---------- ----------
vectors1, vectors2: torch.Tensor vectors1, vectors2: torch.Tensor
Tensor of shape (..., pairs, 3) storing the Rij vectors of pairs of neighbors. Tensor of shape (..., pairs, 3) storing the Rij vectors of pairs
The vectors1(..., j, :) and vectors2(..., j, :) are the Rij vectors of the of neighbors. The vectors1(..., j, :) and vectors2(..., j, :) are
two atoms of pair j. the Rij vectors of the two atoms of pair j.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Tensor of shape (..., pairs, `angular_sublength`) storing the subAEVs. Tensor of shape (..., pairs, `angular_sublength`) storing the
subAEVs.
""" """
vectors1 = vectors1.unsqueeze( vectors1 = vectors1.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # TODO: allow unsqueeze to plug in multiple dims -1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
vectors2 = vectors2.unsqueeze( vectors2 = vectors2.unsqueeze(
-1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1) # TODO: allow unsqueeze to plug in multiple dims -1).unsqueeze(-1).unsqueeze(-1).unsqueeze(-1)
distances1 = vectors1.norm(2, dim=-5) distances1 = vectors1.norm(2, dim=-5)
distances2 = vectors2.norm(2, dim=-5) distances2 = vectors2.norm(2, dim=-5)
# 0.95 is multiplied to the cos values to prevent acos from returning NaN. # 0.95 is multiplied to the cos values to prevent acos from
# returning NaN.
cos_angles = 0.95 * \ cos_angles = 0.95 * \
torch.nn.functional.cosine_similarity( torch.nn.functional.cosine_similarity(
vectors1, vectors2, dim=-5) vectors1, vectors2, dim=-5)
...@@ -137,38 +154,42 @@ class SortedAEV(AEVComputer): ...@@ -137,38 +154,42 @@ class SortedAEV(AEVComputer):
factor2 = torch.exp(-self.EtaA * factor2 = torch.exp(-self.EtaA *
((distances1 + distances2) / 2 - self.ShfA) ** 2) ((distances1 + distances2) / 2 - self.ShfA) ** 2)
ret = 2 * factor1 * factor2 * fcj1 * fcj2 ret = 2 * factor1 * factor2 * fcj1 * fcj2
# ret now have shape (..., pairs, ?, ?, ?, ?) where ? depend on constants # ret now have shape (..., pairs, ?, ?, ?, ?) where ? depend on
# constants
# flat the last 4 dimensions to view the subAEV as one dimension vector # flat the last 4 dimensions to view the subAEV as one dimension vector
return ret.flatten(start_dim=-4) return ret.flatten(start_dim=-4)
def terms_and_indices(self, coordinates): def terms_and_indices(self, coordinates):
"""Compute radial and angular subAEV terms, and original indices. """Compute radial and angular subAEV terms, and original indices.
Terms will be sorted according to their distances to central atoms, and only Terms will be sorted according to their distances to central atoms,
these within cutoff radius are valid. The returned indices contains what would and only these within cutoff radius are valid. The returned indices
their original indices be if they were unsorted. contains what would their original indices be if they were unsorted.
Parameters Parameters
---------- ----------
coordinates : torch.Tensor coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the molecule. The tensor that specifies the xyz coordinates of atoms in the
The tensor must have shape (conformations, atoms, 3) molecule. The tensor must have shape (conformations, atoms, 3)
Returns Returns
------- -------
(radial_terms, angular_terms, indices_r, indices_a) (radial_terms, angular_terms, indices_r, indices_a)
radial_terms : torch.Tensor radial_terms : torch.Tensor
Tensor of shape (conformations, atoms, neighbors, `radial_sublength`) for Tensor shaped (conformations, atoms, neighbors, `radial_sublength`)
the (unsummed) radial subAEV terms. for the (unsummed) radial subAEV terms.
angular_terms : torch.Tensor angular_terms : torch.Tensor
Tensor of shape (conformations, atoms, pairs, `angular_sublength`) for the Tensor of shape (conformations, atoms, pairs, `angular_sublength`)
(unsummed) angular subAEV terms. for the (unsummed) angular subAEV terms.
indices_r : torch.Tensor indices_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors). Let l = indices_r(i,j,k), Tensor of shape (conformations, atoms, neighbors).
then this means that radial_terms(i,j,k,:) is in the subAEV term of conformation Let l = indices_r(i,j,k), then this means that
i between atom j and atom l. radial_terms(i,j,k,:) is in the subAEV term of conformation i
between atom j and atom l.
indices_a : torch.Tensor indices_a : torch.Tensor
Same as indices_r, except that the cutoff radius is Rca instead of Rcr. Same as indices_r, except that the cutoff radius is Rca instead of
Rcr.
""" """
vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1) vec = coordinates.unsqueeze(2) - coordinates.unsqueeze(1)
...@@ -208,10 +229,12 @@ class SortedAEV(AEVComputer): ...@@ -208,10 +229,12 @@ class SortedAEV(AEVComputer):
index1 = grid_y[torch.triu(torch.ones(n, n), diagonal=1) == 1] index1 = grid_y[torch.triu(torch.ones(n, n), diagonal=1) == 1]
index2 = grid_x[torch.triu(torch.ones(n, n), diagonal=1) == 1] index2 = grid_x[torch.triu(torch.ones(n, n), diagonal=1) == 1]
if torch.numel(index1) == 0: if torch.numel(index1) == 0:
# TODO: pytorch are unable to handle size 0 tensor well. Is this an expected behavior? # TODO: pytorch are unable to handle size 0 tensor well.
# Is this an expected behavior?
# See: https://github.com/pytorch/pytorch/issues/5014 # See: https://github.com/pytorch/pytorch/issues/5014
return None return None
return tensor.index_select(dim, index1), tensor.index_select(dim, index2) return tensor.index_select(dim, index1), \
tensor.index_select(dim, index2)
def compute_mask_r(self, species_r): def compute_mask_r(self, species_r):
"""Partition indices according to their species, radial part """Partition indices according to their species, radial part
...@@ -219,14 +242,14 @@ class SortedAEV(AEVComputer): ...@@ -219,14 +242,14 @@ class SortedAEV(AEVComputer):
Parameters Parameters
---------- ----------
species_r : torch.Tensor species_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors) storing species of Tensor of shape (conformations, atoms, neighbors) storing
neighbors. species of neighbors.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Tensor of shape (conformations, atoms, neighbors, all species) storing Tensor of shape (conformations, atoms, neighbors, all species)
the mask for each species. storing the mask for each species.
""" """
mask_r = (species_r.unsqueeze(-1) == mask_r = (species_r.unsqueeze(-1) ==
torch.arange(len(self.species), device=self.device)) torch.arange(len(self.species), device=self.device))
...@@ -238,16 +261,16 @@ class SortedAEV(AEVComputer): ...@@ -238,16 +261,16 @@ class SortedAEV(AEVComputer):
Parameters Parameters
---------- ----------
species_a : torch.Tensor species_a : torch.Tensor
Tensor of shape (conformations, atoms, neighbors) storing the species of Tensor of shape (conformations, atoms, neighbors) storing the
neighbors species of neighbors.
present_species : torch.Tensor present_species : torch.Tensor
Long tensor for the species, already uniqued. Long tensor for the species, already uniqued.
Returns Returns
------- -------
torch.Tensor torch.Tensor
Tensor of shape (conformations, atoms, pairs, present species, present species) Tensor of shape (conformations, atoms, pairs, present species,
storing the mask for each pair. present species) storing the mask for each pair.
""" """
species_a = self.combinations(species_a, -1) species_a = self.combinations(species_a, -1)
if species_a is not None: if species_a is not None:
...@@ -266,32 +289,35 @@ class SortedAEV(AEVComputer): ...@@ -266,32 +289,35 @@ class SortedAEV(AEVComputer):
else: else:
return None return None
def assemble(self, radial_terms, angular_terms, present_species, mask_r, mask_a): def assemble(self, radial_terms, angular_terms, present_species,
"""Assemble radial and angular AEV from computed terms according to the given partition information. mask_r, mask_a):
"""Assemble radial and angular AEV from computed terms according
to the given partition information.
Parameters Parameters
---------- ----------
radial_terms : torch.Tensor radial_terms : torch.Tensor
Tensor of shape (conformations, atoms, neighbors, `radial_sublength`) for Tensor shaped (conformations, atoms, neighbors, `radial_sublength`)
the (unsummed) radial subAEV terms. for the (unsummed) radial subAEV terms.
angular_terms : torch.Tensor angular_terms : torch.Tensor
Tensor of shape (conformations, atoms, pairs, `angular_sublength`) for the Tensor of shape (conformations, atoms, pairs, `angular_sublength`)
(unsummed) angular subAEV terms. for the (unsummed) angular subAEV terms.
present_species : torch.Tensor present_species : torch.Tensor
Long tensor for species of atoms present in the molecules. Long tensor for species of atoms present in the molecules.
mask_r : torch.Tensor mask_r : torch.Tensor
Tensor of shape (conformations, atoms, neighbors, present species) storing Tensor of shape (conformations, atoms, neighbors, present species)
the mask for each species. storing the mask for each species.
mask_a : torch.Tensor mask_a : torch.Tensor
Tensor of shape (conformations, atoms, pairs, present species, present species) Tensor of shape (conformations, atoms, pairs, present species,
storing the mask for each pair. present species) storing the mask for each pair.
Returns Returns
------- -------
(torch.Tensor, torch.Tensor) (torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor of `dtype`. Returns (radial AEV, angular AEV), both are pytorch tensor of
The radial AEV must be of shape (conformations, atoms, radial_length) `dtype`. The radial AEV must be of shape (conformations, atoms,
The angular AEV must be of shape (conformations, atoms, angular_length) radial_length) The angular AEV must be of shape (conformations,
atoms, angular_length)
""" """
conformations = radial_terms.shape[0] conformations = radial_terms.shape[0]
atoms = radial_terms.shape[1] atoms = radial_terms.shape[1]
...@@ -299,17 +325,22 @@ class SortedAEV(AEVComputer): ...@@ -299,17 +325,22 @@ class SortedAEV(AEVComputer):
# assemble radial subaev # assemble radial subaev
present_radial_aevs = (radial_terms.unsqueeze(-2) present_radial_aevs = (radial_terms.unsqueeze(-2)
* mask_r.unsqueeze(-1).type(self.dtype)).sum(-3) * mask_r.unsqueeze(-1).type(self.dtype)).sum(-3)
"""Tensor of shape (conformations, atoms, present species, radial_length)""" """shape (conformations, atoms, present species, radial_length)"""
radial_aevs = present_radial_aevs.flatten(start_dim=2) radial_aevs = present_radial_aevs.flatten(start_dim=2)
# assemble angular subaev # assemble angular subaev
rev_indices = {present_species[i].item(): i # TODO: can we use find_first? # TODO: can we use find_first?
rev_indices = {present_species[i].item(): i
for i in range(len(present_species))} for i in range(len(present_species))}
"""Tensor of shape (conformations, atoms, present species, present species, angular_length)""" """shape (conformations, atoms, present species,
present species, angular_length)"""
angular_aevs = [] angular_aevs = []
zero_angular_subaev = torch.zeros( # TODO: can we make stack and cat broadcast? zero_angular_subaev = torch.zeros(
conformations, atoms, self.angular_sublength, dtype=self.dtype, device=self.device) # TODO: can we make torch.zeros, torch.ones typeless and deviceless? # TODO: can we make stack and cat broadcast?
for s1, s2 in itertools.combinations_with_replacement(range(len(self.species)), 2): conformations, atoms, self.angular_sublength,
dtype=self.dtype, device=self.device)
for s1, s2 in itertools.combinations_with_replacement(
range(len(self.species)), 2):
# TODO: can we remove this if pytorch support 0 size tensors? # TODO: can we remove this if pytorch support 0 size tensors?
if s1 in rev_indices and s2 in rev_indices and mask_a is not None: if s1 in rev_indices and s2 in rev_indices and mask_a is not None:
i1 = rev_indices[s1] i1 = rev_indices[s1]
...@@ -326,15 +357,16 @@ class SortedAEV(AEVComputer): ...@@ -326,15 +357,16 @@ class SortedAEV(AEVComputer):
species = self.species_to_tensor(species) species = self.species_to_tensor(species)
present_species = species.unique(sorted=True) present_species = species.unique(sorted=True)
radial_terms, angular_terms, indices_r, indices_a = self.terms_and_indices( radial_terms, angular_terms, indices_r, indices_a = \
coordinates) self.terms_and_indices(coordinates)
species_r = species[indices_r] species_r = species[indices_r]
mask_r = self.compute_mask_r(species_r) mask_r = self.compute_mask_r(species_r)
species_a = species[indices_a] species_a = species[indices_a]
mask_a = self.compute_mask_a(species_a, present_species) mask_a = self.compute_mask_a(species_a, present_species)
return self.assemble(radial_terms, angular_terms, present_species, mask_r, mask_a) return self.assemble(radial_terms, angular_terms, present_species,
mask_r, mask_a)
def export_radial_subaev_onnx(self, filename): def export_radial_subaev_onnx(self, filename):
"""Export the operation that compute radial subaev into onnx format """Export the operation that compute radial subaev into onnx format
......
import torch import torch
import torch.nn as nn from .env import buildin_const_file, default_dtype, default_device
from . import buildin_const_file, default_dtype, default_device
from .benchmarked import BenchmarkedModule from .benchmarked import BenchmarkedModule
class AEVComputer(BenchmarkedModule): class AEVComputer(BenchmarkedModule):
__constants__ = ['Rcr', 'Rca', 'dtype', 'device', 'radial_sublength', __constants__ = ['Rcr', 'Rca', 'dtype', 'device', 'radial_sublength',
'radial_length', 'angular_sublength', 'angular_length', 'aev_length'] 'radial_length', 'angular_sublength', 'angular_length',
'aev_length']
"""Base class of various implementations of AEV computer """Base class of various implementations of AEV computer
...@@ -15,8 +15,8 @@ class AEVComputer(BenchmarkedModule): ...@@ -15,8 +15,8 @@ class AEVComputer(BenchmarkedModule):
benchmark : boolean benchmark : boolean
Whether to enable benchmark Whether to enable benchmark
dtype : torch.dtype dtype : torch.dtype
Data type of pytorch tensors for all the computations. This is also used Data type of pytorch tensors for all the computations. This is
to specify whether to use CPU or GPU. also used to specify whether to use CPU or GPU.
device : torch.Device device : torch.Device
The device where tensors should be. The device where tensors should be.
const_file : str const_file : str
...@@ -37,7 +37,8 @@ class AEVComputer(BenchmarkedModule): ...@@ -37,7 +37,8 @@ class AEVComputer(BenchmarkedModule):
The length of full aev The length of full aev
""" """
def __init__(self, benchmark=False, dtype=default_dtype, device=default_device, const_file=buildin_const_file): def __init__(self, benchmark=False, dtype=default_dtype,
device=default_device, const_file=buildin_const_file):
super(AEVComputer, self).__init__(benchmark) super(AEVComputer, self).__init__(benchmark)
self.dtype = dtype self.dtype = dtype
...@@ -53,7 +54,8 @@ class AEVComputer(BenchmarkedModule): ...@@ -53,7 +54,8 @@ class AEVComputer(BenchmarkedModule):
value = line[1] value = line[1]
if name == 'Rcr' or name == 'Rca': if name == 'Rcr' or name == 'Rca':
setattr(self, name, float(value)) setattr(self, name, float(value))
elif name in ['EtaR', 'ShfR', 'Zeta', 'ShfZ', 'EtaA', 'ShfA']: elif name in ['EtaR', 'ShfR', 'Zeta',
'ShfZ', 'EtaA', 'ShfA']:
value = [float(x.strip()) for x in value.replace( value = [float(x.strip()) for x in value.replace(
'[', '').replace(']', '').split(',')] '[', '').replace(']', '').split(',')]
value = torch.tensor(value, dtype=dtype, device=device) value = torch.tensor(value, dtype=dtype, device=device)
...@@ -62,7 +64,7 @@ class AEVComputer(BenchmarkedModule): ...@@ -62,7 +64,7 @@ class AEVComputer(BenchmarkedModule):
value = [x.strip() for x in value.replace( value = [x.strip() for x in value.replace(
'[', '').replace(']', '').split(',')] '[', '').replace(']', '').split(',')]
self.species = value self.species = value
except: except Exception:
raise ValueError('unable to parse const file') raise ValueError('unable to parse const file')
# Compute lengths # Compute lengths
...@@ -112,8 +114,8 @@ class AEVComputer(BenchmarkedModule): ...@@ -112,8 +114,8 @@ class AEVComputer(BenchmarkedModule):
Parameters Parameters
---------- ----------
coordinates : torch.Tensor coordinates : torch.Tensor
The tensor that specifies the xyz coordinates of atoms in the molecule. The tensor that specifies the xyz coordinates of atoms in the
The tensor must have shape (conformations, atoms, 3) molecule. The tensor must have shape (conformations, atoms, 3)
species : torch.LongTensor species : torch.LongTensor
Long tensor for the species, where a value k means the species is Long tensor for the species, where a value k means the species is
the same as self.species[k] the same as self.species[k]
...@@ -121,8 +123,9 @@ class AEVComputer(BenchmarkedModule): ...@@ -121,8 +123,9 @@ class AEVComputer(BenchmarkedModule):
Returns Returns
------- -------
(torch.Tensor, torch.Tensor) (torch.Tensor, torch.Tensor)
Returns (radial AEV, angular AEV), both are pytorch tensor of `dtype`. Returns (radial AEV, angular AEV), both are pytorch tensor
The radial AEV must be of shape (conformations, atoms, radial_length) of `dtype`. The radial AEV must be of shape
The angular AEV must be of shape (conformations, atoms, angular_length) (conformations, atoms, radial_length). The angular AEV must
be of shape (conformations, atoms, angular_length)
""" """
raise NotImplementedError('subclass must override this method') raise NotImplementedError('subclass must override this method')
...@@ -8,12 +8,13 @@ class BenchmarkedModule(torch.jit.ScriptModule): ...@@ -8,12 +8,13 @@ class BenchmarkedModule(torch.jit.ScriptModule):
The benchmarking is done by wrapping the original member function with The benchmarking is done by wrapping the original member function with
a wrapped function. The wrapped function will call the original function, a wrapped function. The wrapped function will call the original function,
and accumulate its running time into `self.timers`. Different accumulators are and accumulate its running time into `self.timers`. Different accumulators
distinguished by different keys. All times should have unit seconds. are distinguished by different keys. All times should have unit seconds.
To enable benchmarking for member functions in a subclass, simply To enable benchmarking for member functions in a subclass, simply
call the `__init__` of this class with `benchmark=True`, and add the following call the `__init__` of this class with `benchmark=True`, and add the
code to your subclass's `__init__`: following code to your subclass's `__init__`:
``` ```
if self.benchmark: if self.benchmark:
self._enable_benchmark(self.function_to_be_benchmarked, 'key1', 'key2') self._enable_benchmark(self.function_to_be_benchmarked, 'key1', 'key2')
...@@ -21,8 +22,8 @@ class BenchmarkedModule(torch.jit.ScriptModule): ...@@ -21,8 +22,8 @@ class BenchmarkedModule(torch.jit.ScriptModule):
Example Example
------- -------
The following code implements a subclass for timing the running time of member function The following code implements a subclass for timing the running time of
`f` and `g` and the total of these two:: member function `f` and `g` and the total of these two::
``` ```
class BenchmarkFG(BenchmarkedModule): class BenchmarkFG(BenchmarkedModule):
def __init__(self, benchmark=False) def __init__(self, benchmark=False)
...@@ -47,21 +48,22 @@ class BenchmarkedModule(torch.jit.ScriptModule): ...@@ -47,21 +48,22 @@ class BenchmarkedModule(torch.jit.ScriptModule):
""" """
def _enable_benchmark(self, fun, *keys): def _enable_benchmark(self, fun, *keys):
"""Wrap a function to automatically benchmark it, and assign a key for it. """Wrap a function to automatically benchmark it, and assign a key
for it.
Parameters Parameters
---------- ----------
keys keys
The keys in `self.timers` assigned. If multiple keys are specified, then The keys in `self.timers` assigned. If multiple keys are specified,
the time will be accumulated to all the keys. then the time will be accumulated to all the keys.
func : function func : function
The function to be benchmarked. The function to be benchmarked.
Returns Returns
------- -------
function function
Wrapped function that time the original function and update the corresponding Wrapped function that time the original function and update the
value in `self.timers` automatically. corresponding value in `self.timers` automatically.
""" """
for key in keys: for key in keys:
self.timers[key] = 0 self.timers[key] = 0
...@@ -77,7 +79,8 @@ class BenchmarkedModule(torch.jit.ScriptModule): ...@@ -77,7 +79,8 @@ class BenchmarkedModule(torch.jit.ScriptModule):
return wrapped return wrapped
def reset_timers(self): def reset_timers(self):
"""Reset all timers. If benchmark is not enabled, a `ValueError` will be raised.""" """Reset all timers. If benchmark is not enabled, a `ValueError`
will be raised."""
if not self.benchmark: if not self.benchmark:
raise ValueError('Can not reset timers, benchmark not enabled') raise ValueError('Can not reset timers, benchmark not enabled')
for i in self.timers: for i in self.timers:
......
...@@ -142,10 +142,11 @@ def random_split(dataset, num_chunks, chunk_size): ...@@ -142,10 +142,11 @@ def random_split(dataset, num_chunks, chunk_size):
Randomly split a dataset into non-overlapping new datasets of given lengths Randomly split a dataset into non-overlapping new datasets of given lengths
The splitting is by chunk, which makes it possible for batching: The whole The splitting is by chunk, which makes it possible for batching: The whole
dataset is first splitted into chunks of specified size, each chunk are different dataset is first splitted into chunks of specified size, each chunk are
conformation of the same isomer/molecule, then these chunks are randomly shuffled different conformation of the same isomer/molecule, then these chunks are
and splitted accorting to the given `num_chunks`. After splitted, chunks belong to randomly shuffled and splitted accorting to the given `num_chunks`. After
the same molecule/isomer of the same subset will be merged to allow larger batch. splitted, chunks belong to the same molecule/isomer of the same subset will
be merged to allow larger batch.
Parameters Parameters
---------- ----------
...@@ -160,7 +161,8 @@ def random_split(dataset, num_chunks, chunk_size): ...@@ -160,7 +161,8 @@ def random_split(dataset, num_chunks, chunk_size):
shuffle(chunks) shuffle(chunks)
if sum(num_chunks) != len(chunks): if sum(num_chunks) != len(chunks):
raise ValueError( raise ValueError(
"Sum of input number of chunks does not equal the length of the total dataset!") """Sum of input number of chunks does not equal the length of the
total dataset!""")
offset = 0 offset = 0
subsets = [] subsets = []
for i in num_chunks: for i in num_chunks:
......
from . import buildin_sae_file from .env import buildin_sae_file
class EnergyShifter: class EnergyShifter:
...@@ -20,7 +20,7 @@ class EnergyShifter: ...@@ -20,7 +20,7 @@ class EnergyShifter:
name = line[0].split(',')[0].strip() name = line[0].split(',')[0].strip()
value = float(line[1]) value = float(line[1])
self.self_energies[name] = value self.self_energies[name] = value
except: except Exception:
pass # ignore unrecognizable line pass # ignore unrecognizable line
def subtract_sae(self, energies, species): def subtract_sae(self, energies, species):
......
import pkg_resources
import torch
buildin_const_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/rHCNO-5.2R_16-3.5A_a4-8.params')
buildin_sae_file = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/sae_linfit.dat')
buildin_network_dir = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train0/networks/')
buildin_model_prefix = pkg_resources.resource_filename(
__name__, 'resources/ani-1x_dft_x8ens/train')
default_dtype = torch.float32
default_device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
...@@ -4,9 +4,8 @@ import bz2 ...@@ -4,9 +4,8 @@ import bz2
import os import os
import lark import lark
import struct import struct
import copy
import math import math
from . import buildin_network_dir, buildin_model_prefix from .env import buildin_network_dir, buildin_model_prefix
from .benchmarked import BenchmarkedModule from .benchmarked import BenchmarkedModule
# For python 2 compatibility # For python 2 compatibility
...@@ -15,7 +14,8 @@ if not hasattr(math, 'inf'): ...@@ -15,7 +14,8 @@ if not hasattr(math, 'inf'):
class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
"""Subclass of `torch.nn.Module` for the per atom aev->y transformation, loaded from NeuroChem network dir. """Subclass of `torch.nn.Module` for the per atom aev->y
transformation, loaded from NeuroChem network dir.
Attributes Attributes
---------- ----------
...@@ -30,7 +30,8 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -30,7 +30,8 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
layerN : torch.nn.Linear layerN : torch.nn.Linear
Linear model for each layer. Linear model for each layer.
activation : function activation : function
Function for computing the activation for all layers but the last layer. Function for computing the activation for all layers but the
last layer.
activation_index : int activation_index : int
The NeuroChem index for activation. The NeuroChem index for activation.
""" """
...@@ -43,8 +44,9 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -43,8 +44,9 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
dtype : torch.dtype dtype : torch.dtype
Pytorch data type for tensors Pytorch data type for tensors
filename : string filename : string
The file name for the `.nnf` file that store network hyperparameters. The `.bparam` and `.wparam` The file name for the `.nnf` file that store network
must be in the same directory hyperparameters. The `.bparam` and `.wparam` must be
in the same directory
""" """
super(PerSpeciesFromNeuroChem, self).__init__() super(PerSpeciesFromNeuroChem, self).__init__()
...@@ -87,8 +89,9 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -87,8 +89,9 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
Returns Returns
------- -------
list of dict list of dict
Parsed setups as list of dictionary storing the parsed `.nnf` file content. Parsed setups as list of dictionary storing the parsed `.nnf`
Each dictionary in the list is the hyperparameters for a layer. file content. Each dictionary in the list is the hyperparameters
for a layer.
""" """
# parse input file # parse input file
parser = lark.Lark(r''' parser = lark.Lark(r'''
...@@ -170,8 +173,9 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -170,8 +173,9 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
Parameters Parameters
---------- ----------
setups : list of dict setups : list of dict
Parsed setups as list of dictionary storing the parsed `.nnf` file content. Parsed setups as list of dictionary storing the parsed `.nnf`
Each dictionary in the list is the hyperparameters for a layer. file content. Each dictionary in the list is the hyperparameters
for a layer.
dirname : string dirname : string
The directory where network files are stored. The directory where network files are stored.
""" """
...@@ -205,7 +209,8 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -205,7 +209,8 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
'Unexpected activation {}'.format(activation)) 'Unexpected activation {}'.format(activation))
elif self.activation_index != activation: elif self.activation_index != activation:
raise NotImplementedError( raise NotImplementedError(
'different activation on different layers are not supported') '''different activation on different
layers are not supported''')
linear = torch.nn.Linear(in_size, out_size).type(self.dtype) linear = torch.nn.Linear(in_size, out_size).type(self.dtype)
name = 'layer{}'.format(i) name = 'layer{}'.format(i)
setattr(self, name, linear) setattr(self, name, linear)
...@@ -238,12 +243,13 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -238,12 +243,13 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
Parameters Parameters
---------- ----------
aev : torch.Tensor aev : torch.Tensor
The pytorch tensor of shape (conformations, aev_length) storing AEV as input to this model. The pytorch tensor of shape (conformations, aev_length) storing AEV
as input to this model.
layer : int layer : int
The layer whose activation is desired. The index starts at zero, that is The layer whose activation is desired. The index starts at zero,
`layer=0` means the `activation(layer0(aev))` instead of `aev`. If the given that is `layer=0` means the `activation(layer0(aev))` instead of
layer is larger than the total number of layers, then the activation of the last `aev`. If the given layer is larger than the total number of
layer will be returned. layers, then the activation of the last layer will be returned.
Returns Returns
------- -------
...@@ -268,18 +274,21 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule): ...@@ -268,18 +274,21 @@ class PerSpeciesFromNeuroChem(torch.jit.ScriptModule):
Parameters Parameters
---------- ----------
aev : torch.Tensor aev : torch.Tensor
The pytorch tensor of shape (conformations, aev_length) storing AEV as input to this model. The pytorch tensor of shape (conformations, aev_length) storing
AEV as input to this model.
Returns Returns
------- -------
torch.Tensor torch.Tensor
The pytorch tensor of shape (conformations, output_length) for output. The pytorch tensor of shape (conformations, output_length) for
output.
""" """
return self.get_activations(aev, math.inf) return self.get_activations(aev, math.inf)
class ModelOnAEV(BenchmarkedModule): class ModelOnAEV(BenchmarkedModule):
"""Subclass of `torch.nn.Module` for the [xyz]->[aev]->[per_atom_y]->y pipeline. """Subclass of `torch.nn.Module` for the [xyz]->[aev]->[per_atom_y]->y
pipeline.
Attributes Attributes
---------- ----------
...@@ -288,58 +297,67 @@ class ModelOnAEV(BenchmarkedModule): ...@@ -288,58 +297,67 @@ class ModelOnAEV(BenchmarkedModule):
output_length : int output_length : int
The length of output vector The length of output vector
derivative : boolean derivative : boolean
Whether to support computing the derivative w.r.t coordinates, i.e. d(output)/dR Whether to support computing the derivative w.r.t coordinates,
i.e. d(output)/dR
derivative_graph : boolean derivative_graph : boolean
Whether to generate a graph for the derivative. This would be required only if the Whether to generate a graph for the derivative. This would be required
derivative is included as part of the loss function. only if the derivative is included as part of the loss function.
model_X : nn.Module model_X : nn.Module
Model for species X. There should be one such attribute for each supported species. Model for species X. There should be one such attribute for each
supported species.
reducer : function reducer : function
Function of (input, dim)->output that reduce the input tensor along the given dimension Function of (input, dim)->output that reduce the input tensor along the
to get an output tensor. This function will be called with the per atom output tensor given dimension to get an output tensor. This function will be called
with internal shape as input, and desired reduction dimension as dim, and should reduce with the per atom output tensor with internal shape as input, and
the input into the tensor containing desired output. desired reduction dimension as dim, and should reduce the input into
the tensor containing desired output.
timers : dict timers : dict
Dictionary storing the the benchmark result. It has the following keys: Dictionary storing the the benchmark result. It has the following keys:
aev : time spent on computing AEV. aev : time spent on computing AEV.
nn : time spent on computing output from AEV. nn : time spent on computing output from AEV.
derivative : time spend on computing derivative w.r.t. coordinates after the outputs derivative : time spend on computing derivative w.r.t. coordinates
is given. This key is only available if derivative computation is turned on. after the outputs is given. This key is only available if
derivative computation is turned on.
forward : total time for the forward pass forward : total time for the forward pass
""" """
def __init__(self, aev_computer, derivative=False, derivative_graph=False, benchmark=False, **kwargs): def __init__(self, aev_computer, derivative=False, derivative_graph=False,
"""Initialize object from manual setup or from NeuroChem network directory. benchmark=False, **kwargs):
"""Initialize object from manual setup or from NeuroChem network
directory.
The caller must set either `from_nc` in order to load from NeuroChem network directory, The caller must set either `from_nc` in order to load from NeuroChem
or set `per_species` and `reducer`. network directory, or set `per_species` and `reducer`.
Parameters Parameters
---------- ----------
aev_computer : AEVComputer aev_computer : AEVComputer
The AEV computer. The AEV computer.
derivative : boolean derivative : boolean
Whether to support computing the derivative w.r.t coordinates, i.e. d(output)/dR Whether to support computing the derivative w.r.t coordinates,
i.e. d(output)/dR
derivative_graph : boolean derivative_graph : boolean
Whether to generate a graph for the derivative. This would be required only if the Whether to generate a graph for the derivative. This would be
derivative is included as part of the loss function. This argument must be set to required only if the derivative is included as part of the loss
False if `derivative` is set to False. function. This argument must be set to False if `derivative` is
set to False.
benchmark : boolean benchmark : boolean
Whether to enable benchmarking Whether to enable benchmarking
Other Parameters Other Parameters
---------------- ----------------
from_nc : string from_nc : string
Path to the NeuroChem network directory. If this parameter is set, then `per_species` and Path to the NeuroChem network directory. If this parameter is set,
`reducer` should not be set. If set to `None`, then the network ship with torchani will be then `per_species` and `reducer` should not be set. If set to
used. `None`, then the network ship with torchani will be used.
ensemble : int ensemble : int
Number of models in the model ensemble. If this is not set, then `from_nc` would refer to Number of models in the model ensemble. If this is not set, then
the directory storing the model. If set to a number, then `from_nc` would refer to the prefix `from_nc` would refer to the directory storing the model. If set to
of directories. a number, then `from_nc` would refer to the prefix of directories.
per_species : dict per_species : dict
Dictionary with supported species as keys and objects of `torch.nn.Model` as values, storing Dictionary with supported species as keys and objects of
the model for each supported species. These models will finally become `model_X` attributes. `torch.nn.Model` as values, storing the model for each supported
species. These models will finally become `model_X` attributes.
reducer : function reducer : function
The desired `reducer` attribute. The desired `reducer` attribute.
...@@ -354,7 +372,8 @@ class ModelOnAEV(BenchmarkedModule): ...@@ -354,7 +372,8 @@ class ModelOnAEV(BenchmarkedModule):
self.output_length = None self.output_length = None
if not derivative and derivative_graph: if not derivative and derivative_graph:
raise ValueError( raise ValueError(
'ModelOnAEV: can not create graph for derivative if the computation of derivative is turned off') '''ModelOnAEV: can not create graph for derivative if the
computation of derivative is turned off''')
self.derivative_graph = derivative_graph self.derivative_graph = derivative_graph
if benchmark: if benchmark:
...@@ -371,7 +390,8 @@ class ModelOnAEV(BenchmarkedModule): ...@@ -371,7 +390,8 @@ class ModelOnAEV(BenchmarkedModule):
"ModelOnAEV: aev_computer must be a subclass of AEVComputer") "ModelOnAEV: aev_computer must be a subclass of AEVComputer")
self.aev_computer = aev_computer self.aev_computer = aev_computer
if 'from_nc' in kwargs and 'per_species' not in kwargs and 'reducer' not in kwargs: if 'from_nc' in kwargs and 'per_species' not in kwargs and \
'reducer' not in kwargs:
if 'ensemble' not in kwargs: if 'ensemble' not in kwargs:
if kwargs['from_nc'] is None: if kwargs['from_nc'] is None:
kwargs['from_nc'] = buildin_network_dir kwargs['from_nc'] = buildin_network_dir
...@@ -396,26 +416,31 @@ class ModelOnAEV(BenchmarkedModule): ...@@ -396,26 +416,31 @@ class ModelOnAEV(BenchmarkedModule):
filename = os.path.join( filename = os.path.join(
network_dir, 'ANN-{}.nnf'.format(i)) network_dir, 'ANN-{}.nnf'.format(i))
model_X = PerSpeciesFromNeuroChem( model_X = PerSpeciesFromNeuroChem(
self.aev_computer.dtype, self.aev_computer.device, filename) self.aev_computer.dtype, self.aev_computer.device,
filename)
if self.output_length is None: if self.output_length is None:
self.output_length = model_X.output_length self.output_length = model_X.output_length
elif self.output_length != model_X.output_length: elif self.output_length != model_X.output_length:
raise ValueError( raise ValueError(
'output length of each atomic neural network must match') '''output length of each atomic neural networt
must match''')
setattr(self, 'model_' + i + suffix, model_X) setattr(self, 'model_' + i + suffix, model_X)
elif 'from_nc' not in kwargs and 'per_species' in kwargs and 'reducer' in kwargs: elif 'from_nc' not in kwargs and 'per_species' in kwargs and \
'reducer' in kwargs:
self.suffixes = [''] self.suffixes = ['']
per_species = kwargs['per_species'] per_species = kwargs['per_species']
for i in per_species: for i in per_species:
model_X = per_species[i] model_X = per_species[i]
if not hasattr(model_X, 'output_length'): if not hasattr(model_X, 'output_length'):
raise ValueError( raise ValueError(
'atomic neural network must explicitly specify output length') '''atomic neural network must explicitly specify
output length''')
elif self.output_length is None: elif self.output_length is None:
self.output_length = model_X.output_length self.output_length = model_X.output_length
elif self.output_length != model_X.output_length: elif self.output_length != model_X.output_length:
raise ValueError( raise ValueError(
'output length of each atomic neural network must match') '''output length of each atomic neural network must
match''')
setattr(self, 'model_' + i, model_X) setattr(self, 'model_' + i, model_X)
self.reducer = kwargs['reducer'] self.reducer = kwargs['reducer']
else: else:
...@@ -491,9 +516,11 @@ class ModelOnAEV(BenchmarkedModule): ...@@ -491,9 +516,11 @@ class ModelOnAEV(BenchmarkedModule):
def compute_derivative(self, output, coordinates): def compute_derivative(self, output, coordinates):
"""Compute the gradient d(output)/d(coordinates)""" """Compute the gradient d(output)/d(coordinates)"""
# Since different conformations are independent, computing # Since different conformations are independent, computing
# the derivatives of all outputs w.r.t. its own coordinate is equivalent # the derivatives of all outputs w.r.t. its own coordinate is
# to compute the derivative of the sum of all outputs w.r.t. all coordinates. # equivalent to compute the derivative of the sum of all outputs
return torch.autograd.grad(output.sum(), coordinates, create_graph=self.derivative_graph)[0] # w.r.t. all coordinates.
return torch.autograd.grad(output.sum(), coordinates,
create_graph=self.derivative_graph)[0]
def forward(self, coordinates, species): def forward(self, coordinates, species):
"""Feed forward """Feed forward
...@@ -509,13 +536,13 @@ class ModelOnAEV(BenchmarkedModule): ...@@ -509,13 +536,13 @@ class ModelOnAEV(BenchmarkedModule):
Returns Returns
------- -------
torch.Tensor or (torch.Tensor, torch.Tensor) torch.Tensor or (torch.Tensor, torch.Tensor)
If derivative is turned off, then this function will return a pytorch If derivative is turned off, then this function will return a
tensor of shape (conformations, output_length) for the output of each pytorch tensor of shape (conformations, output_length) for the
conformation. output of each conformation.
If derivative is turned on, then this function will return a pair of If derivative is turned on, then this function will return a pair
pytorch tensors where the first tensor is the output tensor as when the of pytorch tensors where the first tensor is the output tensor as
derivative is off, and the second tensor is a tensor of shape when the derivative is off, and the second tensor is a tensor of
(conformation, atoms, 3) storing the d(output)/dR. shape (conformation, atoms, 3) storing the d(output)/dR.
""" """
if not self.derivative: if not self.derivative:
coordinates = coordinates.detach() coordinates = coordinates.detach()
......
...@@ -49,7 +49,8 @@ class anidataloader(object): ...@@ -49,7 +49,8 @@ class anidataloader(object):
exit('Error: file not found - '+store_file) exit('Error: file not found - '+store_file)
self.store = h5py.File(store_file, 'r') self.store = h5py.File(store_file, 'r')
''' Group recursive iterator (iterate through all groups in all branches and return datasets in dicts) ''' ''' Group recursive iterator (iterate through all groups
in all branches and return datasets in dicts) '''
def h5py_dataset_iterator(self, g, prefix=''): def h5py_dataset_iterator(self, g, prefix=''):
for key in g.keys(): for key in g.keys():
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment