Commit fb88e5e4 authored by Tri Dao's avatar Tri Dao
Browse files

Move benchmark utils, support AMP

parent a5a8806d
...@@ -6,7 +6,7 @@ import torch.nn.functional as F ...@@ -6,7 +6,7 @@ import torch.nn.functional as F
from einops import rearrange, repeat from einops import rearrange, repeat
from benchmarks.utils import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined from flash_attn.utils.benchmark import benchmark_all, benchmark_forward, benchmark_backward, benchmark_combined
from flash_attn.bert_padding import unpad_input, pad_input from flash_attn.bert_padding import unpad_input, pad_input
from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func from flash_attn.flash_attn_interface import flash_attn_unpadded_qkvpacked_func
......
# Adapted from https://github.com/HazyResearch/hippo/blob/datasets/benchmark/utils.py # Copyright (c) 2022, Tri Dao.
""" Useful functions for writing test code. """ """ Useful functions for writing test code. """
import torch import torch
import torch.utils.benchmark as benchmark import torch.utils.benchmark as benchmark
def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', verbose=True, **kwinputs): def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """ """ Use Pytorch Benchmark on the forward pass of an arbitrary function. """
if verbose: if verbose:
print(desc, '- Forward pass') print(desc, '- Forward pass')
def fn_amp(*inputs, **kwinputs):
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
fn(*inputs, **kwinputs)
for _ in range(repeats): # warmup
fn_amp(*inputs, **kwinputs)
t = benchmark.Timer( t = benchmark.Timer(
stmt='fn(*inputs, **kwinputs)', stmt='fn_amp(*inputs, **kwinputs)',
globals={'fn': fn, 'inputs': inputs, 'kwinputs': kwinputs}, globals={'fn_amp': fn_amp, 'inputs': inputs, 'kwinputs': kwinputs},
num_threads=torch.get_num_threads(), num_threads=torch.get_num_threads(),
) )
m = t.timeit(repeats) m = t.timeit(repeats)
...@@ -20,50 +26,51 @@ def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', ve ...@@ -20,50 +26,51 @@ def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', ve
return t, m return t, m
def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **kwinputs): def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """ """ Use Pytorch Benchmark on the backward pass of an arbitrary function. """
if verbose: if verbose:
print(desc, '- Backward pass') print(desc, '- Backward pass')
y = fn(*inputs, **kwinputs) with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if type(y) is tuple: y = fn(*inputs, **kwinputs)
y = y[0] if type(y) is tuple:
y = y[0]
if grad is None: if grad is None:
grad = torch.randn_like(y) grad = torch.randn_like(y)
else: else:
if grad.shape != y.shape: if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape') raise RuntimeError('Grad shape does not match output shape')
for _ in range(repeats): # warmup
y.backward(grad, retain_graph=True)
t = benchmark.Timer( t = benchmark.Timer(
stmt='y.backward(grad, retain_graph=True)', stmt='y.backward(grad, retain_graph=True)',
globals={'y': y, 'grad': grad}, globals={'y': y, 'grad': grad},
num_threads=torch.get_num_threads(), num_threads=torch.get_num_threads(),
) )
m = t.timeit(repeats) m = t.timeit(repeats)
if verbose: if verbose:
print(m) print(m)
return t, m return t, m
def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **kwinputs): def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
if verbose: if verbose:
print(desc, '- Forward + Backward pass') print(desc, '- Forward + Backward pass')
# y = fn(*inputs, **kwinputs)
# if grad is None:
# grad = torch.randn_like(y)
# else:
# if grad.shape != y.shape:
# raise RuntimeError('Grad shape does not match output shape')
# del y
def f(grad, *inputs, **kwinputs): def f(grad, *inputs, **kwinputs):
y = fn(*inputs, **kwinputs) with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if type(y) is tuple: y = fn(*inputs, **kwinputs)
y = y[0] if type(y) is tuple:
y = y[0]
if grad is None: if grad is None:
grad = torch.randn_like(y) grad = torch.randn_like(y)
else: else:
if grad.shape != y.shape: if grad.shape != y.shape:
raise RuntimeError('Grad shape does not match output shape') raise RuntimeError('Grad shape does not match output shape')
y.backward(grad, retain_graph=True) y.backward(grad, retain_graph=True)
for _ in range(repeats): # warmup
f(grad, *inputs, **kwinputs)
t = benchmark.Timer( t = benchmark.Timer(
stmt='f(grad, *inputs, **kwinputs)', stmt='f(grad, *inputs, **kwinputs)',
globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs}, globals={'f': f, 'fn': fn, 'inputs': inputs, 'grad': grad, 'kwinputs': kwinputs},
...@@ -75,43 +82,53 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True ...@@ -75,43 +82,53 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
return t, m return t, m
def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, **kwinputs): def benchmark_all(fn, *inputs, grad=None, repeats=10, desc='', verbose=True, amp=False,
amp_dtype=torch.float16, **kwinputs):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """ """ Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
return ( return (
benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose, **kwinputs), benchmark_forward(fn, *inputs, repeats=repeats, desc=desc, verbose=verbose,
amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, benchmark_backward(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
**kwinputs), amp=amp, amp_dtype=amp_dtype, **kwinputs),
benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose, benchmark_combined(fn, *inputs, grad=grad, repeats=repeats, desc=desc, verbose=verbose,
**kwinputs), amp=amp, amp_dtype=amp_dtype, **kwinputs),
) )
def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False, verbose=True): def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False,
amp_dtype=torch.float16, cpu=False, verbose=True, **kwinputs):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """ """ Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if backward: if backward:
g = torch.randn_like(fn(*inputs)) with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
for _ in range(10): # Warm up g = torch.randn_like(fn(*inputs, **kwinputs))
with torch.autocast(device_type='cuda', enabled=amp): for _ in range(30): # Warm up
with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if backward: if backward:
for x in inputs: for x in inputs:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x.grad = None x.grad = None
fn(*inputs) if not backward else fn(*inputs).backward(g) # fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
out = fn(*inputs, **kwinputs)
# Backward should be done outside autocast
if backward:
out.backward(g)
activities = ([torch.profiler.ProfilerActivity.CPU] if cpu else []) + [torch.profiler.ProfilerActivity.CUDA]
with torch.profiler.profile( with torch.profiler.profile(
# activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,], activities=activities,
activities=[torch.profiler.ProfilerActivity.CUDA,],
record_shapes=True, record_shapes=True,
# profile_memory=True, # profile_memory=True,
with_stack=True, with_stack=True,
) as prof: ) as prof:
with torch.autocast(device_type='cuda', enabled=amp): with torch.autocast(device_type='cuda', dtype=amp_dtype, enabled=amp):
if backward: if backward:
for x in inputs: for x in inputs:
if isinstance(x, torch.Tensor): if isinstance(x, torch.Tensor):
x.grad = None x.grad = None
fn(*inputs) if not backward else fn(*inputs).backward(g) out = fn(*inputs, **kwinputs)
if backward: out.backward(g)
if verbose: if verbose:
print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50)) # print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print(prof.key_averages().table(row_limit=50))
if trace_filename is not None: if trace_filename is not None:
prof.export_chrome_trace(trace_filename) prof.export_chrome_trace(trace_filename)
...@@ -124,6 +141,6 @@ def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs): ...@@ -124,6 +141,6 @@ def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
torch.cuda.synchronize() torch.cuda.synchronize()
mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000) mem = torch.cuda.max_memory_allocated() / ((2 ** 20) * 1000)
if verbose: if verbose:
print(f'{desc} max memory: ', mem) print(f'{desc} max memory: {mem}GB')
torch.cuda.empty_cache() torch.cuda.empty_cache()
return mem return mem
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment