Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
60499abc
"...git@developer.sourcefind.cn:modelzoo/yolo11_pytorch.git" did not exist on "a74dc9a0390d8903281065c5a1a578c44ca0cb68"
Commit
60499abc
authored
Jul 28, 2023
by
Tri Dao
Browse files
[Benchmark] Add script to benchmark FlashAttention
parent
32a953f4
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
187 additions
and
71 deletions
+187
-71
benchmarks/benchmark_flash_attention.py
benchmarks/benchmark_flash_attention.py
+142
-44
flash_attn/utils/benchmark.py
flash_attn/utils/benchmark.py
+45
-27
No files found.
benchmarks/benchmark_flash_attention.py
View file @
60499abc
from
functools
import
partial
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import
pickle
import
math
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -6,65 +8,161 @@ import torch.nn.functional as F
...
@@ -6,65 +8,161 @@ import torch.nn.functional as F
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
,
benchmark_combined
from
flash_attn.utils.benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
from
flash_attn.bert_padding
import
unpad_input
,
pad_input
from
flash_attn.utils.benchmark
import
benchmark_fwd_bwd
,
benchmark_combined
from
flash_attn.flash_attn_interface
import
flash_attn_varlen_qkvpacked_func
from
flash_attn
import
flash_attn_qkvpacked_func
def
attention_ref
(
qkv
,
attn_mask
,
dropout_p
,
upcast
=
False
,
causal
=
False
):
try
:
from
triton.ops.flash_attention
import
attention
as
attention_triton
except
ImportError
:
attention_triton
=
None
try
:
import
xformers.ops
as
xops
except
ImportError
:
xops
=
None
def
flops
(
batch
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd"
):
assert
mode
in
[
"fwd"
,
"bwd"
,
"fwd_bwd"
]
f
=
4
*
batch
*
seqlen
**
2
*
nheads
*
headdim
//
(
2
if
causal
else
1
)
return
f
if
mode
==
"fwd"
else
(
2.5
*
f
if
mode
==
"bwd"
else
3.5
*
f
)
def
efficiency
(
flop
,
time
):
return
(
flop
/
time
/
10
**
12
)
if
not
math
.
isnan
(
time
)
else
0.0
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
True
):
"""
"""
Arguments:
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
qkv: (batch_size, seqlen, 3, nheads, head_dim)
attn_mask: (batch_size, seqlen)
dropout_p: float
dropout_p: float
Output:
Output:
output: (batch_size, seqlen, nheads, head_dim)
output: (batch_size, seqlen, nheads, head_dim)
attention: softmax after dropout
"""
"""
q
,
k
,
v
=
(
qkv
.
float
()
if
upcast
else
qkv
).
unbind
(
dim
=
2
)
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
seqlen
=
qkv
.
shape
[
1
]
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
d
=
qkv
.
shape
[
-
1
]
q
=
rearrange
(
q
,
'b t h d -> (b h) t d'
)
scores
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
,
k
/
math
.
sqrt
(
d
))
k
=
rearrange
(
k
,
'b s h d -> (b h) d s'
)
scores
.
masked_fill_
(
rearrange
(
~
attn_mask
,
'b s -> b 1 1 s'
),
float
(
'-inf'
))
softmax_scale
=
1.0
/
math
.
sqrt
(
d
)
# Preallocate attn_weights for `baddbmm`
scores
=
torch
.
empty
(
batch_size
*
nheads
,
seqlen
,
seqlen
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
scores
=
rearrange
(
torch
.
baddbmm
(
scores
,
q
,
k
,
beta
=
0
,
alpha
=
softmax_scale
),
'(b h) t s -> b h t s'
,
h
=
nheads
)
if
causal
:
if
causal
:
causal_mask
=
torch
.
triu
(
torch
.
ones
(
seqlen
,
seqlen
,
dtype
=
torch
.
bool
,
device
=
qkv
.
device
),
1
)
# "triu_tril_cuda_template" not implemented for 'BFloat16'
scores
.
masked_fill_
(
causal_mask
,
float
(
'-inf'
))
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention_drop
=
F
.
dropout
(
attention
,
dropout_p
)
attention_drop
=
F
.
dropout
(
attention
,
dropout_p
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
# return output.to(dtype=qkv.dtype), attention.to(dtype=qkv.dtype)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
torch
.
manual_seed
(
0
)
def
time_fwd_bwd
(
func
,
*
args
,
**
kwargs
):
time_f
,
time_b
=
benchmark_fwd_bwd
(
func
,
*
args
,
**
kwargs
)
return
time_f
[
1
].
mean
,
time_b
[
1
].
mean
repeats
=
30
repeats
=
30
batch_size
=
64
nheads
=
16
seqlen
=
1024
n
=
1024
d
=
n
//
nheads
dropout_p
=
0.1
causal
=
False
dtype
=
torch
.
float16
device
=
'cuda'
device
=
'cuda'
dtype
=
torch
.
float16
bs_seqlen_vals
=
[(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
16384
)]
causal_vals
=
[
False
,
True
]
headdim_vals
=
[
64
,
128
]
dim
=
2048
dropout_p
=
0.0
methods
=
([
"Flash2"
,
"Pytorch"
]
+
([
"Triton"
]
if
attention_triton
is
not
None
else
[])
+
([
"xformers"
]
if
xops
is
not
None
else
[]))
time_f
=
{}
time_b
=
{}
time_f_b
=
{}
speed_f
=
{}
speed_b
=
{}
speed_f_b
=
{}
for
causal
in
causal_vals
:
for
headdim
in
headdim_vals
:
for
batch_size
,
seqlen
in
bs_seqlen_vals
:
config
=
(
causal
,
headdim
,
batch_size
,
seqlen
)
nheads
=
dim
//
headdim
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
f
,
b
=
time_fwd_bwd
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
time_f
[
config
,
"Flash2"
]
=
f
time_b
[
config
,
"Flash2"
]
=
b
try
:
qkv
=
qkv
.
detach
().
requires_grad_
(
True
)
f
,
b
=
time_fwd_bwd
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
except
:
# Skip if OOM
f
,
b
=
float
(
'nan'
),
float
(
'nan'
)
time_f
[
config
,
"Pytorch"
]
=
f
time_b
[
config
,
"Pytorch"
]
=
b
if
attention_triton
is
not
None
:
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
nheads
,
seqlen
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
# Try both values of sequence_parallel and pick the faster one
try
:
f
,
b
=
time_fwd_bwd
(
attention_triton
,
q
,
k
,
v
,
causal
,
headdim
**
(
-
0.5
),
False
,
repeats
=
repeats
,
verbose
=
False
)
except
:
f
,
b
=
float
(
'nan'
),
float
(
'inf'
)
try
:
_
,
b0
=
time_fwd_bwd
(
attention_triton
,
q
,
k
,
v
,
causal
,
headdim
**
(
-
0.5
),
True
,
repeats
=
repeats
,
verbose
=
False
)
except
:
b0
=
float
(
'inf'
)
time_f
[
config
,
"Triton"
]
=
f
time_b
[
config
,
"Triton"
]
=
min
(
b
,
b0
)
if
min
(
b
,
b0
)
<
float
(
'inf'
)
else
float
(
'nan'
)
if
xops
is
not
None
:
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
f
,
b
=
time_fwd_bwd
(
xops
.
memory_efficient_attention
,
q
,
k
,
v
,
attn_bias
=
xops
.
LowerTriangularMask
()
if
causal
else
None
,
op
=
(
xops
.
fmha
.
cutlass
.
FwOp
,
xops
.
fmha
.
cutlass
.
BwOp
)
)
time_f
[
config
,
"xformers"
]
=
f
time_b
[
config
,
"xformers"
]
=
b
print
(
f
"### causal=
{
causal
}
, headdim=
{
headdim
}
, batch_size=
{
batch_size
}
, seqlen=
{
seqlen
}
###"
)
for
method
in
methods
:
time_f_b
[
config
,
method
]
=
time_f
[
config
,
method
]
+
time_b
[
config
,
method
]
speed_f
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd"
),
time_f
[
config
,
method
]
)
speed_b
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"bwd"
),
time_b
[
config
,
method
]
)
speed_f_b
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd_bwd"
),
time_f_b
[
config
,
method
]
)
print
(
f
"
{
method
}
fwd:
{
speed_f
[
config
,
method
]:.
2
f
}
TFLOPs/s, "
f
"bwd:
{
speed_b
[
config
,
method
]:.
2
f
}
TFLOPs/s, "
f
"fwd + bwd:
{
speed_f_b
[
config
,
method
]:.
2
f
}
TFLOPs/s"
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
n
,
device
=
'cuda'
,
dtype
=
dtype
,
requires_grad
=
True
)
# with open('flash2_attn_time.plk', 'wb') as fp:
Wqkv
=
torch
.
nn
.
Linear
(
nheads
*
d
,
3
*
nheads
*
d
,
device
=
device
,
dtype
=
dtype
)
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
lengths
=
torch
.
randint
(
seqlen
-
20
,
seqlen
,
(
batch_size
,
1
),
device
=
'cuda'
)
attention_mask_bool
=
repeat
(
torch
.
arange
(
seqlen
,
device
=
'cuda'
),
's -> b s'
,
b
=
batch_size
)
<
lengths
attention_mask
=
torch
.
zeros
(
batch_size
,
seqlen
,
device
=
'cuda'
,
dtype
=
dtype
)
attention_mask
[
~
attention_mask_bool
]
=
-
10000.0
attention_mask
=
rearrange
(
attention_mask
,
'b s -> b 1 1 s'
)
x_unpad
,
indices
,
cu_seqlens
,
max_seqlen_in_batch
=
unpad_input
(
x
,
attention_mask_bool
)
qkv_unpad
=
rearrange
(
Wqkv
(
x_unpad
),
'nnz (t h d) -> nnz t h d'
,
t
=
3
,
h
=
nheads
).
detach
().
requires_grad_
()
qkv
=
rearrange
(
Wqkv
(
x
),
'b s (t h d) -> b s t h d'
,
t
=
3
,
h
=
nheads
).
detach
().
requires_grad_
()
fn
=
lambda
qkv_unpad
:
flash_attn_varlen_qkvpacked_func
(
qkv_unpad
,
cu_seqlens
,
max_seqlen_in_batch
,
dropout_p
,
causal
=
causal
)
benchmark_all
(
fn
,
qkv_unpad
,
repeats
=
repeats
,
desc
=
'FlashAttention'
)
fn
=
lambda
qkv
:
attention_ref
(
qkv
,
attention_mask_bool
,
dropout_p
,
causal
=
causal
)
benchmark_all
(
fn
,
qkv
,
repeats
=
repeats
,
desc
=
'PyTorch Standard Attention'
)
flash_attn/utils/benchmark.py
View file @
60499abc
# Copyright (c) 202
2
, Tri Dao.
# Copyright (c) 202
3
, Tri Dao.
""" Useful functions for writing test code. """
""" Useful functions for writing test code. """
import
torch
import
torch
...
@@ -10,14 +10,12 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
...
@@ -10,14 +10,12 @@ def benchmark_forward(fn, *inputs, repeats=10, desc='', verbose=True, amp=False,
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the forward pass of an arbitrary function. """
if
verbose
:
if
verbose
:
print
(
desc
,
'- Forward pass'
)
print
(
desc
,
'- Forward pass'
)
def
fn_
amp
(
*
inputs
,
**
kwinputs
):
def
amp
_wrapper
(
*
inputs
,
**
kwinputs
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
fn
(
*
inputs
,
**
kwinputs
)
fn
(
*
inputs
,
**
kwinputs
)
for
_
in
range
(
repeats
):
# warmup
fn_amp
(
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'fn_amp(*inputs, **kwinputs)'
,
stmt
=
'fn_amp(*inputs, **kwinputs)'
,
globals
=
{
'fn_amp'
:
fn_
amp
,
'inputs'
:
inputs
,
'kwinputs'
:
kwinputs
},
globals
=
{
'fn_amp'
:
amp
_wrapper
,
'inputs'
:
inputs
,
'kwinputs'
:
kwinputs
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
...
@@ -40,13 +38,18 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -40,13 +38,18 @@ def benchmark_backward(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
else
:
else
:
if
grad
.
shape
!=
y
.
shape
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
raise
RuntimeError
(
'Grad shape does not match output shape'
)
for
_
in
range
(
repeats
):
# warmup
def
f
(
*
inputs
,
y
,
grad
):
# Set .grad to None to avoid extra operation of gradient accumulation
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
y
.
backward
(
grad
,
retain_graph
=
True
)
y
.
backward
(
grad
,
retain_graph
=
True
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'
y.backward(grad, retain_graph=True
)'
,
stmt
=
'
f(*inputs, y=y, grad=grad
)'
,
globals
=
{
'y'
:
y
,
'grad'
:
grad
},
globals
=
{
'f'
:
f
,
'inputs'
:
inputs
,
'y'
:
y
,
'grad'
:
grad
},
num_threads
=
torch
.
get_num_threads
(),
num_threads
=
torch
.
get_num_threads
(),
)
)
m
=
t
.
timeit
(
repeats
)
m
=
t
.
timeit
(
repeats
)
if
verbose
:
if
verbose
:
print
(
m
)
print
(
m
)
...
@@ -58,19 +61,24 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -58,19 +61,24 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
if
verbose
:
if
verbose
:
print
(
desc
,
'- Forward + Backward pass'
)
print
(
desc
,
'- Forward + Backward pass'
)
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
def
f
(
grad
,
*
inputs
,
**
kwinputs
):
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
y
=
fn
(
*
inputs
,
**
kwinputs
)
y
=
fn
(
*
inputs
,
**
kwinputs
)
if
type
(
y
)
is
tuple
:
if
type
(
y
)
is
tuple
:
y
=
y
[
0
]
y
=
y
[
0
]
if
grad
is
None
:
grad
=
torch
.
randn_like
(
y
)
else
:
if
grad
.
shape
!=
y
.
shape
:
raise
RuntimeError
(
'Grad shape does not match output shape'
)
y
.
backward
(
grad
,
retain_graph
=
True
)
y
.
backward
(
grad
,
retain_graph
=
True
)
for
_
in
range
(
repeats
):
# warmup
f
(
grad
,
*
inputs
,
**
kwinputs
)
t
=
benchmark
.
Timer
(
t
=
benchmark
.
Timer
(
stmt
=
'f(grad, *inputs, **kwinputs)'
,
stmt
=
'f(grad, *inputs, **kwinputs)'
,
globals
=
{
'f'
:
f
,
'fn'
:
fn
,
'inputs'
:
inputs
,
'grad'
:
grad
,
'kwinputs'
:
kwinputs
},
globals
=
{
'f'
:
f
,
'fn'
:
fn
,
'inputs'
:
inputs
,
'grad'
:
grad
,
'kwinputs'
:
kwinputs
},
...
@@ -82,6 +90,17 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
...
@@ -82,6 +90,17 @@ def benchmark_combined(fn, *inputs, grad=None, repeats=10, desc='', verbose=True
return
t
,
m
return
t
,
m
def
benchmark_fwd_bwd
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
return
(
benchmark_forward
(
fn
,
*
inputs
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
benchmark_backward
(
fn
,
*
inputs
,
grad
=
grad
,
repeats
=
repeats
,
desc
=
desc
,
verbose
=
verbose
,
amp
=
amp
,
amp_dtype
=
amp_dtype
,
**
kwinputs
),
)
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
def
benchmark_all
(
fn
,
*
inputs
,
grad
=
None
,
repeats
=
10
,
desc
=
''
,
verbose
=
True
,
amp
=
False
,
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
amp_dtype
=
torch
.
float16
,
**
kwinputs
):
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
""" Use Pytorch Benchmark on the forward+backward pass of an arbitrary function. """
...
@@ -102,16 +121,15 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
...
@@ -102,16 +121,15 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
g
=
torch
.
randn_like
(
fn
(
*
inputs
,
**
kwinputs
))
for
_
in
range
(
30
):
# Warm up
for
_
in
range
(
30
):
# Warm up
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
# fn(*inputs, **kwinputs) if not backward else fn(*inputs, **kwinputs).backward(g)
out
=
fn
(
*
inputs
,
**
kwinputs
)
out
=
fn
(
*
inputs
,
**
kwinputs
)
# Backward should be done outside autocast
# Backward should be done outside autocast
if
backward
:
if
backward
:
out
.
backward
(
g
)
out
.
backward
(
g
,
retain_graph
=
True
)
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
activities
=
([
torch
.
profiler
.
ProfilerActivity
.
CPU
]
if
cpu
else
[])
+
[
torch
.
profiler
.
ProfilerActivity
.
CUDA
]
with
torch
.
profiler
.
profile
(
with
torch
.
profiler
.
profile
(
activities
=
activities
,
activities
=
activities
,
...
@@ -119,13 +137,13 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
...
@@ -119,13 +137,13 @@ def pytorch_profiler(fn, *inputs, trace_filename=None, backward=False, amp=False
# profile_memory=True,
# profile_memory=True,
with_stack
=
True
,
with_stack
=
True
,
)
as
prof
:
)
as
prof
:
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
with
torch
.
autocast
(
device_type
=
'cuda'
,
dtype
=
amp_dtype
,
enabled
=
amp
):
if
backward
:
for
x
in
inputs
:
if
isinstance
(
x
,
torch
.
Tensor
):
x
.
grad
=
None
out
=
fn
(
*
inputs
,
**
kwinputs
)
out
=
fn
(
*
inputs
,
**
kwinputs
)
if
backward
:
out
.
backward
(
g
)
if
backward
:
out
.
backward
(
g
,
retain_graph
=
True
)
if
verbose
:
if
verbose
:
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
# print(prof.key_averages().table(sort_by="self_cuda_time_total", row_limit=50))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
print
(
prof
.
key_averages
().
table
(
row_limit
=
50
))
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment