training-benchmark.py 2.18 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
import torch
2
import ignite
Xiang Gao's avatar
Xiang Gao committed
3
4
import torchani
import timeit
5
import model
6
import tqdm
7
import argparse
Xiang Gao's avatar
Xiang Gao committed
8

9
10
11
12
13
14
15
16
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset, can a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
17
18
19
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
                    default=1024, type=int)
20
parser = parser.parse_args()
21

22
23
# set up benchmark
device = torch.device(parser.device)
Gao, Xiang's avatar
Gao, Xiang committed
24
25
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt',
                                              True, device=device)
26
27
dataset = torchani.training.BatchedANIDataset(
    parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
Gao, Xiang's avatar
Gao, Xiang committed
28
    transform=[shift_energy.subtract_from_dataset])
29
container = torchani.training.Container({'energies': nnp})
30
optimizer = torch.optim.Adam(nnp.parameters())
Xiang Gao's avatar
Xiang Gao committed
31

32
trainer = ignite.engine.create_supervised_trainer(
33
    container, optimizer, torchani.training.MSELoss('energies'))
Xiang Gao's avatar
Xiang Gao committed
34

35
36
37

@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
38
    trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
39
40
41
42
43
44
45
46
47
48
49
50


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()


51
# run it!
Xiang Gao's avatar
Xiang Gao committed
52
start = timeit.default_timer()
53
trainer.run(dataset, max_epochs=1)
Xiang Gao's avatar
Xiang Gao committed
54
elapsed = round(timeit.default_timer() - start, 2)
55
56
57
58
59
60
61
62
63
print('Radial terms:', nnp[0].timers['radial terms'])
print('Angular terms:', nnp[0].timers['angular terms'])
print('Terms and indices:', nnp[0].timers['terms and indices'])
print('Combinations:', nnp[0].timers['combinations'])
print('Mask R:', nnp[0].timers['mask_r'])
print('Mask A:', nnp[0].timers['mask_a'])
print('Assemble:', nnp[0].timers['assemble'])
print('Total AEV:', nnp[0].timers['total'])
print('NN:', nnp[1].timers['forward'])
Xiang Gao's avatar
Xiang Gao committed
64
print('Epoch time:', elapsed)