Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
fb88e5e4
Commit
fb88e5e4
authored
Oct 23, 2022
by
Tri Dao
Browse files
Move benchmark utils, support AMP
parent
a5a8806d
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
36 deletions
+53
-36
benchmarks/benchmark_flash_attention.py
benchmarks/benchmark_flash_attention.py
+1
-1
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+52
-35
No files found.
benchmarks/benchmark_flash_attention.py
View file @
fb88e5e4
...
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
benchmark
s.utils
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.utils.
benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
...
...
benchmark
s/utils
.py
→
flash_attn/utils/
benchmark.py
View file @
fb88e5e4
#
Adapted from https://github.com/HazyResearch/hippo/blob/datasets/benchmark/utils.py
#
Copyright (c) 2022, Tri Dao.
""" Useful functions for writing test code. """
import
torch
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
min_run_time
=
0.2
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
if
verbose
:
print
(
desc
,
'- Forward pass'
)
def
fn_amp
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
for
_
in
range
(
repeats
):
# warmup
fn_amp
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
'fn(*inputs, **kwinputs)'
,
globals
=
{
'fn'
:
fn
,
'inputs'
:
inputs
,
'kwinputs'
:
kwinputs
},
stmt
=
'fn
_amp
(*inputs, **kwinputs)'
,
globals
=
{
'fn
_amp
'
:
fn
_amp
,
'inputs'
:
inputs
,
'kwinputs'
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
)
m
=
t
.
timeit
(
repeats
)
...
...
@@ -20,10 +26,12 @@ def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', ve
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
if
verbose
:
print
(
desc
,
'- Backward pass'
)
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
...
...
@@ -32,6 +40,8 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
for
_
in
range
(
repeats
):
# warmup
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
stmt
=
'y.backward(grad, retain_graph=True)'
,
globals
=
{
'y'
:
y
,
'grad'
:
grad
},
...
...
@@ -43,18 +53,13 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
if
verbose
:
print
(
desc
,
'- Forward + Backward pass'
)
# y = fn(*inputs, **kwinputs)
# if grad is None:
# grad = torch.randn_like(y)
# else:
# if grad.shape != y.shape:
# raise RuntimeError('Grad shape does not match output shape')
# del y
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
...
...
@@ -64,6 +69,8 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
y
.
backward
(
grad
,
retain_graph
=
True
)
for
_
in
range
(
repeats
):
# warmup
f
(
grad
,
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
stmt
=
'f(grad, *inputs, **kwinputs)'
,
globals
=
{
'f'
:
f
,
'fn'
:
fn
,
'inputs'
:
inputs
,
'grad'
:
grad
,
'kwinputs'
:
kwinputs
},
...
...
@@ -75,43 +82,53 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
return
t
,
m
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
**
kwinputs
),
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
**
kwinputs
),
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
**
kwinputs
),
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
verbose
=
True
):
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if
backward
:
g
=
torch
.
randn_like
(
fn
(
*
inputs
))
for
_
in
range
(
10
):
# Warm up
with
torch
.
autocast
(
device_type
=
'cuda'
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
for
_
in
range
(
30
):
# Warm up
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
fn
(
*
inputs
)
if
not
backward
else
fn
(
*
inputs
).
backward
(
g
)
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
out
=
fn
(
*
inputs
,
**
kwinputs
)
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
# activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,],
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
,],
activities
=
activities
,
record_shapes
=
True
,
# profile_memory=True,
with_stack
=
True
,
)
as
prof
:
with
torch
.
autocast
(
device_type
=
'cuda'
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
fn
(
*
inputs
)
if
not
backward
else
fn
(
*
inputs
).
backward
(
g
)
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
backward
:
out
.
backward
(
g
)
if
verbose
:
print
(
prof
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=
50
))
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
if
trace_filename
is
not
None
:
prof
.
export_chrome_trace
(
trace_filename
)
...
...
@@ -124,6 +141,6 @@ def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
print
(
f
'
{
desc
}
max memory:
'
,
mem
)
print
(
f
'
{
desc
}
max memory:
{
mem
}
GB'
)
torch
.
cuda
.
empty_cache
()
return
mem
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment