Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3aae9c18
"vscode:/vscode.git/clone" did not exist on "53014d34810779994c035388a3de93a78e16a804"
Commit
3aae9c18
authored
Jul 25, 2024
by
Tri Dao
Browse files
Revert "Changes For FP8 (#1075)"
This reverts commit
1899c970
.
parent
1899c970
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
89 additions
and
986 deletions
+89
-986
hopper/benchmark_flash_attention.py
hopper/benchmark_flash_attention.py
+0
-281
hopper/benchmark_flash_attention_fp8.py
hopper/benchmark_flash_attention_fp8.py
+0
-339
hopper/epilogue_fwd_sm90_tma.hpp
hopper/epilogue_fwd_sm90_tma.hpp
+1
-2
hopper/flash_api.cpp
hopper/flash_api.cpp
+11
-34
hopper/flash_attn_interface.py
hopper/flash_attn_interface.py
+2
-2
hopper/flash_fwd_hdim128_fp8_sm90.cu
hopper/flash_fwd_hdim128_fp8_sm90.cu
+0
-9
hopper/flash_fwd_hdim256_fp8_sm90.cu
hopper/flash_fwd_hdim256_fp8_sm90.cu
+0
-9
hopper/flash_fwd_hdim64_fp8_sm90.cu
hopper/flash_fwd_hdim64_fp8_sm90.cu
+0
-9
hopper/flash_fwd_launch_template.h
hopper/flash_fwd_launch_template.h
+8
-14
hopper/kernel_traits.h
hopper/kernel_traits.h
+7
-25
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
+19
-98
hopper/setup.py
hopper/setup.py
+0
-3
hopper/test_flash_attn.py
hopper/test_flash_attn.py
+41
-79
hopper/utils.h
hopper/utils.h
+0
-82
No files found.
hopper/benchmark_flash_attention.py
deleted
100644 → 0
View file @
1899c970
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import
pickle
import
math
import
time
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
from
flash_attn.utils.benchmark
import
benchmark_fwd_bwd
,
benchmark_combined
from
flash_attn
import
flash_attn_qkvpacked_func
from
flash_attn_interface
import
flash_attn_func
try
:
from
triton.ops.flash_attention
import
attention
as
attention_triton
except
ImportError
:
attention_triton
=
None
try
:
import
xformers.ops
as
xops
except
ImportError
:
xops
=
None
try
:
import
cudnn
except
ImportError
:
cudnn
=
None
def
flops
(
batch
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd"
):
assert
mode
in
[
"fwd"
,
"bwd"
,
"fwd_bwd"
]
f
=
4
*
batch
*
seqlen
**
2
*
nheads
*
headdim
//
(
2
if
causal
else
1
)
return
f
if
mode
==
"fwd"
else
(
2.5
*
f
if
mode
==
"bwd"
else
3.5
*
f
)
def
efficiency
(
flop
,
time
):
return
(
flop
/
time
/
10
**
12
)
if
not
math
.
isnan
(
time
)
else
0.0
def
convert_to_cudnn_type
(
torch_type
):
if
torch_type
==
torch
.
float16
:
return
cudnn
.
data_type
.
HALF
elif
torch_type
==
torch
.
bfloat16
:
return
cudnn
.
data_type
.
BFLOAT16
elif
torch_type
==
torch
.
float32
:
return
cudnn
.
data_type
.
FLOAT
elif
torch_type
==
torch
.
int32
:
return
cudnn
.
data_type
.
INT32
elif
torch_type
==
torch
.
int64
:
return
cudnn
.
data_type
.
INT64
else
:
raise
ValueError
(
"Unsupported tensor data type."
)
def
cudnn_spda_setup
(
q
,
k
,
v
,
causal
=
False
):
b
,
nheads
,
seqlen_q
,
headdim
=
q
.
shape
_
,
_
,
seqlen_k
,
_
=
k
.
shape
assert
v
.
shape
==
(
b
,
nheads
,
seqlen_k
,
headdim
)
assert
cudnn
is
not
None
,
'CUDNN is not available'
q_gpu
,
k_gpu
,
v_gpu
=
q
,
k
,
v
o_gpu
=
torch
.
empty_like
(
q_gpu
)
stats_gpu
=
torch
.
empty
(
b
,
nheads
,
seqlen_q
,
1
,
dtype
=
torch
.
float32
,
device
=
q
.
device
)
graph
=
cudnn
.
pygraph
(
io_data_type
=
convert_to_cudnn_type
(
q
.
dtype
),
intermediate_data_type
=
cudnn
.
data_type
.
FLOAT
,
compute_data_type
=
cudnn
.
data_type
.
FLOAT
,
)
q
=
graph
.
tensor_like
(
q_gpu
.
detach
())
k
=
graph
.
tensor_like
(
k_gpu
.
detach
())
v
=
graph
.
tensor_like
(
v_gpu
.
detach
())
o
,
stats
=
graph
.
sdpa
(
name
=
"sdpa"
,
q
=
q
,
k
=
k
,
v
=
v
,
is_inference
=
False
,
attn_scale
=
1.0
/
math
.
sqrt
(
headdim
),
use_causal_mask
=
causal
,
)
o
.
set_output
(
True
).
set_dim
(
o_gpu
.
shape
).
set_stride
(
o_gpu
.
stride
())
stats
.
set_output
(
True
).
set_data_type
(
cudnn
.
data_type
.
FLOAT
)
graph
.
validate
()
graph
.
build_operation_graph
()
graph
.
create_execution_plans
([
cudnn
.
heur_mode
.
A
,
cudnn
.
heur_mode
.
FALLBACK
])
graph
.
check_support
()
graph
.
build_plans
()
variant_pack
=
{
q
:
q_gpu
,
k
:
k_gpu
,
v
:
v_gpu
,
o
:
o_gpu
,
stats
:
stats_gpu
,
}
workspace
=
torch
.
empty
(
graph
.
get_workspace_size
(),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
def
run
(
*
args
,
**
kwargs
):
graph
.
execute
(
variant_pack
,
workspace
)
return
o_gpu
return
run
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
True
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
q
=
rearrange
(
q
,
'b t h d -> (b h) t d'
)
k
=
rearrange
(
k
,
'b s h d -> (b h) d s'
)
softmax_scale
=
1.0
/
math
.
sqrt
(
d
)
# Preallocate attn_weights for `baddbmm`
scores
=
torch
.
empty
(
batch_size
*
nheads
,
seqlen
,
seqlen
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
scores
=
rearrange
(
torch
.
baddbmm
(
scores
,
q
,
k
,
beta
=
0
,
alpha
=
softmax_scale
),
'(b h) t s -> b h t s'
,
h
=
nheads
)
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention_drop
=
F
.
dropout
(
attention
,
dropout_p
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
def
time_fwd_bwd
(
func
,
*
args
,
**
kwargs
):
time
.
sleep
(
1
)
# Sleep to avoid residual power throttling from the previous benchmark
time_f
,
time_b
=
benchmark_fwd_bwd
(
func
,
*
args
,
**
kwargs
)
return
time_f
[
1
].
mean
,
time_b
[
1
].
mean
repeats
=
30
device
=
'cuda'
dtype
=
torch
.
float16
# Ideally, seq-len should be divisible by 132 to avoid wave quantization.
# However, the existing Triton implementation doesn't support seq-len like 8448.
bs_seqlen_vals
=
[(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
)]
# bs_seqlen_vals = [(2, 8192)]
causal_vals
=
[
False
]
# headdim_vals = [64, 128]
headdim_vals
=
[
128
]
dim
=
128
dropout_p
=
0.0
methods
=
([
"Flash2"
,
"Pytorch"
,
"Flash3"
]
+
([
"Triton"
]
if
attention_triton
is
not
None
else
[])
+
([
"xformers.c"
]
if
xops
is
not
None
else
[])
+
([
"xformers.f"
]
if
xops
is
not
None
else
[])
+
([
"cudnn"
]
if
cudnn
is
not
None
else
[]))
time_f
=
{}
time_b
=
{}
time_f_b
=
{}
speed_f
=
{}
speed_b
=
{}
speed_f_b
=
{}
for
causal
in
causal_vals
:
for
headdim
in
headdim_vals
:
for
batch_size
,
seqlen
in
bs_seqlen_vals
:
config
=
(
causal
,
headdim
,
batch_size
,
seqlen
)
nheads
=
dim
//
headdim
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
f
,
b
=
time_fwd_bwd
(
flash_attn_qkvpacked_func
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
time_f
[
config
,
"Flash2"
]
=
f
time_b
[
config
,
"Flash2"
]
=
b
try
:
qkv
=
qkv
.
detach
().
requires_grad_
(
True
)
f
,
b
=
time_fwd_bwd
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
res_baseline
=
attention_pytorch
(
qkv
,
dropout_p
,
causal
=
causal
)
except
:
# Skip if OOM
f
,
b
=
float
(
'nan'
),
float
(
'nan'
)
time_f
[
config
,
"Pytorch"
]
=
f
time_b
[
config
,
"Pytorch"
]
=
b
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
f
,
b
=
time_fwd_bwd
(
flash_attn_func
,
q
,
k
,
v
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
res
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
)
time_f
[
config
,
"Flash3"
]
=
f
time_b
[
config
,
"Flash3"
]
=
b
if
cudnn
is
not
None
:
time
.
sleep
(
1
)
# Sleep to avoid residual power throttling from the previous benchmark
res
=
benchmark_forward
(
cudnn_spda_setup
(
q
.
transpose
(
1
,
2
),
k
.
transpose
(
1
,
2
),
v
.
transpose
(
1
,
2
),
causal
=
causal
),
repeats
=
repeats
,
verbose
=
False
)
f
=
res
[
1
].
mean
time_f
[
config
,
"cudnn"
]
=
f
time_b
[
config
,
"cudnn"
]
=
math
.
inf
if
attention_triton
is
not
None
:
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
nheads
,
seqlen
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
# Try both values of sequence_parallel and pick the faster one
try
:
f
,
b
=
time_fwd_bwd
(
attention_triton
,
q
,
k
,
v
,
causal
,
headdim
**
(
-
0.5
),
False
,
repeats
=
repeats
,
verbose
=
False
)
except
:
f
,
b
=
float
(
'nan'
),
float
(
'inf'
)
try
:
_
,
b0
=
time_fwd_bwd
(
attention_triton
,
q
,
k
,
v
,
causal
,
headdim
**
(
-
0.5
),
True
,
repeats
=
repeats
,
verbose
=
False
)
except
:
b0
=
float
(
'inf'
)
time_f
[
config
,
"Triton"
]
=
f
time_b
[
config
,
"Triton"
]
=
min
(
b
,
b0
)
if
min
(
b
,
b0
)
<
float
(
'inf'
)
else
float
(
'nan'
)
if
xops
is
not
None
:
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
f
,
b
=
time_fwd_bwd
(
xops
.
memory_efficient_attention
,
q
,
k
,
v
,
attn_bias
=
xops
.
LowerTriangularMask
()
if
causal
else
None
,
op
=
(
xops
.
fmha
.
cutlass
.
FwOp
,
xops
.
fmha
.
cutlass
.
BwOp
)
)
time_f
[
config
,
"xformers.c"
]
=
f
time_b
[
config
,
"xformers.c"
]
=
b
if
xops
is
not
None
:
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
for
_
in
range
(
3
)]
f
,
b
=
time_fwd_bwd
(
xops
.
memory_efficient_attention
,
q
,
k
,
v
,
attn_bias
=
xops
.
LowerTriangularMask
()
if
causal
else
None
,
op
=
(
xops
.
fmha
.
flash
.
FwOp
,
xops
.
fmha
.
flash
.
BwOp
)
)
time_f
[
config
,
"xformers.f"
]
=
f
time_b
[
config
,
"xformers.f"
]
=
b
print
(
f
"### causal=
{
causal
}
, headdim=
{
headdim
}
, batch_size=
{
batch_size
}
, seqlen=
{
seqlen
}
###"
)
for
method
in
methods
:
time_f_b
[
config
,
method
]
=
time_f
[
config
,
method
]
+
time_b
[
config
,
method
]
speed_f
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd"
),
time_f
[
config
,
method
]
)
speed_b
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"bwd"
),
time_b
[
config
,
method
]
)
speed_f_b
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd_bwd"
),
time_f_b
[
config
,
method
]
)
#print (time_f[config,method])
print
(
f
"
{
method
}
fwd:
{
speed_f
[
config
,
method
]:.
2
f
}
TFLOPs/s, "
f
"bwd:
{
speed_b
[
config
,
method
]:.
2
f
}
TFLOPs/s, "
f
"fwd + bwd:
{
speed_f_b
[
config
,
method
]:.
2
f
}
TFLOPs/s"
)
# with open('flash2_attn_time.plk', 'wb') as fp:
# pickle.dump((speed_f, speed_b, speed_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
hopper/benchmark_flash_attention_fp8.py
deleted
100644 → 0
View file @
1899c970
# Install the newest triton version with
# pip install "git+https://github.com/openai/triton.git#egg=triton&subdirectory=python"
import
pickle
import
math
import
time
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
einops
import
rearrange
,
repeat
from
flash_attn.utils.benchmark
import
benchmark_all
,
benchmark_forward
,
benchmark_backward
from
flash_attn.utils.benchmark
import
benchmark_fwd_bwd
,
benchmark_combined
from
flash_attn
import
flash_attn_qkvpacked_func
from
flash_attn_interface
import
flash_attn_func
try
:
from
triton_fused_attention
import
attention
as
attention_triton
except
ImportError
:
attention_triton
=
None
try
:
import
xformers.ops
as
xops
except
ImportError
:
xops
=
None
try
:
import
cudnn
except
ImportError
:
cudnn
=
None
def
convert_to_cudnn_type
(
torch_type
):
if
torch_type
==
torch
.
float16
:
return
cudnn
.
data_type
.
HALF
elif
torch_type
==
torch
.
bfloat16
:
return
cudnn
.
data_type
.
BFLOAT16
elif
torch_type
==
torch
.
float32
:
return
cudnn
.
data_type
.
FLOAT
elif
torch_type
==
torch
.
int32
:
return
cudnn
.
data_type
.
INT32
elif
torch_type
==
torch
.
int64
:
return
cudnn
.
data_type
.
INT64
elif
torch_type
==
torch
.
float8_e4m3fn
:
return
cudnn
.
data_type
.
FP8_E4M3
elif
torch_type
==
torch
.
float8_e4m3fn
:
return
cudnn
.
data_type
.
FP8_E5M2
else
:
raise
ValueError
(
"Unsupported tensor data type."
)
def
cudnn_spda_setup
(
qkv
,
seqlen_q
,
seqlen_k
,
causal
=
False
):
b
,
_
,
_
,
nheads
,
headdim
=
qkv
.
shape
assert
cudnn
is
not
None
,
'CUDNN is not available'
o_gpu
=
torch
.
zeros
(
b
,
seqlen_q
,
nheads
,
headdim
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
o_gpu_transposed
=
torch
.
as_strided
(
o_gpu
,
[
b
,
nheads
,
seqlen_q
,
headdim
],
[
nheads
*
seqlen_q
*
headdim
,
headdim
,
nheads
*
headdim
,
1
],
)
stats_gpu
=
torch
.
empty
(
b
,
nheads
,
seqlen_q
,
1
,
dtype
=
torch
.
float32
,
device
=
qkv
.
device
)
amax_s_gpu
=
torch
.
empty
(
1
,
1
,
1
,
1
,
dtype
=
torch
.
float32
,
device
=
qkv
.
device
)
amax_o_gpu
=
torch
.
empty
(
1
,
1
,
1
,
1
,
dtype
=
torch
.
float32
,
device
=
qkv
.
device
)
graph
=
cudnn
.
pygraph
(
io_data_type
=
convert_to_cudnn_type
(
qkv
.
dtype
),
intermediate_data_type
=
cudnn
.
data_type
.
FLOAT
,
compute_data_type
=
cudnn
.
data_type
.
FLOAT
,
)
new_q
=
torch
.
as_strided
(
qkv
,
[
b
,
nheads
,
seqlen_q
,
headdim
],
[
seqlen_q
*
nheads
*
headdim
*
3
,
headdim
,
headdim
*
nheads
*
3
,
1
],
storage_offset
=
0
,
)
q
=
graph
.
tensor
(
name
=
"Q"
,
dim
=
list
(
new_q
.
shape
),
stride
=
list
(
new_q
.
stride
()),
data_type
=
convert_to_cudnn_type
(
qkv
.
dtype
)
)
new_k
=
torch
.
as_strided
(
qkv
,
[
b
,
nheads
,
seqlen_k
,
headdim
],
[
seqlen_k
*
nheads
*
headdim
*
3
,
headdim
,
headdim
*
nheads
*
3
,
1
],
storage_offset
=
nheads
*
headdim
,
)
k
=
graph
.
tensor
(
name
=
"K"
,
dim
=
list
(
new_k
.
shape
),
stride
=
list
(
new_k
.
stride
()),
data_type
=
convert_to_cudnn_type
(
qkv
.
dtype
)
)
new_v
=
torch
.
as_strided
(
qkv
,
[
b
,
nheads
,
seqlen_k
,
headdim
],
[
seqlen_k
*
nheads
*
headdim
*
3
,
headdim
,
headdim
*
nheads
*
3
,
1
],
storage_offset
=
nheads
*
headdim
*
2
,
)
v
=
graph
.
tensor
(
name
=
"V"
,
dim
=
list
(
new_v
.
shape
),
stride
=
list
(
new_v
.
stride
()),
data_type
=
convert_to_cudnn_type
(
qkv
.
dtype
)
)
def
get_default_scale_tensor
():
return
graph
.
tensor
(
dim
=
[
1
,
1
,
1
,
1
],
stride
=
[
1
,
1
,
1
,
1
],
data_type
=
cudnn
.
data_type
.
FLOAT
)
default_scale_gpu
=
torch
.
ones
(
1
,
1
,
1
,
1
,
dtype
=
torch
.
float32
,
device
=
"cuda"
)
descale_q
=
get_default_scale_tensor
()
descale_k
=
get_default_scale_tensor
()
descale_v
=
get_default_scale_tensor
()
descale_s
=
get_default_scale_tensor
()
scale_s
=
get_default_scale_tensor
()
scale_o
=
get_default_scale_tensor
()
o
,
_
,
amax_s
,
amax_o
=
graph
.
sdpa_fp8
(
q
=
q
,
k
=
k
,
v
=
v
,
descale_q
=
descale_q
,
descale_k
=
descale_k
,
descale_v
=
descale_v
,
descale_s
=
descale_s
,
scale_s
=
scale_s
,
scale_o
=
scale_o
,
is_inference
=
True
,
attn_scale
=
1.0
/
math
.
sqrt
(
headdim
),
use_causal_mask
=
causal
,
name
=
"sdpa"
,
)
o
.
set_output
(
True
).
set_dim
(
o_gpu_transposed
.
shape
).
set_stride
(
o_gpu_transposed
.
stride
())
amax_s
.
set_output
(
False
).
set_dim
(
amax_s_gpu
.
shape
).
set_stride
(
amax_s_gpu
.
stride
())
amax_o
.
set_output
(
False
).
set_dim
(
amax_o_gpu
.
shape
).
set_stride
(
amax_o_gpu
.
stride
())
# stats.set_output(True).set_data_type(cudnn.data_type.FLOAT)
graph
.
validate
()
graph
.
build_operation_graph
()
graph
.
create_execution_plans
([
cudnn
.
heur_mode
.
A
,
cudnn
.
heur_mode
.
FALLBACK
])
graph
.
check_support
()
graph
.
build_plans
()
variant_pack
=
{
q
:
new_q
,
k
:
new_k
,
v
:
new_v
,
descale_q
:
default_scale_gpu
,
descale_k
:
default_scale_gpu
,
descale_v
:
default_scale_gpu
,
descale_s
:
default_scale_gpu
,
scale_s
:
default_scale_gpu
,
scale_o
:
default_scale_gpu
,
o
:
o_gpu_transposed
,
amax_s
:
amax_s_gpu
,
amax_o
:
amax_o_gpu
,
}
workspace
=
torch
.
empty
(
graph
.
get_workspace_size
(),
device
=
"cuda"
,
dtype
=
torch
.
uint8
)
def
run
(
*
args
,
**
kwargs
):
graph
.
execute
(
variant_pack
,
workspace
)
return
o_gpu
,
amax_o_gpu
return
run
def
attention_pytorch
(
qkv
,
dropout_p
=
0.0
,
causal
=
True
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, head_dim)
dropout_p: float
Output:
output: (batch_size, seqlen, nheads, head_dim)
"""
batch_size
,
seqlen
,
_
,
nheads
,
d
=
qkv
.
shape
q
,
k
,
v
=
qkv
.
unbind
(
dim
=
2
)
q
=
rearrange
(
q
,
'b t h d -> (b h) t d'
)
k
=
rearrange
(
k
,
'b s h d -> (b h) d s'
)
softmax_scale
=
1.0
/
math
.
sqrt
(
d
)
# Preallocate attn_weights for `baddbmm`
scores
=
torch
.
empty
(
batch_size
*
nheads
,
seqlen
,
seqlen
,
dtype
=
qkv
.
dtype
,
device
=
qkv
.
device
)
scores
=
rearrange
(
torch
.
baddbmm
(
scores
,
q
,
k
,
beta
=
0
,
alpha
=
softmax_scale
),
'(b h) t s -> b h t s'
,
h
=
nheads
)
if
causal
:
# "triu_tril_cuda_template" not implemented for 'BFloat16'
# So we have to construct the mask in float
causal_mask
=
torch
.
triu
(
torch
.
full
((
seqlen
,
seqlen
),
-
10000.0
,
device
=
scores
.
device
),
1
)
# TD [2022-09-30]: Adding is faster than masked_fill_ (idk why, just better kernel I guess)
scores
=
scores
+
causal_mask
.
to
(
dtype
=
scores
.
dtype
)
attention
=
torch
.
softmax
(
scores
,
dim
=-
1
)
attention_drop
=
F
.
dropout
(
attention
,
dropout_p
)
output
=
torch
.
einsum
(
'bhts,bshd->bthd'
,
attention_drop
,
v
)
return
output
.
to
(
dtype
=
qkv
.
dtype
)
def
flops
(
batch
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd"
):
assert
mode
in
[
"fwd"
,
"bwd"
,
"fwd_bwd"
]
f
=
4
*
batch
*
seqlen
**
2
*
nheads
*
headdim
//
(
2
if
causal
else
1
)
return
f
if
mode
==
"fwd"
else
(
2.5
*
f
if
mode
==
"bwd"
else
3.5
*
f
)
def
efficiency
(
flop
,
time
):
return
(
flop
/
time
/
10
**
12
)
if
not
math
.
isnan
(
time
)
else
0.0
def
time_fwd
(
func
,
*
args
,
**
kwargs
):
time
.
sleep
(
1
)
# Sleep to avoid residual power throttling from the previous benchmark
time_f
=
benchmark_forward
(
func
,
*
args
,
**
kwargs
)
return
time_f
[
1
].
mean
torch
.
manual_seed
(
0
)
repeats
=
30
device
=
'cuda'
# dtype = torch.float16
dtype
=
torch
.
float8_e4m3fn
#bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4224), (2, 8448), (1, 8448 * 2)]
bs_seqlen_vals
=
[(
32
,
512
),
(
16
,
1024
),
(
8
,
2048
),
(
4
,
4096
),
(
2
,
8192
),
(
1
,
8192
*
2
)]
#bs_seqlen_vals = [(4, 4224), (2, 8448), (1, 8448 * 2)]
# bs_seqlen_vals = [(32, 512), (16, 1024), (8, 2048), (4, 4096), (2, 8192), (1, 8192 * 2)]
# bs_seqlen_vals = [(4, 8448)]
causal_vals
=
[
False
,
True
]
#headdim_vals = [64, 128, 256]
headdim_vals
=
[
128
,
256
]
dim
=
2048
# dim = 128
dropout_p
=
0.0
methods
=
([
"Pytorch"
,
"Flash3"
,
"cuDNN"
]
+
([
"Triton"
]
if
attention_triton
is
not
None
else
[])
# + (["xformers.c"] if xops is not None else [])
# + (["xformers.f"] if xops is not None else [])
)
time_f
=
{}
time_b
=
{}
time_f_b
=
{}
speed_f
=
{}
speed_b
=
{}
speed_f_b
=
{}
for
causal
in
causal_vals
:
for
headdim
in
headdim_vals
:
for
batch_size
,
seqlen
in
bs_seqlen_vals
:
torch
.
cuda
.
empty_cache
()
config
=
(
causal
,
headdim
,
batch_size
,
seqlen
)
nheads
=
dim
//
headdim
q
,
k
,
v
=
[
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
device
=
device
,
dtype
=
torch
.
float16
,
requires_grad
=
False
)
for
_
in
range
(
3
)]
qkv
=
torch
.
stack
([
q
,
k
,
v
],
dim
=
2
)
qkv
=
qkv
.
to
(
torch
.
float16
)
f
=
time_fwd
(
attention_pytorch
,
qkv
,
dropout_p
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
time_f
[
config
,
"Pytorch"
]
=
f
res_baseline
=
attention_pytorch
(
qkv
,
dropout_p
,
causal
=
causal
)
if
attention_triton
is
not
None
:
q_transposed
=
q
.
transpose
(
1
,
2
).
contiguous
().
to
(
torch
.
float8_e4m3fn
)
k_transposed
=
k
.
transpose
(
1
,
2
).
contiguous
().
to
(
torch
.
float8_e4m3fn
)
v_transposed
=
v
.
transpose
(
1
,
2
).
contiguous
().
permute
(
0
,
1
,
3
,
2
).
to
(
torch
.
float8_e4m3fn
)
scale
=
1
/
math
.
sqrt
(
headdim
)
f
=
time_fwd
(
attention_triton
,
q_transposed
,
k_transposed
,
v_transposed
,
causal
,
scale
,
repeats
=
5
,
verbose
=
False
,
desc
=
'Triton'
)
f
=
time_fwd
(
attention_triton
,
q_transposed
,
k_transposed
,
v_transposed
,
causal
,
scale
,
repeats
=
repeats
,
verbose
=
False
,
desc
=
'Triton'
)
time_f
[
config
,
"Triton"
]
=
f
res
=
attention_triton
(
q_transposed
,
k_transposed
,
v_transposed
.
permute
(
0
,
1
,
3
,
2
),
causal
,
scale
).
half
().
transpose
(
1
,
2
)
torch
.
testing
.
assert_close
(
res
,
res_baseline
,
atol
=
0.5
,
rtol
=
0.5
)
out
=
torch
.
empty_like
(
q
)
q
,
k
,
v
=
q
.
to
(
dtype
),
k
.
to
(
dtype
),
v
.
to
(
dtype
)
v_transposed
=
v
.
transpose
(
1
,
3
).
contiguous
().
clone
()
#v_transposed = v.transpose(1,3).clone()
time
.
sleep
(
1
)
f
=
time_fwd
(
flash_attn_func
,
q
,
k
,
v_transposed
,
causal
=
causal
,
repeats
=
repeats
,
verbose
=
False
)
# res = flash_attn_func(q, k, v, causal=causal, is_fp16_acc=False)
# torch.testing.assert_close(res.half(), res_baseline, atol=0.05, rtol=0.05)
time_f
[
config
,
"Flash3"
]
=
f
if
cudnn
is
not
None
:
qkv_fp8
=
qkv
.
to
(
dtype
)
time
.
sleep
(
1
)
# Sleep to avoid residual power throttling from the previous benchmark
f
=
time_fwd
(
cudnn_spda_setup
(
qkv_fp8
,
seqlen
,
seqlen
,
causal
=
causal
),
repeats
=
repeats
,
verbose
=
False
)
time_f
[
config
,
"cuDNN"
]
=
f
# res, amax_o = cudnn_spda_setup(
# qkv_fp8, seqlen, seqlen,
# causal=causal
# )()
# res = res.half()
# TODO: CUDNN has numerics issues when
# num_heads=16, dim=128, seq_len=1024, batch_size=2
# or larger sizes.
# res_cpu = res.cpu().reshape(-1)
# res_baseline_cpu = res_baseline.cpu().reshape(-1)
# print(amax_o)
# print(res)
# print(res_baseline)
# for i in range(len(res_cpu)):
# item = res_cpu[i]
# item_baseline = res_baseline_cpu[i]
# if abs(item - item_baseline) > 0.5:
# print(i)
# print(item)
# print(item_baseline)
# torch.testing.assert_close(res, res_baseline, atol=0.05, rtol=0.05)
print
(
f
"### causal=
{
causal
}
, headdim=
{
headdim
}
, batch_size=
{
batch_size
}
, seqlen=
{
seqlen
}
###"
)
for
method
in
methods
:
speed_f
[
config
,
method
]
=
efficiency
(
flops
(
batch_size
,
seqlen
,
headdim
,
nheads
,
causal
,
mode
=
"fwd"
),
time_f
[
config
,
method
]
)
#print (time_f[config,method])
print
(
f
"
{
method
}
fwd:
{
speed_f
[
config
,
method
]:.
2
f
}
TFLOPs/s,
{
time_f
[
config
,
method
]
*
1e3
}
ms, "
)
# with open('flash3_attn_time.plk', 'wb') as fp:
# pickle.dump((time_f, time_b, time_f_b), fp, protocol=pickle.HIGHEST_PROTOCOL)
hopper/epilogue_fwd_sm90_tma.hpp
View file @
3aae9c18
...
...
@@ -20,8 +20,7 @@ using namespace cute;
template
<
typename
Ktraits
,
typename
Seqlen_traits
>
struct
CollectiveEpilogueFwd
{
using
PrecType
=
typename
Ktraits
::
Element
;
using
Element
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
PrecType
,
cutlass
::
float_e4m3_t
>>
(
cutlass
::
half_t
{},
PrecType
{}));
using
Element
=
typename
Ktraits
::
Element
;
static
constexpr
int
kBlockM
=
Ktraits
::
kBlockM
;
static
constexpr
int
kBlockN
=
Ktraits
::
kBlockN
;
static
constexpr
int
kHeadDim
=
Ktraits
::
kHeadDim
;
...
...
hopper/flash_api.cpp
View file @
3aae9c18
...
...
@@ -249,13 +249,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
}
}
}
else
{
if
(
params
.
d
==
64
)
{
run_mha_fwd_
<
cutlass
::
float_e4m3_t
,
64
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
128
)
{
run_mha_fwd_
<
cutlass
::
float_e4m3_t
,
128
>
(
params
,
stream
);
}
else
{
run_mha_fwd_
<
cutlass
::
float_e4m3_t
,
256
>
(
params
,
stream
);
}
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
}
}
...
...
@@ -272,8 +266,12 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
is_sm90
,
"FlashAttention only supports Hopper GPUs or newer."
);
auto
q_dtype
=
q
.
dtype
();
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
||
q_dtype
==
torch
::
kFloat8_e4m3fn
,
"FlashAttention only support fp16, bf16 and fp8 (e4m3) data type for now"
);
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
"FlashAttention only support fp16 and bf16 data type for now"
);
// TODO: will add e4m3 later
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
// "FlashAttention only support fp16 and bf16 data type");
// "FlashAttention only support fp16 and fp8 (e4m3) data type for now");
TORCH_CHECK
(
k
.
dtype
()
==
q_dtype
,
"query and key must have the same dtype"
);
TORCH_CHECK
(
v
.
dtype
()
==
q_dtype
,
"query and value must have the same dtype"
);
...
...
@@ -303,50 +301,29 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
CHECK_SHAPE
(
q
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
CHECK_SHAPE
(
k
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
if
(
q_dtype
==
torch
::
kFloat8_e4m3fn
)
{
CHECK_SHAPE
(
v
,
batch_size
,
head_size_og
,
num_heads_k
,
seqlen_k
);
}
else
{
CHECK_SHAPE
(
v
,
batch_size
,
seqlen_k
,
num_heads_k
,
head_size_og
);
}
at
::
Tensor
q_padded
,
k_padded
,
v_padded
;
if
(
q_dtype
==
torch
::
kFloat8_e4m3fn
)
{
if
(
head_size_og
%
16
!=
0
)
{
q_padded
=
torch
::
nn
::
functional
::
pad
(
q
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
16
-
head_size_og
%
16
}));
k_padded
=
torch
::
nn
::
functional
::
pad
(
k
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
16
-
head_size_og
%
16
}));
}
else
{
q_padded
=
q
;
k_padded
=
k
;
}
if
(
seqlen_k
%
16
!=
0
)
{
v_padded
=
torch
::
nn
::
functional
::
pad
(
v
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
16
-
seqlen_k
%
16
}));
}
else
{
v_padded
=
v
;
}
}
else
{
if
(
head_size_og
%
8
!=
0
)
{
if
(
head_size_og
%
8
!=
0
)
{
q_padded
=
torch
::
nn
::
functional
::
pad
(
q
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
k_padded
=
torch
::
nn
::
functional
::
pad
(
k
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
v_padded
=
torch
::
nn
::
functional
::
pad
(
v
,
torch
::
nn
::
functional
::
PadFuncOptions
({
0
,
8
-
head_size_og
%
8
}));
}
else
{
}
else
{
q_padded
=
q
;
k_padded
=
k
;
v_padded
=
v
;
}
}
at
::
Tensor
out
;
if
(
out_
.
has_value
())
{
out
=
out_
.
value
();
//
TORCH_CHECK(out.dtype() == q_dtype, "Output must have the same dtype as inputs");
TORCH_CHECK
(
out
.
dtype
()
==
q_dtype
,
"Output must have the same dtype as inputs"
);
CHECK_DEVICE
(
out
);
TORCH_CHECK
(
out
.
stride
(
-
1
)
==
1
,
"Output tensor must have contiguous last dimension"
);
CHECK_SHAPE
(
out
,
batch_size
,
seqlen_q
,
num_heads
,
head_size_og
);
if
(
head_size_og
%
8
!=
0
)
{
out
=
torch
::
empty_like
(
q_padded
);
}
}
else
{
out
=
q_dtype
==
torch
::
kFloat8_e4m3fn
?
torch
::
empty_like
(
q_padded
,
at
::
kHalf
)
:
torch
::
empty_like
(
q_padded
);
out
=
torch
::
empty_like
(
q_padded
);
}
auto
round_multiple
=
[](
int
x
,
int
m
)
{
return
(
x
+
m
-
1
)
/
m
*
m
;
};
...
...
hopper/flash_attn_interface.py
View file @
3aae9c18
...
...
@@ -15,7 +15,7 @@ def maybe_contiguous(x):
return
x
.
contiguous
()
if
x
is
not
None
and
x
.
stride
(
-
1
)
!=
1
else
x
def
_flash_attn_forward
(
q
,
k
,
v
,
softmax_scale
,
causal
):
#
q, k, v = [maybe_contiguous(x) for x in (q, k, v)]
q
,
k
,
v
=
[
maybe_contiguous
(
x
)
for
x
in
(
q
,
k
,
v
)]
out
,
q
,
k
,
v
,
out_padded
,
softmax_lse
,
S_dmask
=
flashattn_hopper_cuda
.
fwd
(
q
,
k
,
...
...
@@ -41,7 +41,7 @@ def _flash_attn_backward(
causal
):
# dq, dk, dv are allocated by us so they should already be contiguous
#
dout, q, k, v, out = [maybe_contiguous(x) for x in (dout, q, k, v, out)]
dout
,
q
,
k
,
v
,
out
=
[
maybe_contiguous
(
x
)
for
x
in
(
dout
,
q
,
k
,
v
,
out
)]
dq
,
dk
,
dv
,
softmax_d
,
=
flashattn_hopper_cuda
.
bwd
(
dout
,
q
,
...
...
hopper/flash_fwd_hdim128_fp8_sm90.cu
deleted
100644 → 0
View file @
1899c970
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
float_e4m3_t
,
128
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim128
<
cutlass
::
float_e4m3_t
>
(
params
,
stream
);
}
hopper/flash_fwd_hdim256_fp8_sm90.cu
deleted
100644 → 0
View file @
1899c970
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
float_e4m3_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
float_e4m3_t
>
(
params
,
stream
);
}
hopper/flash_fwd_hdim64_fp8_sm90.cu
deleted
100644 → 0
View file @
1899c970
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
float_e4m3_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
float_e4m3_t
>
(
params
,
stream
);
}
hopper/flash_fwd_launch_template.h
View file @
3aae9c18
...
...
@@ -21,7 +21,6 @@
template
<
typename
Kernel_traits
,
bool
Is_causal
,
typename
Seqlen_traits
>
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementO
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
cutlass
::
half_t
{},
Element
{}));
using
TileShape_MNK
=
typename
Kernel_traits
::
TileShape_MNK
;
using
ClusterShape
=
typename
Kernel_traits
::
ClusterShape_MNK
;
...
...
@@ -128,14 +127,10 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, cudaStream_t stream) {
SEQLEN_SWITCH
(
params
.
cu_seqlens_q
,
Seqlen_traits
,
[
&
]
{
// Only use Cluster if number of tiles along seqlen_q is even and not Is_causal
BOOL_SWITCH
(
cutlass
::
ceil_div
(
params
.
seqlen_q
,
128
)
%
2
==
0
&&
!
Is_causal
&&
!
Seqlen_traits
::
kUseVarSeqLen
,
UseCluster
,
[
&
]
{
if
constexpr
(
is_same_v
<
T
,
cutlass
::
float_e4m3_t
>
)
{
//run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 3, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
//run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 192, 128, 16, 2, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
12
,
4
,
false
,
!
Is_causal
&&
UseCluster
?
2
:
1
,
T
>
,
Is_causal
,
Seqlen_traits
>
(
params
,
stream
);
//run_flash_fwd<Flash_fwd_kernel_traits<Headdim, 128, 128, 12, 4, false, !Is_causal && UseCluster ? 2 : 1, T>, Is_causal>(params, stream);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
Is_causal
?
128
:
176
,
12
,
2
,
false
,
UseCluster
?
2
:
1
,
T
>
,
Is_causal
,
Seqlen_traits
>
(
params
,
stream
);
}
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
Is_causal
?
128
:
176
,
12
,
2
,
false
,
UseCluster
?
2
:
1
,
T
>
,
Is_causal
,
Seqlen_traits
>
(
params
,
stream
);
});
});
});
...
...
@@ -148,11 +143,10 @@ void run_mha_fwd_hdim256(Flash_fwd_params ¶ms, cudaStream_t stream) {
SEQLEN_SWITCH
(
params
.
cu_seqlens_q
,
Seqlen_traits
,
[
&
]
{
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH
(
cutlass
::
ceil_div
(
params
.
seqlen_q
,
128
)
%
2
==
0
&&
!
Is_causal
&&
!
Seqlen_traits
::
kUseVarSeqLen
,
UseCluster
,
[
&
]
{
if
constexpr
(
is_same_v
<
T
,
cutlass
::
float_e4m3_t
>
)
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
128
,
12
,
3
,
false
,
!
Is_causal
&&
UseCluster
?
2
:
1
,
T
>
,
Is_causal
,
Seqlen_traits
>
(
params
,
stream
);
}
else
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
80
,
12
,
2
,
false
,
UseCluster
?
2
:
1
,
T
>
,
Is_causal
,
Seqlen_traits
>
(
params
,
stream
);
}
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
80
,
12
,
2
,
false
,
UseCluster
?
2
:
1
,
T
>
,
Is_causal
,
Seqlen_traits
>
(
params
,
stream
);
});
});
});
...
...
hopper/kernel_traits.h
View file @
3aae9c18
...
...
@@ -25,7 +25,6 @@ struct SharedStorageQKVO {
cute
::
array_aligned
<
OutputType
,
cute
::
cosize_v
<
SmemLayoutO
>>
smem_o
;
};
struct
{
cute
::
uint64_t
tma_load_mbar
[
4
];
// 4 TMA barriers pre-allocated for usage.
cutlass
::
arch
::
ClusterTransactionBarrier
barrier_Q
;
cutlass
::
arch
::
ClusterBarrier
barrier_O
;
typename
cutlass
::
PipelineTmaAsync
<
kStages
>::
SharedStorage
pipeline_k
;
...
...
@@ -41,7 +40,6 @@ struct Flash_fwd_kernel_traits {
using
Element
=
elem_type
;
using
ElementAccum
=
float
;
using
index_t
=
int64_t
;
using
ElementO
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
cutlass
::
half_t
{},
Element
{}));
// The number of threads.
static
constexpr
int
kNWarps
=
kNWarps_
;
...
...
@@ -71,11 +69,9 @@ struct Flash_fwd_kernel_traits {
decltype
(
cute
::
GMMA
::
ss_op_selector
<
Element
,
Element
,
ElementAccum
,
TileShape_MNK
>
())
>
{},
AtomLayoutMNK
{}));
using
TiledMma1
=
decltype
(
cute
::
make_tiled_mma
(
cute
::
GMMA
::
rs_op_selector
<
Element
,
Element
,
ElementAccum
,
decltype
(
select
<
0
,
2
,
1
>
(
TileShape_MNK
{})),
GMMA
::
Major
::
K
,
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
GMMA
::
Major
::
K
,
GMMA
::
Major
::
MN
)
>
(),
GMMA
::
Major
::
K
,
GMMA
::
Major
::
MN
>
(),
AtomLayoutMNK
{}));
using
SmemLayoutAtomQ
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
...
...
@@ -88,33 +84,19 @@ struct Flash_fwd_kernel_traits {
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
make_shape
(
shape
<
1
>
(
TileShape_MNK
{}),
shape
<
2
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{})));
using
SmemLayoutAtomV
Fp16
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
using
SmemLayoutAtomV
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutV
Fp16
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
Fp16
{},
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
make_shape
(
shape
<
1
>
(
TileShape_MNK
{}),
shape
<
2
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{})));
using
SmemLayoutAtomVFp8
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutVFp8
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVFp8
{},
make_shape
(
shape
<
2
>
(
TileShape_MNK
{}),
shape
<
1
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{})));
using
SmemLayoutV
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
SmemLayoutVFp8
{},
SmemLayoutVFp16
{}));
// Note this is the transpose in terms of the view, not in terms of memory.
using
SmemLayoutVtFp16
=
decltype
(
cute
::
composition
(
SmemLayoutVFp16
{},
make_layout
(
make_shape
(
get
<
2
>
(
TileShape_MNK
{}),
get
<
1
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{}),
make_stride
(
get
<
1
>
(
TileShape_MNK
{}),
_1
{},
Int
<
size
(
SmemLayoutVFp16
{}(
_
,
_
,
_0
{}))
>
{}))));
using
SmemLayoutVt
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
SmemLayoutVFp8
{},
SmemLayoutVtFp16
{}));
using
SmemLayoutAtomO
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
ElementO
,
using
SmemLayoutAtomO
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutO
=
decltype
(
tile_to_shape
(
SmemLayoutAtomO
{},
select
<
0
,
2
>
(
TileShape_MNK
{})));
using
SmemCopyAtomQ
=
Copy_Atom
<
cute
::
SM75_U32x4_LDSM_N
,
Element
O
>
;
using
SmemCopyAtomQ
=
Copy_Atom
<
cute
::
SM75_U32x4_LDSM_N
,
Element
>
;
using
SharedStorage
=
SharedStorageQKVO
<
kStages
,
Element
,
Element
,
Element
O
,
SmemLayoutQ
,
using
SharedStorage
=
SharedStorageQKVO
<
kStages
,
Element
,
Element
,
Element
,
SmemLayoutQ
,
SmemLayoutK
,
SmemLayoutV
,
SmemLayoutO
>
;
using
MainloopPipeline
=
typename
cutlass
::
PipelineTmaAsync
<
kStages
>
;
...
...
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
View file @
3aae9c18
...
...
@@ -43,30 +43,12 @@ struct CollectiveMainloopFwd {
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
make_shape
(
shape
<
1
>
(
TileShape_MNK
{}),
shape
<
2
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{})));
using
SmemLayoutAtomVFp8
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
decltype
(
cute
::
get
<
2
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutVFp8
=
decltype
(
tile_to_shape
(
SmemLayoutAtomVFp8
{},
make_shape
(
shape
<
2
>
(
TileShape_MNK
{}),
shape
<
1
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{})));
using
SmemLayoutVFp16
=
SmemLayoutK
;
using
SmemLayoutV
=
SmemLayoutK
;
// Note this is the transpose in terms of the view, not in terms of memory.
using
SmemLayoutVt
Fp16
=
decltype
(
cute
::
composition
(
SmemLayoutV
Fp16
{},
using
SmemLayoutVt
=
decltype
(
cute
::
composition
(
SmemLayoutV
{},
make_layout
(
make_shape
(
get
<
2
>
(
TileShape_MNK
{}),
get
<
1
>
(
TileShape_MNK
{}),
Int
<
kStages
>
{}),
make_stride
(
get
<
1
>
(
TileShape_MNK
{}),
_1
{},
Int
<
size
(
SmemLayoutVFp16
{}(
_
,
_
,
_0
{}))
>
{}))));
using
SmemLayoutV
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
SmemLayoutVFp8
{},
SmemLayoutVFp16
{}));
using
SmemLayoutVt
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
SmemLayoutVFp8
{},
SmemLayoutVtFp16
{}));
// Dummy S layout for getting the shape for GEMM-II.
using
SmemLayoutAtomS
=
decltype
(
cutlass
::
gemm
::
collective
::
detail
::
ss_smem_selector
<
GMMA
::
Major
::
K
,
Element
,
decltype
(
cute
::
get
<
0
>
(
TileShape_MNK
{})),
decltype
(
cute
::
get
<
1
>
(
TileShape_MNK
{}))
>
());
using
SmemLayoutS
=
decltype
(
tile_to_shape
(
SmemLayoutAtomS
{},
make_shape
(
shape
<
0
>
(
TileShape_MNK
{}),
shape
<
1
>
(
TileShape_MNK
{}))));
make_stride
(
get
<
1
>
(
TileShape_MNK
{}),
_1
{},
Int
<
size
(
SmemLayoutV
{}(
_
,
_
,
_0
{}))
>
{}))));
// using SmemLayoutAtomVt = cute::GMMA::Layout_MN_SW128_Atom<Element>;
// using SmemLayoutVt =
// decltype(tile_to_shape(SmemLayoutAtomVt{},
...
...
@@ -103,19 +85,6 @@ struct CollectiveMainloopFwd {
take
<
0
,
2
>
(
SmemLayoutK
{}),
select
<
1
,
2
>
(
TileShape_MNK
{}),
size
<
0
>
(
ClusterShape
{})));
// mcast along M mode for this N load, if any
//
using
TileShapeVFP8
=
decltype
(
make_shape
(
cute
::
get
<
2
>
(
TileShape_MNK
{}),
cute
::
get
<
1
>
(
TileShape_MNK
{})));
using
TileShapeVFP16
=
decltype
(
make_shape
(
cute
::
get
<
1
>
(
TileShape_MNK
{}),
cute
::
get
<
2
>
(
TileShape_MNK
{})));
using
TileShapeV
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
TileShapeVFP8
{},
TileShapeVFP16
{}));
using
TMA_VFP8
=
decltype
(
make_tma_copy
(
GmemTiledCopyKV
{},
make_tensor
(
make_gmem_ptr
(
static_cast
<
Element
const
*>
(
nullptr
)),
repeat_like
(
StrideQKV
{},
int32_t
(
0
)),
StrideQKV
{}),
take
<
0
,
2
>
(
SmemLayoutV
{}),
TileShapeV
{},
size
<
0
>
(
ClusterShape
{})));
// mcast along M mode for this N load, if any
using
TMA_V
=
decltype
(
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
TMA_VFP8
{},
TMA_KV
{}));
static
constexpr
int
NumMmaThreads
=
size
(
typename
Ktraits
::
TiledMma0
{});
using
MainloopPipeline
=
typename
Ktraits
::
MainloopPipeline
;
...
...
@@ -128,7 +97,6 @@ struct CollectiveMainloopFwd {
static
constexpr
bool
UseSchedulerBarrier
=
kHeadDim
<=
128
;
// Host side kernel arguments
struct
Arguments
{
Element
const
*
ptr_Q
;
...
...
@@ -147,8 +115,7 @@ struct CollectiveMainloopFwd {
typename
Seqlen_traits
::
LayoutT
layout_V
;
cutlass
::
FastDivmod
qhead_per_khead_divmod
;
TMA_Q
tma_load_Q
;
TMA_KV
tma_load_K
;
TMA_V
tma_load_V
;
TMA_KV
tma_load_K
,
tma_load_V
;
float
const
softmax_scale_log2
;
};
...
...
@@ -169,15 +136,12 @@ struct CollectiveMainloopFwd {
SmemLayoutK
{}(
_
,
_
,
_0
{}),
select
<
1
,
2
>
(
TileShape_MNK
{}),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
auto
gmemLayoutVFp16
=
args
.
shape_K
;
auto
gmemLayoutVFp8
=
select
<
1
,
0
,
2
,
3
>
(
gmemLayoutVFp16
);
auto
gmemLayoutV
=
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
gmemLayoutVFp8
,
gmemLayoutVFp16
);
Tensor
mV
=
make_tensor
(
make_gmem_ptr
(
args
.
ptr_V
),
gmemLayoutV
,
args
.
layout_V
.
stride
());
TMA_V
tma_load_V
=
make_tma_copy
(
Tensor
mV
=
make_tensor
(
make_gmem_ptr
(
args
.
ptr_V
),
args
.
layout_V
);
TMA_KV
tma_load_V
=
make_tma_copy
(
GmemTiledCopyKV
{},
mV
,
SmemLayoutV
{}(
_
,
_
,
_0
{}),
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
select
<
2
,
1
>
(
TileShape_MNK
{}),
select
<
1
,
2
>
(
TileShape_MNK
{})
)
,
select
<
1
,
2
>
(
TileShape_MNK
{}),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
return
{
args
.
layout_Q
,
args
.
layout_K
,
args
.
layout_V
,
cutlass
::
FastDivmod
(
cute
::
ceil_div
(
get
<
2
>
(
args
.
layout_Q
.
shape
()),
get
<
2
>
(
args
.
layout_K
.
shape
()))),
...
...
@@ -234,10 +198,7 @@ struct CollectiveMainloopFwd {
Tensor
mQ
=
mainloop_params
.
tma_load_Q
.
get_tma_tensor
(
mainloop_params
.
layout_Q
.
shape
());
Tensor
mK
=
mainloop_params
.
tma_load_K
.
get_tma_tensor
(
mainloop_params
.
layout_K
.
shape
());
auto
gmemLayoutVFp16
=
mainloop_params
.
shape_K
;
auto
gmemLayoutVFp8
=
select
<
1
,
0
,
2
,
3
>
(
gmemLayoutVFp16
);
auto
gmemLayoutV
=
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
gmemLayoutVFp8
,
gmemLayoutVFp16
);
Tensor
mV
=
mainloop_params
.
tma_load_V
.
get_tma_tensor
(
gmemLayoutV
);
Tensor
mV
=
mainloop_params
.
tma_load_V
.
get_tma_tensor
(
mainloop_params
.
layout_V
.
shape
());
auto
[
m_block
,
bidh
,
bidb
]
=
block_coord
;
int
bidh_kv
=
mainloop_params
.
qhead_per_khead_divmod
.
divide
(
bidh
);
...
...
@@ -246,34 +207,12 @@ struct CollectiveMainloopFwd {
uint32_t
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
constexpr
uint32_t
cluster_shape_x
=
get
<
0
>
(
ClusterShape
());
uint2
cluster_local_block_id
=
{
block_rank_in_cluster
%
cluster_shape_x
,
block_rank_in_cluster
/
cluster_shape_x
};
Tensor
gQ
=
local_tile
(
mQ
(
_
,
_
,
bidh
,
bidb
),
select
<
0
,
2
>
(
TileShape_MNK
{}),
make_coord
(
m_block
,
_0
{}));
// (M, K)
Tensor
gK
=
local_tile
(
mK
(
_
,
_
,
bidh_kv
,
bidb
),
select
<
1
,
2
>
(
TileShape_MNK
{}),
make_coord
(
_
,
_0
{}));
// (N, K, _)
Tensor
gV
=
local_tile
(
mV
(
_
,
_
,
bidh_kv
,
bidb
),
TileShapeV
{},
cute
::
conditional_return
<
is_same_v
<
Element
,
cutlass
::
float_e4m3_t
>>
(
make_coord
(
_0
{},
_
),
make_coord
(
_
,
_0
{})));
// (N, K, _)
#if 0
if (threadIdx.x == 0 && blockIdx.x == 0) {
print ("\n");
print (gV);
print ("\n");
print (gK);
print ("\n");
print ("\n");
print (sV);
print ("\n");
print (sK);
print ("\n");
print (gmemLayoutVFp8);
print ("\n");
print (gmemLayoutVFp16);
}
// Tensor gQ = seqlen_traits_q.get_local_tile_tensor(
// mQ, select<0, 2>(TileShape_MNK{}), bidh, bidb)(_, _, m_block); // (M, K)
// Tensor gK = seqlen_traits_k.get_local_tile_tensor(
// mK, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
// Tensor gV = seqlen_traits_k.get_local_tile_tensor(
// mV, select<1, 2>(TileShape_MNK{}), bidh_kv, bidb); // (N, K, _)
Tensor
gQ
=
seqlen_traits_q
.
get_local_tile_tensor
(
mQ
,
select
<
0
,
2
>
(
TileShape_MNK
{}),
bidh
,
bidb
)(
_
,
_
,
m_block
);
// (M, K)
Tensor
gK
=
seqlen_traits_k
.
get_local_tile_tensor
(
mK
,
select
<
1
,
2
>
(
TileShape_MNK
{}),
bidh_kv
,
bidb
);
// (N, K, _)
Tensor
gV
=
seqlen_traits_k
.
get_local_tile_tensor
(
mV
,
select
<
1
,
2
>
(
TileShape_MNK
{}),
bidh_kv
,
bidb
);
// (N, K, _)
Tensor
sQ_x
=
make_tensor
(
sQ
.
data
(),
make_layout
(
sQ
.
layout
(),
Layout
<
_1
>
{}));
Tensor
gQ_x
=
make_tensor
(
gQ
.
data
(),
make_layout
(
gQ
.
layout
(),
Layout
<
_1
>
{}));
...
...
@@ -430,13 +369,6 @@ struct CollectiveMainloopFwd {
// Note: S becomes P.
Tensor
tOrV
=
threadMma1
.
partition_fragment_B
(
sVt
);
// Dummy sS to just get the shape correctly for GEMM-II.
Tensor sS = make_tensor(make_smem_ptr(shared_storage.smem_k.data()), SmemLayoutS{});
Tensor tOrS = threadMma1.partition_fragment_A(sS);
Tensor tSrS = partition_fragment_C(tiled_mma0, select<0, 1>(TileShape_MNK{}));
ReorgCFp8toAFp8 reg2reg;
auto tOrPLayout = ReshapeTStoTP()(tSrS, tOrS);
auto
consumer_wait
=
[](
auto
&
pipeline
,
auto
&
smem_pipe_read
)
{
auto
barrier_token
=
pipeline
.
consumer_try_wait
(
smem_pipe_read
);
pipeline
.
consumer_wait
(
smem_pipe_read
,
barrier_token
);
...
...
@@ -450,6 +382,7 @@ struct CollectiveMainloopFwd {
cutlass
::
ConsumerToken
barrier_token
=
static_cast
<
cutlass
::
BarrierStatus
>
(
shared_storage
.
barrier_Q
.
try_wait
(
work_idx
%
2
));
if
(
barrier_token
==
cutlass
::
BarrierStatus
::
WaitAgain
)
{
shared_storage
.
barrier_Q
.
wait
(
work_idx
%
2
);
}
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
warp_scheduler_barrier_sync
();
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
...
...
@@ -491,11 +424,7 @@ struct CollectiveMainloopFwd {
}
softmax
.
template
online_softmax
<
/*Is_first=*/
true
>(
tSrS
,
mainloop_params
.
softmax_scale_log2
);
auto tSrSPrec = convert_type<Element>(tSrS);
if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
reg2reg(tSrSPrec);
}
Tensor tOrP = make_tensor(tSrSPrec.data(), tOrPLayout);
Tensor
tOrP
=
make_tensor
(
convert_type
<
Element
>
(
tSrS
).
data
(),
convert_layout_acc_Aregs
<
typename
Ktraits
::
TiledMma1
>
(
tSrS
.
layout
()));
Tensor
scores_scale
=
make_fragment_like
(
softmax
.
row_max
);
clear
(
scores_scale
);
...
...
@@ -527,11 +456,7 @@ struct CollectiveMainloopFwd {
pipeline_v
.
consumer_release
(
smem_pipe_read_v
);
// release V
++
smem_pipe_read_k
;
++
smem_pipe_read_v
;
auto tSrSPrec = convert_type<Element>(tSrS);
if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
reg2reg(tSrSPrec);
}
cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
cute
::
copy
(
make_tensor
(
convert_type
<
Element
>
(
tSrS
).
data
(),
convert_layout_acc_Aregs
<
typename
Ktraits
::
TiledMma1
>
(
tSrS
.
layout
())),
tOrP
);
}
#pragma unroll 1
...
...
@@ -554,11 +479,7 @@ struct CollectiveMainloopFwd {
++
smem_pipe_read_k
;
++
smem_pipe_read_v
;
// softmax.rescale_o(tOrO, scores_scale);
auto tSrSPrec = convert_type<Element>(tSrS);
if constexpr (is_same_v<Element, cutlass::float_e4m3_t>) {
reg2reg(tSrSPrec);
}
cute::copy(make_tensor(tSrSPrec.data(), tOrPLayout), tOrP);
cute
::
copy
(
make_tensor
(
convert_type
<
Element
>
(
tSrS
).
data
(),
convert_layout_acc_Aregs
<
typename
Ktraits
::
TiledMma1
>
(
tSrS
.
layout
())),
tOrP
);
}
// Tell warp 0 that smem_q is ready
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
QueryEmpty
)
/*id*/
);
...
...
hopper/setup.py
View file @
3aae9c18
...
...
@@ -116,9 +116,6 @@ if not SKIP_CUDA_BUILD:
"flash_fwd_hdim128_bf16_sm90.cu"
,
"flash_fwd_hdim256_fp16_sm90.cu"
,
"flash_fwd_hdim256_bf16_sm90.cu"
,
"flash_fwd_hdim64_fp8_sm90.cu"
,
"flash_fwd_hdim128_fp8_sm90.cu"
,
"flash_fwd_hdim256_fp8_sm90.cu"
,
"flash_bwd_hdim64_fp16_sm90.cu"
,
"flash_bwd_hdim128_fp16_sm90.cu"
,
"flash_bwd_hdim256_fp16_sm90.cu"
,
...
...
hopper/test_flash_attn.py
View file @
3aae9c18
...
...
@@ -170,7 +170,7 @@ def test_flash_attn_output(
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
384
,
256
),
(
384
,
256
),
(
512
,
256
),
(
640
,
128
),
(
1024
,
1024
),
...
...
@@ -261,87 +261,49 @@ def test_flash_attn_varlen_output(
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float8_e4m3fn
])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["gqa"])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [True])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
64
,
128
,
256
])
#@pytest.mark.parametrize("d", [128])
# @pytest.mark.parametrize("d", [256])
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
[
(
64
,
128
),
(
128
,
128
),
(
256
,
256
),
(
113
,
203
),
(
128
,
217
),
(
113
,
211
),
(
108
,
256
),
(
256
,
512
),
(
384
,
256
),
(
640
,
128
),
(
512
,
256
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1024
,
1023
),
(
2048
,
2048
),
],
)
def
test_flash_attn_output_fp8
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
mha_type
,
dtype
):
device
=
"cuda"
# set seed
torch
.
random
.
manual_seed
(
0
)
# batch_size = 40
# nheads = 16
batch_size
=
9
nheads
=
6
nheads_kv
=
6
if
mha_type
==
"mha"
else
(
2
if
mha_type
==
"gqa"
else
1
)
# batch_size = 1
# nheads = 1
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
d
,
device
=
device
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
d
,
device
=
device
,
dtype
=
torch
.
float16
,
requires_grad
=
True
)
out
,
lse
=
flash_attn_func
(
q
.
to
(
dtype
),
k
.
to
(
dtype
),
v
.
to
(
dtype
).
transpose
(
1
,
3
).
contiguous
().
clone
(),
causal
=
causal
)
q
=
q
.
to
(
dtype
).
to
(
torch
.
float16
)
k
=
k
.
to
(
dtype
).
to
(
torch
.
float16
)
v
=
v
.
to
(
dtype
).
to
(
torch
.
float16
)
out_ref
,
attn_ref
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
causal
=
causal
,
)
out_pt
,
attn_pt
=
attention_ref
(
q
,
k
,
v
,
None
,
None
,
causal
=
causal
,
upcast
=
False
,
reorder_ops
=
True
,
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
# g = torch.randn_like(out)
# if d <= 128:
# (
# dq_unpad,
# dk_unpad,
# dv_unpad,
# ) = torch.autograd.grad(out, (q_unpad, k_unpad, v_unpad), g)
# dk = dk_pad_fn(dk_unpad)
# dv = dk_pad_fn(dv_unpad)
# (
# dq_ref,
# dk_ref,
# dv_ref,
# ) = torch.autograd.grad(out_ref, (q, k, v), g)
# (
# dq_pt,
# dk_pt,
# dv_pt,
# ) = torch.autograd.grad(out_pt, (q, k, v), g)
# dq = dq_pad_fn(dq_unpad)
# print(f"dQ max diff: {(dq - dq_ref).abs().max().item()}")
# print(f"dK max diff: {(dk - dk_ref).abs().max().item()}")
# print(f"dV max diff: {(dv - dv_ref).abs().max().item()}")
# print(f"dQ mean diff: {(dq - dq_ref).abs().mean().item()}")
# print(f"dK mean diff: {(dk - dk_ref).abs().mean().item()}")
# print(f"dV mean diff: {(dv - dv_ref).abs().mean().item()}")
# print(f"dQ Pytorch max diff: {(dq_pt - dq_ref).abs().max().item()}")
# print(f"dK Pytorch max diff: {(dk_pt - dk_ref).abs().max().item()}")
# print(f"dV Pytorch max diff: {(dv_pt - dv_ref).abs().max().item()}")
# print(f"dQ Pytorch mean diff: {(dq_pt - dq_ref).abs().mean().item()}")
# print(f"dK Pytorch mean diff: {(dk_pt - dk_ref).abs().mean().item()}")
# print(f"dV Pytorch mean diff: {(dv_pt - dv_ref).abs().mean().item()}")
# Check that FlashAttention's numerical error is at most twice the numerical error
# of a Pytorch implementation.
assert
(
out
-
out_ref
).
abs
().
max
().
item
()
<=
2
*
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
# if d <= 128:
# assert (dq - dq_ref).abs().max().item() < 1e-4 or (dq - dq_ref).abs().max().item() <= 3 * (dq_pt - dq_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() < 1e-4 or (dk - dk_ref).abs().max().item() <= 3 * (dk_pt - dk_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() < 1e-4 or (dv - dv_ref).abs().max().item() <= 3 * (dv_pt - dv_ref).abs().max().item()
hopper/utils.h
View file @
3aae9c18
...
...
@@ -228,88 +228,6 @@ __forceinline__ __device__ void copy(TiledCopy tiled_copy, Tensor<Engine0, Layou
}
////////////////////////////////////////////////////////////////////////////////////////////////////
//
//
// Need this register byte permute/shuffle to match register layout of
// (FP8 downcasted) accumulator of GEMM-I to FP8 operand A of GEMM-II.
struct
ReorgCFp8toAFp8
{
int
selectorEx0
;
int
selectorEx1
;
int
selectorEx4
;
int
selectorEx5
;
int
upper_map
[
4
]
=
{
0
,
3
,
1
,
2
};
int
lower_map
[
4
]
=
{
1
,
2
,
0
,
3
};
CUTLASS_DEVICE
ReorgCFp8toAFp8
()
{
int
laneId
=
cutlass
::
canonical_lane_idx
();
if
(
laneId
%
4
==
0
||
laneId
%
4
==
3
)
{
selectorEx0
=
0x3210
;
selectorEx1
=
0x7654
;
selectorEx4
=
0x5410
;
selectorEx5
=
0x7632
;
}
else
{
selectorEx0
=
0x7654
;
selectorEx1
=
0x3210
;
selectorEx4
=
0x1054
;
selectorEx5
=
0x3276
;
}
}
template
<
typename
Fragment
>
CUTLASS_DEVICE
auto
operator
()(
Fragment
&
accum
)
{
using
namespace
cute
;
// First update `mi` to the max per-row
//
auto
VT
=
shape
<
0
>
(
accum
);
// number of vector elements per tile.
auto
MT
=
shape
<
1
>
(
accum
);
// number of tiles along M.
auto
NT
=
shape
<
2
>
(
accum
);
// number of tiles along N.
auto
data
=
accum
.
data
();
int
n
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
MT
;
++
i
)
{
// Traverse 2-rows + 2-cols (2x2) simultaneously.
#pragma unroll
for
(
int
k
=
0
;
k
<
NT
*
size
<
2
>
(
VT
)
/
2
;
++
k
)
{
auto
upper
=
*
reinterpret_cast
<
uint32_t
*>
(
&
data
[
n
]);
auto
lower
=
*
reinterpret_cast
<
uint32_t
*>
(
&
data
[
n
+
4
]);
auto
upper0
=
__byte_perm
(
upper
,
lower
,
selectorEx0
);
auto
lower0
=
__byte_perm
(
upper
,
lower
,
selectorEx1
);
upper0
=
__shfl_sync
(
uint32_t
(
-
1
),
upper0
,
upper_map
[
threadIdx
.
x
%
4
],
4
);
lower0
=
__shfl_sync
(
uint32_t
(
-
1
),
lower0
,
lower_map
[
threadIdx
.
x
%
4
],
4
);
uint32_t
*
data_32bit
=
reinterpret_cast
<
uint32_t
*>
(
&
data
[
n
]);
data_32bit
[
0
]
=
__byte_perm
(
upper0
,
lower0
,
selectorEx4
);
data_32bit
[
1
]
=
__byte_perm
(
upper0
,
lower0
,
selectorEx5
);
n
+=
8
;
}
}
}
};
// Reshape Utility for converting the layout from accumulator of GEMM-I
// to Operand A of GEMM-II.
struct
ReshapeTStoTP
{
template
<
class
FragmentC
,
class
FragmentQ
>
CUTLASS_DEVICE
auto
operator
()(
FragmentC
&&
tC
,
FragmentQ
&&
tQ
)
{
// get the layout of one row of Q.
auto
layoutQRow
=
make_layout_like
(
tQ
(
_
,
0
,
_
).
layout
());
// get the layout of M dimension of C.
auto
layoutCM
=
get
<
1
>
(
tC
.
layout
());
return
make_layout
(
get
<
0
>
(
layoutQRow
),
layoutCM
,
get
<
1
>
(
layoutQRow
));
}
};
template
<
int
NumCopyThreads
,
typename
ElemO
,
typename
TMACopyO
,
typename
LayoutO
,
typename
TileShapeO
,
typename
SMemO
,
typename
SeqLenTraits
>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment