training-benchmark.py 2.72 KB
Newer Older
Xiang Gao's avatar
Xiang Gao committed
1
import torch
2
import ignite
Xiang Gao's avatar
Xiang Gao committed
3
4
import torchani
import timeit
5
import tqdm
6
import argparse
Xiang Gao's avatar
Xiang Gao committed
7

8
9
10
11
12
13
14
15
# parse command line arguments
parser = argparse.ArgumentParser()
parser.add_argument('dataset_path',
                    help='Path of the dataset, can a hdf5 file \
                          or a directory containing hdf5 files')
parser.add_argument('-d', '--device',
                    help='Device of modules and tensors',
                    default=('cuda' if torch.cuda.is_available() else 'cpu'))
16
17
parser.add_argument('--batch_size',
                    help='Number of conformations of each batch',
18
                    default=1024, type=int)
19
parser = parser.parse_args()
20

21
22
# set up benchmark
device = torch.device(parser.device)
Gao, Xiang's avatar
Gao, Xiang committed
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
builtins = torchani.neurochem.Builtins()
consts = builtins.consts
aev_computer = builtins.aev_computer
shift_energy = builtins.energy_shifter


def atomic():
    model = torch.nn.Sequential(
        torch.nn.Linear(384, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 64),
        torch.nn.CELU(0.1),
        torch.nn.Linear(64, 1)
    )
    return model


model = torchani.ANIModel([atomic() for _ in range(4)])


class Flatten(torch.nn.Module):
    def forward(self, x):
        return x[0], x[1].flatten()


nnp = torch.nn.Sequential(aev_computer, model, Flatten()).to(device)

52
dataset = torchani.data.BatchedANIDataset(
Gao, Xiang's avatar
Gao, Xiang committed
53
    parser.dataset_path, consts.species_to_tensor,
54
    parser.batch_size, device=device,
Gao, Xiang's avatar
Gao, Xiang committed
55
    transform=[shift_energy.subtract_from_dataset])
56
container = torchani.ignite.Container({'energies': nnp})
57
optimizer = torch.optim.Adam(nnp.parameters())
Xiang Gao's avatar
Xiang Gao committed
58

59
trainer = ignite.engine.create_supervised_trainer(
60
    container, optimizer, torchani.ignite.MSELoss('energies'))
Xiang Gao's avatar
Xiang Gao committed
61

62
63
64

@trainer.on(ignite.engine.Events.EPOCH_STARTED)
def init_tqdm(trainer):
65
    trainer.state.tqdm = tqdm.tqdm(total=len(dataset), desc='epoch')
66
67
68
69
70
71
72
73
74
75
76
77


@trainer.on(ignite.engine.Events.ITERATION_COMPLETED)
def update_tqdm(trainer):
    trainer.state.tqdm.update(1)


@trainer.on(ignite.engine.Events.EPOCH_COMPLETED)
def finalize_tqdm(trainer):
    trainer.state.tqdm.close()


78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
timers = {}


def time_func(key, func):
    timers[key] = 0

    def wrapper(*args, **kwargs):
        start = timeit.default_timer()
        ret = func(*args, **kwargs)
        end = timeit.default_timer()
        timers[key] += end - start
        return ret

    return wrapper


# enable timers
nnp[0].forward = time_func('total', nnp[0].forward)
nnp[1].forward = time_func('forward', nnp[1].forward)

98
# run it!
Xiang Gao's avatar
Xiang Gao committed
99
start = timeit.default_timer()
100
trainer.run(dataset, max_epochs=1)
Xiang Gao's avatar
Xiang Gao committed
101
elapsed = round(timeit.default_timer() - start, 2)
102
103
print('Total AEV:', timers['total'])
print('NN:', timers['forward'])
Xiang Gao's avatar
Xiang Gao committed
104
print('Epoch time:', elapsed)