Unverified Commit 1d8bba37 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add argparser for nnp_training.py (#46)

parent 861580e3
......@@ -15,4 +15,5 @@ a.out
benchmark_xyz
*.pyc
*checkpoint*
*.pt
/runs
\ No newline at end of file
import sys
import torch
import ignite
import torchani
......@@ -7,29 +6,61 @@ import tqdm
import timeit
import tensorboardX
import math
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
chunk_size = 256
batch_chunks = 4
dataset_path = sys.argv[1]
dataset_checkpoint = 'dataset-checkpoint.dat'
model_checkpoint = 'checkpoint.pt'
max_epochs = 10
writer = tensorboardX.SummaryWriter()
import argparse
import json
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset, can a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('--dataset_checkpoint',
help='Checkpoint file for datasets',
default='dataset-checkpoint.dat')
parser.add_argument('--model_checkpoint',
help='Checkpoint file for model',
default='model.pt')
parser.add_argument('-m', '--max_epochs',
help='Maximum number of epoches',
default=10, type=int)
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser.add_argument('--log',
help='Log directory for tensorboardX',
default=None)
parser.add_argument('--optimizer',
help='Optimizer used to train the model',
default='Adam')
parser.add_argument('--optim_args',
help='Arguments to optimizers, in the format of json',
default='{}')
parser = parser.parse_args()
# set up the training
device = torch.device(parser.device)
writer = tensorboardX.SummaryWriter(log_dir=parser.log)
start = timeit.default_timer()
shift_energy = torchani.EnergyShifter()
training, validation, testing = torchani.data.load_or_create(
dataset_checkpoint, dataset_path, chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, batch_chunks)
validation = torchani.data.dataloader(validation, batch_chunks)
nnp = model.get_or_create_model(model_checkpoint, device=device)
parser.dataset_checkpoint, parser.dataset_path, parser.chunk_size,
device=device, transform=[shift_energy.dataset_subtract_sae])
training = torchani.data.dataloader(training, parser.batch_chunks)
validation = torchani.data.dataloader(validation, parser.batch_chunks)
nnp = model.get_or_create_model(parser.model_checkpoint, device=device)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
parser.optim_args = json.loads(parser.optim_args)
optimizer = getattr(torch.optim, parser.optimizer)
optimizer = optimizer(nnp.parameters(), **parser.optim_args)
trainer = ignite.engine.create_supervised_trainer(
container, optimizer, torchani.ignite.energy_mse_loss)
......@@ -78,4 +109,4 @@ def log_loss_and_time(trainer):
writer.add_scalar('training_rmse_vs_iteration', rmse, iteration)
trainer.run(training, max_epochs=max_epochs)
trainer.run(training, max_epochs=parser.max_epochs)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment