aev-benchmark-size.py 8.26 KB
Newer Older
1
2
3
4
5
6
7
8
9
import time
import torch
import torchani
import pynvml
import gc
import os
from ase.io import read
import argparse

Jinze Xue's avatar
Jinze Xue committed
10
11
12
13
14
summary = '\n'
runcounter = 0
N = 200
last_py_speed = None

15
16
17

def checkgpu(device=None):
    i = device if device else torch.cuda.current_device()
Jinze Xue's avatar
Jinze Xue committed
18
19
20
21
    t = torch.cuda.get_device_properties(i).total_memory
    c = torch.cuda.memory_reserved(i)
    name = torch.cuda.get_device_properties(i).name
    print('   GPU Memory Cached (pytorch) : {:7.1f}MB / {:.1f}MB ({})'.format(c / 1024 / 1024, t / 1024 / 1024, name))
22
23
24
25
26
    real_i = int(os.environ['CUDA_VISIBLE_DEVICES'][0]) if 'CUDA_VISIBLE_DEVICES' in os.environ else i
    pynvml.nvmlInit()
    h = pynvml.nvmlDeviceGetHandleByIndex(real_i)
    info = pynvml.nvmlDeviceGetMemoryInfo(h)
    name = pynvml.nvmlDeviceGetName(h)
Jinze Xue's avatar
Jinze Xue committed
27
28
    print('   GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))
    return f'{(info.used / 1024 / 1024):.1f}MB'
29
30
31
32
33
34
35
36
37
38


def alert(text):
    print('\033[91m{}\33[0m'.format(text))  # red


def info(text):
    print('\033[32m{}\33[0m'.format(text))  # green


Jinze Xue's avatar
Jinze Xue committed
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
def format_time(t):
    if t < 1:
        t = f'{t * 1000:.1f} ms'
    else:
        t = f'{t:.3f} sec'
    return t


def addSummaryLine(items=None, init=False):
    if init:
        addSummaryEmptyLine()
        items = ['RUN', 'PDB', 'Size', 'forward', 'backward', 'Others', 'Total', f'Total({N})', 'Speedup', 'GPU']
    global summary
    summary += items[0].ljust(20) + items[1].ljust(13) + items[2].ljust(13) + items[3].ljust(13) + items[4].ljust(13) + items[5].ljust(13) + \
        items[6].ljust(13) + items[7].ljust(13) + items[8].ljust(13) + items[9].ljust(13) + '\n'


def addSummaryEmptyLine():
    global summary
    summary += f"{'-'*20}".ljust(20) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + \
        f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + '\n'


def benchmark(speciesPositions, aev_comp, runbackward=False, mol_info=None, verbose=True):
    global runcounter
    global last_py_speed

    runname = f"{'cu' if aev_comp.use_cuda_extension else 'py'} aev fd{'+bd' if runbackward else''}"
    items = [f'{(runcounter+1):02} {runname}', f"{mol_info['name']}", f"{mol_info['atoms']}", '-', '-', '-', '-', '-', '-', '-']

    forward_time = 0
    force_time = 0
71
72
73
74
75
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.synchronize()
    start = time.time()

Jinze Xue's avatar
Jinze Xue committed
76
    aev = None
Jinze Xue's avatar
Jinze Xue committed
77
78
    force = None
    gpumem = None
79
    for i in range(N):
80
        species, coordinates = speciesPositions
Jinze Xue's avatar
Jinze Xue committed
81
82
83
84
85
        coordinates = coordinates.requires_grad_(runbackward)

        torch.cuda.synchronize()
        forward_start = time.time()
        try:
86
            _, aev = aev_comp((species, coordinates))
Jinze Xue's avatar
Jinze Xue committed
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
        except Exception as e:
            alert(f"  AEV faild: {str(e)[:50]}...")
            addSummaryLine(items)
            runcounter += 1
            return None, None, None
        torch.cuda.synchronize()
        forward_time += time.time() - forward_start

        if runbackward:  # backward
            force_start = time.time()
            try:
                force = -torch.autograd.grad(aev.sum(), coordinates, create_graph=True, retain_graph=True)[0]
            except Exception as e:
                alert(f" Force faild: {str(e)[:50]}...")
                addSummaryLine(items)
                runcounter += 1
                return None, None, None
            torch.cuda.synchronize()
            force_time += time.time() - force_start

        if i == 2 and verbose:
            gpumem = checkgpu()
109
110

    torch.cuda.synchronize()
Jinze Xue's avatar
Jinze Xue committed
111
112
113
114
115
    total_time = (time.time() - start) / N
    force_time = force_time / N
    forward_time = forward_time / N
    others_time = total_time - force_time - forward_time

116
    if verbose:
Jinze Xue's avatar
Jinze Xue committed
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
        if aev_comp.use_cuda_extension:
            if last_py_speed is not None:
                speed_up = last_py_speed / total_time
                speed_up = f'{speed_up:.2f}'
            else:
                speed_up = '-'
            last_py_speed = None
        else:
            last_py_speed = total_time
            speed_up = '-'

    if verbose:
        print(f'  Duration: {total_time * N:.2f} s')
        print(f'  Speed: {total_time*1000:.2f} ms/it')
        if runcounter == 0:
            addSummaryLine(init=True)
            addSummaryEmptyLine()
        if runcounter >= 0:
            items = [f'{(runcounter+1):02} {runname}',
                     f"{mol_info['name']}",
                     f"{mol_info['atoms']}",
                     f'{format_time(forward_time)}',
                     f'{format_time(force_time)}',
                     f'{format_time(others_time)}',
                     f'{format_time(total_time)}',
                     f'{format_time(total_time * N)}',
                     f'{speed_up}',
                     f'{gpumem}']
            addSummaryLine(items)
        runcounter += 1

    return aev, total_time, force
149
150
151


def check_speedup_error(aev, aev_ref, speed, speed_ref):
Jinze Xue's avatar
Jinze Xue committed
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
    if (speed_ref is not None) and (speed is not None) and (aev is not None) and (aev_ref is not None):
        speedUP = speed_ref / speed
        if speedUP > 1:
            info(f'  Speed up: {speedUP:.2f} X\n')
        else:
            alert(f'  Speed up (slower): {speedUP:.2f} X\n')

        aev_error = torch.max(torch.abs(aev - aev_ref))
        assert aev_error < 0.02, f'  Error: {aev_error:.1e}\n'


def run(file, nnp_ref, nnp_cuaev, runbackward, maxatoms=10000):
    filepath = os.path.join(path, f'../dataset/pdb/{file}')
    mol = read(filepath)
    species = torch.tensor([mol.get_atomic_numbers()], device=device)
    positions = torch.tensor([mol.get_positions()], dtype=torch.float32, requires_grad=False, device=device)
    spelist = list(torch.unique(species.flatten()).cpu().numpy())
    species = species[:, :maxatoms]
    positions = positions[:, :maxatoms, :]
    speciesPositions = nnp_ref.species_converter((species, positions))
    print(f'File: {file}, Molecule size: {species.shape[-1]}, Species: {spelist}\n')

    if args.nsight:
        torch.cuda.nvtx.range_push(file)
176

Jinze Xue's avatar
Jinze Xue committed
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
    print('Original TorchANI:')
    mol_info = {'name': file, 'atoms': species.shape[-1]}
    aev_ref, delta_ref, force_ref = benchmark(speciesPositions, nnp_ref.aev_computer, runbackward, mol_info)
    print()

    print('CUaev:')
    # warm up
    _, _, _ = benchmark(speciesPositions, nnp_cuaev.aev_computer, runbackward, mol_info, verbose=False)
    # run
    aev, delta, force_cuaev = benchmark(speciesPositions, nnp_cuaev.aev_computer, runbackward, mol_info)

    if args.nsight:
        torch.cuda.nvtx.range_pop()

    check_speedup_error(aev, aev_ref, delta, delta_ref)
    print('-' * 70 + '\n')
193
194
195
196
197


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
198
    parser.add_argument('-s', '--nsight',
Jinze Xue's avatar
Jinze Xue committed
199
200
                        action='store_true',
                        help='use nsight profile')
201
202
203
    parser.add_argument('-b', '--backward',
                        action='store_true',
                        help='benchmark double backward')
Jinze Xue's avatar
Jinze Xue committed
204
205
206
    parser.add_argument('-n', '--N',
                        help='Number of Repeat',
                        default=200, type=int)
207
    parser.set_defaults(backward=0)
Jinze Xue's avatar
Jinze Xue committed
208
    args = parser.parse_args()
209
    path = os.path.dirname(os.path.realpath(__file__))
Jinze Xue's avatar
Jinze Xue committed
210
    N = args.N
211
212
213

    device = torch.device('cuda')
    files = ['small.pdb', '1hz5.pdb', '6W8H.pdb']
Jinze Xue's avatar
Jinze Xue committed
214
215
216
217
    # files = ['small.pdb']
    nnp_ref = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
    nnp_cuaev = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
    nnp_cuaev.aev_computer.use_cuda_extension = True
218

Jinze Xue's avatar
Jinze Xue committed
219
    if args.nsight:
Jinze Xue's avatar
Jinze Xue committed
220
221
222
        N = 3
        torch.cuda.profiler.start()

223
    for file in files:
Jinze Xue's avatar
Jinze Xue committed
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
        run(file, nnp_ref, nnp_cuaev, runbackward=False)
    for maxatom in [6000, 10000]:
        file = '1C17.pdb'
        run(file, nnp_ref, nnp_cuaev, runbackward=False, maxatoms=maxatom)

    addSummaryEmptyLine()
    info('Add Backward\n')

    for file in files:
        run(file, nnp_ref, nnp_cuaev, runbackward=True)
    for maxatom in [6000, 10000]:
        file = '1C17.pdb'
        run(file, nnp_ref, nnp_cuaev, runbackward=True, maxatoms=maxatom)
    addSummaryEmptyLine()

    print(summary)
    if args.nsight:
Jinze Xue's avatar
Jinze Xue committed
241
        torch.cuda.profiler.stop()