Unverified Commit 1d8bba37 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add argparser for nnp_training.py (#46)

parent 861580e3
...@@ -15,4 +15,5 @@ a.out ...@@ -15,4 +15,5 @@ a.out
benchmark_xyz benchmark_xyz
*.pyc *.pyc
*checkpoint* *checkpoint*
*.pt
/runs /runs
\ No newline at end of file
import sys
import torch import torch
import ignite import ignite
import torchani import torchani
...@@ -7,29 +6,61 @@ import tqdm ...@@ -7,29 +6,61 @@ import tqdm
import timeit import timeit
import tensorboardX import tensorboardX
import math import math
import argparse
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') import json
chunk_size = 256 # parse command line arguments
batch_chunks = 4 parser = argparse.ArgumentParser()
dataset_path = sys.argv[1] parser.add_argument('dataset_path',
dataset_checkpoint = 'dataset-checkpoint.dat' help='Path of the dataset, can a hdf5 file \
model_checkpoint = 'checkpoint.pt' or a directory containing hdf5 files')
max_epochs = 10 parser.add_argument('--dataset_checkpoint',
help='Checkpoint file for datasets',
writer = tensorboardX.SummaryWriter() default='dataset-checkpoint.dat')
parser.add_argument('--model_checkpoint',
help='Checkpoint file for model',
default='model.pt')
parser.add_argument('-m', '--max_epochs',
help='Maximum number of epoches',
default=10, type=int)
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--log',
help='Log directory for tensorboardX',
default=None)
parser.add_argument('--optimizer',
help='Optimizer used to train the model',
default='Adam')
parser.add_argument('--optim_args',
help='Arguments to optimizers, in the format of json',
default='{}')
parser = parser.parse_args()
# set up the training
device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer() start = timeit.default_timer()
shift_energy = torchani.EnergyShifter() shift_energy = torchani.EnergyShifter()
training, validation, testing = torchani.data.load_or_create( training, validation, testing = torchani.data.load_or_create(
dataset_checkpoint, dataset_path, chunk_size, device=device, parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
transform=[shift_energy.dataset_subtract_sae]) device=device, transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks) training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks) validation = torchani.data.dataloader(validation, parser.batch_chunks)
nnp = model.get_or_create_model(model_checkpoint, device=device) nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer( trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss) container, optimizer, torchani.ignite.energy_mse_loss)
...@@ -78,4 +109,4 @@ def log_loss_and_time(trainer): ...@@ -78,4 +109,4 @@ def log_loss_and_time(trainer):
writer.add_scalar('training_rmse_vs_iteration', rmse, iteration) writer.add_scalar('training_rmse_vs_iteration', rmse, iteration)
trainer.run(training, max_epochs=max_epochs) trainer.run(training, max_epochs=parser.max_epochs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment