Unverified Commit 2ec2fb6d authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Modify dataset API to allow atomic properties (#231)

parent 4f63c32d
...@@ -34,7 +34,7 @@ Utilities ...@@ -34,7 +34,7 @@ Utilities
.. automodule:: torchani.utils .. automodule:: torchani.utils
.. autofunction:: torchani.utils.pad .. autofunction:: torchani.utils.pad
.. autofunction:: torchani.utils.pad_coordinates .. autofunction:: torchani.utils.pad_atomic_properties
.. autofunction:: torchani.utils.present_species .. autofunction:: torchani.utils.present_species
.. autofunction:: torchani.utils.strip_redundant_padding .. autofunction:: torchani.utils.strip_redundant_padding
.. autofunction:: torchani.utils.map2central .. autofunction:: torchani.utils.map2central
......
...@@ -113,11 +113,11 @@ class TestAEV(unittest.TestCase): ...@@ -113,11 +113,11 @@ class TestAEV(unittest.TestCase):
species = self.transform(species) species = self.transform(species)
radial = self.transform(radial) radial = self.transform(radial)
angular = self.transform(angular) angular = self.transform(angular)
species_coordinates.append((species, coordinates)) species_coordinates.append({'species': species, 'coordinates': coordinates})
radial_angular.append((radial, angular)) radial_angular.append((radial, angular))
species, coordinates = torchani.utils.pad_coordinates( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
_, aev = self.aev_computer((species, coordinates)) _, aev = self.aev_computer((species_coordinates['species'], species_coordinates['coordinates']))
start = 0 start = 0
for expected_radial, expected_angular in radial_angular: for expected_radial, expected_angular in radial_angular:
conformations = expected_radial.shape[0] conformations = expected_radial.shape[0]
......
...@@ -30,16 +30,20 @@ class TestData(unittest.TestCase): ...@@ -30,16 +30,20 @@ class TestData(unittest.TestCase):
coordinates2 = torch.randn(2, 8, 3) coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long) species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3) coordinates3 = torch.randn(10, 20, 3)
species, coordinates = torchani.utils.pad_coordinates([ species_coordinates = torchani.utils.pad_atomic_properties([
(species1, coordinates1), {'species': species1, 'coordinates': coordinates1},
(species2, coordinates2), {'species': species2, 'coordinates': coordinates2},
(species3, coordinates3), {'species': species3, 'coordinates': coordinates3},
]) ])
species = species_coordinates['species']
coordinates = species_coordinates['coordinates']
natoms = (species >= 0).to(torch.long).sum(1) natoms = (species >= 0).to(torch.long).sum(1)
chunks = torchani.data.split_batch(natoms, species, coordinates) chunks = torchani.data.split_batch(natoms, species_coordinates)
start = 0 start = 0
last = None last = None
for s, c in chunks: for chunk in chunks:
s = chunk['species']
c = chunk['coordinates']
n = (s >= 0).to(torch.long).sum(1) n = (s >= 0).to(torch.long).sum(1)
if last is not None: if last is not None:
self.assertNotEqual(last[-1], n[0]) self.assertNotEqual(last[-1], n[0])
...@@ -47,19 +51,26 @@ class TestData(unittest.TestCase): ...@@ -47,19 +51,26 @@ class TestData(unittest.TestCase):
self.assertGreater(conformations, 0) self.assertGreater(conformations, 0)
s_ = species[start:(start + conformations), ...] s_ = species[start:(start + conformations), ...]
c_ = coordinates[start:(start + conformations), ...] c_ = coordinates[start:(start + conformations), ...]
s_, c_ = torchani.utils.strip_redundant_padding(s_, c_) sc = torchani.utils.strip_redundant_padding({'species': s_, 'coordinates': c_})
s_ = sc['species']
c_ = sc['coordinates']
self._assertTensorEqual(s, s_) self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_) self._assertTensorEqual(c, c_)
start += conformations start += conformations
s, c = torchani.utils.pad_coordinates(chunks) sc = torchani.utils.pad_atomic_properties(chunks)
s = sc['species']
c = sc['coordinates']
self._assertTensorEqual(s, species) self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates) self._assertTensorEqual(c, coordinates)
def testTensorShape(self): def testTensorShape(self):
for i in self.ds: for i in self.ds:
input_, output = i input_, output = i
species, coordinates = torchani.utils.pad_coordinates(input_) input_ = [{'species': x[0], 'coordinates': x[1]} for x in input_]
species_coordinates = torchani.utils.pad_atomic_properties(input_)
species = species_coordinates['species']
coordinates = species_coordinates['coordinates']
energies = output['energies'] energies = output['energies']
self.assertEqual(len(species.shape), 2) self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size) self.assertLessEqual(species.shape[0], batch_size)
......
...@@ -89,12 +89,12 @@ class TestEnergies(unittest.TestCase): ...@@ -89,12 +89,12 @@ class TestEnergies(unittest.TestCase):
coordinates = self.transform(coordinates) coordinates = self.transform(coordinates)
species = self.transform(species) species = self.transform(species)
e = self.transform(e) e = self.transform(e)
species_coordinates.append((species, coordinates)) species_coordinates.append({'species': species, 'coordinates': coordinates})
energies.append(e) energies.append(e)
species, coordinates = torchani.utils.pad_coordinates( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
energies = torch.cat(energies) energies = torch.cat(energies)
_, energies_ = self.model((species, coordinates)) _, energies_ = self.model((species_coordinates['species'], species_coordinates['coordinates']))
max_diff = (energies - energies_).abs().max().item() max_diff = (energies - energies_).abs().max().item()
self.assertLess(max_diff, self.tolerance) self.assertLess(max_diff, self.tolerance)
......
...@@ -55,11 +55,10 @@ class TestForce(unittest.TestCase): ...@@ -55,11 +55,10 @@ class TestForce(unittest.TestCase):
species = self.transform(species) species = self.transform(species)
forces = self.transform(forces) forces = self.transform(forces)
coordinates.requires_grad_(True) coordinates.requires_grad_(True)
species_coordinates.append((species, coordinates)) species_coordinates.append({'species': species, 'coordinates': coordinates})
coordinates_forces.append((coordinates, forces)) species_coordinates = torchani.utils.pad_atomic_properties(
species, coordinates = torchani.utils.pad_coordinates(
species_coordinates) species_coordinates)
_, energies = self.model((species, coordinates)) _, energies = self.model((species_coordinates['species'], species_coordinates['coordinates']))
energies = energies.sum() energies = energies.sum()
for coordinates, forces in coordinates_forces: for coordinates, forces in coordinates_forces:
derivative = torch.autograd.grad(energies, coordinates, derivative = torch.autograd.grad(energies, coordinates,
......
...@@ -6,17 +6,17 @@ import torchani ...@@ -6,17 +6,17 @@ import torchani
class TestPaddings(unittest.TestCase): class TestPaddings(unittest.TestCase):
def testVectorSpecies(self): def testVectorSpecies(self):
species1 = torch.LongTensor([0, 2, 3, 1]) species1 = torch.tensor([[0, 2, 3, 1]])
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_coordinates([ atomic_properties = torchani.utils.pad_atomic_properties([
(species1, coordinates1), {'species': species1, 'coordinates': coordinates1},
(species2, coordinates2), {'species': species2, 'coordinates': coordinates2},
]) ])
self.assertEqual(species.shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(species.shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
expected_species = torch.LongTensor([ expected_species = torch.tensor([
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
...@@ -25,21 +25,21 @@ class TestPaddings(unittest.TestCase): ...@@ -25,21 +25,21 @@ class TestPaddings(unittest.TestCase):
[3, 2, 0, 1, 0], [3, 2, 0, 1, 0],
[3, 2, 0, 1, 0], [3, 2, 0, 1, 0],
]) ])
self.assertEqual((species - expected_species).abs().max().item(), 0) self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0) self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
def testTensorShape1NSpecies(self): def testTensorShape1NSpecies(self):
species1 = torch.LongTensor([[0, 2, 3, 1]]) species1 = torch.tensor([[0, 2, 3, 1]])
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_coordinates([ atomic_properties = torchani.utils.pad_atomic_properties([
(species1, coordinates1), {'species': species1, 'coordinates': coordinates1},
(species2, coordinates2), {'species': species2, 'coordinates': coordinates2},
]) ])
self.assertEqual(species.shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(species.shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
expected_species = torch.LongTensor([ expected_species = torch.tensor([
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
...@@ -48,11 +48,11 @@ class TestPaddings(unittest.TestCase): ...@@ -48,11 +48,11 @@ class TestPaddings(unittest.TestCase):
[3, 2, 0, 1, 0], [3, 2, 0, 1, 0],
[3, 2, 0, 1, 0], [3, 2, 0, 1, 0],
]) ])
self.assertEqual((species - expected_species).abs().max().item(), 0) self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0) self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
def testTensorSpecies(self): def testTensorSpecies(self):
species1 = torch.LongTensor([ species1 = torch.tensor([
[0, 2, 3, 1], [0, 2, 3, 1],
[0, 2, 3, 1], [0, 2, 3, 1],
[0, 2, 3, 1], [0, 2, 3, 1],
...@@ -60,15 +60,15 @@ class TestPaddings(unittest.TestCase): ...@@ -60,15 +60,15 @@ class TestPaddings(unittest.TestCase):
[0, 2, 3, 1], [0, 2, 3, 1],
]) ])
coordinates1 = torch.zeros(5, 4, 3) coordinates1 = torch.zeros(5, 4, 3)
species2 = torch.LongTensor([3, 2, 0, 1, 0]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
species, coordinates = torchani.utils.pad_coordinates([ atomic_properties = torchani.utils.pad_atomic_properties([
(species1, coordinates1), {'species': species1, 'coordinates': coordinates1},
(species2, coordinates2), {'species': species2, 'coordinates': coordinates2},
]) ])
self.assertEqual(species.shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(species.shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
expected_species = torch.LongTensor([ expected_species = torch.tensor([
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
...@@ -77,22 +77,22 @@ class TestPaddings(unittest.TestCase): ...@@ -77,22 +77,22 @@ class TestPaddings(unittest.TestCase):
[3, 2, 0, 1, 0], [3, 2, 0, 1, 0],
[3, 2, 0, 1, 0], [3, 2, 0, 1, 0],
]) ])
self.assertEqual((species - expected_species).abs().max().item(), 0) self.assertEqual((atomic_properties['species'] - expected_species).abs().max().item(), 0)
self.assertEqual(coordinates.abs().max().item(), 0) self.assertEqual(atomic_properties['coordinates'].abs().max().item(), 0)
def testPadSpecies(self): def testPadSpecies(self):
species1 = torch.LongTensor([ species1 = torch.tensor([
[0, 2, 3, 1], [0, 2, 3, 1],
[0, 2, 3, 1], [0, 2, 3, 1],
[0, 2, 3, 1], [0, 2, 3, 1],
[0, 2, 3, 1], [0, 2, 3, 1],
[0, 2, 3, 1], [0, 2, 3, 1],
]) ])
species2 = torch.LongTensor([3, 2, 0, 1, 0]).expand(2, 5) species2 = torch.tensor([[3, 2, 0, 1, 0]]).expand(2, 5)
species = torchani.utils.pad([species1, species2]) species = torchani.utils.pad([species1, species2])
self.assertEqual(species.shape[0], 7) self.assertEqual(species.shape[0], 7)
self.assertEqual(species.shape[1], 5) self.assertEqual(species.shape[1], 5)
expected_species = torch.LongTensor([ expected_species = torch.tensor([
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
[0, 2, 3, 1, -1], [0, 2, 3, 1, -1],
...@@ -104,9 +104,9 @@ class TestPaddings(unittest.TestCase): ...@@ -104,9 +104,9 @@ class TestPaddings(unittest.TestCase):
self.assertEqual((species - expected_species).abs().max().item(), 0) self.assertEqual((species - expected_species).abs().max().item(), 0)
def testPresentSpecies(self): def testPresentSpecies(self):
species = torch.LongTensor([0, 1, 1, 0, 3, 7, -1, -1]) species = torch.tensor([0, 1, 1, 0, 3, 7, -1, -1])
present_species = torchani.utils.present_species(species) present_species = torchani.utils.present_species(species)
expected = torch.LongTensor([0, 1, 3, 7]) expected = torch.tensor([0, 1, 3, 7])
self.assertEqual((expected - present_species).abs().max().item(), 0) self.assertEqual((expected - present_species).abs().max().item(), 0)
...@@ -120,23 +120,31 @@ class TestStripRedundantPadding(unittest.TestCase): ...@@ -120,23 +120,31 @@ class TestStripRedundantPadding(unittest.TestCase):
coordinates1 = torch.randn(5, 4, 3) coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 5), dtype=torch.long) species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3) coordinates2 = torch.randn(2, 5, 3)
species12, coordinates12 = torchani.utils.pad_coordinates([ atomic_properties12 = torchani.utils.pad_atomic_properties([
(species1, coordinates1), {'species': species1, 'coordinates': coordinates1},
(species2, coordinates2), {'species': species2, 'coordinates': coordinates2},
]) ])
species12 = atomic_properties12['species']
coordinates12 = atomic_properties12['coordinates']
species3 = torch.randint(4, (2, 10), dtype=torch.long) species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3) coordinates3 = torch.randn(2, 10, 3)
species123, coordinates123 = torchani.utils.pad_coordinates([ atomic_properties123 = torchani.utils.pad_atomic_properties([
(species1, coordinates1), {'species': species1, 'coordinates': coordinates1},
(species2, coordinates2), {'species': species2, 'coordinates': coordinates2},
(species3, coordinates3), {'species': species3, 'coordinates': coordinates3},
]) ])
species1_, coordinates1_ = torchani.utils.strip_redundant_padding( species123 = atomic_properties123['species']
species123[:5, ...], coordinates123[:5, ...]) coordinates123 = atomic_properties123['coordinates']
species_coordinates1_ = torchani.utils.strip_redundant_padding(
{'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]})
species1_ = species_coordinates1_['species']
coordinates1_ = species_coordinates1_['coordinates']
self._assertTensorEqual(species1_, species1) self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1) self._assertTensorEqual(coordinates1_, coordinates1)
species12_, coordinates12_ = torchani.utils.strip_redundant_padding( species_coordinates12_ = torchani.utils.strip_redundant_padding(
species123[:7, ...], coordinates123[:7, ...]) {'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]})
species12_ = species_coordinates12_['species']
coordinates12_ = species_coordinates12_['coordinates']
self._assertTensorEqual(species12_, species12) self._assertTensorEqual(species12_, species12)
self._assertTensorEqual(coordinates12_, coordinates12) self._assertTensorEqual(coordinates12_, coordinates12)
......
...@@ -21,21 +21,22 @@ def chunk_counts(counts, split): ...@@ -21,21 +21,22 @@ def chunk_counts(counts, split):
for i in split: for i in split:
count_chunks.append(counts[start:i]) count_chunks.append(counts[start:i])
start = i start = i
chunk_conformations = [sum([y[1] for y in x]) for x in count_chunks] chunk_molecules = [sum([y[1] for y in x]) for x in count_chunks]
chunk_maxatoms = [x[-1][0] for x in count_chunks] chunk_maxatoms = [x[-1][0] for x in count_chunks]
return chunk_conformations, chunk_maxatoms return chunk_molecules, chunk_maxatoms
def split_cost(counts, split): def split_cost(counts, split):
split_min_cost = 40000 split_min_cost = 40000
cost = 0 cost = 0
chunk_conformations, chunk_maxatoms = chunk_counts(counts, split) chunk_molecules, chunk_maxatoms = chunk_counts(counts, split)
for conformations, maxatoms in zip(chunk_conformations, chunk_maxatoms): for molecules, maxatoms in zip(chunk_molecules, chunk_maxatoms):
cost += max(conformations * maxatoms ** 2, split_min_cost) cost += max(molecules * maxatoms ** 2, split_min_cost)
return cost return cost
def split_batch(natoms, species, coordinates): def split_batch(natoms, atomic_properties):
# count number of conformation by natoms # count number of conformation by natoms
natoms = natoms.tolist() natoms = natoms.tolist()
counts = [] counts = []
...@@ -47,6 +48,7 @@ def split_batch(natoms, species, coordinates): ...@@ -47,6 +48,7 @@ def split_batch(natoms, species, coordinates):
counts[-1][1] += 1 counts[-1][1] += 1
else: else:
counts.append([i, 1]) counts.append([i, 1])
# find best split using greedy strategy # find best split using greedy strategy
split = [] split = []
cost = split_cost(counts, split) cost = split_cost(counts, split)
...@@ -66,19 +68,21 @@ def split_batch(natoms, species, coordinates): ...@@ -66,19 +68,21 @@ def split_batch(natoms, species, coordinates):
if improved: if improved:
split = cycle_split split = cycle_split
cost = cycle_cost cost = cycle_cost
# do split # do split
start = 0 chunk_molecules, _ = chunk_counts(counts, split)
species_coordinates = [] num_chunks = None
chunk_conformations, _ = chunk_counts(counts, split) for k in atomic_properties:
for i in chunk_conformations: atomic_properties[k] = atomic_properties[k].split(chunk_molecules)
s = species if num_chunks is None:
end = start + i num_chunks = len(atomic_properties[k])
s = species[start:end, ...] else:
c = coordinates[start:end, ...] assert num_chunks == len(atomic_properties[k])
s, c = utils.strip_redundant_padding(s, c) chunks = []
species_coordinates.append((s, c)) for i in range(num_chunks):
start = end chunk = {k: atomic_properties[k][i] for k in atomic_properties}
return species_coordinates chunks.append(utils.strip_redundant_padding(chunk))
return chunks
class BatchedANIDataset(Dataset): class BatchedANIDataset(Dataset):
...@@ -118,13 +122,24 @@ class BatchedANIDataset(Dataset): ...@@ -118,13 +122,24 @@ class BatchedANIDataset(Dataset):
batch_size (int): Number of different 3D structures in a single batch_size (int): Number of different 3D structures in a single
minibatch. minibatch.
shuffle (bool): Whether to shuffle the whole dataset. shuffle (bool): Whether to shuffle the whole dataset.
properties (list): List of keys in the dataset to be loaded. properties (list): List of keys of `molecular` properties in the
``'species'`` and ``'coordinates'`` are always loaded and need not dataset to be loaded. Here `molecular` means, no matter the number
to be specified here. of atoms that property always have fixed size, i.e. the tensor
shape of molecular properties should be (molecule, ...). An example
of molecular property is the molecular energies. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
atomic_properties (list): List of keys of `atomic` properties in the
dataset to be loaded. Here `atomic` means, the size of property
is proportional to the number of atoms in the molecule, i.e. the
tensor shape of atomic properties should be (molecule, atoms, ...).
An example of atomic property is the forces. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
transform (list): List of :class:`collections.abc.Callable` that transform (list): List of :class:`collections.abc.Callable` that
transform the data. Callables must take species, coordinates, transform the data. Callables must take atomic properties,
and properties of the whole dataset as arguments, and return properties as arguments, and return the transformed atomic
the transformed species, coordinates, and properties. properties and properties.
dtype (:class:`torch.dtype`): dtype of coordinates and properties to dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to. to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating. device (:class:`torch.dtype`): device to put tensors when iterating.
...@@ -134,7 +149,7 @@ class BatchedANIDataset(Dataset): ...@@ -134,7 +149,7 @@ class BatchedANIDataset(Dataset):
""" """
def __init__(self, path, species_tensor_converter, batch_size, def __init__(self, path, species_tensor_converter, batch_size,
shuffle=True, properties=['energies'], transform=(), shuffle=True, properties=('energies',), atomic_properties=(), transform=(),
dtype=torch.get_default_dtype(), device=default_device): dtype=torch.get_default_dtype(), device=default_device):
super(BatchedANIDataset, self).__init__() super(BatchedANIDataset, self).__init__()
self.properties = properties self.properties = properties
...@@ -153,68 +168,81 @@ class BatchedANIDataset(Dataset): ...@@ -153,68 +168,81 @@ class BatchedANIDataset(Dataset):
raise ValueError('Bad path') raise ValueError('Bad path')
# load full dataset # load full dataset
species_coordinates = [] atomic_properties_ = []
properties = {k: [] for k in self.properties} properties = {k: [] for k in self.properties}
for f in files: for f in files:
for m in anidataloader(f): for m in anidataloader(f):
s = species_tensor_converter(m['species']) atomic_properties_.append(dict(
c = torch.from_numpy(m['coordinates']).to(torch.double) species=species_tensor_converter(m['species']).unsqueeze(0),
species_coordinates.append((s, c)) **{
k: torch.from_numpy(m[k]).to(torch.double)
for k in ['coordinates'] + list(atomic_properties)
}
))
for i in properties: for i in properties:
p = torch.from_numpy(m[i]).to(torch.double) p = torch.from_numpy(m[i]).to(torch.double)
properties[i].append(p) properties[i].append(p)
species, coordinates = utils.pad_coordinates(species_coordinates) atomic_properties = utils.pad_atomic_properties(atomic_properties_)
for i in properties: for i in properties:
properties[i] = torch.cat(properties[i]) properties[i] = torch.cat(properties[i])
# shuffle if required # shuffle if required
conformations = coordinates.shape[0] molecules = atomic_properties['species'].shape[0]
if shuffle: if shuffle:
indices = torch.randperm(conformations) indices = torch.randperm(molecules)
species = species.index_select(0, indices)
coordinates = coordinates.index_select(0, indices)
for i in properties: for i in properties:
properties[i] = properties[i].index_select(0, indices) properties[i] = properties[i].index_select(0, indices)
for i in atomic_properties:
atomic_properties[i] = atomic_properties[i].index_select(0, indices)
# do transformations on data # do transformations on data
for t in transform: for t in transform:
species, coordinates, properties = t(species, coordinates, atomic_properties, properties = t(atomic_properties, properties)
properties)
# convert to desired dtype # convert to desired dtype
species = species
coordinates = coordinates.to(dtype)
for k in properties: for k in properties:
properties[k] = properties[k].to(dtype) properties[k] = properties[k].to(dtype)
for k in atomic_properties:
if k == 'species':
continue
atomic_properties[k] = atomic_properties[k].to(dtype)
# split into minibatches, and strip redundant padding # split into minibatches
natoms = (species >= 0).to(torch.long).sum(1) for k in properties:
batches = [] properties[k] = properties[k].split(batch_size)
num_batches = (conformations + batch_size - 1) // batch_size for k in atomic_properties:
atomic_properties[k] = atomic_properties[k].split(batch_size)
# further split batch into chunks and strip redundant padding
self.batches = []
num_batches = (molecules + batch_size - 1) // batch_size
for i in range(num_batches): for i in range(num_batches):
start = i * batch_size batch_properties = {k: v[i] for k, v in properties.items()}
end = min((i + 1) * batch_size, conformations) batch_atomic_properties = {k: v[i] for k, v in atomic_properties.items()}
natoms_batch = natoms[start:end] species = batch_atomic_properties['species']
natoms = (species >= 0).to(torch.long).sum(1)
# sort batch by number of atoms to prepare for splitting # sort batch by number of atoms to prepare for splitting
natoms_batch, indices = natoms_batch.sort() natoms, indices = natoms.sort()
species_batch = species[start:end, ...].index_select(0, indices) for k in batch_properties:
coordinates_batch = coordinates[start:end, ...] \ batch_properties[k] = batch_properties[k].index_select(0, indices)
.index_select(0, indices) for k in batch_atomic_properties:
properties_batch = { batch_atomic_properties[k] = batch_atomic_properties[k].index_select(0, indices)
k: properties[k][start:end, ...].index_select(0, indices)
.to(self.device) for k in properties batch_atomic_properties = split_batch(natoms, batch_atomic_properties)
} self.batches.append((batch_atomic_properties, batch_properties))
# further split batch into chunks
species_coordinates = split_batch(natoms_batch, species_batch,
coordinates_batch)
batch = species_coordinates, properties_batch
batches.append(batch)
self.batches = batches
def __getitem__(self, idx): def __getitem__(self, idx):
species_coordinates, properties = self.batches[idx] atomic_properties, properties = self.batches[idx]
species_coordinates = [(s.to(self.device), c.to(self.device)) atomic_properties, properties = atomic_properties.copy(), properties.copy()
for s, c in species_coordinates] species_coordinates = []
for chunk in atomic_properties:
for k in chunk:
chunk[k] = chunk[k].to(self.device)
species_coordinates.append((chunk['species'], chunk['coordinates']))
for k in properties:
properties[k] = properties[k].to(self.device)
properties['atomic'] = atomic_properties
return species_coordinates, properties return species_coordinates, properties
def __len__(self): def __len__(self):
......
import torch import torch
import torch.utils.data import torch.utils.data
import math import math
from collections import defaultdict
def pad(species): def pad(species):
...@@ -30,41 +31,35 @@ def pad(species): ...@@ -30,41 +31,35 @@ def pad(species):
return torch.cat(padded_species) return torch.cat(padded_species)
def pad_coordinates(species_coordinates): def pad_atomic_properties(atomic_properties, padding_values=defaultdict(lambda: 0.0, species=-1)):
"""Put different species and coordinates together into single tensor. """Put a sequence of atomic properties together into single tensor.
If the species and coordinates are from molecules of different number of Inputs are `[{'species': ..., ...}, {'species': ..., ...}, ...]` and the outputs
total atoms, then ghost atoms with atom type -1 and coordinate (0, 0, 0) are `{'species': padded_tensor, ...}`
will be added to make it fit into the same shape.
Arguments: Arguments:
species_coordinates (:class:`collections.abc.Sequence`): sequence of species_coordinates (:class:`collections.abc.Sequence`): sequence of
pairs of species and coordinates. Species must be of shape atomic properties.
``(N, A)`` and coordinates must be of shape ``(N, A, 3)``, where padding_values (dict): the value to fill to pad tensors to same size
``N`` is the number of 3D structures, ``A`` is the number of atoms.
Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): Species, and
coordinates batched together.
""" """
max_atoms = max([c.shape[1] for _, c in species_coordinates]) keys = list(atomic_properties[0])
species = [] anykey = keys[0]
coordinates = [] max_atoms = max(x[anykey].shape[1] for x in atomic_properties)
for s, c in species_coordinates: padded = {k: [] for k in keys}
natoms = c.shape[1] for p in atomic_properties:
if len(s.shape) == 1: num_molecules = max(v.shape[0] for v in p.values())
s = s.unsqueeze(0) for k, v in p.items():
if natoms < max_atoms: shape = list(v.shape)
padding = torch.full((s.shape[0], max_atoms - natoms), -1, padatoms = max_atoms - shape[1]
dtype=torch.long, device=s.device) shape[1] = padatoms
s = torch.cat([s, padding], dim=1) padding = v.new_full(shape, padding_values[k])
padding = torch.full((c.shape[0], max_atoms - natoms, 3), 0, v = torch.cat([v, padding], dim=1)
dtype=c.dtype, device=c.device) if v.shape[0] < num_molecules:
c = torch.cat([c, padding], dim=1) shape = list(v.shape)
s = s.expand(c.shape[0], max_atoms) shape[0] = num_molecules
species.append(s) v = v.expand(*shape)
coordinates.append(c) padded[k].append(v)
return torch.cat(species), torch.cat(coordinates) return {k: torch.cat(v) for k, v in padded.items()}
# @torch.jit.script # @torch.jit.script
...@@ -84,23 +79,20 @@ def present_species(species): ...@@ -84,23 +79,20 @@ def present_species(species):
return present_species return present_species
def strip_redundant_padding(species, coordinates): def strip_redundant_padding(atomic_properties):
"""Strip trailing padding atoms. """Strip trailing padding atoms.
Arguments: Arguments:
species (:class:`torch.Tensor`): Long tensor of shape atomic_properties (dict): properties to strip
``(molecules, atoms)``.
coordinates (:class:`torch.Tensor`): Tensor of shape
``(molecules, atoms, 3)``.
Returns: Returns:
(:class:`torch.Tensor`, :class:`torch.Tensor`): species and coordinates dict: same set of properties with redundant padding atoms stripped.
with redundant padding atoms stripped.
""" """
species = atomic_properties['species']
non_padding = (species >= 0).any(dim=0).nonzero().squeeze() non_padding = (species >= 0).any(dim=0).nonzero().squeeze()
species = species.index_select(1, non_padding) for k in atomic_properties:
coordinates = coordinates.index_select(1, non_padding) atomic_properties[k] = atomic_properties[k].index_select(1, non_padding)
return species, coordinates return atomic_properties
def map2central(cell, coordinates, pbc): def map2central(cell, coordinates, pbc):
...@@ -170,15 +162,16 @@ class EnergyShifter(torch.nn.Module): ...@@ -170,15 +162,16 @@ class EnergyShifter(torch.nn.Module):
self_energies[species == -1] = 0 self_energies[species == -1] = 0
return self_energies.sum(dim=1) return self_energies.sum(dim=1)
def subtract_from_dataset(self, species, coordinates, properties): def subtract_from_dataset(self, atomic_properties, properties):
"""Transformer for :class:`torchani.data.BatchedANIDataset` that """Transformer for :class:`torchani.data.BatchedANIDataset` that
subtract self energies. subtract self energies.
""" """
species = atomic_properties['species']
energies = properties['energies'] energies = properties['energies']
device = energies.device device = energies.device
energies = energies.to(torch.double) - self.sae(species).to(device) energies = energies.to(torch.double) - self.sae(species).to(device)
properties['energies'] = energies properties['energies'] = energies
return species, coordinates, properties return atomic_properties, properties
def forward(self, species_energies): def forward(self, species_energies):
"""(species, molecular energies)->(species, molecular energies + sae) """(species, molecular energies)->(species, molecular energies + sae)
...@@ -263,6 +256,6 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'): ...@@ -263,6 +256,6 @@ def vibrational_analysis(masses, hessian, unit='cm^-1'):
return wavenumbers, modes return wavenumbers, modes
__all__ = ['pad', 'pad_coordinates', 'present_species', 'hessian', __all__ = ['pad', 'pad_atomic_properties', 'present_species', 'hessian',
'vibrational_analysis', 'strip_redundant_padding', 'vibrational_analysis', 'strip_redundant_padding',
'ChemicalSymbolsToInts'] 'ChemicalSymbolsToInts']
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment