Unverified Commit 1d64bbae authored by Jinze Xue's avatar Jinze Xue Committed by GitHub
Browse files

pdb benchmark tool (#582)

parent f87ba402
This source diff could not be displayed because it is too large. You can view the blob instead.
......@@ -7,15 +7,25 @@ import os
from ase.io import read
import argparse
summary = '\n'
runcounter = 0
N = 200
last_py_speed = None
def checkgpu(device=None):
i = device if device else torch.cuda.current_device()
t = torch.cuda.get_device_properties(i).total_memory
c = torch.cuda.memory_reserved(i)
name = torch.cuda.get_device_properties(i).name
print(' GPU Memory Cached (pytorch) : {:7.1f}MB / {:.1f}MB ({})'.format(c / 1024 / 1024, t / 1024 / 1024, name))
real_i = int(os.environ['CUDA_VISIBLE_DEVICES'][0]) if 'CUDA_VISIBLE_DEVICES' in os.environ else i
pynvml.nvmlInit()
h = pynvml.nvmlDeviceGetHandleByIndex(real_i)
info = pynvml.nvmlDeviceGetMemoryInfo(h)
name = pynvml.nvmlDeviceGetName(h)
print(' GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))
print(' GPU Memory Used (nvidia-smi): {:7.1f}MB / {:.1f}MB ({})'.format(info.used / 1024 / 1024, info.total / 1024 / 1024, name.decode()))
return f'{(info.used / 1024 / 1024):.1f}MB'
def alert(text):
......@@ -26,107 +36,206 @@ def info(text):
print('\033[32m{}\33[0m'.format(text)) # green
def benchmark(speciesPositions, aev_comp, N, check_gpu_mem, nn=None, verbose=True):
def format_time(t):
if t < 1:
t = f'{t * 1000:.1f} ms'
else:
t = f'{t:.3f} sec'
return t
def addSummaryLine(items=None, init=False):
if init:
addSummaryEmptyLine()
items = ['RUN', 'PDB', 'Size', 'forward', 'backward', 'Others', 'Total', f'Total({N})', 'Speedup', 'GPU']
global summary
summary += items[0].ljust(20) + items[1].ljust(13) + items[2].ljust(13) + items[3].ljust(13) + items[4].ljust(13) + items[5].ljust(13) + \
items[6].ljust(13) + items[7].ljust(13) + items[8].ljust(13) + items[9].ljust(13) + '\n'
def addSummaryEmptyLine():
global summary
summary += f"{'-'*20}".ljust(20) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + \
f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + f"{'-'*13}".ljust(13) + '\n'
def benchmark(speciesPositions, aev_comp, runbackward=False, mol_info=None, verbose=True):
global runcounter
global last_py_speed
runname = f"{'cu' if aev_comp.use_cuda_extension else 'py'} aev fd{'+bd' if runbackward else''}"
items = [f'{(runcounter+1):02} {runname}', f"{mol_info['name']}", f"{mol_info['atoms']}", '-', '-', '-', '-', '-', '-', '-']
forward_time = 0
force_time = 0
torch.cuda.empty_cache()
gc.collect()
torch.cuda.synchronize()
start = time.time()
aev = None
force = None
gpumem = None
for i in range(N):
species, coordinates = speciesPositions
if nn is not None: # double backward
coordinates = coordinates.requires_grad_()
_, aev = aev_computer((species, coordinates))
E = nn(aev).sum()
force = -torch.autograd.grad(E, coordinates, create_graph=True, retain_graph=True)[0]
force_true = torch.randn_like(force)
loss = torch.abs(force_true - force).sum(dim=(1, 2)).mean()
loss.backward()
else:
coordinates = coordinates.requires_grad_(runbackward)
torch.cuda.synchronize()
forward_start = time.time()
try:
_, aev = aev_comp((species, coordinates))
if i == 2 and check_gpu_mem:
checkgpu()
except Exception as e:
alert(f" AEV faild: {str(e)[:50]}...")
addSummaryLine(items)
runcounter += 1
return None, None, None
torch.cuda.synchronize()
forward_time += time.time() - forward_start
if runbackward: # backward
force_start = time.time()
try:
force = -torch.autograd.grad(aev.sum(), coordinates, create_graph=True, retain_graph=True)[0]
except Exception as e:
alert(f" Force faild: {str(e)[:50]}...")
addSummaryLine(items)
runcounter += 1
return None, None, None
torch.cuda.synchronize()
force_time += time.time() - force_start
if i == 2 and verbose:
gpumem = checkgpu()
torch.cuda.synchronize()
delta = time.time() - start
total_time = (time.time() - start) / N
force_time = force_time / N
forward_time = forward_time / N
others_time = total_time - force_time - forward_time
if verbose:
print(f' Duration: {delta:.2f} s')
print(f' Speed: {delta/N*1000:.2f} ms/it')
return aev, delta
if aev_comp.use_cuda_extension:
if last_py_speed is not None:
speed_up = last_py_speed / total_time
speed_up = f'{speed_up:.2f}'
else:
speed_up = '-'
last_py_speed = None
else:
last_py_speed = total_time
speed_up = '-'
if verbose:
print(f' Duration: {total_time * N:.2f} s')
print(f' Speed: {total_time*1000:.2f} ms/it')
if runcounter == 0:
addSummaryLine(init=True)
addSummaryEmptyLine()
if runcounter >= 0:
items = [f'{(runcounter+1):02} {runname}',
f"{mol_info['name']}",
f"{mol_info['atoms']}",
f'{format_time(forward_time)}',
f'{format_time(force_time)}',
f'{format_time(others_time)}',
f'{format_time(total_time)}',
f'{format_time(total_time * N)}',
f'{speed_up}',
f'{gpumem}']
addSummaryLine(items)
runcounter += 1
return aev, total_time, force
def check_speedup_error(aev, aev_ref, speed, speed_ref):
speedUP = speed_ref / speed
if speedUP > 1:
info(f' Speed up: {speedUP:.2f} X\n')
else:
alert(f' Speed up (slower): {speedUP:.2f} X\n')
if (speed_ref is not None) and (speed is not None) and (aev is not None) and (aev_ref is not None):
speedUP = speed_ref / speed
if speedUP > 1:
info(f' Speed up: {speedUP:.2f} X\n')
else:
alert(f' Speed up (slower): {speedUP:.2f} X\n')
aev_error = torch.max(torch.abs(aev - aev_ref))
assert aev_error < 0.02, f' Error: {aev_error:.1e}\n'
def run(file, nnp_ref, nnp_cuaev, runbackward, maxatoms=10000):
filepath = os.path.join(path, f'../dataset/pdb/{file}')
mol = read(filepath)
species = torch.tensor([mol.get_atomic_numbers()], device=device)
positions = torch.tensor([mol.get_positions()], dtype=torch.float32, requires_grad=False, device=device)
spelist = list(torch.unique(species.flatten()).cpu().numpy())
species = species[:, :maxatoms]
positions = positions[:, :maxatoms, :]
speciesPositions = nnp_ref.species_converter((species, positions))
print(f'File: {file}, Molecule size: {species.shape[-1]}, Species: {spelist}\n')
if args.nsight:
torch.cuda.nvtx.range_push(file)
aev_error = torch.max(torch.abs(aev - aev_ref))
assert aev_error < 0.02, f' Error: {aev_error:.1e}\n'
print('Original TorchANI:')
mol_info = {'name': file, 'atoms': species.shape[-1]}
aev_ref, delta_ref, force_ref = benchmark(speciesPositions, nnp_ref.aev_computer, runbackward, mol_info)
print()
print('CUaev:')
# warm up
_, _, _ = benchmark(speciesPositions, nnp_cuaev.aev_computer, runbackward, mol_info, verbose=False)
# run
aev, delta, force_cuaev = benchmark(speciesPositions, nnp_cuaev.aev_computer, runbackward, mol_info)
if args.nsight:
torch.cuda.nvtx.range_pop()
check_speedup_error(aev, aev_ref, delta, delta_ref)
print('-' * 70 + '\n')
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument('-m', '--check_gpu_mem',
dest='check_gpu_mem',
action='store_const',
const=1)
parser.add_argument('-s', '--nsight',
action='store_true',
help='use nsight profile')
parser.add_argument('-b', '--backward',
action='store_true',
help='benchmark double backward')
parser.set_defaults(check_gpu_mem=0)
parser.add_argument('-n', '--N',
help='Number of Repeat',
default=200, type=int)
parser.set_defaults(backward=0)
parser = parser.parse_args()
args = parser.parse_args()
path = os.path.dirname(os.path.realpath(__file__))
N = args.N
check_gpu_mem = parser.check_gpu_mem
device = torch.device('cuda')
files = ['small.pdb', '1hz5.pdb', '6W8H.pdb']
# files = ['small.pdb']
nnp_ref = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
nnp_cuaev = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
nnp_cuaev.aev_computer.use_cuda_extension = True
N = 200
if parser.nsight:
if args.nsight:
N = 3
torch.cuda.profiler.start()
for file in files:
datafile = os.path.join(path, f'../dataset/pdb/{file}')
mol = read(datafile)
species = torch.tensor([mol.get_atomic_numbers()], device=device)
positions = torch.tensor([mol.get_positions()], dtype=torch.float32, requires_grad=False, device=device)
print(f'File: {file}, Molecule size: {species.shape[-1]}\n')
nnp = torchani.models.ANI2x(periodic_table_index=True, model_index=None).to(device)
speciesPositions = nnp.species_converter((species, positions))
aev_computer = nnp.aev_computer
if parser.backward:
nn = torch.nn.Sequential(torch.nn.Linear(nnp.aev_computer.aev_length, 1, False)).to(device)
else:
nn = None
if parser.nsight:
torch.cuda.nvtx.range_push(file)
print('Original TorchANI:')
aev_ref, delta_ref = benchmark(speciesPositions, aev_computer, N, check_gpu_mem, nn)
print()
print('CUaev:')
nnp.aev_computer.use_cuda_extension = True
cuaev_computer = nnp.aev_computer
# warm up
_, _ = benchmark(speciesPositions, cuaev_computer, 1, check_gpu_mem, nn, verbose=False)
# run
aev, delta = benchmark(speciesPositions, cuaev_computer, N, check_gpu_mem, nn)
if parser.nsight:
torch.cuda.nvtx.range_pop()
check_speedup_error(aev, aev_ref, delta, delta_ref)
print('-' * 70 + '\n')
if parser.nsight:
run(file, nnp_ref, nnp_cuaev, runbackward=False)
for maxatom in [6000, 10000]:
file = '1C17.pdb'
run(file, nnp_ref, nnp_cuaev, runbackward=False, maxatoms=maxatom)
addSummaryEmptyLine()
info('Add Backward\n')
for file in files:
run(file, nnp_ref, nnp_cuaev, runbackward=True)
for maxatom in [6000, 10000]:
file = '1C17.pdb'
run(file, nnp_ref, nnp_cuaev, runbackward=True, maxatoms=maxatom)
addSummaryEmptyLine()
print(summary)
if args.nsight:
torch.cuda.profiler.stop()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment