training-aev-benchmark.py 10.9 KB
Newer Older
1
2
3
4
5
6
7
8
9
import torch
import torchani
import time
import timeit
import argparse
import pkbar
import gc
import pynvml
import os
Jinze Xue's avatar
Jinze Xue committed
10
import pickle
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
from torchani.units import hartree2kcalmol


def build_network():
    H_network = torch.nn.Sequential(
        torch.nn.Linear(384, 160),
        torch.nn.CELU(0.1),
        torch.nn.Linear(160, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 96),
        torch.nn.CELU(0.1),
        torch.nn.Linear(96, 1)
    )

    C_network = torch.nn.Sequential(
        torch.nn.Linear(384, 144),
        torch.nn.CELU(0.1),
        torch.nn.Linear(144, 112),
        torch.nn.CELU(0.1),
        torch.nn.Linear(112, 96),
        torch.nn.CELU(0.1),
        torch.nn.Linear(96, 1)
    )

    N_network = torch.nn.Sequential(
        torch.nn.Linear(384, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 112),
        torch.nn.CELU(0.1),
        torch.nn.Linear(112, 96),
        torch.nn.CELU(0.1),
        torch.nn.Linear(96, 1)
    )

    O_network = torch.nn.Sequential(
        torch.nn.Linear(384, 128),
        torch.nn.CELU(0.1),
        torch.nn.Linear(128, 112),
        torch.nn.CELU(0.1),
        torch.nn.Linear(112, 96),
        torch.nn.CELU(0.1),
        torch.nn.Linear(96, 1)
    )
    return [H_network, C_network, N_network, O_network]


def checkgpu(device=None):
    i = device if device else torch.cuda.current_device()
    t = torch.cuda.get_device_properties(i).total_memory
    c = torch.cuda.memory_reserved(i)
    name = torch.cuda.get_device_properties(i).name
    print('   GPU Memory Cached (pytorch) : {:7.1f}MB / {:.1f}MB ({})'.format(c / 1024 / 1024, t / 1024 / 1024, name))
    real_i = int(os.environ['CUDA_VISIBLE_DEVICES'][0]) if 'CUDA_VISIBLE_DEVICES' in os.environ else i
    pynvml.nvmlInit()
    h = pynvml.nvmlDeviceGetHandleByIndex(real_i)
    info = pynvml.nvmlDeviceGetMemoryInfo(h)
    name = pynvml.nvmlDeviceGetName(h)
    print('   GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))


def alert(text):
    print('\033[91m{}\33[0m'.format(text))  # red


def sync_cuda(sync):
    if sync:
        torch.cuda.synchronize()


Jinze Xue's avatar
Jinze Xue committed
80
81
82
83
84
85
86
87
88
89
def print_timer(label, t):
    if t < 1:
        t = f'{t * 1000:.1f} ms'
    else:
        t = f'{t:.3f} sec'
    print(f'{label} - {t}')


def benchmark(parser, dataset, use_cuda_extension, force_inference=False):
    synchronize = True
90
91
92
93
94
95
96
97
    timers = {}

    def time_func(key, func):
        timers[key] = 0

        def wrapper(*args, **kwargs):
            start = timeit.default_timer()
            ret = func(*args, **kwargs)
Jinze Xue's avatar
Jinze Xue committed
98
            sync_cuda(synchronize)
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
            end = timeit.default_timer()
            timers[key] += end - start
            return ret

        return wrapper

    Rcr = 5.2000e+00
    Rca = 3.5000e+00
    EtaR = torch.tensor([1.6000000e+01], device=parser.device)
    ShfR = torch.tensor([9.0000000e-01, 1.1687500e+00, 1.4375000e+00, 1.7062500e+00, 1.9750000e+00, 2.2437500e+00, 2.5125000e+00, 2.7812500e+00, 3.0500000e+00, 3.3187500e+00, 3.5875000e+00, 3.8562500e+00, 4.1250000e+00, 4.3937500e+00, 4.6625000e+00, 4.9312500e+00], device=parser.device)
    Zeta = torch.tensor([3.2000000e+01], device=parser.device)
    ShfZ = torch.tensor([1.9634954e-01, 5.8904862e-01, 9.8174770e-01, 1.3744468e+00, 1.7671459e+00, 2.1598449e+00, 2.5525440e+00, 2.9452431e+00], device=parser.device)
    EtaA = torch.tensor([8.0000000e+00], device=parser.device)
    ShfA = torch.tensor([9.0000000e-01, 1.5500000e+00, 2.2000000e+00, 2.8500000e+00], device=parser.device)
    num_species = 4
    aev_computer = torchani.AEVComputer(Rcr, Rca, EtaR, ShfR, EtaA, Zeta, ShfA, ShfZ, num_species, use_cuda_extension)

    nn = torchani.ANIModel(build_network())
    model = torch.nn.Sequential(aev_computer, nn).to(parser.device)
    optimizer = torch.optim.Adam(model.parameters(), lr=0.000001)
    mse = torch.nn.MSELoss(reduction='none')

    # enable timers
    torchani.aev.cutoff_cosine = time_func('torchani.aev.cutoff_cosine', torchani.aev.cutoff_cosine)
    torchani.aev.radial_terms = time_func('torchani.aev.radial_terms', torchani.aev.radial_terms)
    torchani.aev.angular_terms = time_func('torchani.aev.angular_terms', torchani.aev.angular_terms)
    torchani.aev.compute_shifts = time_func('torchani.aev.compute_shifts', torchani.aev.compute_shifts)
    torchani.aev.neighbor_pairs = time_func('torchani.aev.neighbor_pairs', torchani.aev.neighbor_pairs)
    torchani.aev.neighbor_pairs_nopbc = time_func('torchani.aev.neighbor_pairs_nopbc', torchani.aev.neighbor_pairs_nopbc)
    torchani.aev.triu_index = time_func('torchani.aev.triu_index', torchani.aev.triu_index)
    torchani.aev.cumsum_from_zero = time_func('torchani.aev.cumsum_from_zero', torchani.aev.cumsum_from_zero)
    torchani.aev.triple_by_molecule = time_func('torchani.aev.triple_by_molecule', torchani.aev.triple_by_molecule)
    torchani.aev.compute_aev = time_func('torchani.aev.compute_aev', torchani.aev.compute_aev)
    model[0].forward = time_func('total', model[0].forward)
    model[1].forward = time_func('forward', model[1].forward)
    optimizer.step = time_func('optimizer.step', optimizer.step)

    print('=> start training')
    start = time.time()
    loss_time = 0
Jinze Xue's avatar
Jinze Xue committed
139
    force_time = 0
140
141
142
143
144
145
146
147

    for epoch in range(0, parser.num_epochs):

        print('Epoch: %d/%d' % (epoch + 1, parser.num_epochs))
        progbar = pkbar.Kbar(target=len(dataset) - 1, width=8)

        for i, properties in enumerate(dataset):
            species = properties['species'].to(parser.device)
Jinze Xue's avatar
Jinze Xue committed
148
            coordinates = properties['coordinates'].to(parser.device).float().requires_grad_(force_inference)
149
150
151
152
153
154
            true_energies = properties['energies'].to(parser.device).float()
            num_atoms = (species >= 0).sum(dim=1, dtype=true_energies.dtype)
            _, predicted_energies = model((species, coordinates))
            # TODO add sync after aev is done
            sync_cuda(synchronize)
            energy_loss = (mse(predicted_energies, true_energies) / num_atoms.sqrt()).mean()
Jinze Xue's avatar
Jinze Xue committed
155
156
            if force_inference:
                sync_cuda(synchronize)
157
158
                force_coefficient = 0.1
                true_forces = properties['forces'].to(parser.device).float()
Jinze Xue's avatar
Jinze Xue committed
159
                force_start = time.time()
160
                try:
Jinze Xue's avatar
Jinze Xue committed
161
                    sync_cuda(synchronize)
162
                    forces = -torch.autograd.grad(predicted_energies.sum(), coordinates, create_graph=True, retain_graph=True)[0]
Jinze Xue's avatar
Jinze Xue committed
163
                    sync_cuda(synchronize)
164
165
166
                except Exception as e:
                    alert('Error: {}'.format(e))
                    return
Jinze Xue's avatar
Jinze Xue committed
167
                force_time += time.time() - force_start
168
169
                force_loss = (mse(true_forces, forces).sum(dim=(1, 2)) / num_atoms).mean()
                loss = energy_loss + force_coefficient * force_loss
Jinze Xue's avatar
Jinze Xue committed
170
                sync_cuda(synchronize)
171
172
173
174
            else:
                loss = energy_loss
            rmse = hartree2kcalmol((mse(predicted_energies, true_energies)).mean()).detach().cpu().numpy()
            progbar.update(i, values=[("rmse", rmse)])
Jinze Xue's avatar
Jinze Xue committed
175
176
177
178
179
180
181
182
183
184
            if not force_inference:
                sync_cuda(synchronize)
                loss_start = time.time()
                loss.backward()
                # print('2', coordinates.grad)
                sync_cuda(synchronize)
                loss_stop = time.time()
                loss_time += loss_stop - loss_start
                optimizer.step()
                sync_cuda(synchronize)
185
186
187
188
189
190
191
192

        checkgpu()
    sync_cuda(synchronize)
    stop = time.time()

    print('=> More detail about benchmark PER EPOCH')
    total_time = (stop - start) / parser.num_epochs
    loss_time = loss_time / parser.num_epochs
Jinze Xue's avatar
Jinze Xue committed
193
    force_time = force_time / parser.num_epochs
194
195
196
    opti_time = timers['optimizer.step'] / parser.num_epochs
    forward_time = timers['forward'] / parser.num_epochs
    aev_time = timers['total'] / parser.num_epochs
Jinze Xue's avatar
Jinze Xue committed
197
198
199
200
201
202
203
    print_timer('   Total AEV', aev_time)
    print_timer('   Forward', forward_time)
    print_timer('   Backward', loss_time)
    print_timer('   Force', force_time)
    print_timer('   Optimizer', opti_time)
    print_timer('   Others', total_time - loss_time - aev_time - forward_time - opti_time - force_time)
    print_timer('   Epoch time', total_time)
204
205
206
207
208
209
210
211
212
213
214
215
216
217


if __name__ == "__main__":
    # parse command line arguments
    parser = argparse.ArgumentParser()
    parser.add_argument('dataset_path',
                        help='Path of the dataset, can a hdf5 file \
                            or a directory containing hdf5 files')
    parser.add_argument('-d', '--device',
                        help='Device of modules and tensors',
                        default=('cuda' if torch.cuda.is_available() else 'cpu'))
    parser.add_argument('-b', '--batch_size',
                        help='Number of conformations of each batch',
                        default=2560, type=int)
Jinze Xue's avatar
Jinze Xue committed
218
    parser.add_argument('-p', '--pickle',
219
                        action='store_true',
Jinze Xue's avatar
Jinze Xue committed
220
221
222
223
                        help='Dataset is pickled or not')
    parser.add_argument('--nsight',
                        action='store_true',
                        help='use nsight profile')
224
225
226
227
228
229
    parser.add_argument('-n', '--num_epochs',
                        help='epochs',
                        default=1, type=int)
    parser = parser.parse_args()

    print('=> loading dataset...')
Jinze Xue's avatar
Jinze Xue committed
230
231
232
233
234
235
236
237
238
239
240
241
    if parser.pickle:
        f = open(parser.dataset_path, 'rb')
        dataset_shuffled = pickle.load(f)
        f.close()
    else:
        shifter = torchani.EnergyShifter(None)
        dataset = torchani.data.load(parser.dataset_path, additional_properties=('forces',)).subtract_self_energies(shifter).species_to_indices()
        print('=> Caching shuffled dataset...')
        dataset_shuffled = list(dataset.shuffle().collate(parser.batch_size))
        f = open(f'{parser.dataset_path}.pickle', 'wb')
        pickle.dump(dataset_shuffled, f)
        f.close()
242
243
244
245
246
247
248
249
250
251

    print("=> CUDA info:")
    devices = torch.cuda.device_count()
    print('Total devices: {}'.format(devices))
    for i in range(devices):
        d = 'cuda:{}'.format(i)
        print('{}: {}'.format(i, torch.cuda.get_device_name(d)))
        print('   {}'.format(torch.cuda.get_device_properties(i)))
        checkgpu(i)

Jinze Xue's avatar
Jinze Xue committed
252
    print("\n\n=> Test 1: USE cuda extension, Energy training")
253
254
    torch.cuda.empty_cache()
    gc.collect()
Jinze Xue's avatar
Jinze Xue committed
255
256
    benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_inference=False)
    print("\n\n=> Test 2: NO cuda extension, Energy training")
257
258
    torch.cuda.empty_cache()
    gc.collect()
Jinze Xue's avatar
Jinze Xue committed
259
    benchmark(parser, dataset_shuffled, use_cuda_extension=False, force_inference=False)
260

Jinze Xue's avatar
Jinze Xue committed
261
    print("\n\n=> Test 3: USE cuda extension, Force and Energy inference")
262
263
    torch.cuda.empty_cache()
    gc.collect()
Jinze Xue's avatar
Jinze Xue committed
264
265
    benchmark(parser, dataset_shuffled, use_cuda_extension=True, force_inference=True)
    print("\n\n=> Test 4: NO cuda extension, Force and Energy inference")
266
267
    torch.cuda.empty_cache()
    gc.collect()
Jinze Xue's avatar
Jinze Xue committed
268
    benchmark(parser, dataset_shuffled, use_cuda_extension=False, force_inference=True)