training-benchmark.py 2.38 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
import torch
2
import ignite
Xiang Gao's avatar
Xiang Gao committed
3
4
import torchani
import timeit
5
import model
6
import tqdm
7
import argparse
Xiang Gao's avatar
Xiang Gao committed
8

9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset, can a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
parser.add_argument('--chunk_size',
                    help='Number of conformations of each chunk',
                    default=256, type=int)
parser.add_argument('--batch_chunks',
                    help='Number of chunks in each minibatch',
                    default=4, type=int)
parser = parser.parse_args()
24

25
26
# set up benchmark
device = torch.device(parser.device)
27
28
shift_energy = torchani.EnergyShifter()
dataset = torchani.data.ANIDataset(
29
    parser.dataset_path, parser.chunk_size, device=device,
30
    transform=[shift_energy.dataset_subtract_sae])
31
dataloader = torchani.data.dataloader(dataset, parser.batch_chunks)
32
nnp = model.get_or_create_model('/tmp/model.pt', True, device=device)
33
batch_nnp = torchani.models.BatchModel(nnp)
34
35
container = torchani.ignite.Container({'energies': batch_nnp})
optimizer = torch.optim.Adam(nnp.parameters())
Xiang Gao's avatar
Xiang Gao committed
36

37
38
trainer = ignite.engine.create_supervised_trainer(
    container, optimizer, torchani.ignite.energy_mse_loss)
Xiang Gao's avatar
Xiang Gao committed
39

40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55

@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
    trainer.state.tqdm = tqdm.tqdm(total=len(dataloader), desc='epoch')


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()


56
# run it!
Xiang Gao's avatar
Xiang Gao committed
57
start = timeit.default_timer()
58
trainer.run(dataloader, max_epochs=1)
Xiang Gao's avatar
Xiang Gao committed
59
elapsed = round(timeit.default_timer() - start, 2)
60
61
62
63
64
65
66
67
68
print('Radial terms:', nnp[1].timers['radial terms'])
print('Angular terms:', nnp[1].timers['angular terms'])
print('Terms and indices:', nnp[1].timers['terms and indices'])
print('Combinations:', nnp[1].timers['combinations'])
print('Mask R:', nnp[1].timers['mask_r'])
print('Mask A:', nnp[1].timers['mask_a'])
print('Assemble:', nnp[1].timers['assemble'])
print('Total AEV:', nnp[1].timers['total'])
print('NN:', nnp[2].timers['forward'])
Xiang Gao's avatar
Xiang Gao committed
69
print('Epoch time:', elapsed)