Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
fb88e5e4
Commit
fb88e5e4
authored
Oct 23, 2022
by
Tri Dao
Browse files
Move benchmark utils, support AMP
parent
a5a8806d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
53 additions
and
36 deletions
+53
-36
benchmarks/benchmark_flash_attention.py
benchmarks/benchmark_flash_attention.py
+1
-1
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+52
-35
No files found.
benchmarks/benchmark_flash_attention.py
View file @
fb88e5e4
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
...
@@ -6,7 +6,7 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
benchmark
s.utils
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.utils.
benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
from
flash_attn.flash_attn_interface
import
flash_attn_unpadded_qkvpacked_func
...
...
benchmark
s/utils
.py
→
flash_attn/utils/
benchmark.py
View file @
fb88e5e4
#
Adapted from https://github.com/HazyResearch/hippo/blob/datasets/benchmark/utils.py
#
Copyright (c) 2022, Tri Dao.
""" Useful functions for writing test code. """
""" Useful functions for writing test code. """
import
torch
import
torch
import
torch.utils.benchmark
as
benchmark
import
torch.utils.benchmark
as
benchmark
def
benchmark_forward
(
fn
,
*
inputs
,
min_run_time
=
0.2
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
if
verbose
:
if
verbose
:
print
(
desc
,
'- Forward pass'
)
print
(
desc
,
'- Forward pass'
)
def
fn_amp
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
for
_
in
range
(
repeats
):
# warmup
fn_amp
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'fn(*inputs, **kwinputs)'
,
stmt
=
'fn
_amp
(*inputs, **kwinputs)'
,
globals
=
{
'fn'
:
fn
,
'inputs'
:
inputs
,
'kwinputs'
:
kwinputs
},
globals
=
{
'fn
_amp
'
:
fn
_amp
,
'inputs'
:
inputs
,
'kwinputs'
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
...
@@ -20,50 +26,51 @@ def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', ve
...
@@ -20,50 +26,51 @@ def benchmark_forward(fn, *inputs, min_run_time = 0.2, repeats = 10, desc='', ve
return
t
,
m
return
t
,
m
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_backward
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the backward pass of an arbitrary function. """
if
verbose
:
if
verbose
:
print
(
desc
,
'- Backward pass'
)
print
(
desc
,
'- Backward pass'
)
y
=
fn
(
*
inputs
,
**
kwinputs
)
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
type
(
y
)
is
tuple
:
y
=
fn
(
*
inputs
,
**
kwinputs
)
y
=
y
[
0
]
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
grad
=
torch
.
randn_like
(
y
)
else
:
else
:
if
grad
.
shape
!=
y
.
shape
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
'Grad shape does not match output shape'
)
for
_
in
range
(
repeats
):
# warmup
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'y.backward(grad, retain_graph=True)'
,
stmt
=
'y.backward(grad, retain_graph=True)'
,
globals
=
{
'y'
:
y
,
'grad'
:
grad
},
globals
=
{
'y'
:
y
,
'grad'
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
if
verbose
:
print
(
m
)
print
(
m
)
return
t
,
m
return
t
,
m
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_combined
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
if
verbose
:
if
verbose
:
print
(
desc
,
'- Forward + Backward pass'
)
print
(
desc
,
'- Forward + Backward pass'
)
# y = fn(*inputs, **kwinputs)
# if grad is None:
# grad = torch.randn_like(y)
# else:
# if grad.shape != y.shape:
# raise RuntimeError('Grad shape does not match output shape')
# del y
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
type
(
y
)
is
tuple
:
y
=
fn
(
*
inputs
,
**
kwinputs
)
y
=
y
[
0
]
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
grad
=
torch
.
randn_like
(
y
)
else
:
else
:
if
grad
.
shape
!=
y
.
shape
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
'Grad shape does not match output shape'
)
y
.
backward
(
grad
,
retain_graph
=
True
)
y
.
backward
(
grad
,
retain_graph
=
True
)
for
_
in
range
(
repeats
):
# warmup
f
(
grad
,
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'f(grad, *inputs, **kwinputs)'
,
stmt
=
'f(grad, *inputs, **kwinputs)'
,
globals
=
{
'f'
:
f
,
'fn'
:
fn
,
'inputs'
:
inputs
,
'grad'
:
grad
,
'kwinputs'
:
kwinputs
},
globals
=
{
'f'
:
f
,
'fn'
:
fn
,
'inputs'
:
inputs
,
'grad'
:
grad
,
'kwinputs'
:
kwinputs
},
...
@@ -75,43 +82,53 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -75,43 +82,53 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
return
t
,
m
return
t
,
m
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
**
kwinputs
):
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
return
(
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
**
kwinputs
),
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
**
kwinputs
),
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
benchmark_combined
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
**
kwinputs
),
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
)
)
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
verbose
=
True
):
def
pytorch_profiler
(
fn
,
*
inputs
,
trace_filename
=
None
,
backward
=
False
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
cpu
=
False
,
verbose
=
True
,
**
kwinputs
):
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
""" Wrap benchmark functions in Pytorch profiler to see CUDA information. """
if
backward
:
if
backward
:
g
=
torch
.
randn_like
(
fn
(
*
inputs
))
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
for
_
in
range
(
10
):
# Warm up
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
with
torch
.
autocast
(
device_type
=
'cuda'
,
enabled
=
amp
):
for
_
in
range
(
30
):
# Warm up
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
backward
:
if
backward
:
for
x
in
inputs
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
x
.
grad
=
None
fn
(
*
inputs
)
if
not
backward
else
fn
(
*
inputs
).
backward
(
g
)
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
out
=
fn
(
*
inputs
,
**
kwinputs
)
# Backward should be done outside autocast
if
backward
:
out
.
backward
(
g
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
with
torch
.
profiler
.
profile
(
# activities=[torch.profiler.ProfilerActivity.CPU, torch.profiler.ProfilerActivity.CUDA,],
activities
=
activities
,
activities
=
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
,],
record_shapes
=
True
,
record_shapes
=
True
,
# profile_memory=True,
# profile_memory=True,
with_stack
=
True
,
with_stack
=
True
,
)
as
prof
:
)
as
prof
:
with
torch
.
autocast
(
device_type
=
'cuda'
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
backward
:
if
backward
:
for
x
in
inputs
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
x
.
grad
=
None
fn
(
*
inputs
)
if
not
backward
else
fn
(
*
inputs
).
backward
(
g
)
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
backward
:
out
.
backward
(
g
)
if
verbose
:
if
verbose
:
print
(
prof
.
key_averages
().
table
(
sort_by
=
"self_cuda_time_total"
,
row_limit
=
50
))
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
if
trace_filename
is
not
None
:
if
trace_filename
is
not
None
:
prof
.
export_chrome_trace
(
trace_filename
)
prof
.
export_chrome_trace
(
trace_filename
)
...
@@ -124,6 +141,6 @@ def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
...
@@ -124,6 +141,6 @@ def benchmark_memory(fn, *inputs, desc='', verbose=True, **kwinputs):
torch
.
cuda
.
synchronize
()
torch
.
cuda
.
synchronize
()
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
mem
=
torch
.
cuda
.
max_memory_allocated
()
/
((
2
**
20
)
*
1000
)
if
verbose
:
if
verbose
:
print
(
f
'
{
desc
}
max memory:
'
,
mem
)
print
(
f
'
{
desc
}
max memory:
{
mem
}
GB'
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
return
mem
return
mem
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment