Unverified Commit 6b058c6e authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove everything about chunking (#432)

* Remove everything about chunking

* aev.py

* neurochem trainer

* training-benchmark-nsys-profile.py

* fix eval

* training-benchmark.py

* nnp_training.py

* flake8

* nnp_training_force.py

* fix dtype of species

* fix

* flake8

* requires_grad_

* git ignore

* fix

* original

* fix

* fix

* fix

* fix

* save

* save

* save

* save

* save

* save

* save

* save

* save

* collate

* fix

* save

* fix

* save

* save

* fix

* save

* fix

* fix

* no len

* float

* save

* save

* save

* save

* save

* save

* save

* save

* save

* fix

* save

* save

* save

* save

* fix

* fix

* fix

* fix mypy

* don't remove outliers

* save

* save

* save

* fix

* flake8

* save

* fix

* flake8

* docs

* more docs

* fix test_data

* remove test_data_new

* fix
parent 338f896a
...@@ -39,3 +39,5 @@ jobs: ...@@ -39,3 +39,5 @@ jobs:
run: python tools/comp6.py dataset/COMP6/COMP6v1/s66x8 run: python tools/comp6.py dataset/COMP6/COMP6v1/s66x8
- name: Training Benchmark - name: Training Benchmark
run: python tools/training-benchmark.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 run: python tools/training-benchmark.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
- name: Training Benchmark Nsight System
run: python tools/training-benchmark-nsys-profile.py --dry-run dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.8] python-version: [3.6, 3.8]
test-filenames: [ test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py, test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py, test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py] test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
...@@ -32,3 +32,5 @@ dist ...@@ -32,3 +32,5 @@ dist
/download.tar.xz /download.tar.xz
*.qdrep *.qdrep
*.qdstrm *.qdstrm
*.zip
Untitled.ipynb
...@@ -26,12 +26,6 @@ Datasets ...@@ -26,12 +26,6 @@ Datasets
======== ========
.. automodule:: torchani.data .. automodule:: torchani.data
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autoclass:: torchani.data.PaddedBatchChunkDataset
......
...@@ -80,37 +80,17 @@ try: ...@@ -80,37 +80,17 @@ try:
except NameError: except NameError:
path = os.getcwd() path = os.getcwd()
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560 batch_size = 2560
training, validation = torchani.data.load_ani_dataset( dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
dspath, species_to_tensor, batch_size, rm_outlier=True, device=device, size = len(dataset)
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None]) training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
############################################################################### ###############################################################################
# When iterating the dataset, we will get pairs of input and output # When iterating the dataset, we will get a dict of name->property mapping
# ``(species_coordinates, properties)``, where ``species_coordinates`` is the
# input and ``properties`` is the output.
#
# ``species_coordinates`` is a list of species-coordinate pairs, with shape
# ``(N, Na)`` and ``(N, Na, 3)``. The reason for getting this type is, when
# loading the dataset and generating minibatches, the whole dataset are
# shuffled and each minibatch contains structures of molecules with a wide
# range of number of atoms. Molecules of different number of atoms are batched
# into single by padding. The way padding works is: adding ghost atoms, with
# species 'X', and do computations as if they were normal atoms. But when
# computing AEVs, atoms with species `X` would be ignored. To avoid computation
# wasting on padding atoms, minibatches are further splitted into chunks. Each
# chunk contains structures of molecules of similar size, which minimize the
# total number of padding atoms required to add. The input list
# ``species_coordinates`` contains chunks of that minibatch we are getting. The
# batching and chunking happens automatically, so the user does not need to
# worry how to construct chunks, but the user need to compute the energies for
# each chunk and concat them into single tensor.
#
# The output, i.e. ``properties`` is a dictionary holding each property. This
# allows us to extend TorchANI in the future to training forces and properties.
# #
############################################################################### ###############################################################################
# Now let's define atomic neural networks. # Now let's define atomic neural networks.
...@@ -279,16 +259,11 @@ def validate(): ...@@ -279,16 +259,11 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for batch_x, batch_y in validation: for properties in validation:
true_energies = batch_y['energies'] species = properties['species'].to(device)
predicted_energies = [] coordinates = properties['coordinates'].to(device).float()
atomic_properties = [] true_energies = properties['energies'].to(device).float()
for chunk_species, chunk_coordinates in batch_x: _, predicted_energies = model((species, coordinates))
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -331,26 +306,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -331,26 +306,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch) tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in tqdm.tqdm( for i, properties in tqdm.tqdm(
enumerate(training), enumerate(training),
total=len(training), total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch) desc="epoch {}".format(AdamW_scheduler.last_epoch)
): ):
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
_, predicted_energies = model((species, coordinates))
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x:
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
num_atoms = torch.cat(num_atoms)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW.zero_grad() AdamW.zero_grad()
......
...@@ -49,34 +49,14 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5') ...@@ -49,34 +49,14 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 2560 batch_size = 2560
############################################################################### dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
# The code to create the dataset is a bit different: we need to manually size = len(dataset)
# specify that ``atomic_properties=['forces']`` so that forces will be read training, validation = dataset.split(int(0.8 * size), None)
# from hdf5 files. training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, rm_outlier=True,
device=device, atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
###############################################################################
# When iterating the dataset, we will get pairs of input and output
# ``(species_coordinates, properties)``, in this case, ``properties`` would
# contain a key ``'atomic'`` where ``properties['atomic']`` is a list of dict
# containing forces:
data = training[0]
properties = data[1]
atomic_properties = properties['atomic']
print(type(atomic_properties))
print(list(atomic_properties[0].keys()))
###############################################################################
# Due to padding, part of the forces might be 0
print(atomic_properties[0]['forces'][0])
############################################################################### ###############################################################################
# The code to define networks, optimizers, are mostly the same # The code to define networks, optimizers, are mostly the same
...@@ -225,13 +205,11 @@ def validate(): ...@@ -225,13 +205,11 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for batch_x, batch_y in validation: for properties in validation:
true_energies = batch_y['energies'] species = properties['species'].to(device)
predicted_energies = [] coordinates = properties['coordinates'].to(device).float()
for chunk_species, chunk_coordinates in batch_x: true_energies = properties['energies'].to(device).float()
chunk_energies = model((chunk_species, chunk_coordinates)).energies _, predicted_energies = model((species, coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -275,49 +253,28 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -275,49 +253,28 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
# Besides being stored in x, species and coordinates are also stored in y. # Besides being stored in x, species and coordinates are also stored in y.
# So here, for simplicity, we just ignore the x and use y for everything. # So here, for simplicity, we just ignore the x and use y for everything.
for i, (_, batch_y) in tqdm.tqdm( for i, properties in tqdm.tqdm(
enumerate(training), enumerate(training),
total=len(training), total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch) desc="epoch {}".format(AdamW_scheduler.last_epoch)
): ):
species = properties['species'].to(device)
true_energies = batch_y['energies'] coordinates = properties['coordinates'].to(device).float().requires_grad_(True)
predicted_energies = [] true_energies = properties['energies'].to(device).float()
num_atoms = [] true_forces = properties['forces'].to(device).float()
force_loss = [] num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
_, predicted_energies = model((species, coordinates))
for chunk in batch_y['atomic']:
chunk_species = chunk['species'] # We can use torch.autograd.grad to compute force. Remember to
chunk_coordinates = chunk['coordinates'] # create graph so that the loss of the force can contribute to
chunk_true_forces = chunk['forces'] # the gradient of parameters, and also to retain graph so that
chunk_num_atoms = (chunk_species >= 0).to(true_energies.dtype).sum(dim=1) # we can backward through it a second time when computing gradient
num_atoms.append(chunk_num_atoms) # w.r.t. parameters.
forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0]
# We must set `chunk_coordinates` to make it requires grad, so
# that we could compute force from it
chunk_coordinates.requires_grad_(True)
chunk_energies = model((chunk_species, chunk_coordinates)).energies
# We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to
# the gradient of parameters, and also to retain graph so that
# we can backward through it a second time when computing gradient
# w.r.t. parameters.
chunk_forces = -torch.autograd.grad(chunk_energies.sum(), chunk_coordinates, create_graph=True, retain_graph=True)[0]
# Now let's compute loss for force of this chunk
chunk_force_loss = mse(chunk_true_forces, chunk_forces).sum(dim=(1, 2)) / chunk_num_atoms
predicted_energies.append(chunk_energies)
force_loss.append(chunk_force_loss)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
# Now the total loss has two parts, energy loss and force loss # Now the total loss has two parts, energy loss and force loss
energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
force_loss = torch.cat(force_loss).mean() force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) / num_atoms).mean()
loss = energy_loss + force_coefficient * force_loss loss = energy_loss + force_coefficient * force_loss
AdamW.zero_grad() AdamW.zero_grad()
......
...@@ -118,7 +118,8 @@ class TestAEV(_TestAEVBase): ...@@ -118,7 +118,8 @@ class TestAEV(_TestAEVBase):
species = self.transform(species) species = self.transform(species)
radial = self.transform(radial) radial = self.transform(radial)
angular = self.transform(angular) angular = self.transform(angular)
species_coordinates.append({'species': species, 'coordinates': coordinates}) species_coordinates.append(torchani.utils.broadcast_first_dim(
{'species': species, 'coordinates': coordinates}))
radial_angular.append((radial, angular)) radial_angular.append((radial, angular))
species_coordinates = torchani.utils.pad_atomic_properties( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
......
import os import os
import torch
import torchani import torchani
import unittest import unittest
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4') dataset_path = os.path.join(path, 'dataset/ani-1x/sample.h5')
dataset_path2 = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 256 batch_size = 256
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
consts = ani1x.consts consts = ani1x.consts
sae_dict = ani1x.sae_dict
aev_computer = ani1x.aev_computer aev_computer = ani1x.aev_computer
class TestData(unittest.TestCase): class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.data.load_ani_dataset(dataset_path,
consts.species_to_tensor,
batch_size)
def _assertTensorEqual(self, t1, t2):
self.assertLess((t1 - t2).abs().max().item(), 1e-6)
def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 8), dtype=torch.long)
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species_coordinates = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
{'species': species3, 'coordinates': coordinates3},
])
species = species_coordinates['species']
coordinates = species_coordinates['coordinates']
natoms = (species >= 0).to(torch.long).sum(1)
chunks = torchani.data.split_batch(natoms, species_coordinates)
start = 0
last = None
for chunk in chunks:
s = chunk['species']
c = chunk['coordinates']
n = (s >= 0).to(torch.long).sum(1)
if last is not None:
self.assertNotEqual(last[-1], n[0])
conformations = s.shape[0]
self.assertGreater(conformations, 0)
s_ = species[start:(start + conformations), ...]
c_ = coordinates[start:(start + conformations), ...]
sc = torchani.utils.strip_redundant_padding({'species': s_, 'coordinates': c_})
s_ = sc['species']
c_ = sc['coordinates']
self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_)
start += conformations
sc = torchani.utils.pad_atomic_properties(chunks)
s = sc['species']
c = sc['coordinates']
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self): def testTensorShape(self):
for i in self.ds: ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
input_, output = i for d in ds:
input_ = [{'species': x[0], 'coordinates': x[1]} for x in input_] species = d['species']
species_coordinates = torchani.utils.pad_atomic_properties(input_) coordinates = d['coordinates']
species = species_coordinates['species'] energies = d['energies']
coordinates = species_coordinates['coordinates']
energies = output['energies']
self.assertEqual(len(species.shape), 2) self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size) self.assertLessEqual(species.shape[0], batch_size)
self.assertEqual(len(coordinates.shape), 3) self.assertEqual(len(coordinates.shape), 3)
...@@ -80,11 +28,11 @@ class TestData(unittest.TestCase): ...@@ -80,11 +28,11 @@ class TestData(unittest.TestCase):
self.assertEqual(coordinates.shape[0], energies.shape[0]) self.assertEqual(coordinates.shape[0], energies.shape[0])
def testNoUnnecessaryPadding(self): def testNoUnnecessaryPadding(self):
for i in self.ds: ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
for input_ in i[0]: for d in ds:
species, _ = input_ species = d['species']
non_padding = (species >= 0)[:, -1].nonzero() non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0) self.assertGreater(non_padding.numel(), 0)
if __name__ == '__main__': if __name__ == '__main__':
......
import torchani
import unittest
import pkbar
import torch
import os
path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5')
batch_size = 2560
chunk_threshold = 5
other_properties = {'properties': ['dipoles', 'forces', 'energies'],
'padding_values': [None, 0, None],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3), (batch_size, )],
'dtypes': [torch.float32, torch.float32, torch.float64],
}
other_properties = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(batch_size, )],
'dtypes': [torch.float64],
}
class TestFindThreshold(unittest.TestCase):
def setUp(self):
print('.. check find threshold to split chunks')
def testFindThreshould(self):
torchani.data.find_threshold(dspath, batch_size=batch_size, threshold_max=10)
class TestShuffledData(unittest.TestCase):
def setUp(self):
print('.. setup shuffle dataset')
self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size,
chunk_threshold=chunk_threshold,
num_workers=2,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = iter(self.ds).next()
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
# check shape
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
print('2. properties')
for i, key in enumerate(other_properties['properties']):
print(key, list(self.properties[key].size()), self.properties[key].dtype)
# check dtype
self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
# shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self):
print('=> test loading all dataset')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
len(self.ds))
for i, _ in enumerate(self.ds):
pbar.update(i)
def testSplitDataset(self):
print('=> test splitting dataset')
train_ds, val_ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2, validation_split=0.1)
frac = len(val_ds) / (len(val_ds) + len(train_ds))
self.assertLess(abs(frac - 0.1), 0.05)
def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
species, _ = chunk
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
class TestCachedData(unittest.TestCase):
def setUp(self):
print('.. setup cached dataset')
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = self.ds[0]
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
# check shape
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
print('2. properties')
for i, key in enumerate(other_properties['properties']):
print(key, list(self.properties[key].size()), self.properties[key].dtype)
# check dtype
self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
# shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self):
print('=> test loading all dataset')
self.ds.load()
def testSplitDataset(self):
print('=> test splitting dataset')
train_dataset, val_dataset = self.ds.split(0.1)
frac = len(val_dataset) / len(self.ds)
self.assertLess(abs(frac - 0.1), 0.05)
def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
species, _ = chunk
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
if __name__ == "__main__":
unittest.main()
...@@ -54,7 +54,8 @@ class TestEnergies(unittest.TestCase): ...@@ -54,7 +54,8 @@ class TestEnergies(unittest.TestCase):
coordinates = self.transform(coordinates) coordinates = self.transform(coordinates)
species = self.transform(species) species = self.transform(species)
e = self.transform(e) e = self.transform(e)
species_coordinates.append({'species': species, 'coordinates': coordinates}) species_coordinates.append(
torchani.utils.broadcast_first_dim({'species': species, 'coordinates': coordinates}))
energies.append(e) energies.append(e)
species_coordinates = torchani.utils.pad_atomic_properties( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
......
...@@ -55,7 +55,8 @@ class TestForce(unittest.TestCase): ...@@ -55,7 +55,8 @@ class TestForce(unittest.TestCase):
species = self.transform(species) species = self.transform(species)
forces = self.transform(forces) forces = self.transform(forces)
coordinates.requires_grad_(True) coordinates.requires_grad_(True)
species_coordinates.append({'species': species, 'coordinates': coordinates}) species_coordinates.append(torchani.utils.broadcast_first_dim(
{'species': species, 'coordinates': coordinates}))
species_coordinates = torchani.utils.pad_atomic_properties( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
_, energies = self.model((species_coordinates['species'], species_coordinates['coordinates'])) _, energies = self.model((species_coordinates['species'], species_coordinates['coordinates']))
......
...@@ -18,16 +18,14 @@ other_properties = {'properties': ['energies'], ...@@ -18,16 +18,14 @@ other_properties = {'properties': ['energies'],
class TestBuiltinModelsJIT(unittest.TestCase): class TestBuiltinModelsJIT(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.ani1ccx = torchani.models.ANI1ccx() self.ani1ccx = torchani.models.ANI1ccx()
self.ds = torchani.data.load(dspath).subtract_self_energies(self.ani1ccx.sae_dict).species_to_indices().shuffle().collate(256).cache()
def _test_model(self, model): def _test_model(self, model):
chunk = self.ds[0][0][0] properties = next(iter(self.ds))
_, e = model(chunk) input_ = (properties['species'], properties['coordinates'].float())
_, e2 = torch.jit.script(model)(chunk) _, e = model(input_)
_, e2 = torch.jit.script(model)(input_)
self.assertTrue(torch.allclose(e, e2)) self.assertTrue(torch.allclose(e, e2))
def _test_ensemble(self, ensemble): def _test_ensemble(self, ensemble):
......
...@@ -3,6 +3,9 @@ import torch ...@@ -3,6 +3,9 @@ import torch
import torchani import torchani
b = torchani.utils.broadcast_first_dim
class TestPaddings(unittest.TestCase): class TestPaddings(unittest.TestCase):
def testVectorSpecies(self): def testVectorSpecies(self):
...@@ -11,8 +14,8 @@ class TestPaddings(unittest.TestCase): ...@@ -11,8 +14,8 @@ class TestPaddings(unittest.TestCase):
species2 = torch.tensor([[3, 2, 0, 1, 0]]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
atomic_properties = torchani.utils.pad_atomic_properties([ atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
self.assertEqual(atomic_properties['species'].shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
...@@ -34,8 +37,8 @@ class TestPaddings(unittest.TestCase): ...@@ -34,8 +37,8 @@ class TestPaddings(unittest.TestCase):
species2 = torch.tensor([[3, 2, 0, 1, 0]]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
atomic_properties = torchani.utils.pad_atomic_properties([ atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
self.assertEqual(atomic_properties['species'].shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
...@@ -63,8 +66,8 @@ class TestPaddings(unittest.TestCase): ...@@ -63,8 +66,8 @@ class TestPaddings(unittest.TestCase):
species2 = torch.tensor([[3, 2, 0, 1, 0]]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
atomic_properties = torchani.utils.pad_atomic_properties([ atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
self.assertEqual(atomic_properties['species'].shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
...@@ -98,28 +101,28 @@ class TestStripRedundantPadding(unittest.TestCase): ...@@ -98,28 +101,28 @@ class TestStripRedundantPadding(unittest.TestCase):
species2 = torch.randint(4, (2, 5), dtype=torch.long) species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3) coordinates2 = torch.randn(2, 5, 3)
atomic_properties12 = torchani.utils.pad_atomic_properties([ atomic_properties12 = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
species12 = atomic_properties12['species'] species12 = atomic_properties12['species']
coordinates12 = atomic_properties12['coordinates'] coordinates12 = atomic_properties12['coordinates']
species3 = torch.randint(4, (2, 10), dtype=torch.long) species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3) coordinates3 = torch.randn(2, 10, 3)
atomic_properties123 = torchani.utils.pad_atomic_properties([ atomic_properties123 = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
{'species': species3, 'coordinates': coordinates3}, b({'species': species3, 'coordinates': coordinates3}),
]) ])
species123 = atomic_properties123['species'] species123 = atomic_properties123['species']
coordinates123 = atomic_properties123['coordinates'] coordinates123 = atomic_properties123['coordinates']
species_coordinates1_ = torchani.utils.strip_redundant_padding( species_coordinates1_ = torchani.utils.strip_redundant_padding(
{'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]}) b({'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]}))
species1_ = species_coordinates1_['species'] species1_ = species_coordinates1_['species']
coordinates1_ = species_coordinates1_['coordinates'] coordinates1_ = species_coordinates1_['coordinates']
self._assertTensorEqual(species1_, species1) self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1) self._assertTensorEqual(coordinates1_, coordinates1)
species_coordinates12_ = torchani.utils.strip_redundant_padding( species_coordinates12_ = torchani.utils.strip_redundant_padding(
{'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]}) b({'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]}))
species12_ = species_coordinates12_['species'] species12_ = species_coordinates12_['species']
coordinates12_ = species_coordinates12_['coordinates'] coordinates12_ = species_coordinates12_['coordinates']
self._assertTensorEqual(species12_, species12) self._assertTensorEqual(species12_, species12)
......
...@@ -55,24 +55,11 @@ if __name__ == "__main__": ...@@ -55,24 +55,11 @@ if __name__ == "__main__":
parser.add_argument('-b', '--batch_size', parser.add_argument('-b', '--batch_size',
help='Number of conformations of each batch', help='Number of conformations of each batch',
default=2560, type=int) default=2560, type=int)
parser.add_argument('-o', '--original_dataset_api', parser.add_argument('-d', '--dry-run',
help='use original dataset api', help='just run it in a CI without GPU',
dest='dataset', action='store_true')
action='store_const',
const='original')
parser.add_argument('-s', '--shuffle_dataset_api',
help='use shuffle dataset api',
dest='dataset',
action='store_const',
const='shuffle')
parser.add_argument('-c', '--cache_dataset_api',
help='use cache dataset api',
dest='dataset',
action='store_const',
const='cache')
parser.set_defaults(dataset='shuffle')
parser = parser.parse_args() parser = parser.parse_args()
parser.device = torch.device('cuda') parser.device = torch.device('cpu' if parser.dry_run else 'cuda')
Rcr = 5.2000e+00 Rcr = 5.2000e+00
Rca = 3.5000e+00 Rca = 3.5000e+00
...@@ -90,45 +77,9 @@ if __name__ == "__main__": ...@@ -90,45 +77,9 @@ if __name__ == "__main__":
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001) optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
mse = torch.nn.MSELoss(reduction='none') mse = torch.nn.MSELoss(reduction='none')
if parser.dataset == 'shuffle': print('=> loading dataset...')
print('using shuffle dataset API') shifter = torchani.EnergyShifter(None)
print('=> loading dataset...') dataset = list(torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size))
dataset = torchani.data.ShuffledDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size,
num_workers=2)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = iter(dataset).next()
elif parser.dataset == 'original':
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
elif parser.dataset == 'cache':
print('using cache dataset API')
print('=> loading dataset...')
dataset = torchani.data.CachedDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size)
print('=> caching all dataset into cpu')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(dataset), parser.batch_size),
len(dataset))
for i, t in enumerate(dataset):
pbar.update(i)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
for i, chunk in enumerate(chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), list(chunk[1].size()))
print('energies', list(properties['energies'].size()))
print('=> start warming up') print('=> start warming up')
total_batch_counter = 0 total_batch_counter = 0
...@@ -137,36 +88,24 @@ if __name__ == "__main__": ...@@ -137,36 +88,24 @@ if __name__ == "__main__":
print('Epoch: %d/inf' % (epoch + 1,)) print('Epoch: %d/inf' % (epoch + 1,))
progbar = pkbar.Kbar(target=len(dataset) - 1, width=8) progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)
for i, (batch_x, batch_y) in enumerate(dataset): for i, properties in enumerate(dataset):
if total_batch_counter == WARM_UP_BATCHES: if not parser.dry_run and total_batch_counter == WARM_UP_BATCHES:
print('=> warm up finished, start profiling') print('=> warm up finished, start profiling')
enable_timers(model) enable_timers(model)
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
PROFILING_STARTED = (total_batch_counter >= WARM_UP_BATCHES) PROFILING_STARTED = not parser.dry_run and (total_batch_counter >= WARM_UP_BATCHES)
if PROFILING_STARTED: if PROFILING_STARTED:
torch.cuda.nvtx.range_push("batch{}".format(total_batch_counter)) torch.cuda.nvtx.range_push("batch{}".format(total_batch_counter))
true_energies = batch_y['energies'].to(parser.device) species = properties['species'].to(parser.device)
predicted_energies = [] coordinates = properties['coordinates'].to(parser.device).float()
num_atoms = [] true_energies = properties['energies'].to(parser.device).float()
num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
for j, (chunk_species, chunk_coordinates) in enumerate(batch_x): with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
if PROFILING_STARTED: _, predicted_energies = model((species, coordinates))
torch.cuda.nvtx.range_push("chunk{}".format(j))
chunk_species = chunk_species.to(parser.device)
chunk_coordinates = chunk_coordinates.to(parser.device)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
if PROFILING_STARTED:
torch.cuda.nvtx.range_pop()
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies).to(true_energies.dtype)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
......
...@@ -49,25 +49,9 @@ if __name__ == "__main__": ...@@ -49,25 +49,9 @@ if __name__ == "__main__":
parser.add_argument('-b', '--batch_size', parser.add_argument('-b', '--batch_size',
help='Number of conformations of each batch', help='Number of conformations of each batch',
default=2560, type=int) default=2560, type=int)
parser.add_argument('-o', '--original_dataset_api',
help='use original dataset api',
dest='dataset',
action='store_const',
const='original')
parser.add_argument('-s', '--shuffle_dataset_api',
help='use shuffle dataset api',
dest='dataset',
action='store_const',
const='shuffle')
parser.add_argument('-c', '--cache_dataset_api',
help='use cache dataset api',
dest='dataset',
action='store_const',
const='cache')
parser.add_argument('-y', '--synchronize', parser.add_argument('-y', '--synchronize',
action='store_true', action='store_true',
help='whether to insert torch.cuda.synchronize() at the end of each function') help='whether to insert torch.cuda.synchronize() at the end of each function')
parser.set_defaults(dataset='shuffle')
parser.add_argument('-n', '--num_epochs', parser.add_argument('-n', '--num_epochs',
help='epochs', help='epochs',
default=1, type=int) default=1, type=int)
...@@ -107,48 +91,9 @@ if __name__ == "__main__": ...@@ -107,48 +91,9 @@ if __name__ == "__main__":
model[0].forward = time_func('total', model[0].forward) model[0].forward = time_func('total', model[0].forward)
model[1].forward = time_func('forward', model[1].forward) model[1].forward = time_func('forward', model[1].forward)
if parser.dataset == 'shuffle': print('=> loading dataset...')
torchani.data.ShuffledDataset = time_func('data_loading', torchani.data.ShuffledDataset) shifter = torchani.EnergyShifter(None)
print('using shuffle dataset API') dataset = list(torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size))
print('=> loading dataset...')
dataset = torchani.data.ShuffledDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size,
num_workers=2)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = iter(dataset).next()
elif parser.dataset == 'original':
torchani.data.load_ani_dataset = time_func('data_loading', torchani.data.load_ani_dataset)
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
elif parser.dataset == 'cache':
torchani.data.CachedDataset = time_func('data_loading', torchani.data.CachedDataset)
print('using cache dataset API')
print('=> loading dataset...')
dataset = torchani.data.CachedDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size)
print('=> caching all dataset into cpu')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(dataset), parser.batch_size),
len(dataset))
for i, t in enumerate(dataset):
pbar.update(i)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
for i, chunk in enumerate(chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), list(chunk[1].size()))
print('energies', list(properties['energies'].size()))
print('=> start training') print('=> start training')
start = time.time() start = time.time()
...@@ -158,24 +103,12 @@ if __name__ == "__main__": ...@@ -158,24 +103,12 @@ if __name__ == "__main__":
print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs)) print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs))
progbar = pkbar.Kbar(target=len(dataset) - 1, width=8) progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)
for i, (batch_x, batch_y) in enumerate(dataset): for i, properties in enumerate(dataset):
species = properties['species'].to(parser.device)
true_energies = batch_y['energies'].to(parser.device) coordinates = properties['coordinates'].to(parser.device).float()
predicted_energies = [] true_energies = properties['energies'].to(parser.device).float()
num_atoms = [] num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
atomic_properties = [] _, predicted_energies = model((species, coordinates))
for chunk_species, chunk_coordinates in batch_x:
chunk_species = chunk_species.to(parser.device)
chunk_coordiantes = chunk_coordinates.to(parser.device)
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies.to(true_energies.dtype)
num_atoms = torch.cat(num_atoms)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward() loss.backward()
...@@ -191,6 +124,5 @@ if __name__ == "__main__": ...@@ -191,6 +124,5 @@ if __name__ == "__main__":
if k.startswith('torchani.'): if k.startswith('torchani.'):
print('{} - {:.1f}s'.format(k, timers[k])) print('{} - {:.1f}s'.format(k, timers[k]))
print('Total AEV - {:.1f}s'.format(timers['total'])) print('Total AEV - {:.1f}s'.format(timers['total']))
print('Data Loading - {:.1f}s'.format(timers['data_loading']))
print('NN - {:.1f}s'.format(timers['forward'])) print('NN - {:.1f}s'.format(timers['forward']))
print('Epoch time - {:.1f}s'.format(stop - start)) print('Epoch time - {:.1f}s'.format(stop - start))
...@@ -408,8 +408,8 @@ class AEVComputer(torch.nn.Module): ...@@ -408,8 +408,8 @@ class AEVComputer(torch.nn.Module):
If you don't care about periodic boundary conditions at all, If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species, coordinates. then input can be a tuple of two tensors: species, coordinates.
species must have shape ``(C, A)``, coordinates must have shape species must have shape ``(N, A)``, coordinates must have shape
``(C, A, 3)`` where ``C`` is the number of molecules in a chunk, ``(N, A, 3)`` where ``N`` is the number of molecules in a batch,
and ``A`` is the number of atoms. and ``A`` is the number of atoms.
.. warning:: .. warning::
...@@ -437,7 +437,7 @@ class AEVComputer(torch.nn.Module): ...@@ -437,7 +437,7 @@ class AEVComputer(torch.nn.Module):
Returns: Returns:
NamedTuple: Species and AEVs. species are the species from the input NamedTuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape ``(C, A, self.aev_length())`` unchanged, and AEVs is a tensor of shape ``(N, A, self.aev_length())``
""" """
species, coordinates = input_ species, coordinates = input_
......
# -*- coding: utf-8 -*- # -*- coding: utf-8 -*-
"""Tools for loading, shuffling, and batching ANI datasets""" """Tools for loading, shuffling, and batching ANI datasets
The `torchani.data.load(path)` creates an iterable of raw data,
where species are strings, and coordinates are numpy ndarrays.
You can transform these iterable by using transformations.
To do transformation, just do `it.transformation_name()`.
Available transformations are listed below:
- `species_to_indices` converts species from strings to numbers.
- `subtract_self_energies` subtracts self energies, you can pass.
a dict of self energies, or an `EnergyShifter` to let it infer
self energy from dataset and store the result to the given shifter.
- `remove_outliers`
- `shuffle`
- `cache` cache the result of previous transformations.
- `collate` pad the dataset, convert it to tensor, and stack them
together to get a batch.
- `pin_memory` copy the tensor to pinned memory so that later transfer
to cuda could be faster.
You can also use `split` to split the iterable to pieces. Use `split` as:
.. code-block:: python
it.split(size1, size2, None)
where the None in the end indicate that we want to use all of the the rest
Example:
.. code-block:: python
energy_shifter = torchani.utils.EnergyShifter(None)
dataset = torchani.data.load(path).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
size = len(dataset)
training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
"""
from torch.utils.data import Dataset
from os.path import join, isfile, isdir from os.path import join, isfile, isdir
import os import os
from ._pyanitools import anidataloader from ._pyanitools import anidataloader
import torch import torch
from .. import utils from .. import utils
from .new import CachedDataset, ShuffledDataset, find_threshold import importlib
import functools
default_device = 'cuda' if torch.cuda.is_available() else 'cpu' import math
import random
from collections import Counter
def chunk_counts(counts, split): import numpy
split = [x + 1 for x in split] + [None]
count_chunks = [] PKBAR_INSTALLED = importlib.util.find_spec('pkbar') is not None # type: ignore
start = 0 if PKBAR_INSTALLED:
for i in split: import pkbar
count_chunks.append(counts[start:i])
start = i verbose = True
chunk_molecules = [sum([y[1] for y in x]) for x in count_chunks]
chunk_maxatoms = [x[-1][0] for x in count_chunks]
return chunk_molecules, chunk_maxatoms PROPERTIES = ('energies', 'forces')
PADDING = {
'species': -1,
def split_cost(counts, split): 'coordinates': 0.0,
split_min_cost = 40000 'forces': 0.0,
cost = 0 'energies': 0.0
chunk_molecules, chunk_maxatoms = chunk_counts(counts, split) }
for molecules, maxatoms in zip(chunk_molecules, chunk_maxatoms):
cost += max(molecules * maxatoms ** 2, split_min_cost)
return cost class Transformations:
@staticmethod
def split_batch(natoms, atomic_properties): def species_to_indices(iter_, species_order=('H', 'C', 'N', 'O', 'F', 'Cl', 'S')):
if species_order == 'periodic_table':
# count number of conformation by natoms species_order = utils.PERIODIC_TABLE
natoms = natoms.tolist() idx = {k: i for i, k in enumerate(species_order)}
counts = [] for d in iter_:
for i in natoms: d['species'] = numpy.array([idx[s] for s in d['species']])
if not counts: yield d
counts.append([i, 1])
continue @staticmethod
if i == counts[-1][0]: def subtract_self_energies(iter_, self_energies=None):
counts[-1][1] += 1 iter_ = list(iter_)
else: intercept = 0.0
counts.append([i, 1]) if isinstance(self_energies, utils.EnergyShifter):
shifter = self_energies
# find best split using greedy strategy self_energies = {}
split = [] counts = {}
cost = split_cost(counts, split) Y = []
improved = True for n, d in enumerate(iter_):
while improved: species = d['species']
improved = False count = Counter()
cycle_split = split for s in species:
cycle_cost = cost count[s] += 1
for i in range(len(counts) - 1): for s, c in count.items():
if i not in split: if s not in counts:
s = sorted(split + [i]) counts[s] = [0] * n
c = split_cost(counts, s) counts[s].append(c)
if c < cycle_cost: for s in counts:
improved = True if len(counts[s]) != n + 1:
cycle_cost = c counts[s].append(0)
cycle_split = s Y.append(d['energies'])
if improved: species = sorted(list(counts.keys()))
split = cycle_split X = [counts[s] for s in species]
cost = cycle_cost if shifter.fit_intercept:
X.append([1] * n)
# do split X = numpy.array(X).transpose()
chunk_molecules, _ = chunk_counts(counts, split) Y = numpy.array(Y)
num_chunks = None sae, _, _, _ = numpy.linalg.lstsq(X, Y, rcond=None)
for k in atomic_properties: sae_ = sae
atomic_properties[k] = atomic_properties[k].split(chunk_molecules) if shifter.fit_intercept:
if num_chunks is None: intercept = sae[-1]
num_chunks = len(atomic_properties[k]) sae_ = sae[:-1]
else: for s, e in zip(species, sae_):
assert num_chunks == len(atomic_properties[k]) self_energies[s] = e
chunks = [] shifter.__init__(sae, shifter.fit_intercept)
for i in range(num_chunks): for d in iter_:
chunk = {k: atomic_properties[k][i] for k in atomic_properties} e = intercept
chunks.append(utils.strip_redundant_padding(chunk)) for s in d['species']:
return chunks e += self_energies[s]
d['energies'] -= e
yield d
def load_and_pad_whole_dataset(path, species_tensor_converter, shuffle=True,
properties=('energies',), atomic_properties=()): @staticmethod
# get name of files storing data def remove_outliers(iter_, threshold1=15.0, threshold2=8.0):
files = [] assert 'subtract_self_energies', "Transformation remove_outliers can only run after subtract_self_energies"
if isdir(path):
for f in os.listdir(path): # pass 1: remove everything that has per-atom energy > threshold1
f = join(path, f) def scaled_energy(x):
if isfile(f) and (f.endswith('.h5') or f.endswith('.hdf5')): num_atoms = len(x['species'])
files.append(f) return abs(x['energies']) / math.sqrt(num_atoms)
elif isfile(path): filtered = [x for x in iter_ if scaled_energy(x) < threshold1]
files = [path]
else: # pass 2: compute those that are outside the mean by threshold2 * std
raise ValueError('Bad path') n = 0
mean = 0
# load full dataset std = 0
atomic_properties_ = [] for m in filtered:
properties = {k: [] for k in properties} n += 1
for f in files: mean += m['energies']
for m in anidataloader(f): std += m['energies'] ** 2
atomic_properties_.append(dict( mean /= n
species=species_tensor_converter(m['species']).unsqueeze(0), std = math.sqrt(std / n - mean ** 2)
**{
k: torch.from_numpy(m[k]).to(torch.double) return filter(lambda x: abs(x['energies'] - mean) < threshold2 * std, filtered)
for k in ['coordinates'] + list(atomic_properties)
} @staticmethod
)) def shuffle(iter_):
for i in properties: list_ = list(iter_)
p = torch.from_numpy(m[i]).to(torch.double) random.shuffle(list_)
properties[i].append(p) return list_
atomic_properties = utils.pad_atomic_properties(atomic_properties_)
for i in properties: @staticmethod
properties[i] = torch.cat(properties[i]) def cache(iter_):
return list(iter_)
# shuffle if required
molecules = atomic_properties['species'].shape[0] @staticmethod
if shuffle: def collate(iter_, batch_size):
indices = torch.randperm(molecules) batch = []
for i in properties: i = 0
properties[i] = properties[i].index_select(0, indices) for d in iter_:
for i in atomic_properties: d = {k: torch.as_tensor(d[k]) for k in d}
atomic_properties[i] = atomic_properties[i].index_select(0, indices) batch.append(d)
return atomic_properties, properties i += 1
if i == batch_size:
i = 0
def split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size): yield utils.stack_with_padding(batch, PADDING)
molecules = atomic_properties['species'].shape[0] batch = []
# split into minibatches if len(batch) > 0:
for k in properties: yield utils.stack_with_padding(batch, PADDING)
properties[k] = properties[k].split(batch_size)
for k in atomic_properties: @staticmethod
atomic_properties[k] = atomic_properties[k].split(batch_size) def pin_memory(iter_):
for d in iter_:
# further split batch into chunks and strip redundant padding yield {k: d[k].pin_memory() for k in d}
batches = []
num_batches = (molecules + batch_size - 1) // batch_size
for i in range(num_batches): class TransformableIterable:
batch_properties = {k: v[i] for k, v in properties.items()} def __init__(self, wrapped_iter, transformations=()):
batch_atomic_properties = {k: v[i] for k, v in atomic_properties.items()} self.wrapped_iter = wrapped_iter
species = batch_atomic_properties['species'] self.transformations = transformations
natoms = (species >= 0).to(torch.long).sum(1)
def __iter__(self):
# sort batch by number of atoms to prepare for splitting return iter(self.wrapped_iter)
natoms, indices = natoms.sort()
for k in batch_properties: def __next__(self):
batch_properties[k] = batch_properties[k].index_select(0, indices) return next(self.wrapped_iter)
for k in batch_atomic_properties:
batch_atomic_properties[k] = batch_atomic_properties[k].index_select(0, indices) def __getattr__(self, name):
transformation = getattr(Transformations, name)
batch_atomic_properties = split_batch(natoms, batch_atomic_properties)
batches.append((batch_atomic_properties, batch_properties)) @functools.wraps(transformation)
def f(*args, **kwargs):
return batches return TransformableIterable(
transformation(self, *args, **kwargs),
self.transformations + (name,))
class PaddedBatchChunkDataset(Dataset):
r""" Dataset that contains batches in 'chunks', with padded structures return f
This dataset acts as a container of batches to be used when training. Each def split(self, *nums):
of the batches is broken up into 'chunks', each of which is a tensor has iters = []
molecules with a smiliar number of atoms, but which have been padded with self_iter = iter(self)
dummy atoms in order for them to have the same tensor dimensions. for n in nums:
""" list_ = []
if n is not None:
def __init__(self, atomic_properties, properties, batch_size, for _ in range(n):
dtype=torch.get_default_dtype(), device=default_device): list_.append(next(self_iter))
super().__init__() else:
self.device = device for i in self_iter:
self.dtype = dtype list_.append(i)
iters.append(TransformableIterable(list_, self.transformations + ('split',)))
# convert to desired dtype return iters
for k in properties:
properties[k] = properties[k].to(dtype)
for k in atomic_properties:
if k == 'species':
continue
atomic_properties[k] = atomic_properties[k].to(dtype)
self.batches = split_whole_into_batches_and_chunks(atomic_properties, properties, batch_size)
def __getitem__(self, idx):
atomic_properties, properties = self.batches[idx]
atomic_properties, properties = atomic_properties.copy(), properties.copy()
species_coordinates = []
for chunk in atomic_properties:
for k in chunk:
chunk[k] = chunk[k].to(self.device)
species_coordinates.append((chunk['species'], chunk['coordinates']))
for k in properties:
properties[k] = properties[k].to(self.device)
properties['atomic'] = atomic_properties
return species_coordinates, properties
def __len__(self): def __len__(self):
return len(self.batches) return len(self.wrapped_iter)
def load_ani_dataset(path, species_tensor_converter, batch_size, shuffle=True, def load(path, additional_properties=()):
rm_outlier=False, properties=('energies',), atomic_properties=(), properties = PROPERTIES + additional_properties
transform=(), dtype=torch.get_default_dtype(), device=default_device,
split=(None,)): def h5_files(path):
"""Load ANI dataset from hdf5 files, and split into subsets. """yield file name of all h5 files in a path"""
if isdir(path):
The return datasets are already a dataset of batches, so when iterated, a for f in os.listdir(path):
batch rather than a single data point will be yielded. f = join(path, f)
yield from h5_files(f)
Since each batch might contain molecules of very different sizes, putting elif isfile(path) and (path.endswith('.h5') or path.endswith('.hdf5')):
the whole batch into a single tensor would require adding ghost atoms to yield path
pad everything to the size of the largest molecule. As a result, huge
amount of computation would be wasted on ghost atoms. To avoid this issue, def molecules():
the input of each batch, i.e. species and coordinates, are further divided for f in h5_files(path):
into chunks according to some heuristics, so that each chunk would only anidata = anidataloader(f)
have molecules of similar size, to minimize the padding required. anidata_size = anidata.size()
use_pbar = PKBAR_INSTALLED and verbose
So, when iterating on this dataset, a tuple will be yielded. The first if use_pbar:
element of this tuple is a list of (species, coordinates) pairs. Each pair pbar = pkbar.Pbar('=> loading {}, total molecules: {}'.format(f, anidata_size), anidata_size)
is a chunk of molecules of similar size. The second element of this tuple for i, m in enumerate(anidata):
would be a dictionary, where the keys are those specified in the argument yield m
:attr:`properties`, and values are a single tensor of the whole batch if use_pbar:
(properties are not splitted into chunks). pbar.update(i)
Splitting batch into chunks leads to some inconvenience on training, def conformations():
especially when using high level libraries like ``ignite``. To overcome for m in molecules():
this inconvenience, :class:`torchani.ignite.Container` is created for species = m['species']
working with ignite. coordinates = m['coordinates']
for i in range(coordinates.shape[0]):
Arguments: ret = {'species': species, 'coordinates': coordinates[i]}
path (str): Path to hdf5 files. If :attr:`path` is a file, then that for k in properties:
file would be loaded using `pyanitools.py`_. If :attr:`path` is if k in m:
a directory, then all files with suffix `.h5` or `.hdf5` will be ret[k] = m[k][i]
loaded. yield ret
species_tensor_converter (:class:`collections.abc.Callable`): A
callable that convert species in the format of list of strings return TransformableIterable(conformations())
to 1D tensor.
batch_size (int): Number of different 3D structures in a single
minibatch. __all__ = ['load']
shuffle (bool): Whether to shuffle the whole dataset.
rm_outlier (bool): Whether to discard the outlier energy conformers
from a given dataset.
properties (list): List of keys of `molecular` properties in the
dataset to be loaded. Here `molecular` means, no matter the number
of atoms that property always have fixed size, i.e. the tensor
shape of molecular properties should be (molecule, ...). An example
of molecular property is the molecular energies. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
atomic_properties (list): List of keys of `atomic` properties in the
dataset to be loaded. Here `atomic` means, the size of property
is proportional to the number of atoms in the molecule, i.e. the
tensor shape of atomic properties should be (molecule, atoms, ...).
An example of atomic property is the forces. ``'species'`` and
``'coordinates'`` are always loaded and need not to be specified
anywhere.
transform (list): List of :class:`collections.abc.Callable` that
transform the data. Callables must take atomic properties,
properties as arguments, and return the transformed atomic
properties and properties.
dtype (:class:`torch.dtype`): dtype of coordinates and properties to
to convert the dataset to.
device (:class:`torch.dtype`): device to put tensors when iterating.
split (list): as sequence of integers or floats or ``None``. Integers
are interpreted as number of elements, floats are interpreted as
percentage, and ``None`` are interpreted as the rest of the dataset
and can only appear as the last element of :class:`split`. For
example, if the whole dataset has 10000 entry, and split is
``(5000, 0.1, None)``, then this function will create 3 datasets,
where the first dataset contains 5000 elements, the second dataset
contains ``int(0.1 * 10000)``, which is 1000, and the third dataset
will contains ``10000 - 5000 - 1000`` elements. By default this
creates only a single dataset.
Returns:
An instance of :class:`torchani.data.PaddedBatchChunkDataset` if there is
only one element in :attr:`split`, otherwise returns a tuple of the same
classes according to :attr:`split`.
.. _pyanitools.py:
https://github.com/isayev/ASE_ANI/blob/master/lib/pyanitools.py
"""
atomic_properties_, properties_ = load_and_pad_whole_dataset(
path, species_tensor_converter, shuffle, properties, atomic_properties)
molecules = atomic_properties_['species'].shape[0]
atomic_keys = ['species', 'coordinates', *atomic_properties]
keys = properties
# do transformations on data
for t in transform:
atomic_properties_, properties_ = t(atomic_properties_, properties_)
if rm_outlier:
transformed_energies = properties_['energies']
num_atoms = (atomic_properties_['species'] >= 0).to(transformed_energies.dtype).sum(dim=1)
scaled_diff = transformed_energies / num_atoms.sqrt()
mean = scaled_diff[torch.abs(scaled_diff) < 15.0].mean()
std = scaled_diff[torch.abs(scaled_diff) < 15.0].std()
# -8 * std + mean < scaled_diff < +8 * std + mean
tol = 8.0 * std + mean
low_idx = (torch.abs(scaled_diff) < tol).nonzero().squeeze()
outlier_count = molecules - low_idx.numel()
# discard outlier energy conformers if exist
if outlier_count > 0:
print("Note: {} outlier energy conformers have been discarded from dataset".format(outlier_count))
for key, val in atomic_properties_.items():
atomic_properties_[key] = val[low_idx]
for key, val in properties_.items():
properties_[key] = val[low_idx]
molecules = low_idx.numel()
# compute size of each subset
split_ = []
total = 0
for index, size in enumerate(split):
if isinstance(size, float):
size = int(size * molecules)
if size is None:
assert index == len(split) - 1
size = molecules - total
split_.append(size)
total += size
# split
start = 0
splitted = []
for size in split_:
ap = {k: atomic_properties_[k][start:start + size] for k in atomic_keys}
p = {k: properties_[k][start:start + size] for k in keys}
start += size
splitted.append((ap, p))
# consturct batched dataset
ret = []
for ap, p in splitted:
ds = PaddedBatchChunkDataset(ap, p, batch_size, dtype, device)
ds.properties = properties
ds.atomic_properties = atomic_properties
ret.append(ds)
if len(ret) == 1:
return ret[0]
return tuple(ret)
__all__ = ['load_ani_dataset', 'PaddedBatchChunkDataset', 'CachedDataset', 'ShuffledDataset', 'find_threshold']
import numpy as np
import torch
import functools
from ._pyanitools import anidataloader
from importlib import util as u
import gc
PKBAR_INSTALLED = u.find_spec('pkbar') is not None
if PKBAR_INSTALLED:
import pkbar
def find_threshold(file_path, batch_size, threshold_max=100):
"""Find resonable threshold to split chunks before using ``torchani.data.CachedDataset`` or ``torchani.data.ShuffledDataset``.
Arguments:
file_path (str): Path to one hdf5 files.
batch_size (int): batch size.
threshold_max (int): max threshould to test.
"""
ds = CachedDataset(file_path=file_path, batch_size=batch_size)
ds.find_threshold(threshold_max + 1)
class CachedDataset(torch.utils.data.Dataset):
""" Cached Dataset which is shuffled once, but the dataset keeps the same at every epoch.
Arguments:
file_path (str): Path to one hdf5 file.
batch_size (int): batch size.
device (str): ``'cuda'`` or ``'cpu'``, cache to CPU or GPU. Commonly, 'cpu' is already fast enough.
Default is ``'cpu'``.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` will not split chunks.
Use ``torchani.data.find_threshold`` to find resonable ``chunk_threshold``.
other_properties (dict): A dict which is used to extract properties other than
``energies`` from dataset with correct padding, shape and dtype.\n
The example below will extract ``dipoles`` and ``forces``.\n
``padding_values``: set to ``None`` means there is no need to pad for this property.
.. code-block:: python
other_properties = {'properties': ['dipoles', 'forces'],
'padding_values': [None, 0],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3)],
'dtypes': [torch.float32, torch.float32]
}
include_energies (bool): Whether include energies into properties. Default is ``True``.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
self_energies (list): if `subtract_self_energies` is True, the order should keep
the same as ``species_order``.
for example :``[-0.600953, -38.08316, -54.707756, -75.194466]`` will be converted
to ``{'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}``.
.. note::
The resulting dataset will be:
``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a
tuple of ``(species, coordinates)``.
e.g. the shape of\n
chunk1: ``[[1807, 21], [1807, 21, 3]]``\n
chunk2: ``[[193, 50], [193, 50, 3]]``\n
'energies': ``[2000, 1]``
"""
def __init__(self, file_path,
batch_size=1000,
device='cpu',
chunk_threshold=20,
other_properties={},
include_energies=True,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=False,
self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]):
super(CachedDataset, self).__init__()
# example of species_dict will looks like
# species_dict: {'H': 0, 'C': 1, 'N': 2, 'O': 3}
# self_energies_dict: {'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}
species_dict = {}
self_energies_dict = {}
for i, s in enumerate(species_order):
species_dict[s] = i
self_energies_dict[s] = self_energies[i]
self.batch_size = batch_size
self.data_species = []
self.data_coordinates = []
data_self_energies = []
self.data_properties = {}
self.properties_info = other_properties
# whether include energies to properties
if include_energies:
self.add_energies_to_properties()
# let user check the properties will be loaded
self.check_properties()
# anidataloader
anidata = anidataloader(file_path)
anidata_size = anidata.group_size()
self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if self.enable_pkbar:
pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size)
# load h5 data into cpu memory as lists
for i, molecule in enumerate(anidata):
# conformations
num_conformations = len(molecule['coordinates'])
# species and coordinates
self.data_coordinates += list(molecule['coordinates'].reshape(num_conformations, -1).astype(np.float32))
species = np.array([species_dict[x] for x in molecule['species']])
self.data_species += list(np.tile(species, (num_conformations, 1)))
# if subtract_self_energies
if subtract_self_energies:
self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']]))
data_self_energies += list(np.tile(self_energies, (num_conformations, 1)))
# properties
for key in self.data_properties:
self.data_properties[key] += list(molecule[key].reshape(num_conformations, -1))
# pkbar update
if self.enable_pkbar:
pbar.update(i)
# if subtract self energies
if subtract_self_energies and 'energies' in self.properties_info['properties']:
self.data_properties['energies'] = np.array(self.data_properties['energies']) - np.array(data_self_energies)
del data_self_energies
gc.collect()
self.length = (len(self.data_species) + self.batch_size - 1) // self.batch_size
self.device = device
self.shuffled_index = np.arange(len(self.data_species))
np.random.shuffle(self.shuffled_index)
self.chunk_threshold = chunk_threshold
if not self.chunk_threshold:
self.chunk_threshold = np.inf
# clean trash
anidata.cleanup()
del num_conformations
del species
del anidata
gc.collect()
@functools.lru_cache(maxsize=None)
def __getitem__(self, index):
if index >= self.length:
raise IndexError()
batch_indices = slice(index * self.batch_size, (index + 1) * self.batch_size)
batch_indices_shuffled = self.shuffled_index[batch_indices]
batch_species = [self.data_species[i] for i in batch_indices_shuffled]
batch_coordinates = [self.data_coordinates[i] for i in batch_indices_shuffled]
# get sort index
num_atoms_each_mole = [b.shape[0] for b in batch_species]
atoms = torch.tensor(num_atoms_each_mole, dtype=torch.int32)
sorted_atoms, sorted_atoms_idx = torch.sort(atoms)
# sort each batch of data
batch_species = self.sort_list_with_index(batch_species, sorted_atoms_idx.numpy())
batch_coordinates = self.sort_list_with_index(batch_coordinates, sorted_atoms_idx.numpy())
# get chunk size
output, count = torch.unique(atoms, sorted=True, return_counts=True)
counts = torch.cat((output.unsqueeze(-1).int(), count.unsqueeze(-1).int()), dim=-1)
chunk_size_list, chunk_max_list = split_to_chunks(counts, chunk_threshold=self.chunk_threshold * self.batch_size * 20)
chunk_size_list = torch.stack(chunk_size_list).flatten()
# split into chunks
chunks_batch_species = self.split_list_with_size(batch_species, chunk_size_list.numpy())
chunks_batch_coordinates = self.split_list_with_size(batch_coordinates, chunk_size_list.numpy())
# padding each data
chunks_batch_species = self.pad_and_convert_to_tensor(chunks_batch_species, padding_value=-1)
chunks_batch_coordinates = self.pad_and_convert_to_tensor(chunks_batch_coordinates)
# chunks
chunks = list(zip(chunks_batch_species, chunks_batch_coordinates))
for i, _ in enumerate(chunks):
chunks[i] = (chunks[i][0], chunks[i][1].reshape(chunks[i][1].shape[0], -1, 3))
# properties
properties = {}
for i, key in enumerate(self.properties_info['properties']):
# get a batch of property
prop = [self.data_properties[key][i] for i in batch_indices_shuffled]
# sort with number of atoms
prop = self.sort_list_with_index(prop, sorted_atoms_idx.numpy())
# padding and convert to tensor
if self.properties_info['padding_values'][i] is None:
prop = self.pad_and_convert_to_tensor([prop], no_padding=True)[0]
else:
prop = self.pad_and_convert_to_tensor([prop], padding_value=self.properties_info['padding_values'][i])[0]
# set property shape and dtype
padded_shape = list(self.properties_info['padded_shapes'][i])
padded_shape[0] = prop.shape[0] # the last batch may does not have one batch data
properties[key] = prop.reshape(padded_shape).to(self.properties_info['dtypes'][i])
# return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates)
# e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]]
# 'energies' = [2000, 1]
return chunks, properties
def __len__(self):
return self.length
def split(self, validation_split):
"""Split dataset into traning and validaiton.
Arguments:
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
"""
val_size = int(validation_split * len(self))
train_size = len(self) - val_size
ds = []
if self.enable_pkbar:
message = ('=> processing, splitting and caching dataset into cpu memory: \n'
+ 'total batches: {}, train batches: {}, val batches: {}, batch_size: {}')
pbar = pkbar.Pbar(message.format(len(self), train_size, val_size, self.batch_size),
len(self))
for i, _ in enumerate(self):
ds.append(self[i])
if self.enable_pkbar:
pbar.update(i)
train_dataset = ds[:train_size]
val_dataset = ds[train_size:]
return train_dataset, val_dataset
def load(self):
"""Cache dataset into CPU memory. If not called, dataset will be cached during the first epoch.
"""
if self.enable_pkbar:
pbar = pkbar.Pbar('=> processing and caching dataset into cpu memory: \ntotal '
+ 'batches: {}, batch_size: {}'.format(len(self), self.batch_size),
len(self))
for i, _ in enumerate(self):
if self.enable_pkbar:
pbar.update(i)
def add_energies_to_properties(self):
# if user does not provide energies info
if 'properties' in self.properties_info and 'energies' not in self.properties_info['properties']:
# setup energies info, so the user does not need to input energies
self.properties_info['properties'].append('energies')
self.properties_info['padding_values'].append(None)
self.properties_info['padded_shapes'].append((self.batch_size, ))
self.properties_info['dtypes'].append(torch.float64)
# if no properties provided
if 'properties' not in self.properties_info:
self.properties_info = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(self.batch_size, )],
'dtypes': [torch.float64],
}
def check_properties(self):
# print properties information
print('... The following properties will be loaded:')
for i, prop in enumerate(self.properties_info['properties']):
self.data_properties[prop] = []
message = '{}: (dtype: {}, padding_value: {}, padded_shape: {})'
print(message.format(prop, self.properties_info['dtypes'][i],
self.properties_info['padding_values'][i],
self.properties_info['padded_shapes'][i]))
@staticmethod
def sort_list_with_index(inputs, index):
return [inputs[i] for i in index]
@staticmethod
def split_list_with_size(inputs, split_size):
output = []
for i, _ in enumerate(split_size):
start_index = np.sum(split_size[:i])
stop_index = np.sum(split_size[:i + 1])
output.append(inputs[start_index:stop_index])
return output
def pad_and_convert_to_tensor(self, inputs, padding_value=0, no_padding=False):
if no_padding:
for i, input_tmp in enumerate(inputs):
inputs[i] = torch.from_numpy(np.stack(input_tmp)).to(self.device)
else:
for i, input_tmp in enumerate(inputs):
inputs[i] = torch.nn.utils.rnn.pad_sequence(
[torch.from_numpy(b) for b in inputs[i]],
batch_first=True, padding_value=padding_value).to(self.device)
return inputs
def find_threshold(self, threshold_max=100):
batch_indices = slice(0, self.batch_size)
batch_indices_shuffled = self.shuffled_index[batch_indices]
batch_species = [self.data_species[i] for i in batch_indices_shuffled]
num_atoms_each_mole = [b.shape[0] for b in batch_species]
atoms = torch.tensor(num_atoms_each_mole, dtype=torch.int32)
output, count = torch.unique(atoms, sorted=True, return_counts=True)
counts = torch.cat((output.unsqueeze(-1).int(), count.unsqueeze(-1).int()), dim=-1)
print('=> choose a reasonable threshold to split chunks')
print('format is [chunk_size, chunk_max]')
for b in range(0, threshold_max, 1):
test_chunk_size_list, test_chunk_max_list = split_to_chunks(counts, chunk_threshold=b * self.batch_size * 20)
size_max = []
for i, _ in enumerate(test_chunk_size_list):
size_max.append([list(test_chunk_size_list[i].numpy())[0],
list(test_chunk_max_list[i].numpy())[0]])
print('chunk_threshold = {}'.format(b))
print(size_max)
def release_h5(self):
del self.data_species
del self.data_coordinates
del self.data_energies
gc.collect()
def ShuffledDataset(file_path,
batch_size=1000, num_workers=0, shuffle=True,
chunk_threshold=20,
other_properties={},
include_energies=True,
validation_split=0.0,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=False,
self_energies=[-0.600953, -38.08316, -54.707756, -75.194466]):
""" Shuffled Dataset which using `torch.utils.data.DataLoader`, it will shuffle at every epoch.
Arguments:
file_path (str): Path to one hdf5 file.
batch_size (int): batch size.
num_workers (int): multiple process to prepare dataset at background when
training is going.
shuffle (bool): whether to shuffle.
chunk_threshold (int): threshould to split batch into chunks. Set to ``None`` will not split chunks.
Use ``torchani.data.find_threshold`` to find resonable ``chunk_threshold``.
other_properties (dict): A dict which is used to extract properties other than
``energies`` from dataset with correct padding, shape and dtype.\n
The example below will extract ``dipoles`` and ``forces``.\n
``padding_values``: set to ``None`` means there is no need to pad for this property.
.. code-block:: python
other_properties = {'properties': ['dipoles', 'forces'],
'padding_values': [None, 0],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3)],
'dtypes': [torch.float32, torch.float32]
}
include_energies (bool): Whether include energies into properties. Default is ``True``.
validation_split (float): Float between 0 and 1. Fraction of the dataset to be used
as validation data.
species_order (list): a list which specify how species are transfomed to int.
for example: ``['H', 'C', 'N', 'O']`` means ``{'H': 0, 'C': 1, 'N': 2, 'O': 3}``.
subtract_self_energies (bool): whether subtract self energies from ``energies``.
self_energies (list): if `subtract_self_energies` is True, the order should keep
the same as ``species_order``.
for example :``[-0.600953, -38.08316, -54.707756, -75.194466]`` will be
converted to ``{'H': -0.600953, 'C': -38.08316, 'N': -54.707756, 'O': -75.194466}``.
.. note::
Return a dataloader that, when iterating, you will get
``([chunk1, chunk2, ...], {'energies', 'force', ...})`` in which chunk1 is a
tuple of ``(species, coordinates)``.\n
e.g. the shape of\n
chunk1: ``[[1807, 21], [1807, 21, 3]]``\n
chunk2: ``[[193, 50], [193, 50, 3]]``\n
'energies': ``[2000, 1]``
"""
dataset = TorchData(file_path,
batch_size,
other_properties,
include_energies,
species_order,
subtract_self_energies,
self_energies)
properties_info = dataset.get_properties_info()
if not chunk_threshold:
chunk_threshold = np.inf
def my_collate_fn(data, chunk_threshold=chunk_threshold, properties_info=properties_info):
return collate_fn(data, chunk_threshold, properties_info)
val_size = int(validation_split * len(dataset))
train_size = len(dataset) - val_size
train_dataset, val_dataset = torch.utils.data.random_split(dataset, [train_size, val_size])
train_data_loader = torch.utils.data.DataLoader(dataset=train_dataset,
batch_size=batch_size,
shuffle=shuffle,
num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)
if val_size == 0:
return train_data_loader
val_data_loader = torch.utils.data.DataLoader(dataset=val_dataset,
batch_size=batch_size,
shuffle=False,
num_workers=num_workers,
pin_memory=False,
collate_fn=my_collate_fn)
return train_data_loader, val_data_loader
class TorchData(torch.utils.data.Dataset):
def __init__(self, file_path,
batch_size,
other_properties,
include_energies,
species_order,
subtract_self_energies,
self_energies):
super(TorchData, self).__init__()
species_dict = {}
self_energies_dict = {}
for i, s in enumerate(species_order):
species_dict[s] = i
self_energies_dict[s] = self_energies[i]
self.batch_size = batch_size
self.data_species = []
self.data_coordinates = []
data_self_energies = []
self.data_properties = {}
self.properties_info = other_properties
# whether include energies to properties
if include_energies:
self.add_energies_to_properties()
# let user check the properties will be loaded
self.check_properties()
# anidataloader
anidata = anidataloader(file_path)
anidata_size = anidata.group_size()
self.enable_pkbar = anidata_size > 5 and PKBAR_INSTALLED
if self.enable_pkbar:
pbar = pkbar.Pbar('=> loading h5 dataset into cpu memory, total molecules: {}'.format(anidata_size), anidata_size)
# load h5 data into cpu memory as lists
for i, molecule in enumerate(anidata):
# conformations
num_conformations = len(molecule['coordinates'])
# species and coordinates
self.data_coordinates += list(molecule['coordinates'].reshape(num_conformations, -1).astype(np.float32))
species = np.array([species_dict[x] for x in molecule['species']])
self.data_species += list(np.tile(species, (num_conformations, 1)))
# if subtract_self_energies
if subtract_self_energies:
self_energies = np.array(sum([self_energies_dict[x] for x in molecule['species']]))
data_self_energies += list(np.tile(self_energies, (num_conformations, 1)))
# properties
for key in self.data_properties:
self.data_properties[key] += list(molecule[key].reshape(num_conformations, -1))
# pkbar update
if self.enable_pkbar:
pbar.update(i)
# if subtract self energies
if subtract_self_energies and 'energies' in self.properties_info['properties']:
self.data_properties['energies'] = np.array(self.data_properties['energies']) - np.array(data_self_energies)
del data_self_energies
gc.collect()
self.length = len(self.data_species)
# clean trash
anidata.cleanup()
del num_conformations
del species
del anidata
gc.collect()
def __getitem__(self, index):
if index >= self.length:
raise IndexError()
species = torch.from_numpy(self.data_species[index])
coordinates = torch.from_numpy(self.data_coordinates[index]).float()
properties = {}
for key in self.data_properties:
properties[key] = torch.from_numpy(self.data_properties[key][index])
return [species, coordinates, properties]
def __len__(self):
return self.length
def add_energies_to_properties(self):
# if user does not provide energies info
if 'properties' in self.properties_info and 'energies' not in self.properties_info['properties']:
# setup energies info, so the user does not need to input energies
self.properties_info['properties'].append('energies')
self.properties_info['padding_values'].append(None)
self.properties_info['padded_shapes'].append((self.batch_size, ))
self.properties_info['dtypes'].append(torch.float64)
# if no properties provided
if 'properties' not in self.properties_info:
self.properties_info = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(self.batch_size, )],
'dtypes': [torch.float64],
}
def check_properties(self):
# print properties information
print('... The following properties will be loaded:')
for i, prop in enumerate(self.properties_info['properties']):
self.data_properties[prop] = []
message = '{}: (dtype: {}, padding_value: {}, padded_shape: {})'
print(message.format(prop, self.properties_info['dtypes'][i],
self.properties_info['padding_values'][i],
self.properties_info['padded_shapes'][i]))
def get_properties_info(self):
return self.properties_info
def collate_fn(data, chunk_threshold, properties_info):
"""Creates a batch of chunked data.
"""
# unzip a batch of molecules (each molecule is a list)
batch_species, batch_coordinates, batch_properties = zip(*data)
batch_size = len(batch_species)
# padding - time: 13.2s
batch_species = torch.nn.utils.rnn.pad_sequence(batch_species,
batch_first=True,
padding_value=-1)
batch_coordinates = torch.nn.utils.rnn.pad_sequence(batch_coordinates,
batch_first=True,
padding_value=np.inf)
# sort - time: 0.7s
atoms = torch.sum(~(batch_species == -1), dim=-1, dtype=torch.int32)
sorted_atoms, sorted_atoms_idx = torch.sort(atoms)
batch_species = torch.index_select(batch_species, dim=0, index=sorted_atoms_idx)
batch_coordinates = torch.index_select(batch_coordinates, dim=0, index=sorted_atoms_idx)
# get chunk size - time: 2.1s
output, count = torch.unique(atoms, sorted=True, return_counts=True)
counts = torch.cat((output.unsqueeze(-1).int(), count.unsqueeze(-1).int()), dim=-1)
chunk_size_list, chunk_max_list = split_to_chunks(counts, chunk_threshold=chunk_threshold * batch_size * 20)
# split into chunks - time: 0.3s
chunks_batch_species = torch.split(batch_species, chunk_size_list, dim=0)
chunks_batch_coordinates = torch.split(batch_coordinates, chunk_size_list, dim=0)
# truncate redundant padding - time: 1.3s
chunks_batch_species = trunc_pad(list(chunks_batch_species), padding_value=-1)
chunks_batch_coordinates = trunc_pad(list(chunks_batch_coordinates), padding_value=np.inf)
for i, c in enumerate(chunks_batch_coordinates):
chunks_batch_coordinates[i] = c.reshape(c.shape[0], -1, 3)
chunks = list(zip(chunks_batch_species, chunks_batch_coordinates))
for i, _ in enumerate(chunks):
chunks[i] = (chunks[i][0], chunks[i][1])
# properties
properties = {}
for i, key in enumerate(properties_info['properties']):
# get a batch of property
prop = tuple(p[key] for p in batch_properties)
# padding and convert to tensor
if properties_info['padding_values'][i] is None:
prop = torch.stack(prop)
else:
prop = torch.nn.utils.rnn.pad_sequence(prop,
batch_first=True,
padding_value=properties_info['padding_values'][i])
# sort with number of atoms
prop = torch.index_select(prop, dim=0, index=sorted_atoms_idx)
# set property shape and dtype
padded_shape = list(properties_info['padded_shapes'][i])
padded_shape[0] = prop.shape[0] # the last batch may does not have one batch data
properties[key] = prop.reshape(padded_shape).to(properties_info['dtypes'][i])
# return: [chunk1, chunk2, ...], {"energies", "force", ...} in which chunk1=(species, coordinates)
# e.g. chunk1 = [[1807, 21], [1807, 21, 3]], chunk2 = [[193, 50], [193, 50, 3]]
# 'energies' = [2000, 1]
return chunks, properties
def split_to_two_chunks(counts, chunk_threshold):
counts = counts.cpu()
# NB (@yueyericardo): In principle this dtype should be `torch.bool`, but unfortunately
# `triu` is not implemented for bool tensor right now. This should be fixed when PyTorch
# add support for it.
left_mask = torch.triu(torch.ones([counts.shape[0], counts.shape[0]], dtype=torch.uint8))
left_mask = left_mask.t()
counts_atoms = counts[:, 0].repeat(counts.shape[0], 1)
counts_counts = counts[:, 1].repeat(counts.shape[0], 1)
counts_atoms_left = torch.where(left_mask, counts_atoms, torch.zeros_like(counts_atoms))
counts_atoms_right = torch.where(~left_mask, counts_atoms, torch.zeros_like(counts_atoms))
counts_counts_left = torch.where(left_mask, counts_counts, torch.zeros_like(counts_atoms))
counts_counts_right = torch.where(~left_mask, counts_counts, torch.zeros_like(counts_atoms))
# chunk max
chunk_max_left = torch.max(counts_atoms_left, dim=-1, keepdim=True).values
chunk_max_right = torch.max(counts_atoms_right, dim=-1, keepdim=True).values
# chunk size
chunk_size_left = torch.sum(counts_counts_left, dim=-1, keepdim=True, dtype=torch.int32)
chunk_size_right = torch.sum(counts_counts_right, dim=-1, keepdim=True, dtype=torch.int32)
# calculate cost
min_cost_threshold = torch.tensor([chunk_threshold], dtype=torch.int32)
cost = (torch.max(chunk_size_left * chunk_max_left * chunk_max_left, min_cost_threshold)
+ torch.max(chunk_size_right * chunk_max_right * chunk_max_right, min_cost_threshold))
# find smallest cost
cost_min, cost_min_index = torch.min(cost.squeeze(), dim=-1)
# find smallest cost chunk_size, if not splitted, it will be [max_chunk_size, 0]
final_chunk_size = [chunk_size_left[cost_min_index], chunk_size_right[cost_min_index]]
final_chunk_max = [chunk_max_left[cost_min_index], chunk_max_right[cost_min_index]]
# if not splitted
if cost_min_index == (counts.shape[0] - 1):
return False, counts, [final_chunk_size[0]], [final_chunk_max[0]], cost_min
# if splitted
return True, [counts[:cost_min_index + 1], counts[(cost_min_index + 1):]], \
final_chunk_size, final_chunk_max, cost_min
def split_to_chunks(counts, chunk_threshold=np.inf):
splitted, counts_list, chunk_size, chunk_max, cost = split_to_two_chunks(counts, chunk_threshold)
final_chunk_size = []
final_chunk_max = []
if (splitted):
for i, _ in enumerate(counts_list):
tmp_chunk_size, tmp_chunk_max = split_to_chunks(counts_list[i], chunk_threshold)
final_chunk_size.extend(tmp_chunk_size)
final_chunk_max.extend(tmp_chunk_max)
return final_chunk_size, final_chunk_max
# if not splitted
return chunk_size, chunk_max
def trunc_pad(chunks, padding_value=0):
for i, _ in enumerate(chunks):
lengths = torch.sum(~(chunks[i] == padding_value), dim=-1, dtype=torch.int32)
chunks[i] = chunks[i][..., :lengths.max()]
return chunks
...@@ -91,7 +91,7 @@ class BuiltinNet(torch.nn.Module): ...@@ -91,7 +91,7 @@ class BuiltinNet(torch.nn.Module):
self.species = self.consts.species self.species = self.consts.species
self.species_converter = SpeciesConverter(self.species) self.species_converter = SpeciesConverter(self.species)
self.aev_computer = AEVComputer(**self.consts) self.aev_computer = AEVComputer(**self.consts)
self.energy_shifter = neurochem.load_sae(self.sae_file) self.energy_shifter, self.sae_dict = neurochem.load_sae(self.sae_file, return_dict=True)
self.neural_networks = neurochem.load_model_ensemble( self.neural_networks = neurochem.load_model_ensemble(
self.species, self.ensemble_prefix, self.ensemble_size) self.species, self.ensemble_prefix, self.ensemble_size)
......
...@@ -70,17 +70,22 @@ class Constants(collections.abc.Mapping): ...@@ -70,17 +70,22 @@ class Constants(collections.abc.Mapping):
return getattr(self, item) return getattr(self, item)
def load_sae(filename): def load_sae(filename, return_dict=False):
"""Returns an object of :class:`EnergyShifter` with self energies from """Returns an object of :class:`EnergyShifter` with self energies from
NeuroChem sae file""" NeuroChem sae file"""
self_energies = [] self_energies = []
d = {}
with open(filename) as f: with open(filename) as f:
for i in f: for i in f:
line = [x.strip() for x in i.split('=')] line = [x.strip() for x in i.split('=')]
species = line[0].split(',')[0].strip()
index = int(line[0].split(',')[1].strip()) index = int(line[0].split(',')[1].strip())
value = float(line[1]) value = float(line[1])
d[species] = value
self_energies.append((index, value)) self_energies.append((index, value))
self_energies = [i for _, i in sorted(self_energies)] self_energies = [i for _, i in sorted(self_energies)]
if return_dict:
return EnergyShifter(self_energies), d
return EnergyShifter(self_energies) return EnergyShifter(self_energies)
...@@ -285,13 +290,13 @@ if sys.version_info[0] > 2: ...@@ -285,13 +290,13 @@ if sys.version_info[0] > 2:
def __init__(self, filename, device=torch.device('cuda'), tqdm=False, def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, checkpoint_name='model.pt'): tensorboard=None, checkpoint_name='model.pt'):
from ..data import load_ani_dataset # noqa: E402 from ..data import load # noqa: E402
class dummy: class dummy:
pass pass
self.imports = dummy() self.imports = dummy()
self.imports.load_ani_dataset = load_ani_dataset self.imports.load = load
self.filename = filename self.filename = filename
self.device = device self.device = device
...@@ -467,7 +472,7 @@ if sys.version_info[0] > 2: ...@@ -467,7 +472,7 @@ if sys.version_info[0] > 2:
self.aev_computer = AEVComputer(**self.consts) self.aev_computer = AEVComputer(**self.consts)
del params['sflparamsfile'] del params['sflparamsfile']
self.sae_file = os.path.join(dir_, params['atomEnergyFile']) self.sae_file = os.path.join(dir_, params['atomEnergyFile'])
self.shift_energy = load_sae(self.sae_file) self.shift_energy, self.sae = load_sae(self.sae_file, return_dict=True)
del params['atomEnergyFile'] del params['atomEnergyFile']
network_dir = os.path.join(dir_, params['ntwkStoreDir']) network_dir = os.path.join(dir_, params['ntwkStoreDir'])
if not os.path.exists(network_dir): if not os.path.exists(network_dir):
...@@ -552,26 +557,18 @@ if sys.version_info[0] > 2: ...@@ -552,26 +557,18 @@ if sys.version_info[0] > 2:
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.""" """Load training and validation dataset from file."""
self.training_set = self.imports.load_ani_dataset( self.training_set = self.imports.load(training_path).subtract_self_energies(self.sae).species_to_indices().shuffle().collate(self.training_batch_size).cache()
training_path, self.consts.species_to_tensor, self.validation_set = self.imports.load(validation_path).subtract_self_energies(self.sae).species_to_indices().shuffle().collate(self.validation_batch_size).cache()
self.training_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def evaluate(self, dataset): def evaluate(self, dataset):
"""Run the evaluation""" """Run the evaluation"""
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for batch_x, batch_y in dataset: for properties in dataset:
true_energies = batch_y['energies'] species = properties['species'].to(self.device)
predicted_energies = [] coordinates = properties['coordinates'].to(self.device).float()
for chunk_species, chunk_coordinates in batch_x: true_energies = properties['energies'].to(self.device).float()
_, chunk_energies = self.model((chunk_species, chunk_coordinates)) _, predicted_energies = self.model((species, coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += self.mse_sum(predicted_energies, true_energies).item() total_mse += self.mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -620,21 +617,16 @@ if sys.version_info[0] > 2: ...@@ -620,21 +617,16 @@ if sys.version_info[0] > 2:
self.tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in self.tqdm( for i, properties in self.tqdm(
enumerate(self.training_set), enumerate(self.training_set),
total=len(self.training_set), total=len(self.training_set),
desc='epoch {}'.format(AdamW_scheduler.last_epoch) desc='epoch {}'.format(AdamW_scheduler.last_epoch)
): ):
species = properties['species'].to(self.device)
true_energies = batch_y['energies'] coordinates = properties['coordinates'].to(self.device).float()
predicted_energies = [] true_energies = properties['energies'].to(self.device).float()
num_atoms = [] num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
for chunk_species, chunk_coordinates in batch_x: _, predicted_energies = self.model((species, coordinates))
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = self.model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW_optim.zero_grad() AdamW_optim.zero_grad()
SGD_optim.zero_grad() SGD_optim.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment