aev-benchmark-size.py 3.4 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import time
import torch
import torchani
import pynvml
import gc
import os
from ase.io import read
import argparse


def checkgpu(device=None):
    i = device if device else torch.cuda.current_device()
    real_i = int(os.environ['CUDA_VISIBLE_DEVICES'][0]) if 'CUDA_VISIBLE_DEVICES' in os.environ else i
    pynvml.nvmlInit()
    h = pynvml.nvmlDeviceGetHandleByIndex(real_i)
    info = pynvml.nvmlDeviceGetMemoryInfo(h)
    name = pynvml.nvmlDeviceGetName(h)
    print('  GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))


def alert(text):
    print('\033[91m{}\33[0m'.format(text))  # red


def info(text):
    print('\033[32m{}\33[0m'.format(text))  # green


def benchmark(speciesPositions, aev_comp, N, check_gpu_mem):
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.synchronize()
    start = time.time()

Jinze Xue's avatar
Jinze Xue committed
35
    aev = None
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
    for i in range(N):
        aev = aev_comp(speciesPositions).aevs
        if i == 2 and check_gpu_mem:
            checkgpu()

    torch.cuda.synchronize()
    delta = time.time() - start
    print(f'  Duration: {delta:.2f} s')
    print(f'  Speed: {delta/N*1000:.2f} ms/it')
    return aev, delta


def check_speedup_error(aev, aev_ref, speed, speed_ref):
    speedUP = speed_ref / speed
    if speedUP > 1:
        info(f'  Speed up: {speedUP:.2f} X\n')
    else:
        alert(f'  Speed up (slower): {speedUP:.2f} X\n')

    aev_error = torch.max(torch.abs(aev - aev_ref))
    assert aev_error < 0.02, f'  Error: {aev_error:.1e}\n'


if __name__ == "__main__":

    parser = argparse.ArgumentParser()
Jinze Xue's avatar
Jinze Xue committed
62
    parser.add_argument('-m', '--check_gpu_mem',
63
64
65
                        dest='check_gpu_mem',
                        action='store_const',
                        const=1)
Jinze Xue's avatar
Jinze Xue committed
66
67
68
    parser.add_argument('--nsight',
                        action='store_true',
                        help='use nsight profile')
69
70
71
72
73
74
75
76
    parser.set_defaults(check_gpu_mem=0)
    parser = parser.parse_args()
    path = os.path.dirname(os.path.realpath(__file__))

    check_gpu_mem = parser.check_gpu_mem
    device = torch.device('cuda')
    files = ['small.pdb', '1hz5.pdb', '6W8H.pdb']

Jinze Xue's avatar
Jinze Xue committed
77
78
79
80
81
    N = 500
    if parser.nsight:
        N = 3
        torch.cuda.profiler.start()

82
83
84
85
86
87
88
89
90
91
92
    for file in files:
        datafile = os.path.join(path, f'../dataset/pdb/{file}')
        mol = read(datafile)
        species = torch.tensor([mol.get_atomic_numbers()], device=device)
        positions = torch.tensor([mol.get_positions()], dtype=torch.float32, requires_grad=False, device=device)
        print(f'File: {file}, Molecule size: {species.shape[-1]}\n')

        nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
        speciesPositions = nnp.species_converter((species, positions))
        aev_computer = nnp.aev_computer

Jinze Xue's avatar
Jinze Xue committed
93
94
        if parser.nsight:
            torch.cuda.nvtx.range_push(file)
95
96
97
98
99
100
101
102
        print('Original TorchANI:')
        aev_ref, delta_ref = benchmark(speciesPositions, aev_computer, N, check_gpu_mem)
        print()

        print('CUaev:')
        nnp.aev_computer.use_cuda_extension = True
        cuaev_computer = nnp.aev_computer
        aev, delta = benchmark(speciesPositions, cuaev_computer, N, check_gpu_mem)
Jinze Xue's avatar
Jinze Xue committed
103
104
        if parser.nsight:
            torch.cuda.nvtx.range_pop()
105

Jinze Xue's avatar
Jinze Xue committed
106
        check_speedup_error(aev, aev_ref, delta, delta_ref)
107
        print('-' * 70 + '\n')
Jinze Xue's avatar
Jinze Xue committed
108
109
110

    if parser.nsight:
        torch.cuda.profiler.stop()