training-benchmark.py 3.76 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
import torch
2
import ignite
Xiang Gao's avatar
Xiang Gao committed
3
4
import torchani
import timeit
5
import tqdm
6
import argparse
Xiang Gao's avatar
Xiang Gao committed
7

8
9
10
11
12
13
14
15
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset, can a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
16
17
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
18
                    default=1024, type=int)
19
parser = parser.parse_args()
20

21
22
# set up benchmark
device = torch.device(parser.device)
23
24
25
26
ani1x = torchani.models.ANI1x()
consts = ani1x.consts
aev_computer = ani1x.aev_computer
shift_energy = ani1x.energy_shifter
Gao, Xiang's avatar
Gao, Xiang committed
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51


def atomic():
    model = torch.nn.Sequential(
        torch.nn.Linear(384, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 64),
        torch.nn.CELU(0.1),
        torch.nn.Linear(64, 1)
    )
    return model


model = torchani.ANIModel([atomic() for _ in range(4)])


class Flatten(torch.nn.Module):
    def forward(self, x):
        return x[0], x[1].flatten()


nnp = torch.nn.Sequential(aev_computer, model, Flatten()).to(device)

52
dataset = torchani.data.load_ani_dataset(
Gao, Xiang's avatar
Gao, Xiang committed
53
    parser.dataset_path, consts.species_to_tensor,
54
    parser.batch_size, device=device,
Gao, Xiang's avatar
Gao, Xiang committed
55
    transform=[shift_energy.subtract_from_dataset])
56
container = torchani.ignite.Container({'energies': nnp})
57
optimizer = torch.optim.Adam(nnp.parameters())
Xiang Gao's avatar
Xiang Gao committed
58

59
trainer = ignite.engine.create_supervised_trainer(
60
    container, optimizer, torchani.ignite.MSELoss('energies'))
Xiang Gao's avatar
Xiang Gao committed
61

62
63
64

@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
65
    trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
66
67
68
69
70
71
72
73
74
75
76
77


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
timers = {}


def time_func(key, func):
    timers[key] = 0

    def wrapper(*args, **kwargs):
        start = timeit.default_timer()
        ret = func(*args, **kwargs)
        end = timeit.default_timer()
        timers[key] += end - start
        return ret

    return wrapper


# enable timers
Gao, Xiang's avatar
Gao, Xiang committed
95
96
97
98
99
100
101
102
103
104
torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine)
torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms)
torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
torchani.aev.convert_pair_index = time_func('torchani.aev.convert_pair_index', torchani.aev.convert_pair_index)
torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
105
106
107
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)

108
# run it!
Xiang Gao's avatar
Xiang Gao committed
109
start = timeit.default_timer()
110
trainer.run(dataset, max_epochs=1)
Xiang Gao's avatar
Xiang Gao committed
111
elapsed = round(timeit.default_timer() - start, 2)
Gao, Xiang's avatar
Gao, Xiang committed
112
113
114
for k in timers:
    if k.startswith('torchani.'):
        print(k, timers[k])
115
116
print('Total AEV:', timers['total'])
print('NN:', timers['forward'])
Xiang Gao's avatar
Xiang Gao committed
117
print('Epoch time:', elapsed)