training-benchmark.py 3.08 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
import torch
2
import ignite
Xiang Gao's avatar
Xiang Gao committed
3
4
import torchani
import timeit
5
import model
6
import tqdm
7
import argparse
Xiang Gao's avatar
Xiang Gao committed
8

9
10
11
12
13
14
15
16
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset, can a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
17
18
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
19
                    default=1024, type=int)
20
parser = parser.parse_args()
21

22
23
# set up benchmark
device = torch.device(parser.device)
24
nnp, shift_energy = model.get_or_create_model('/tmp/model.pt', device=device)
25
26
dataset = torchani.training.BatchedANIDataset(
    parser.dataset_path, nnp[0].species, parser.batch_size, device=device,
Gao, Xiang's avatar
Gao, Xiang committed
27
    transform=[shift_energy.subtract_from_dataset])
28
container = torchani.training.Container({'energies': nnp})
29
optimizer = torch.optim.Adam(nnp.parameters())
Xiang Gao's avatar
Xiang Gao committed
30

31
trainer = ignite.engine.create_supervised_trainer(
32
    container, optimizer, torchani.training.MSELoss('energies'))
Xiang Gao's avatar
Xiang Gao committed
33

34
35
36

@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
37
    trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
38
39
40
41
42
43
44
45
46
47
48
49


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()


50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
timers = {}


def time_func(key, func):
    timers[key] = 0

    def wrapper(*args, **kwargs):
        start = timeit.default_timer()
        ret = func(*args, **kwargs)
        end = timeit.default_timer()
        timers[key] += end - start
        return ret

    return wrapper


# enable timers
nnp[0].radial_subaev_terms = time_func('radial terms',
                                       nnp[0].radial_subaev_terms)
nnp[0].angular_subaev_terms = time_func('angular terms',
                                        nnp[0].angular_subaev_terms)
nnp[0].terms_and_indices = time_func('terms and indices',
                                     nnp[0].terms_and_indices)
nnp[0].combinations = time_func('combinations', nnp[0].combinations)
nnp[0].compute_mask_r = time_func('mask_r', nnp[0].compute_mask_r)
nnp[0].compute_mask_a = time_func('mask_a', nnp[0].compute_mask_a)
nnp[0].assemble = time_func('assemble', nnp[0].assemble)
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)

80
# run it!
Xiang Gao's avatar
Xiang Gao committed
81
start = timeit.default_timer()
82
trainer.run(dataset, max_epochs=1)
Xiang Gao's avatar
Xiang Gao committed
83
elapsed = round(timeit.default_timer() - start, 2)
84
85
86
87
88
89
90
91
92
print('Radial terms:', timers['radial terms'])
print('Angular terms:', timers['angular terms'])
print('Terms and indices:', timers['terms and indices'])
print('Combinations:', timers['combinations'])
print('Mask R:', timers['mask_r'])
print('Mask A:', timers['mask_a'])
print('Assemble:', timers['assemble'])
print('Total AEV:', timers['total'])
print('NN:', timers['forward'])
Xiang Gao's avatar
Xiang Gao committed
93
print('Epoch time:', elapsed)