Unverified Commit 861580e3 authored by Gao, Xiang's avatar Gao, Xiang Committed by GitHub
Browse files

add argparse to training-benchmark.py (#45)

parent d3ae0788
import sys
import torch import torch
import ignite import ignite
import torchani import torchani
import timeit import timeit
import model import model
import tqdm import tqdm
import argparse
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') # parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
help='Path of the dataset, can a hdf5 file \
or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
help='Device of modules and tensors',
default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
help='Number of conformations of each chunk',
default=256, type=int)
parser.add_argument('--batch_chunks',
help='Number of chunks in each minibatch',
default=4, type=int)
parser = parser.parse_args()
chunk_size = 256 # set up benchmark
batch_chunks = 4 device = torch.device(parser.device)
dataset_path = sys.argv[1]
shift_energy = torchani.EnergyShifter() shift_energy = torchani.EnergyShifter()
dataset = torchani.data.ANIDataset( dataset = torchani.data.ANIDataset(
dataset_path, chunk_size, device=device, parser.dataset_path, parser.chunk_size, device=device,
transform=[shift_energy.dataset_subtract_sae]) transform=[shift_energy.dataset_subtract_sae])
dataloader = torchani.data.dataloader(dataset, batch_chunks) dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device) nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
batch_nnp = torchani.models.BatchModel(nnp) batch_nnp = torchani.models.BatchModel(nnp)
container = torchani.ignite.Container({'energies': batch_nnp}) container = torchani.ignite.Container({'energies': batch_nnp})
...@@ -40,6 +53,7 @@ def finalize_tqdm(trainer): ...@@ -40,6 +53,7 @@ def finalize_tqdm(trainer):
trainer.state.tqdm.close() trainer.state.tqdm.close()
# run it!
start = timeit.default_timer() start = timeit.default_timer()
trainer.run(dataloader, max_epochs=1) trainer.run(dataloader, max_epochs=1)
elapsed = round(timeit.default_timer() - start, 2) elapsed = round(timeit.default_timer() - start, 2)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment