Unverified Commit 2ec2fb6d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Modify dataset API to allow atomic properties (#231)

parent 4f63c32d
......@@ -34,7 +34,7 @@ Utilities
.. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad
.. autofunction:: torchani.utils.pad_coordinates
.. autofunction:: torchani.utils.pad_atomic_properties
.. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central
......
......@@ -113,11 +113,11 @@ class TestAEV(unittest.TestCase):
species = self.transform(species)
radial = self.transform(radial)
angular = self.transform(angular)
species_coordinates.append((species, coordinates))
species_coordinates.append({'species': species, 'coordinates': coordinates})
radial_angular.append((radial, angular))
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates)
_, aev = self.aev_computer((species, coordinates))
_, aev = self.aev_computer((species_coordinates['species'], species_coordinates['coordinates']))
start = 0
for expected_radial, expected_angular in radial_angular:
conformations = expected_radial.shape[0]
......
......@@ -30,16 +30,20 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
species_coordinates = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
{'species': species3, 'coordinates': coordinates3},
])
species = species_coordinates['species']
coordinates = species_coordinates['coordinates']
natoms = (species >= 0).to(torch.long).sum(1)
chunks = torchani.data.split_batch(natoms, species, coordinates)
chunks = torchani.data.split_batch(natoms, species_coordinates)
start = 0
last = None
for s, c in chunks:
for chunk in chunks:
s = chunk['species']
c = chunk['coordinates']
n = (s >= 0).to(torch.long).sum(1)
if last is not None:
self.assertNotEqual(last[-1], n[0])
......@@ -47,19 +51,26 @@ class TestData(unittest.TestCase):
self.assertGreater(conformations, 0)
s_ = species[start:(start + conformations), ...]
c_ = coordinates[start:(start + conformations), ...]
s_, c_ = torchani.utils.strip_redundant_padding(s_, c_)
sc = torchani.utils.strip_redundant_padding({'species': s_, 'coordinates': c_})
s_ = sc['species']
c_ = sc['coordinates']
self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_)
start += conformations
s, c = torchani.utils.pad_coordinates(chunks)
sc = torchani.utils.pad_atomic_properties(chunks)
s = sc['species']
c = sc['coordinates']
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self):
for i in self.ds:
input_, output = i
species, coordinates = torchani.utils.pad_coordinates(input_)
input_ = [{'species': x[0], 'coordinates': x[1]} for x in input_]
species_coordinates = torchani.utils.pad_atomic_properties(input_)
species = species_coordinates['species']
coordinates = species_coordinates['coordinates']
energies = output['energies']
self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size)
......
......@@ -89,12 +89,12 @@ class TestEnergies(unittest.TestCase):
coordinates = self.transform(coordinates)
species = self.transform(species)
e = self.transform(e)
species_coordinates.append((species, coordinates))
species_coordinates.append({'species': species, 'coordinates': coordinates})
energies.append(e)
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates)
energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates))
_, energies_ = self.model((species_coordinates['species'], species_coordinates['coordinates']))
max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance)
......
......@@ -55,11 +55,10 @@ class TestForce(unittest.TestCase):
species = self.transform(species)
forces = self.transform(forces)
coordinates.requires_grad_(True)
species_coordinates.append((species, coordinates))
coordinates_forces.append((coordinates, forces))
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates.append({'species': species, 'coordinates': coordinates})
species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates)
_, energies = self.model((species, coordinates))
_, energies = self.model((species_coordinates['species'], species_coordinates['coordinates']))
energies = energies.sum()
for coordinates, forces in coordinates_forces:
derivative = torch.autograd.grad(energies, coordinates,
......
......@@ -6,17 +6,17 @@ import torchani
class TestPaddings(unittest.TestCase):
def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1])
species1 = torch.tensor([[0, 2, 3, 1]])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5)
expected_species = torch.tensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
......@@ -25,21 +25,21 @@ class TestPaddings(unittest.TestCase):
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
def testTensorShape1NSpecies(self):
species1 = torch.LongTensor([[0, 2, 3, 1]])
species1 = torch.tensor([[0, 2, 3, 1]])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5)
expected_species = torch.tensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
......@@ -48,11 +48,11 @@ class TestPaddings(unittest.TestCase):
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
def testTensorSpecies(self):
species1 = torch.LongTensor([
species1 = torch.tensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
......@@ -60,15 +60,15 @@ class TestPaddings(unittest.TestCase):
[0, 2, 3, 1],
])
coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0])
species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5)
expected_species = torch.tensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
......@@ -77,22 +77,22 @@ class TestPaddings(unittest.TestCase):
[3, 2, 0, 1, 0],
[3, 2, 0, 1, 0],
])
self.assertEqual((species - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0)
self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
def testPadSpecies(self):
species1 = torch.LongTensor([
species1 = torch.tensor([
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
[0, 2, 3, 1],
])
species2 = torch.LongTensor([3, 2, 0, 1, 0]).expand(2, 5)
species2 = torch.tensor([[3, 2, 0, 1, 0]]).expand(2, 5)
species = torchani.utils.pad([species1, species2])
self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([
expected_species = torch.tensor([
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
[0, 2, 3, 1, -1],
......@@ -104,9 +104,9 @@ class TestPaddings(unittest.TestCase):
self.assertEqual((species - expected_species).abs().max().item(), 0)
def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1])
species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.utils.present_species(species)
expected = torch.LongTensor([0, 1, 3, 7])
expected = torch.tensor([0, 1, 3, 7])
self.assertEqual((expected - present_species).abs().max().item(), 0)
......@@ -120,23 +120,31 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
atomic_properties12 = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
])
species12 = atomic_properties12['species']
coordinates12 = atomic_properties12['coordinates']
species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.utils.pad_coordinates([
(species1, coordinates1),
(species2, coordinates2),
(species3, coordinates3),
atomic_properties123 = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
{'species': species3, 'coordinates': coordinates3},
])
species1_, coordinates1_ = torchani.utils.strip_redundant_padding(
species123[:5, ...], coordinates123[:5, ...])
species123 = atomic_properties123['species']
coordinates123 = atomic_properties123['coordinates']
species_coordinates1_ = torchani.utils.strip_redundant_padding(
{'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]})
species1_ = species_coordinates1_['species']
coordinates1_ = species_coordinates1_['coordinates']
self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1)
species12_, coordinates12_ = torchani.utils.strip_redundant_padding(
species123[:7, ...], coordinates123[:7, ...])
species_coordinates12_ = torchani.utils.strip_redundant_padding(
{'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]})
species12_ = species_coordinates12_['species']
coordinates12_ = species_coordinates12_['coordinates']
self._assertTensorEqual(species12_, species12)
self._assertTensorEqual(coordinates12_, coordinates12)
......
......@@ -21,21 +21,22 @@ def chunk_counts(counts, split):
for i in split:
count_chunks.append(counts[start:i])
start = i
chunk_conformations = [sum([y[1] for y in x]) for x in count_chunks]
chunk_molecules = [sum([y[1] for y in x]) for x in count_chunks]
chunk_maxatoms = [x[-1][0] for x in count_chunks]
return chunk_conformations, chunk_maxatoms
return chunk_molecules, chunk_maxatoms
def split_cost(counts, split):
split_min_cost = 40000
cost = 0
chunk_conformations, chunk_maxatoms = chunk_counts(counts, split)
for conformations, maxatoms in zip(chunk_conformations, chunk_maxatoms):
cost += max(conformations * maxatoms ** 2, split_min_cost)
chunk_molecules, chunk_maxatoms = chunk_counts(counts, split)
for molecules, maxatoms in zip(chunk_molecules, chunk_maxatoms):
cost += max(molecules * maxatoms ** 2, split_min_cost)
return cost
def split_batch(natoms, species, coordinates):
def split_batch(natoms, atomic_properties):
# count number of conformation by natoms
natoms = natoms.tolist()
counts = []
......@@ -47,6 +48,7 @@ def split_batch(natoms, species, coordinates):
counts[-1][1] += 1
else:
counts.append([i, 1])
# find best split using greedy strategy
split = []
cost = split_cost(counts, split)
......@@ -66,19 +68,21 @@ def split_batch(natoms, species, coordinates):
if improved:
split = cycle_split
cost = cycle_cost
# do split
start = 0
species_coordinates = []
chunk_conformations, _ = chunk_counts(counts, split)
for i in chunk_conformations:
s = species
end = start + i
s = species[start:end, ...]
c = coordinates[start:end, ...]
s, c = utils.strip_redundant_padding(s, c)
species_coordinates.append((s, c))
start = end
return species_coordinates
chunk_molecules, _ = chunk_counts(counts, split)
num_chunks = None
for k in atomic_properties:
atomic_properties[k] = atomic_properties[k].split(chunk_molecules)
if num_chunks is None:
num_chunks = len(atomic_properties[k])
else:
assert num_chunks == len(atomic_properties[k])
chunks = []
for i in range(num_chunks):
chunk = {k: atomic_properties[k][i] for k in atomic_properties}
chunks.append(utils.strip_redundant_padding(chunk))
return chunks
class BatchedANIDataset(Dataset):
......@@ -118,13 +122,24 @@ class BatchedANIDataset(Dataset):
batch_size (int): Number of different 3D structures in a single
minibatch.
shuffle (bool): Whether to shuffle the whole dataset.
properties (list): List of keys in the dataset to be loaded.
``'species'`` and ``'coordinates'`` are always loaded and need not
to be specified here.
properties (list): List of keys of `molecular` properties in the
dataset to be loaded. Here `molecular` means, no matter the number
of atoms that property always have fixed size, i.e. the tensor
shape of molecular properties should be (molecule, ...). An example
of molecular property is the molecular energies. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
atomic_properties (list): List of keys of `atomic` properties in the
dataset to be loaded. Here `atomic` means, the size of property
is proportional to the number of atoms in the molecule, i.e. the
tensor shape of atomic properties should be (molecule, atoms, ...).
An example of atomic property is the forces. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
transform (list): List of :class:`collections.abc.Callable` that
transform the data. Callables must take species, coordinates,
and properties of the whole dataset as arguments, and return
the transformed species, coordinates, and properties.
transform the data. Callables must take atomic properties,
properties as arguments, and return the transformed atomic
properties and properties.
dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating.
......@@ -134,7 +149,7 @@ class BatchedANIDataset(Dataset):
"""
def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=['energies'], transform=(),
shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device):
super(BatchedANIDataset, self).__init__()
self.properties = properties
......@@ -153,68 +168,81 @@ class BatchedANIDataset(Dataset):
raise ValueError('Bad path')
# load full dataset
species_coordinates = []
atomic_properties_ = []
properties = {k: [] for k in self.properties}
for f in files:
for m in anidataloader(f):
s = species_tensor_converter(m['species'])
c = torch.from_numpy(m['coordinates']).to(torch.double)
species_coordinates.append((s, c))
atomic_properties_.append(dict(
species=species_tensor_converter(m['species']).unsqueeze(0),
**{
k: torch.from_numpy(m[k]).to(torch.double)
for k in ['coordinates'] + list(atomic_properties)
}
))
for i in properties:
p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p)
species, coordinates = utils.pad_coordinates(species_coordinates)
atomic_properties = utils.pad_atomic_properties(atomic_properties_)
for i in properties:
properties[i] = torch.cat(properties[i])
# shuffle if required
conformations = coordinates.shape[0]
molecules = atomic_properties['species'].shape[0]
if shuffle:
indices = torch.randperm(conformations)
species = species.index_select(0, indices)
coordinates = coordinates.index_select(0, indices)
indices = torch.randperm(molecules)
for i in properties:
properties[i] = properties[i].index_select(0, indices)
for i in atomic_properties:
atomic_properties[i] = atomic_properties[i].index_select(0, indices)
# do transformations on data
for t in transform:
species, coordinates, properties = t(species, coordinates,
properties)
atomic_properties, properties = t(atomic_properties, properties)
# convert to desired dtype
species = species
coordinates = coordinates.to(dtype)
for k in properties:
properties[k] = properties[k].to(dtype)
for k in atomic_properties:
if k == 'species':
continue
atomic_properties[k] = atomic_properties[k].to(dtype)
# split into minibatches, and strip redundant padding
natoms = (species >= 0).to(torch.long).sum(1)
batches = []
num_batches = (conformations + batch_size - 1) // batch_size
# split into minibatches
for k in properties:
properties[k] = properties[k].split(batch_size)
for k in atomic_properties:
atomic_properties[k] = atomic_properties[k].split(batch_size)
# further split batch into chunks and strip redundant padding
self.batches = []
num_batches = (molecules + batch_size - 1) // batch_size
for i in range(num_batches):
start = i * batch_size
end = min((i + 1) * batch_size, conformations)
natoms_batch = natoms[start:end]
batch_properties = {k: v[i] for k, v in properties.items()}
batch_atomic_properties = {k: v[i] for k, v in atomic_properties.items()}
species = batch_atomic_properties['species']
natoms = (species >= 0).to(torch.long).sum(1)
# sort batch by number of atoms to prepare for splitting
natoms_batch, indices = natoms_batch.sort()
species_batch = species[start:end, ...].index_select(0, indices)
coordinates_batch = coordinates[start:end, ...] \
.index_select(0, indices)
properties_batch = {
k: properties[k][start:end, ...].index_select(0, indices)
.to(self.device) for k in properties
}
# further split batch into chunks
species_coordinates = split_batch(natoms_batch, species_batch,
coordinates_batch)
batch = species_coordinates, properties_batch
batches.append(batch)
self.batches = batches
natoms, indices = natoms.sort()
for k in batch_properties:
batch_properties[k] = batch_properties[k].index_select(0, indices)
for k in batch_atomic_properties:
batch_atomic_properties[k] = batch_atomic_properties[k].index_select(0, indices)
batch_atomic_properties = split_batch(natoms, batch_atomic_properties)
self.batches.append((batch_atomic_properties, batch_properties))
def __getitem__(self, idx):
species_coordinates, properties = self.batches[idx]
species_coordinates = [(s.to(self.device), c.to(self.device))
for s, c in species_coordinates]
atomic_properties, properties = self.batches[idx]
atomic_properties, properties = atomic_properties.copy(), properties.copy()
species_coordinates = []
for chunk in atomic_properties:
for k in chunk:
chunk[k] = chunk[k].to(self.device)
species_coordinates.append((chunk['species'], chunk['coordinates']))
for k in properties:
properties[k] = properties[k].to(self.device)
properties['atomic'] = atomic_properties
return species_coordinates, properties
def __len__(self):
......
import torch
import torch.utils.data
import math
from collections import defaultdict
def pad(species):
......@@ -30,41 +31,35 @@ def pad(species):
return torch.cat(padded_species)
def pad_coordinates(species_coordinates):
"""Put different species and coordinates together into single tensor.
def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda: 0.0, species=-1)):
"""Put a sequence of atomic properties together into single tensor.
If the species and coordinates are from molecules of different number of
total atoms, then ghost atoms with atom type -1 and coordinate (0, 0, 0)
will be added to make it fit into the same shape.
Inputs are `[{'species': ..., ...}, {'species': ..., ...}, ...]` and the outputs
are `{'species': padded_tensor, ...}`
Arguments:
species_coordinates (:class:`collections.abc.Sequence`): sequence of
pairs of species and coordinates. Species must be of shape
``(N, A)`` and coordinates must be of shape ``(N, A, 3)``, where
``N`` is the number of 3D structures, ``A`` is the number of atoms.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): Species, and
coordinates batched together.
atomic properties.
padding_values (dict): the value to fill to pad tensors to same size
"""
max_atoms = max([c.shape[1] for _, c in species_coordinates])
species = []
coordinates = []
for s, c in species_coordinates:
natoms = c.shape[1]
if len(s.shape) == 1:
s = s.unsqueeze(0)
if natoms < max_atoms:
padding = torch.full((s.shape[0], max_atoms - natoms), -1,
dtype=torch.long, device=s.device)
s = torch.cat([s, padding], dim=1)
padding = torch.full((c.shape[0], max_atoms - natoms, 3), 0,
dtype=c.dtype, device=c.device)
c = torch.cat([c, padding], dim=1)
s = s.expand(c.shape[0], max_atoms)
species.append(s)
coordinates.append(c)
return torch.cat(species), torch.cat(coordinates)
keys = list(atomic_properties[0])
anykey = keys[0]
max_atoms = max(x[anykey].shape[1] for x in atomic_properties)
padded = {k: [] for k in keys}
for p in atomic_properties:
num_molecules = max(v.shape[0] for v in p.values())
for k, v in p.items():
shape = list(v.shape)
padatoms = max_atoms - shape[1]
shape[1] = padatoms
padding = v.new_full(shape, padding_values[k])
v = torch.cat([v, padding], dim=1)
if v.shape[0] < num_molecules:
shape = list(v.shape)
shape[0] = num_molecules
v = v.expand(*shape)
padded[k].append(v)
return {k: torch.cat(v) for k, v in padded.items()}
# @torch.jit.script
......@@ -84,23 +79,20 @@ def present_species(species):
return present_species
def strip_redundant_padding(species, coordinates):
def strip_redundant_padding(atomic_properties):
"""Strip trailing padding atoms.
Arguments:
species (:class:`torch.Tensor`): Long tensor of shape
``(molecules, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
atomic_properties (dict): properties to strip
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates
with redundant padding atoms stripped.
dict: same set of properties with redundant padding atoms stripped.
"""
species = atomic_properties['species']
non_padding = (species >= 0).any(dim=0).nonzero().squeeze()
species = species.index_select(1, non_padding)
coordinates = coordinates.index_select(1, non_padding)
return species, coordinates
for k in atomic_properties:
atomic_properties[k] = atomic_properties[k].index_select(1, non_padding)
return atomic_properties
def map2central(cell, coordinates, pbc):
......@@ -170,15 +162,16 @@ class EnergyShifter(torch.nn.Module):
self_energies[species == -1] = 0
return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties):
def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that
subtract self energies.
"""
species = atomic_properties['species']
energies = properties['energies']
device = energies.device
energies = energies.to(torch.double) - self.sae(species).to(device)
properties['energies'] = energies
return species, coordinates, properties
return atomic_properties, properties
def forward(self, species_energies):
"""(species, molecular energies)->(species, molecular energies + sae)
......@@ -263,6 +256,6 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
return wavenumbers, modes
__all__ = ['pad', 'pad_coordinates', 'present_species', 'hessian',
__all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment