Unverified Commit 57dd26bf authored by Ignacio Pickering's avatar Ignacio Pickering Committed by GitHub
Browse files

Add code snippet to save validation and training sets in nnp_training (#548)

parent 12422dd1
...@@ -39,3 +39,4 @@ Untitled.ipynb ...@@ -39,3 +39,4 @@ Untitled.ipynb
htmlcov/ htmlcov/
/include /include
training_outputs/ training_outputs/
examples/dataset.pkl
...@@ -39,6 +39,7 @@ import os ...@@ -39,6 +39,7 @@ import os
import math import math
import torch.utils.tensorboard import torch.utils.tensorboard
import tqdm import tqdm
import pickle
# helper function to convert energy unit from Hartree to kcal/mol # helper function to convert energy unit from Hartree to kcal/mol
from torchani.units import hartree2kcalmol from torchani.units import hartree2kcalmol
...@@ -95,9 +96,31 @@ except NameError: ...@@ -95,9 +96,31 @@ except NameError:
dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5') dspath = os.path.join(path, '../dataset/ani1-up_to_gdb4/ani_gdb_s01.h5')
batch_size = 2560 batch_size = 2560
training, validation = torchani.data.load(dspath).subtract_self_energies(energy_shifter, species_order).species_to_indices(species_order).shuffle().split(0.8, None) pickled_dataset_path = 'dataset.pkl'
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache() # We pickle the dataset after loading to ensure we use the same validation set
# each time we restart training, otherwise we risk mixing the validation and
# training sets on each restart.
if os.path.isfile(pickled_dataset_path):
print(f'Unpickling preprocessed dataset found in {pickled_dataset_path}')
with open(pickled_dataset_path, 'rb') as f:
dataset = pickle.load(f)
training = dataset['training'].collate(batch_size).cache()
validation = dataset['validation'].collate(batch_size).cache()
energy_shifter.self_energies = dataset['self_energies'].to(device)
else:
print(f'Processing dataset in {dspath}')
training, validation = torchani.data.load(dspath)\
.subtract_self_energies(energy_shifter, species_order)\
.species_to_indices(species_order)\
.shuffle()\
.split(0.8, None)
with open(pickled_dataset_path, 'wb') as f:
pickle.dump({'training': training,
'validation': validation,
'self_energies': energy_shifter.self_energies.cpu()}, f)
training = training.collate(batch_size).cache()
validation = validation.collate(batch_size).cache()
print('Self atomic energies: ', energy_shifter.self_energies) print('Self atomic energies: ', energy_shifter.self_energies)
############################################################################### ###############################################################################
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment