Unverified Commit 861580e3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add argparse to training-benchmark.py (#45)

parent d3ae0788
import sys
import torch
import ignite
import torchani
import timeit
import model
import tqdm
import argparse
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset, can a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser = parser.parse_args()
chunk_size = 256
batch_chunks = 4
dataset_path = sys.argv[1]
# set up benchmark
device = torch.device(parser.device)
shift_energy = torchani.EnergyShifter()
dataset = torchani.data.ANIDataset(
dataset_path, chunk_size, device=device,
parser.dataset_path, parser.chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks)
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp})
......@@ -40,6 +53,7 @@ def finalize_tqdm(trainer):
trainer.state.tqdm.close()
# run it!
start = timeit.default_timer()
trainer.run(dataloader, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment