Unverified Commit 6b058c6e authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

Remove everything about chunking (#432)

* Remove everything about chunking

* aev.py

* neurochem trainer

* training-benchmark-nsys-profile.py

* fix eval

* training-benchmark.py

* nnp_training.py

* flake8

* nnp_training_force.py

* fix dtype of species

* fix

* flake8

* requires_grad_

* git ignore

* fix

* original

* fix

* fix

* fix

* fix

* save

* save

* save

* save

* save

* save

* save

* save

* save

* collate

* fix

* save

* fix

* save

* save

* fix

* save

* fix

* fix

* no len

* float

* save

* save

* save

* save

* save

* save

* save

* save

* save

* fix

* save

* save

* save

* save

* fix

* fix

* fix

* fix mypy

* don't remove outliers

* save

* save

* save

* fix

* flake8

* save

* fix

* flake8

* docs

* more docs

* fix test_data

* remove test_data_new

* fix
parent 338f896a
...@@ -39,3 +39,5 @@ jobs: ...@@ -39,3 +39,5 @@ jobs:
run: python tools/comp6.py dataset/COMP6/COMP6v1/s66x8 run: python tools/comp6.py dataset/COMP6/COMP6v1/s66x8
- name: Training Benchmark - name: Training Benchmark
run: python tools/training-benchmark.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5 run: python tools/training-benchmark.py dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
- name: Training Benchmark Nsight System
run: python tools/training-benchmark-nsys-profile.py --dry-run dataset/ani1-up_to_gdb4/ani_gdb_s01.h5
...@@ -18,7 +18,7 @@ jobs: ...@@ -18,7 +18,7 @@ jobs:
python-version: [3.6, 3.8] python-version: [3.6, 3.8]
test-filenames: [ test-filenames: [
test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py, test_aev.py, test_aev_benzene_md.py, test_aev_nist.py, test_aev_tripeptide_md.py,
test_data_new.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py, test_utils.py, test_ase.py, test_energies.py, test_periodic_table_indexing.py,
test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py, test_neurochem.py, test_vibrational.py, test_ensemble.py, test_padding.py,
test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py] test_data.py, test_forces.py, test_structure_optim.py, test_jit_builtin_models.py]
......
...@@ -32,3 +32,5 @@ dist ...@@ -32,3 +32,5 @@ dist
/download.tar.xz /download.tar.xz
*.qdrep *.qdrep
*.qdstrm *.qdstrm
*.zip
Untitled.ipynb
...@@ -26,12 +26,6 @@ Datasets ...@@ -26,12 +26,6 @@ Datasets
======== ========
.. automodule:: torchani.data .. automodule:: torchani.data
.. autofunction:: torchani.data.find_threshold
.. autofunction:: torchani.data.ShuffledDataset
.. autoclass:: torchani.data.CachedDataset
:members:
.. autofunction:: torchani.data.load_ani_dataset
.. autoclass:: torchani.data.PaddedBatchChunkDataset
......
...@@ -80,37 +80,17 @@ try: ...@@ -80,37 +80,17 @@ try:
except NameError: except NameError:
path = os.getcwd() path = os.getcwd()
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560 batch_size = 2560
training, validation = torchani.data.load_ani_dataset( dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
dspath, species_to_tensor, batch_size, rm_outlier=True, device=device, size = len(dataset)
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None]) training, validation = dataset.split(int(0.8 * size), None)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
############################################################################### ###############################################################################
# When iterating the dataset, we will get pairs of input and output # When iterating the dataset, we will get a dict of name->property mapping
# ``(species_coordinates, properties)``, where ``species_coordinates`` is the
# input and ``properties`` is the output.
#
# ``species_coordinates`` is a list of species-coordinate pairs, with shape
# ``(N, Na)`` and ``(N, Na, 3)``. The reason for getting this type is, when
# loading the dataset and generating minibatches, the whole dataset are
# shuffled and each minibatch contains structures of molecules with a wide
# range of number of atoms. Molecules of different number of atoms are batched
# into single by padding. The way padding works is: adding ghost atoms, with
# species 'X', and do computations as if they were normal atoms. But when
# computing AEVs, atoms with species `X` would be ignored. To avoid computation
# wasting on padding atoms, minibatches are further splitted into chunks. Each
# chunk contains structures of molecules of similar size, which minimize the
# total number of padding atoms required to add. The input list
# ``species_coordinates`` contains chunks of that minibatch we are getting. The
# batching and chunking happens automatically, so the user does not need to
# worry how to construct chunks, but the user need to compute the energies for
# each chunk and concat them into single tensor.
#
# The output, i.e. ``properties`` is a dictionary holding each property. This
# allows us to extend TorchANI in the future to training forces and properties.
# #
############################################################################### ###############################################################################
# Now let's define atomic neural networks. # Now let's define atomic neural networks.
...@@ -279,16 +259,11 @@ def validate(): ...@@ -279,16 +259,11 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for batch_x, batch_y in validation: for properties in validation:
true_energies = batch_y['energies'] species = properties['species'].to(device)
predicted_energies = [] coordinates = properties['coordinates'].to(device).float()
atomic_properties = [] true_energies = properties['energies'].to(device).float()
for chunk_species, chunk_coordinates in batch_x: _, predicted_energies = model((species, coordinates))
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -331,26 +306,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -331,26 +306,17 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch) tensorboard.add_scalar('best_validation_rmse', AdamW_scheduler.best, AdamW_scheduler.last_epoch)
tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in tqdm.tqdm( for i, properties in tqdm.tqdm(
enumerate(training), enumerate(training),
total=len(training), total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch) desc="epoch {}".format(AdamW_scheduler.last_epoch)
): ):
species = properties['species'].to(device)
coordinates = properties['coordinates'].to(device).float()
true_energies = properties['energies'].to(device).float()
num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
_, predicted_energies = model((species, coordinates))
true_energies = batch_y['energies']
predicted_energies = []
num_atoms = []
atomic_properties = []
for chunk_species, chunk_coordinates in batch_x:
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies
num_atoms = torch.cat(num_atoms)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW.zero_grad() AdamW.zero_grad()
......
...@@ -49,34 +49,14 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5') ...@@ -49,34 +49,14 @@ dspath = os.path.join(path, '../dataset/ani-1x/sample.h5')
batch_size = 2560 batch_size = 2560
############################################################################### dataset = torchani.data.load(dspath).subtract_self_energies(energy_shifter).species_to_indices().shuffle()
# The code to create the dataset is a bit different: we need to manually size = len(dataset)
# specify that ``atomic_properties=['forces']`` so that forces will be read training, validation = dataset.split(int(0.8 * size), None)
# from hdf5 files. training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
training, validation = torchani.data.load_ani_dataset(
dspath, species_to_tensor, batch_size, rm_outlier=True,
device=device, atomic_properties=['forces'],
transform=[energy_shifter.subtract_from_dataset], split=[0.8, None])
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
###############################################################################
# When iterating the dataset, we will get pairs of input and output
# ``(species_coordinates, properties)``, in this case, ``properties`` would
# contain a key ``'atomic'`` where ``properties['atomic']`` is a list of dict
# containing forces:
data = training[0]
properties = data[1]
atomic_properties = properties['atomic']
print(type(atomic_properties))
print(list(atomic_properties[0].keys()))
###############################################################################
# Due to padding, part of the forces might be 0
print(atomic_properties[0]['forces'][0])
############################################################################### ###############################################################################
# The code to define networks, optimizers, are mostly the same # The code to define networks, optimizers, are mostly the same
...@@ -225,13 +205,11 @@ def validate(): ...@@ -225,13 +205,11 @@ def validate():
mse_sum = torch.nn.MSELoss(reduction='sum') mse_sum = torch.nn.MSELoss(reduction='sum')
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for batch_x, batch_y in validation: for properties in validation:
true_energies = batch_y['energies'] species = properties['species'].to(device)
predicted_energies = [] coordinates = properties['coordinates'].to(device).float()
for chunk_species, chunk_coordinates in batch_x: true_energies = properties['energies'].to(device).float()
chunk_energies = model((chunk_species, chunk_coordinates)).energies _, predicted_energies = model((species, coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += mse_sum(predicted_energies, true_energies).item() total_mse += mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -275,49 +253,28 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs): ...@@ -275,49 +253,28 @@ for _ in range(AdamW_scheduler.last_epoch + 1, max_epochs):
# Besides being stored in x, species and coordinates are also stored in y. # Besides being stored in x, species and coordinates are also stored in y.
# So here, for simplicity, we just ignore the x and use y for everything. # So here, for simplicity, we just ignore the x and use y for everything.
for i, (_, batch_y) in tqdm.tqdm( for i, properties in tqdm.tqdm(
enumerate(training), enumerate(training),
total=len(training), total=len(training),
desc="epoch {}".format(AdamW_scheduler.last_epoch) desc="epoch {}".format(AdamW_scheduler.last_epoch)
): ):
species = properties['species'].to(device)
true_energies = batch_y['energies'] coordinates = properties['coordinates'].to(device).float().requires_grad_(True)
predicted_energies = [] true_energies = properties['energies'].to(device).float()
num_atoms = [] true_forces = properties['forces'].to(device).float()
force_loss = [] num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
_, predicted_energies = model((species, coordinates))
for chunk in batch_y['atomic']:
chunk_species = chunk['species'] # We can use torch.autograd.grad to compute force. Remember to
chunk_coordinates = chunk['coordinates'] # create graph so that the loss of the force can contribute to
chunk_true_forces = chunk['forces'] # the gradient of parameters, and also to retain graph so that
chunk_num_atoms = (chunk_species >= 0).to(true_energies.dtype).sum(dim=1) # we can backward through it a second time when computing gradient
num_atoms.append(chunk_num_atoms) # w.r.t. parameters.
forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0]
# We must set `chunk_coordinates` to make it requires grad, so
# that we could compute force from it
chunk_coordinates.requires_grad_(True)
chunk_energies = model((chunk_species, chunk_coordinates)).energies
# We can use torch.autograd.grad to compute force. Remember to
# create graph so that the loss of the force can contribute to
# the gradient of parameters, and also to retain graph so that
# we can backward through it a second time when computing gradient
# w.r.t. parameters.
chunk_forces = -torch.autograd.grad(chunk_energies.sum(), chunk_coordinates, create_graph=True, retain_graph=True)[0]
# Now let's compute loss for force of this chunk
chunk_force_loss = mse(chunk_true_forces, chunk_forces).sum(dim=(1, 2)) / chunk_num_atoms
predicted_energies.append(chunk_energies)
force_loss.append(chunk_force_loss)
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies)
# Now the total loss has two parts, energy loss and force loss # Now the total loss has two parts, energy loss and force loss
energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
force_loss = torch.cat(force_loss).mean() force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) / num_atoms).mean()
loss = energy_loss + force_coefficient * force_loss loss = energy_loss + force_coefficient * force_loss
AdamW.zero_grad() AdamW.zero_grad()
......
...@@ -118,7 +118,8 @@ class TestAEV(_TestAEVBase): ...@@ -118,7 +118,8 @@ class TestAEV(_TestAEVBase):
species = self.transform(species) species = self.transform(species)
radial = self.transform(radial) radial = self.transform(radial)
angular = self.transform(angular) angular = self.transform(angular)
species_coordinates.append({'species': species, 'coordinates': coordinates}) species_coordinates.append(torchani.utils.broadcast_first_dim(
{'species': species, 'coordinates': coordinates}))
radial_angular.append((radial, angular)) radial_angular.append((radial, angular))
species_coordinates = torchani.utils.pad_atomic_properties( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
......
import os import os
import torch
import torchani import torchani
import unittest import unittest
path = os.path.dirname(os.path.realpath(__file__)) path = os.path.dirname(os.path.realpath(__file__))
dataset_path = os.path.join(path, '../dataset/ani1-up_to_gdb4') dataset_path = os.path.join(path, 'dataset/ani-1x/sample.h5')
dataset_path2 = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 256 batch_size = 256
ani1x = torchani.models.ANI1x() ani1x = torchani.models.ANI1x()
consts = ani1x.consts consts = ani1x.consts
sae_dict = ani1x.sae_dict
aev_computer = ani1x.aev_computer aev_computer = ani1x.aev_computer
class TestData(unittest.TestCase): class TestData(unittest.TestCase):
def setUp(self):
self.ds = torchani.data.load_ani_dataset(dataset_path,
consts.species_to_tensor,
batch_size)
def _assertTensorEqual(self, t1, t2):
self.assertLess((t1 - t2).abs().max().item(), 1e-6)
def testSplitBatch(self):
species1 = torch.randint(4, (5, 4), dtype=torch.long)
coordinates1 = torch.randn(5, 4, 3)
species2 = torch.randint(4, (2, 8), dtype=torch.long)
coordinates2 = torch.randn(2, 8, 3)
species3 = torch.randint(4, (10, 20), dtype=torch.long)
coordinates3 = torch.randn(10, 20, 3)
species_coordinates = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1},
{'species': species2, 'coordinates': coordinates2},
{'species': species3, 'coordinates': coordinates3},
])
species = species_coordinates['species']
coordinates = species_coordinates['coordinates']
natoms = (species >= 0).to(torch.long).sum(1)
chunks = torchani.data.split_batch(natoms, species_coordinates)
start = 0
last = None
for chunk in chunks:
s = chunk['species']
c = chunk['coordinates']
n = (s >= 0).to(torch.long).sum(1)
if last is not None:
self.assertNotEqual(last[-1], n[0])
conformations = s.shape[0]
self.assertGreater(conformations, 0)
s_ = species[start:(start + conformations), ...]
c_ = coordinates[start:(start + conformations), ...]
sc = torchani.utils.strip_redundant_padding({'species': s_, 'coordinates': c_})
s_ = sc['species']
c_ = sc['coordinates']
self._assertTensorEqual(s, s_)
self._assertTensorEqual(c, c_)
start += conformations
sc = torchani.utils.pad_atomic_properties(chunks)
s = sc['species']
c = sc['coordinates']
self._assertTensorEqual(s, species)
self._assertTensorEqual(c, coordinates)
def testTensorShape(self): def testTensorShape(self):
for i in self.ds: ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
input_, output = i for d in ds:
input_ = [{'species': x[0], 'coordinates': x[1]} for x in input_] species = d['species']
species_coordinates = torchani.utils.pad_atomic_properties(input_) coordinates = d['coordinates']
species = species_coordinates['species'] energies = d['energies']
coordinates = species_coordinates['coordinates']
energies = output['energies']
self.assertEqual(len(species.shape), 2) self.assertEqual(len(species.shape), 2)
self.assertLessEqual(species.shape[0], batch_size) self.assertLessEqual(species.shape[0], batch_size)
self.assertEqual(len(coordinates.shape), 3) self.assertEqual(len(coordinates.shape), 3)
...@@ -80,11 +28,11 @@ class TestData(unittest.TestCase): ...@@ -80,11 +28,11 @@ class TestData(unittest.TestCase):
self.assertEqual(coordinates.shape[0], energies.shape[0]) self.assertEqual(coordinates.shape[0], energies.shape[0])
def testNoUnnecessaryPadding(self): def testNoUnnecessaryPadding(self):
for i in self.ds: ds = torchani.data.load(dataset_path).subtract_self_energies(sae_dict).species_to_indices().shuffle().collate(batch_size).cache()
for input_ in i[0]: for d in ds:
species, _ = input_ species = d['species']
non_padding = (species >= 0)[:, -1].nonzero() non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0) self.assertGreater(non_padding.numel(), 0)
if __name__ == '__main__': if __name__ == '__main__':
......
import torchani
import unittest
import pkbar
import torch
import os
path = os.path.dirname(os.path.realpath(__file__))
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s03.h5')
batch_size = 2560
chunk_threshold = 5
other_properties = {'properties': ['dipoles', 'forces', 'energies'],
'padding_values': [None, 0, None],
'padded_shapes': [(batch_size, 3), (batch_size, -1, 3), (batch_size, )],
'dtypes': [torch.float32, torch.float32, torch.float64],
}
other_properties = {'properties': ['energies'],
'padding_values': [None],
'padded_shapes': [(batch_size, )],
'dtypes': [torch.float64],
}
class TestFindThreshold(unittest.TestCase):
def setUp(self):
print('.. check find threshold to split chunks')
def testFindThreshould(self):
torchani.data.find_threshold(dspath, batch_size=batch_size, threshold_max=10)
class TestShuffledData(unittest.TestCase):
def setUp(self):
print('.. setup shuffle dataset')
self.ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size,
chunk_threshold=chunk_threshold,
num_workers=2,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = iter(self.ds).next()
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
# check shape
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
print('2. properties')
for i, key in enumerate(other_properties['properties']):
print(key, list(self.properties[key].size()), self.properties[key].dtype)
# check dtype
self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
# shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self):
print('=> test loading all dataset')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(self.ds), batch_size),
len(self.ds))
for i, _ in enumerate(self.ds):
pbar.update(i)
def testSplitDataset(self):
print('=> test splitting dataset')
train_ds, val_ds = torchani.data.ShuffledDataset(dspath, batch_size=batch_size, chunk_threshold=chunk_threshold, num_workers=2, validation_split=0.1)
frac = len(val_ds) / (len(val_ds) + len(train_ds))
self.assertLess(abs(frac - 0.1), 0.05)
def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
species, _ = chunk
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
class TestCachedData(unittest.TestCase):
def setUp(self):
print('.. setup cached dataset')
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.chunks, self.properties = self.ds[0]
def testTensorShape(self):
print('=> checking tensor shape')
print('the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
batch_len = 0
print('1. chunks')
for i, chunk in enumerate(self.chunks):
print('chunk{}'.format(i + 1), 'species:', list(chunk[0].size()), chunk[0].dtype,
'coordinates:', list(chunk[1].size()), chunk[1].dtype)
# check dtype
self.assertEqual(chunk[0].dtype, torch.int64)
self.assertEqual(chunk[1].dtype, torch.float32)
# check shape
self.assertEqual(chunk[1].shape[2], 3)
self.assertEqual(chunk[1].shape[:2], chunk[0].shape[:2])
batch_len += chunk[0].shape[0]
print('2. properties')
for i, key in enumerate(other_properties['properties']):
print(key, list(self.properties[key].size()), self.properties[key].dtype)
# check dtype
self.assertEqual(self.properties[key].dtype, other_properties['dtypes'][i])
# shape[0] == batch_size
self.assertEqual(self.properties[key].shape[0], other_properties['padded_shapes'][i][0])
# check len(shape)
self.assertEqual(len(self.properties[key].shape), len(other_properties['padded_shapes'][i]))
def testLoadDataset(self):
print('=> test loading all dataset')
self.ds.load()
def testSplitDataset(self):
print('=> test splitting dataset')
train_dataset, val_dataset = self.ds.split(0.1)
frac = len(val_dataset) / len(self.ds)
self.assertLess(abs(frac - 0.1), 0.05)
def testNoUnnecessaryPadding(self):
print('=> checking No Unnecessary Padding')
for i, chunk in enumerate(self.chunks):
species, _ = chunk
non_padding = (species >= 0)[:, -1].nonzero()
self.assertGreater(non_padding.numel(), 0)
if __name__ == "__main__":
unittest.main()
...@@ -54,7 +54,8 @@ class TestEnergies(unittest.TestCase): ...@@ -54,7 +54,8 @@ class TestEnergies(unittest.TestCase):
coordinates = self.transform(coordinates) coordinates = self.transform(coordinates)
species = self.transform(species) species = self.transform(species)
e = self.transform(e) e = self.transform(e)
species_coordinates.append({'species': species, 'coordinates': coordinates}) species_coordinates.append(
torchani.utils.broadcast_first_dim({'species': species, 'coordinates': coordinates}))
energies.append(e) energies.append(e)
species_coordinates = torchani.utils.pad_atomic_properties( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
......
...@@ -55,7 +55,8 @@ class TestForce(unittest.TestCase): ...@@ -55,7 +55,8 @@ class TestForce(unittest.TestCase):
species = self.transform(species) species = self.transform(species)
forces = self.transform(forces) forces = self.transform(forces)
coordinates.requires_grad_(True) coordinates.requires_grad_(True)
species_coordinates.append({'species': species, 'coordinates': coordinates}) species_coordinates.append(torchani.utils.broadcast_first_dim(
{'species': species, 'coordinates': coordinates}))
species_coordinates = torchani.utils.pad_atomic_properties( species_coordinates = torchani.utils.pad_atomic_properties(
species_coordinates) species_coordinates)
_, energies = self.model((species_coordinates['species'], species_coordinates['coordinates'])) _, energies = self.model((species_coordinates['species'], species_coordinates['coordinates']))
......
...@@ -18,16 +18,14 @@ other_properties = {'properties': ['energies'], ...@@ -18,16 +18,14 @@ other_properties = {'properties': ['energies'],
class TestBuiltinModelsJIT(unittest.TestCase): class TestBuiltinModelsJIT(unittest.TestCase):
def setUp(self): def setUp(self):
self.ds = torchani.data.CachedDataset(dspath, batch_size=batch_size, device='cpu',
chunk_threshold=chunk_threshold,
other_properties=other_properties,
subtract_self_energies=True)
self.ani1ccx = torchani.models.ANI1ccx() self.ani1ccx = torchani.models.ANI1ccx()
self.ds = torchani.data.load(dspath).subtract_self_energies(self.ani1ccx.sae_dict).species_to_indices().shuffle().collate(256).cache()
def _test_model(self, model): def _test_model(self, model):
chunk = self.ds[0][0][0] properties = next(iter(self.ds))
_, e = model(chunk) input_ = (properties['species'], properties['coordinates'].float())
_, e2 = torch.jit.script(model)(chunk) _, e = model(input_)
_, e2 = torch.jit.script(model)(input_)
self.assertTrue(torch.allclose(e, e2)) self.assertTrue(torch.allclose(e, e2))
def _test_ensemble(self, ensemble): def _test_ensemble(self, ensemble):
......
...@@ -3,6 +3,9 @@ import torch ...@@ -3,6 +3,9 @@ import torch
import torchani import torchani
b = torchani.utils.broadcast_first_dim
class TestPaddings(unittest.TestCase): class TestPaddings(unittest.TestCase):
def testVectorSpecies(self): def testVectorSpecies(self):
...@@ -11,8 +14,8 @@ class TestPaddings(unittest.TestCase): ...@@ -11,8 +14,8 @@ class TestPaddings(unittest.TestCase):
species2 = torch.tensor([[3, 2, 0, 1, 0]]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
atomic_properties = torchani.utils.pad_atomic_properties([ atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
self.assertEqual(atomic_properties['species'].shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
...@@ -34,8 +37,8 @@ class TestPaddings(unittest.TestCase): ...@@ -34,8 +37,8 @@ class TestPaddings(unittest.TestCase):
species2 = torch.tensor([[3, 2, 0, 1, 0]]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
atomic_properties = torchani.utils.pad_atomic_properties([ atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
self.assertEqual(atomic_properties['species'].shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
...@@ -63,8 +66,8 @@ class TestPaddings(unittest.TestCase): ...@@ -63,8 +66,8 @@ class TestPaddings(unittest.TestCase):
species2 = torch.tensor([[3, 2, 0, 1, 0]]) species2 = torch.tensor([[3, 2, 0, 1, 0]])
coordinates2 = torch.zeros(2, 5, 3) coordinates2 = torch.zeros(2, 5, 3)
atomic_properties = torchani.utils.pad_atomic_properties([ atomic_properties = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
self.assertEqual(atomic_properties['species'].shape[0], 7) self.assertEqual(atomic_properties['species'].shape[0], 7)
self.assertEqual(atomic_properties['species'].shape[1], 5) self.assertEqual(atomic_properties['species'].shape[1], 5)
...@@ -98,28 +101,28 @@ class TestStripRedundantPadding(unittest.TestCase): ...@@ -98,28 +101,28 @@ class TestStripRedundantPadding(unittest.TestCase):
species2 = torch.randint(4, (2, 5), dtype=torch.long) species2 = torch.randint(4, (2, 5), dtype=torch.long)
coordinates2 = torch.randn(2, 5, 3) coordinates2 = torch.randn(2, 5, 3)
atomic_properties12 = torchani.utils.pad_atomic_properties([ atomic_properties12 = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
]) ])
species12 = atomic_properties12['species'] species12 = atomic_properties12['species']
coordinates12 = atomic_properties12['coordinates'] coordinates12 = atomic_properties12['coordinates']
species3 = torch.randint(4, (2, 10), dtype=torch.long) species3 = torch.randint(4, (2, 10), dtype=torch.long)
coordinates3 = torch.randn(2, 10, 3) coordinates3 = torch.randn(2, 10, 3)
atomic_properties123 = torchani.utils.pad_atomic_properties([ atomic_properties123 = torchani.utils.pad_atomic_properties([
{'species': species1, 'coordinates': coordinates1}, b({'species': species1, 'coordinates': coordinates1}),
{'species': species2, 'coordinates': coordinates2}, b({'species': species2, 'coordinates': coordinates2}),
{'species': species3, 'coordinates': coordinates3}, b({'species': species3, 'coordinates': coordinates3}),
]) ])
species123 = atomic_properties123['species'] species123 = atomic_properties123['species']
coordinates123 = atomic_properties123['coordinates'] coordinates123 = atomic_properties123['coordinates']
species_coordinates1_ = torchani.utils.strip_redundant_padding( species_coordinates1_ = torchani.utils.strip_redundant_padding(
{'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]}) b({'species': species123[:5, ...], 'coordinates': coordinates123[:5, ...]}))
species1_ = species_coordinates1_['species'] species1_ = species_coordinates1_['species']
coordinates1_ = species_coordinates1_['coordinates'] coordinates1_ = species_coordinates1_['coordinates']
self._assertTensorEqual(species1_, species1) self._assertTensorEqual(species1_, species1)
self._assertTensorEqual(coordinates1_, coordinates1) self._assertTensorEqual(coordinates1_, coordinates1)
species_coordinates12_ = torchani.utils.strip_redundant_padding( species_coordinates12_ = torchani.utils.strip_redundant_padding(
{'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]}) b({'species': species123[:7, ...], 'coordinates': coordinates123[:7, ...]}))
species12_ = species_coordinates12_['species'] species12_ = species_coordinates12_['species']
coordinates12_ = species_coordinates12_['coordinates'] coordinates12_ = species_coordinates12_['coordinates']
self._assertTensorEqual(species12_, species12) self._assertTensorEqual(species12_, species12)
......
...@@ -55,24 +55,11 @@ if __name__ == "__main__": ...@@ -55,24 +55,11 @@ if __name__ == "__main__":
parser.add_argument('-b', '--batch_size', parser.add_argument('-b', '--batch_size',
help='Number of conformations of each batch', help='Number of conformations of each batch',
default=2560, type=int) default=2560, type=int)
parser.add_argument('-o', '--original_dataset_api', parser.add_argument('-d', '--dry-run',
help='use original dataset api', help='just run it in a CI without GPU',
dest='dataset', action='store_true')
action='store_const',
const='original')
parser.add_argument('-s', '--shuffle_dataset_api',
help='use shuffle dataset api',
dest='dataset',
action='store_const',
const='shuffle')
parser.add_argument('-c', '--cache_dataset_api',
help='use cache dataset api',
dest='dataset',
action='store_const',
const='cache')
parser.set_defaults(dataset='shuffle')
parser = parser.parse_args() parser = parser.parse_args()
parser.device = torch.device('cuda') parser.device = torch.device('cpu' if parser.dry_run else 'cuda')
Rcr = 5.2000e+00 Rcr = 5.2000e+00
Rca = 3.5000e+00 Rca = 3.5000e+00
...@@ -90,45 +77,9 @@ if __name__ == "__main__": ...@@ -90,45 +77,9 @@ if __name__ == "__main__":
optimizer = torch.optim.Adam(model.parameters(), lr=0.000001) optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
mse = torch.nn.MSELoss(reduction='none') mse = torch.nn.MSELoss(reduction='none')
if parser.dataset == 'shuffle': print('=> loading dataset...')
print('using shuffle dataset API') shifter = torchani.EnergyShifter(None)
print('=> loading dataset...') dataset = list(torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size))
dataset = torchani.data.ShuffledDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size,
num_workers=2)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = iter(dataset).next()
elif parser.dataset == 'original':
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
elif parser.dataset == 'cache':
print('using cache dataset API')
print('=> loading dataset...')
dataset = torchani.data.CachedDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size)
print('=> caching all dataset into cpu')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(dataset), parser.batch_size),
len(dataset))
for i, t in enumerate(dataset):
pbar.update(i)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
for i, chunk in enumerate(chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), list(chunk[1].size()))
print('energies', list(properties['energies'].size()))
print('=> start warming up') print('=> start warming up')
total_batch_counter = 0 total_batch_counter = 0
...@@ -137,36 +88,24 @@ if __name__ == "__main__": ...@@ -137,36 +88,24 @@ if __name__ == "__main__":
print('Epoch: %d/inf' % (epoch + 1,)) print('Epoch: %d/inf' % (epoch + 1,))
progbar = pkbar.Kbar(target=len(dataset) - 1, width=8) progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)
for i, (batch_x, batch_y) in enumerate(dataset): for i, properties in enumerate(dataset):
if total_batch_counter == WARM_UP_BATCHES: if not parser.dry_run and total_batch_counter == WARM_UP_BATCHES:
print('=> warm up finished, start profiling') print('=> warm up finished, start profiling')
enable_timers(model) enable_timers(model)
torch.cuda.cudart().cudaProfilerStart() torch.cuda.cudart().cudaProfilerStart()
PROFILING_STARTED = (total_batch_counter >= WARM_UP_BATCHES) PROFILING_STARTED = not parser.dry_run and (total_batch_counter >= WARM_UP_BATCHES)
if PROFILING_STARTED: if PROFILING_STARTED:
torch.cuda.nvtx.range_push("batch{}".format(total_batch_counter)) torch.cuda.nvtx.range_push("batch{}".format(total_batch_counter))
true_energies = batch_y['energies'].to(parser.device) species = properties['species'].to(parser.device)
predicted_energies = [] coordinates = properties['coordinates'].to(parser.device).float()
num_atoms = [] true_energies = properties['energies'].to(parser.device).float()
num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
for j, (chunk_species, chunk_coordinates) in enumerate(batch_x): with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
if PROFILING_STARTED: _, predicted_energies = model((species, coordinates))
torch.cuda.nvtx.range_push("chunk{}".format(j))
chunk_species = chunk_species.to(parser.device)
chunk_coordinates = chunk_coordinates.to(parser.device)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
with torch.autograd.profiler.emit_nvtx(enabled=PROFILING_STARTED, record_shapes=True):
_, chunk_energies = model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
if PROFILING_STARTED:
torch.cuda.nvtx.range_pop()
num_atoms = torch.cat(num_atoms)
predicted_energies = torch.cat(predicted_energies).to(true_energies.dtype)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
......
...@@ -49,25 +49,9 @@ if __name__ == "__main__": ...@@ -49,25 +49,9 @@ if __name__ == "__main__":
parser.add_argument('-b', '--batch_size', parser.add_argument('-b', '--batch_size',
help='Number of conformations of each batch', help='Number of conformations of each batch',
default=2560, type=int) default=2560, type=int)
parser.add_argument('-o', '--original_dataset_api',
help='use original dataset api',
dest='dataset',
action='store_const',
const='original')
parser.add_argument('-s', '--shuffle_dataset_api',
help='use shuffle dataset api',
dest='dataset',
action='store_const',
const='shuffle')
parser.add_argument('-c', '--cache_dataset_api',
help='use cache dataset api',
dest='dataset',
action='store_const',
const='cache')
parser.add_argument('-y', '--synchronize', parser.add_argument('-y', '--synchronize',
action='store_true', action='store_true',
help='whether to insert torch.cuda.synchronize() at the end of each function') help='whether to insert torch.cuda.synchronize() at the end of each function')
parser.set_defaults(dataset='shuffle')
parser.add_argument('-n', '--num_epochs', parser.add_argument('-n', '--num_epochs',
help='epochs', help='epochs',
default=1, type=int) default=1, type=int)
...@@ -107,48 +91,9 @@ if __name__ == "__main__": ...@@ -107,48 +91,9 @@ if __name__ == "__main__":
model[0].forward = time_func('total', model[0].forward) model[0].forward = time_func('total', model[0].forward)
model[1].forward = time_func('forward', model[1].forward) model[1].forward = time_func('forward', model[1].forward)
if parser.dataset == 'shuffle': print('=> loading dataset...')
torchani.data.ShuffledDataset = time_func('data_loading', torchani.data.ShuffledDataset) shifter = torchani.EnergyShifter(None)
print('using shuffle dataset API') dataset = list(torchani.data.load(parser.dataset_path).subtract_self_energies(shifter).species_to_indices().shuffle().collate(parser.batch_size))
print('=> loading dataset...')
dataset = torchani.data.ShuffledDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size,
num_workers=2)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = iter(dataset).next()
elif parser.dataset == 'original':
torchani.data.load_ani_dataset = time_func('data_loading', torchani.data.load_ani_dataset)
print('using original dataset API')
print('=> loading dataset...')
energy_shifter = torchani.utils.EnergyShifter(None)
species_to_tensor = torchani.utils.ChemicalSymbolsToInts(['H', 'C', 'N', 'O'])
dataset = torchani.data.load_ani_dataset(parser.dataset_path, species_to_tensor,
parser.batch_size, device=parser.device,
transform=[energy_shifter.subtract_from_dataset])
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
elif parser.dataset == 'cache':
torchani.data.CachedDataset = time_func('data_loading', torchani.data.CachedDataset)
print('using cache dataset API')
print('=> loading dataset...')
dataset = torchani.data.CachedDataset(file_path=parser.dataset_path,
species_order=['H', 'C', 'N', 'O'],
subtract_self_energies=True,
batch_size=parser.batch_size)
print('=> caching all dataset into cpu')
pbar = pkbar.Pbar('loading and processing dataset into cpu memory, total '
+ 'batches: {}, batch_size: {}'.format(len(dataset), parser.batch_size),
len(dataset))
for i, t in enumerate(dataset):
pbar.update(i)
print('=> the first batch is ([chunk1, chunk2, ...], {"energies", "force", ...}) in which chunk1=(species, coordinates)')
chunks, properties = dataset[0]
for i, chunk in enumerate(chunks):
print('chunk{}'.format(i + 1), list(chunk[0].size()), list(chunk[1].size()))
print('energies', list(properties['energies'].size()))
print('=> start training') print('=> start training')
start = time.time() start = time.time()
...@@ -158,24 +103,12 @@ if __name__ == "__main__": ...@@ -158,24 +103,12 @@ if __name__ == "__main__":
print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs)) print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs))
progbar = pkbar.Kbar(target=len(dataset) - 1, width=8) progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)
for i, (batch_x, batch_y) in enumerate(dataset): for i, properties in enumerate(dataset):
species = properties['species'].to(parser.device)
true_energies = batch_y['energies'].to(parser.device) coordinates = properties['coordinates'].to(parser.device).float()
predicted_energies = [] true_energies = properties['energies'].to(parser.device).float()
num_atoms = [] num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
atomic_properties = [] _, predicted_energies = model((species, coordinates))
for chunk_species, chunk_coordinates in batch_x:
chunk_species = chunk_species.to(parser.device)
chunk_coordiantes = chunk_coordinates.to(parser.device)
atomic_chunk = {'species': chunk_species, 'coordinates': chunk_coordinates}
atomic_properties.append(atomic_chunk)
num_atoms.append((chunk_species >= 0).to(true_energies.dtype).sum(dim=1))
atomic_properties = torchani.utils.pad_atomic_properties(atomic_properties)
predicted_energies = model((atomic_properties['species'], atomic_properties['coordinates'])).energies.to(true_energies.dtype)
num_atoms = torch.cat(num_atoms)
loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy() rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
loss.backward() loss.backward()
...@@ -191,6 +124,5 @@ if __name__ == "__main__": ...@@ -191,6 +124,5 @@ if __name__ == "__main__":
if k.startswith('torchani.'): if k.startswith('torchani.'):
print('{} - {:.1f}s'.format(k, timers[k])) print('{} - {:.1f}s'.format(k, timers[k]))
print('Total AEV - {:.1f}s'.format(timers['total'])) print('Total AEV - {:.1f}s'.format(timers['total']))
print('Data Loading - {:.1f}s'.format(timers['data_loading']))
print('NN - {:.1f}s'.format(timers['forward'])) print('NN - {:.1f}s'.format(timers['forward']))
print('Epoch time - {:.1f}s'.format(stop - start)) print('Epoch time - {:.1f}s'.format(stop - start))
...@@ -408,8 +408,8 @@ class AEVComputer(torch.nn.Module): ...@@ -408,8 +408,8 @@ class AEVComputer(torch.nn.Module):
If you don't care about periodic boundary conditions at all, If you don't care about periodic boundary conditions at all,
then input can be a tuple of two tensors: species, coordinates. then input can be a tuple of two tensors: species, coordinates.
species must have shape ``(C, A)``, coordinates must have shape species must have shape ``(N, A)``, coordinates must have shape
``(C, A, 3)`` where ``C`` is the number of molecules in a chunk, ``(N, A, 3)`` where ``N`` is the number of molecules in a batch,
and ``A`` is the number of atoms. and ``A`` is the number of atoms.
.. warning:: .. warning::
...@@ -437,7 +437,7 @@ class AEVComputer(torch.nn.Module): ...@@ -437,7 +437,7 @@ class AEVComputer(torch.nn.Module):
Returns: Returns:
NamedTuple: Species and AEVs. species are the species from the input NamedTuple: Species and AEVs. species are the species from the input
unchanged, and AEVs is a tensor of shape ``(C, A, self.aev_length())`` unchanged, and AEVs is a tensor of shape ``(N, A, self.aev_length())``
""" """
species, coordinates = input_ species, coordinates = input_
......
This diff is collapsed.
This diff is collapsed.
...@@ -91,7 +91,7 @@ class BuiltinNet(torch.nn.Module): ...@@ -91,7 +91,7 @@ class BuiltinNet(torch.nn.Module):
self.species = self.consts.species self.species = self.consts.species
self.species_converter = SpeciesConverter(self.species) self.species_converter = SpeciesConverter(self.species)
self.aev_computer = AEVComputer(**self.consts) self.aev_computer = AEVComputer(**self.consts)
self.energy_shifter = neurochem.load_sae(self.sae_file) self.energy_shifter, self.sae_dict = neurochem.load_sae(self.sae_file, return_dict=True)
self.neural_networks = neurochem.load_model_ensemble( self.neural_networks = neurochem.load_model_ensemble(
self.species, self.ensemble_prefix, self.ensemble_size) self.species, self.ensemble_prefix, self.ensemble_size)
......
...@@ -70,17 +70,22 @@ class Constants(collections.abc.Mapping): ...@@ -70,17 +70,22 @@ class Constants(collections.abc.Mapping):
return getattr(self, item) return getattr(self, item)
def load_sae(filename): def load_sae(filename, return_dict=False):
"""Returns an object of :class:`EnergyShifter` with self energies from """Returns an object of :class:`EnergyShifter` with self energies from
NeuroChem sae file""" NeuroChem sae file"""
self_energies = [] self_energies = []
d = {}
with open(filename) as f: with open(filename) as f:
for i in f: for i in f:
line = [x.strip() for x in i.split('=')] line = [x.strip() for x in i.split('=')]
species = line[0].split(',')[0].strip()
index = int(line[0].split(',')[1].strip()) index = int(line[0].split(',')[1].strip())
value = float(line[1]) value = float(line[1])
d[species] = value
self_energies.append((index, value)) self_energies.append((index, value))
self_energies = [i for _, i in sorted(self_energies)] self_energies = [i for _, i in sorted(self_energies)]
if return_dict:
return EnergyShifter(self_energies), d
return EnergyShifter(self_energies) return EnergyShifter(self_energies)
...@@ -285,13 +290,13 @@ if sys.version_info[0] > 2: ...@@ -285,13 +290,13 @@ if sys.version_info[0] > 2:
def __init__(self, filename, device=torch.device('cuda'), tqdm=False, def __init__(self, filename, device=torch.device('cuda'), tqdm=False,
tensorboard=None, checkpoint_name='model.pt'): tensorboard=None, checkpoint_name='model.pt'):
from ..data import load_ani_dataset # noqa: E402 from ..data import load # noqa: E402
class dummy: class dummy:
pass pass
self.imports = dummy() self.imports = dummy()
self.imports.load_ani_dataset = load_ani_dataset self.imports.load = load
self.filename = filename self.filename = filename
self.device = device self.device = device
...@@ -467,7 +472,7 @@ if sys.version_info[0] > 2: ...@@ -467,7 +472,7 @@ if sys.version_info[0] > 2:
self.aev_computer = AEVComputer(**self.consts) self.aev_computer = AEVComputer(**self.consts)
del params['sflparamsfile'] del params['sflparamsfile']
self.sae_file = os.path.join(dir_, params['atomEnergyFile']) self.sae_file = os.path.join(dir_, params['atomEnergyFile'])
self.shift_energy = load_sae(self.sae_file) self.shift_energy, self.sae = load_sae(self.sae_file, return_dict=True)
del params['atomEnergyFile'] del params['atomEnergyFile']
network_dir = os.path.join(dir_, params['ntwkStoreDir']) network_dir = os.path.join(dir_, params['ntwkStoreDir'])
if not os.path.exists(network_dir): if not os.path.exists(network_dir):
...@@ -552,26 +557,18 @@ if sys.version_info[0] > 2: ...@@ -552,26 +557,18 @@ if sys.version_info[0] > 2:
def load_data(self, training_path, validation_path): def load_data(self, training_path, validation_path):
"""Load training and validation dataset from file.""" """Load training and validation dataset from file."""
self.training_set = self.imports.load_ani_dataset( self.training_set = self.imports.load(training_path).subtract_self_energies(self.sae).species_to_indices().shuffle().collate(self.training_batch_size).cache()
training_path, self.consts.species_to_tensor, self.validation_set = self.imports.load(validation_path).subtract_self_energies(self.sae).species_to_indices().shuffle().collate(self.validation_batch_size).cache()
self.training_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
self.validation_set = self.imports.load_ani_dataset(
validation_path, self.consts.species_to_tensor,
self.validation_batch_size, rm_outlier=True, device=self.device,
transform=[self.shift_energy.subtract_from_dataset])
def evaluate(self, dataset): def evaluate(self, dataset):
"""Run the evaluation""" """Run the evaluation"""
total_mse = 0.0 total_mse = 0.0
count = 0 count = 0
for batch_x, batch_y in dataset: for properties in dataset:
true_energies = batch_y['energies'] species = properties['species'].to(self.device)
predicted_energies = [] coordinates = properties['coordinates'].to(self.device).float()
for chunk_species, chunk_coordinates in batch_x: true_energies = properties['energies'].to(self.device).float()
_, chunk_energies = self.model((chunk_species, chunk_coordinates)) _, predicted_energies = self.model((species, coordinates))
predicted_energies.append(chunk_energies)
predicted_energies = torch.cat(predicted_energies)
total_mse += self.mse_sum(predicted_energies, true_energies).item() total_mse += self.mse_sum(predicted_energies, true_energies).item()
count += predicted_energies.shape[0] count += predicted_energies.shape[0]
return hartree2kcalmol(math.sqrt(total_mse / count)) return hartree2kcalmol(math.sqrt(total_mse / count))
...@@ -620,21 +617,16 @@ if sys.version_info[0] > 2: ...@@ -620,21 +617,16 @@ if sys.version_info[0] > 2:
self.tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('learning_rate', learning_rate, AdamW_scheduler.last_epoch)
self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, AdamW_scheduler.last_epoch) self.tensorboard.add_scalar('no_improve_count_vs_epoch', no_improve_count, AdamW_scheduler.last_epoch)
for i, (batch_x, batch_y) in self.tqdm( for i, properties in self.tqdm(
enumerate(self.training_set), enumerate(self.training_set),
total=len(self.training_set), total=len(self.training_set),
desc='epoch {}'.format(AdamW_scheduler.last_epoch) desc='epoch {}'.format(AdamW_scheduler.last_epoch)
): ):
species = properties['species'].to(self.device)
true_energies = batch_y['energies'] coordinates = properties['coordinates'].to(self.device).float()
predicted_energies = [] true_energies = properties['energies'].to(self.device).float()
num_atoms = [] num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
for chunk_species, chunk_coordinates in batch_x: _, predicted_energies = self.model((species, coordinates))
num_atoms.append((chunk_species >= 0).sum(dim=1))
_, chunk_energies = self.model((chunk_species, chunk_coordinates))
predicted_energies.append(chunk_energies)
num_atoms = torch.cat(num_atoms).to(true_energies.dtype)
predicted_energies = torch.cat(predicted_energies)
loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean() loss = (self.mse_se(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
AdamW_optim.zero_grad() AdamW_optim.zero_grad()
SGD_optim.zero_grad() SGD_optim.zero_grad()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment