Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
74b0761f
Commit
74b0761f
authored
Jul 14, 2024
by
Tri Dao
Browse files
[FA3] BF16 forward
parent
898dd4bb
Changes
14
Hide whitespace changes
Inline
Side-by-side
Showing
14 changed files
with
279 additions
and
270 deletions
+279
-270
csrc/cutlass
csrc/cutlass
+1
-1
hopper/epilogue_fwd_sm90_tma.hpp
hopper/epilogue_fwd_sm90_tma.hpp
+2
-1
hopper/flash.h
hopper/flash.h
+0
-2
hopper/flash_api.cpp
hopper/flash_api.cpp
+19
-13
hopper/flash_fwd_hdim128_bf16_sm90.cu
hopper/flash_fwd_hdim128_bf16_sm90.cu
+9
-0
hopper/flash_fwd_hdim256_bf16_sm90.cu
hopper/flash_fwd_hdim256_bf16_sm90.cu
+9
-0
hopper/flash_fwd_hdim64_bf16_sm90.cu
hopper/flash_fwd_hdim64_bf16_sm90.cu
+9
-0
hopper/flash_fwd_kernel.h
hopper/flash_fwd_kernel.h
+24
-40
hopper/flash_fwd_launch_template.h
hopper/flash_fwd_launch_template.h
+27
-15
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
+38
-62
hopper/named_barrier.hpp
hopper/named_barrier.hpp
+23
-0
hopper/setup.py
hopper/setup.py
+3
-0
hopper/test_flash_attn.py
hopper/test_flash_attn.py
+17
-17
hopper/tile_scheduler.hpp
hopper/tile_scheduler.hpp
+98
-119
No files found.
cutlass
@
756c351b
Compare
fa4f6359
...
756c351b
Subproject commit
fa4f6359069bd4dd6fabd0cda2476dd8e72b3837
Subproject commit
756c351b4994854b2f8c6dded3821ebbb580876b
hopper/epilogue_fwd_sm90_tma.hpp
View file @
74b0761f
...
@@ -9,6 +9,7 @@
...
@@ -9,6 +9,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.hpp"
#include "utils.h"
#include "utils.h"
namespace
flash
{
namespace
flash
{
...
@@ -127,7 +128,7 @@ struct CollectiveEpilogueFwd {
...
@@ -127,7 +128,7 @@ struct CollectiveEpilogueFwd {
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
Tensor
taccOsO
=
smem_thr_copy_O
.
partition_D
(
sO
);
// ((Atom,AtomNum),PIPE_M,PIPE_N)
// Make sure all WGs have finished reading V
// Make sure all WGs have finished reading V
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
,
0
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
ValueEmpty
)
/*id*/
);
cute
::
copy
(
smem_tiled_copy_O
,
taccOrO
,
taccOsO
);
cute
::
copy
(
smem_tiled_copy_O
,
taccOrO
,
taccOsO
);
cutlass
::
arch
::
fence_view_async_shared
();
// ensure smem writes are visible to TMA
cutlass
::
arch
::
fence_view_async_shared
();
// ensure smem writes are visible to TMA
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
...
...
hopper/flash.h
View file @
74b0761f
...
@@ -66,8 +66,6 @@ struct Flash_fwd_params : public Qkv_params {
...
@@ -66,8 +66,6 @@ struct Flash_fwd_params : public Qkv_params {
// The dimensions.
// The dimensions.
int
b
,
seqlen_q
,
seqlen_k
,
seqlen_knew
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
,
rotary_dim
;
int
b
,
seqlen_q
,
seqlen_k
,
seqlen_knew
,
d
,
seqlen_q_rounded
,
seqlen_k_rounded
,
d_rounded
,
rotary_dim
;
cutlass
::
FastDivmod
head_divmod
,
m_block_divmod
;
int
total_blocks
;
// The scaling factors for the kernel.
// The scaling factors for the kernel.
float
scale_softmax
;
float
scale_softmax
;
...
...
hopper/flash_api.cpp
View file @
74b0761f
...
@@ -99,8 +99,6 @@ void set_params_fprop(Flash_fwd_params ¶ms,
...
@@ -99,8 +99,6 @@ void set_params_fprop(Flash_fwd_params ¶ms,
params
.
d
=
d
;
params
.
d
=
d
;
params
.
d_rounded
=
d_rounded
;
params
.
d_rounded
=
d_rounded
;
params
.
head_divmod
=
cutlass
::
FastDivmod
(
int
(
h
));
// Set the different scale values.
// Set the different scale values.
params
.
scale_softmax
=
softmax_scale
;
params
.
scale_softmax
=
softmax_scale
;
params
.
scale_softmax_log2
=
softmax_scale
*
M_LOG2E
;
params
.
scale_softmax_log2
=
softmax_scale
*
M_LOG2E
;
...
@@ -225,12 +223,22 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
...
@@ -225,12 +223,22 @@ void run_mha_fwd(Flash_fwd_params ¶ms, cudaStream_t stream, bool force_split
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
// run_mha_fwd_<cutlass::half_t, kHeadSize>(params, stream);
// });
// });
if
(
!
params
.
is_e4m3
)
{
if
(
!
params
.
is_e4m3
)
{
if
(
params
.
d
==
64
)
{
if
(
params
.
is_bf16
)
{
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
params
,
stream
);
if
(
params
.
d
==
64
)
{
}
else
if
(
params
.
d
==
128
)
{
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
params
,
stream
);
run_mha_fwd_
<
cutlass
::
half_t
,
128
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
128
)
{
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
128
>
(
params
,
stream
);
}
else
{
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
>
(
params
,
stream
);
}
}
else
{
}
else
{
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
params
,
stream
);
if
(
params
.
d
==
64
)
{
run_mha_fwd_
<
cutlass
::
half_t
,
64
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
128
)
{
run_mha_fwd_
<
cutlass
::
half_t
,
128
>
(
params
,
stream
);
}
else
{
run_mha_fwd_
<
cutlass
::
half_t
,
256
>
(
params
,
stream
);
}
}
}
}
else
{
}
else
{
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
// run_mha_fwd_<cutlass::float_e4m3_t, 128>(params, stream);
...
@@ -250,9 +258,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -250,9 +258,8 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
TORCH_CHECK
(
is_sm90
,
"FlashAttention only supports Hopper GPUs or newer."
);
TORCH_CHECK
(
is_sm90
,
"FlashAttention only supports Hopper GPUs or newer."
);
auto
q_dtype
=
q
.
dtype
();
auto
q_dtype
=
q
.
dtype
();
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kBFloat16,
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
||
q_dtype
==
torch
::
kBFloat16
,
TORCH_CHECK
(
q_dtype
==
torch
::
kFloat16
,
"FlashAttention only support fp16 and bf16 data type for now"
);
"FlashAttention only support fp16 data type for now"
);
// TODO: will add e4m3 later
// TODO: will add e4m3 later
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
// TORCH_CHECK(q_dtype == torch::kFloat16 || q_dtype == torch::kFloat8_e4m3fn,
// "FlashAttention only support fp16 and bf16 data type");
// "FlashAttention only support fp16 and bf16 data type");
...
@@ -278,10 +285,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -278,10 +285,9 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
const
int
head_size_og
=
sizes
[
3
];
const
int
head_size_og
=
sizes
[
3
];
const
int
seqlen_k
=
k
.
size
(
1
);
const
int
seqlen_k
=
k
.
size
(
1
);
const
int
num_heads_k
=
k
.
size
(
2
);
const
int
num_heads_k
=
k
.
size
(
2
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be postive"
);
TORCH_CHECK
(
batch_size
>
0
,
"batch size must be pos
i
tive"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
head_size_og
<=
256
,
"FlashAttention forward only supports head dimension at most 256"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
%
num_heads_k
==
0
,
"Number of heads in key/value must divide number of heads in query"
);
TORCH_CHECK
(
num_heads
==
num_heads_k
,
"We do not support MQA/GQA yet"
);
TORCH_CHECK
(
head_size_og
==
64
||
head_size_og
==
128
||
head_size_og
==
256
,
"Only support head size 64, 128, and 256 for now"
);
TORCH_CHECK
(
head_size_og
==
64
||
head_size_og
==
128
||
head_size_og
==
256
,
"Only support head size 64, 128, and 256 for now"
);
...
@@ -345,7 +351,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
...
@@ -345,7 +351,7 @@ mha_fwd(at::Tensor &q, // batch_size x seqlen_q x num_heads x head_size
/*window_size_left=*/
-
1
,
/*window_size_left=*/
-
1
,
/*window_size_right=*/
is_causal
?
0
:
-
1
);
/*window_size_right=*/
is_causal
?
0
:
-
1
);
auto
tile_count_semaphore
=
is_causal
?
torch
::
full
({
1
},
132
,
opts
.
dtype
(
torch
::
kInt32
))
:
torch
::
empty
({
1
},
opts
.
dtype
(
torch
::
kInt32
));
auto
tile_count_semaphore
=
is_causal
?
torch
::
zeros
({
1
}
,
opts
.
dtype
(
torch
::
kInt32
))
:
torch
::
empty
({
1
},
opts
.
dtype
(
torch
::
kInt32
));
params
.
tile_count_semaphore
=
tile_count_semaphore
.
data_ptr
<
int
>
();
params
.
tile_count_semaphore
=
tile_count_semaphore
.
data_ptr
<
int
>
();
if
(
seqlen_k
>
0
)
{
if
(
seqlen_k
>
0
)
{
...
...
hopper/flash_fwd_hdim128_bf16_sm90.cu
0 → 100644
View file @
74b0761f
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
128
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim128
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
hopper/flash_fwd_hdim256_bf16_sm90.cu
0 → 100644
View file @
74b0761f
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
256
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim256
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
hopper/flash_fwd_hdim64_bf16_sm90.cu
0 → 100644
View file @
74b0761f
// Copyright (c) 2024, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
#include "flash_fwd_launch_template.h"
template
<
>
void
run_mha_fwd_
<
cutlass
::
bfloat16_t
,
64
>
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
run_mha_fwd_hdim64
<
cutlass
::
bfloat16_t
>
(
params
,
stream
);
}
hopper/flash_fwd_kernel.h
View file @
74b0761f
...
@@ -26,8 +26,7 @@ using namespace cute;
...
@@ -26,8 +26,7 @@ using namespace cute;
template
<
typename
Ktraits
,
bool
Is_causal
,
typename
TileScheduler
>
template
<
typename
Ktraits
,
bool
Is_causal
,
typename
TileScheduler
>
__global__
void
__launch_bounds__
(
Ktraits
::
kNWarps
*
cutlass
::
NumThreadsPerWarp
,
1
)
__global__
void
__launch_bounds__
(
Ktraits
::
kNWarps
*
cutlass
::
NumThreadsPerWarp
,
1
)
compute_attn_ws
(
CUTE_GRID_CONSTANT
Flash_fwd_params
const
params
,
compute_attn_ws
(
CUTE_GRID_CONSTANT
typename
CollectiveMainloopFwd
<
Ktraits
,
Is_causal
>::
Params
const
mainloop_params
,
CUTE_GRID_CONSTANT
typename
CollectiveMainloopFwd
<
Ktraits
,
Is_causal
>::
Params
const
mainloop_params
,
CUTE_GRID_CONSTANT
typename
CollectiveEpilogueFwd
<
Ktraits
>::
Params
const
epilogue_params
,
CUTE_GRID_CONSTANT
typename
CollectiveEpilogueFwd
<
Ktraits
>::
Params
const
epilogue_params
,
CUTE_GRID_CONSTANT
typename
TileScheduler
::
Params
const
scheduler_params
CUTE_GRID_CONSTANT
typename
TileScheduler
::
Params
const
scheduler_params
)
{
)
{
...
@@ -101,9 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
...
@@ -101,9 +100,6 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
if
(
warp_group_idx
==
0
)
{
// Producer
if
(
warp_group_idx
==
0
)
{
// Producer
cutlass
::
arch
::
warpgroup_reg_dealloc
<
Ktraits
::
kNWarps
==
12
?
24
:
32
>
();
cutlass
::
arch
::
warpgroup_reg_dealloc
<
Ktraits
::
kNWarps
==
12
?
24
:
32
>
();
// cutlass::arch::warpgroup_reg_dealloc<56>();
// cutlass::arch::warpgroup_reg_dealloc<56>();
// StaticPersistentTileScheduler scheduler{params.m_block_divmod, params.head_divmod, params.total_blocks};
// auto work_tile_info = scheduler.get_current_work();
TileScheduler
scheduler
;
int
warp_idx_in_warpgroup
=
__shfl_sync
(
0xffffffff
,
(
threadIdx
.
x
/
32
)
%
4
,
0
);
int
warp_idx_in_warpgroup
=
__shfl_sync
(
0xffffffff
,
(
threadIdx
.
x
/
32
)
%
4
,
0
);
if
(
warp_idx_in_warpgroup
==
0
)
{
// Load Q, K, V
if
(
warp_idx_in_warpgroup
==
0
)
{
// Load Q, K, V
...
@@ -112,20 +108,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
...
@@ -112,20 +108,22 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
int
work_idx
=
0
;
int
work_idx
=
0
;
// auto get_tile_count = [&] () {
TileScheduler
scheduler
(
&
shared_storage
.
tile_count_semaphore
);
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
for
(
auto
work_tile_info
=
scheduler
.
get_initial_work
();
// return shared_storage.tile_count_semaphore;
work_tile_info
.
is_valid
(
scheduler_params
);
// };
work_tile_info
=
scheduler
.
template
get_next_work
<
/*IsProducer=*/
true
>(
scheduler_params
,
work_tile_info
))
{
auto
block_coord
=
work_tile_info
.
get_block_coord
(
scheduler_params
);
// while (work_tile_info.is_valid()) {
auto
[
m_block
,
bidh
,
bidb
]
=
block_coord
;
// for (int tile_count = blockIdx.x; tile_count < params.total_blocks; tile_count = get_tile_count()) {
// for (int tile_count_semaphore = blockIdx.x; tile_count_semaphore < params.total_blocks; tile_count_semaphore = __shfl_sync(0xffffffff, tile_count_semaphore, 0)) {
int
n_block_max
=
collective_mainloop
.
get_n_block_max
(
mainloop_params
,
m_block
);
for
(
auto
work_tile_info
=
scheduler
.
get_initial_work
();
work_tile_info
.
is_valid
(
scheduler_params
);
work_tile_info
=
scheduler
.
get_next_work
(
scheduler_params
,
work_tile_info
))
{
if
(
Is_causal
&&
n_block_max
<=
0
)
{
int
tile_count_semaphore
=
0
;
scheduler
.
prefetch_next_work
(
scheduler_params
,
work_tile_info
);
collective_mainloop
.
load
(
params
,
mainloop_params
,
scheduler_params
,
pipeline_k
,
pipeline_v
,
smem_pipe_write_k
,
smem_pipe_write_v
,
scheduler
.
broadcast_next_work
(
work_tile_info
);
shared_storage
,
work_tile_info
,
work_idx
,
tile_count_semaphore
);
continue
;
// ++work_idx;
}
// work_tile_info = scheduler.fetch_next_work();
collective_mainloop
.
load
(
mainloop_params
,
pipeline_k
,
pipeline_v
,
smem_pipe_write_k
,
smem_pipe_write_v
,
shared_storage
,
scheduler
,
scheduler_params
,
work_tile_info
,
block_coord
,
work_idx
);
++
work_idx
;
}
}
collective_mainloop
.
load_tail
(
pipeline_k
,
pipeline_v
,
smem_pipe_write_k
,
smem_pipe_write_v
);
collective_mainloop
.
load_tail
(
pipeline_k
,
pipeline_v
,
smem_pipe_write_k
,
smem_pipe_write_v
);
}
}
...
@@ -133,44 +131,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
...
@@ -133,44 +131,31 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
cutlass
::
arch
::
warpgroup_reg_alloc
<
Ktraits
::
kNWarps
==
12
?
240
:
160
>
();
cutlass
::
arch
::
warpgroup_reg_alloc
<
Ktraits
::
kNWarps
==
12
?
240
:
160
>
();
// cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
// cutlass::arch::warpgroup_reg_alloc<Ktraits::kNWarps == 12 ? 224 : 160>();
TileScheduler
scheduler
(
&
shared_storage
.
tile_count_semaphore
);
// Initialize matmul objects.
// Initialize matmul objects.
typename
Ktraits
::
TiledMma1
tiled_mma1
;
typename
Ktraits
::
TiledMma1
tiled_mma1
;
TileScheduler
scheduler
{};
PipelineState
smem_pipe_read_k
,
smem_pipe_read_v
;
PipelineState
smem_pipe_read_k
,
smem_pipe_read_v
;
// We don't need separate variables smem_pip_release_k and smem_pipe_release_v
// We don't need separate variables smem_pip
e
_release_k and smem_pipe_release_v
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
// (like in Cutlass's gemm) because the read and release pipeline states are always the same.
auto
get_tile_count
=
[
&
]
()
{
// cutlass::arch::NamedBarrier::sync(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
10
/*id*/
);
return
shared_storage
.
tile_count_semaphore
;
};
collective_mainloop
.
mma_init
();
collective_mainloop
.
mma_init
();
scheduler
.
init_consumer
();
int
work_idx
=
0
;
int
work_idx
=
0
;
CUTLASS_PRAGMA_NO_UNROLL
CUTLASS_PRAGMA_NO_UNROLL
//
for (
int
work_
idx = 0; work_idx * gridDim.x + blockIdx.x < params.total_blocks; ++work_idx) {
for
(
auto
work_
tile_info
=
scheduler
.
get_initial_work
();
// for (int tile_count_semaphore = blockIdx.x, work_idx = 0; tile_count_semaphore < params.total_blocks; tile_count_semaphore = get_tile_count()) {
work_tile_info
.
is_valid
(
scheduler_params
);
for
(
auto
work_tile_info
=
scheduler
.
get_initial_work
();
work_tile_info
.
is_valid
(
sche
du
l
er
_params
);
work_tile_info
=
scheduler
.
get_next_work
(
scheduler_params
,
work_tile_info
))
{
work_tile_info
=
scheduler
.
template
get_next_work
<
/*IsPro
du
c
er
=*/
false
>
(
scheduler_params
,
work_tile_info
))
{
// Attention output (GEMM-II) accumulator.
// Attention output (GEMM-II) accumulator.
Tensor
tOrO
=
partition_fragment_C
(
tiled_mma1
,
select
<
0
,
2
>
(
TileShape_MNK
{}));
Tensor
tOrO
=
partition_fragment_C
(
tiled_mma1
,
select
<
0
,
2
>
(
TileShape_MNK
{}));
flash
::
Softmax
<
2
*
(
2
*
kBlockM
/
NumMmaThreads
)
>
softmax
;
flash
::
Softmax
<
2
*
(
2
*
kBlockM
/
NumMmaThreads
)
>
softmax
;
// int m_block;
// int bidh, bidb;
// // bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, work_idx * gridDim.x + blockIdx.x));
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
// cute::tuple<int32_t, int32_t, int32_t> block_coord = {m_block, bidh, bidb};
auto
block_coord
=
work_tile_info
.
get_block_coord
(
scheduler_params
);
auto
block_coord
=
work_tile_info
.
get_block_coord
(
scheduler_params
);
auto
[
m_block
,
bidh
,
bidb
]
=
block_coord
;
auto
[
m_block
,
bidh
,
bidb
]
=
block_coord
;
int
n_block_max
=
collective_mainloop
.
get_n_block_max
(
mainloop_params
,
m_block
);
int
n_block_max
=
collective_mainloop
.
get_n_block_max
(
mainloop_params
,
m_block
);
if
(
Is_causal
&&
n_block_max
<=
0
)
{
// We exit early and write 0 to gO and -inf to gLSE.
if
(
Is_causal
&&
n_block_max
<=
0
)
{
// We exit early and write 0 to gO and -inf to gLSE.
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
7
/*id*/
);
collective_epilogue
.
store_zero
(
epilogue_params
,
threadIdx
.
x
-
NumCopyThreads
,
block_coord
);
collective_epilogue
.
store_zero
(
epilogue_params
,
threadIdx
.
x
-
NumCopyThreads
,
block_coord
);
continue
;
continue
;
}
}
...
@@ -178,15 +163,14 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
...
@@ -178,15 +163,14 @@ __global__ void __launch_bounds__(Ktraits::kNWarps * cutlass::NumThreadsPerWarp,
collective_mainloop
.
mma
(
mainloop_params
,
pipeline_k
,
pipeline_v
,
smem_pipe_read_k
,
smem_pipe_read_v
,
collective_mainloop
.
mma
(
mainloop_params
,
pipeline_k
,
pipeline_v
,
smem_pipe_read_k
,
smem_pipe_read_v
,
tOrO
,
softmax
,
n_block_max
,
threadIdx
.
x
-
NumCopyThreads
,
work_idx
,
m_block
,
shared_storage
);
tOrO
,
softmax
,
n_block_max
,
threadIdx
.
x
-
NumCopyThreads
,
work_idx
,
m_block
,
shared_storage
);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads + (work_idx >> 30), work_idx, shared_storage);
// tOrO, softmax, n_block_max, threadIdx.x - NumCopyThreads, 0, shared_storage);
collective_epilogue
.
store
(
epilogue_params
,
tOrO
,
softmax
.
row_sum
,
shared_storage
,
tiled_mma1
,
collective_epilogue
.
store
(
epilogue_params
,
tOrO
,
softmax
.
row_sum
,
shared_storage
,
tiled_mma1
,
threadIdx
.
x
-
NumCopyThreads
,
block_coord
);
threadIdx
.
x
-
NumCopyThreads
,
block_coord
);
++
work_idx
;
++
work_idx
;
// work_tile_info = scheduler.fetch_next_work();
}
}
collective_epilogue
.
store_tail
();
collective_epilogue
.
store_tail
();
}
}
}
}
}
// namespace flash
}
// namespace flash
hopper/flash_fwd_launch_template.h
View file @
74b0761f
...
@@ -8,6 +8,7 @@
...
@@ -8,6 +8,7 @@
#include "cute/tensor.hpp"
#include "cute/tensor.hpp"
#include "cutlass/cutlass.h"
#include "cutlass/cluster_launch.hpp"
#include "cutlass/cluster_launch.hpp"
#include "static_switch.h"
#include "static_switch.h"
...
@@ -26,8 +27,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -26,8 +27,10 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
// print(typename Kernel_traits::SmemLayoutVt{}); printf("\n"); print(typename Kernel_traits::SmemLayoutVt_tmp{});
using
CollectiveMainloop
=
flash
::
CollectiveMainloopFwd
<
Kernel_traits
,
Is_causal
>
;
using
CollectiveMainloop
=
flash
::
CollectiveMainloopFwd
<
Kernel_traits
,
Is_causal
>
;
using
CollectiveEpilogue
=
flash
::
CollectiveEpilogueFwd
<
Kernel_traits
>
;
using
CollectiveEpilogue
=
flash
::
CollectiveEpilogueFwd
<
Kernel_traits
>
;
// using Scheduler = flash::SingleTileScheduler;
using
Scheduler
=
std
::
conditional_t
<!
Is_causal
,
using
Scheduler
=
flash
::
StaticPersistentTileScheduler
;
flash
::
StaticPersistentTileScheduler
,
flash
::
DynamicPersistentTileScheduler
<
Kernel_traits
::
kNThreads
-
cutlass
::
NumThreadsPerWarpGroup
>>
;
// flash::SingleTileScheduler>;
typename
CollectiveMainloop
::
Params
mainloop_params
=
typename
CollectiveMainloop
::
Params
mainloop_params
=
CollectiveMainloop
::
to_underlying_arguments
({
CollectiveMainloop
::
to_underlying_arguments
({
static_cast
<
Element
const
*>
(
params
.
q_ptr
),
static_cast
<
Element
const
*>
(
params
.
q_ptr
),
...
@@ -51,32 +54,35 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
...
@@ -51,32 +54,35 @@ void run_flash_fwd(Flash_fwd_params ¶ms, cudaStream_t stream) {
int
num_blocks_m
=
cutlass
::
ceil_div
(
params
.
seqlen_q
,
Kernel_traits
::
kBlockM
);
int
num_blocks_m
=
cutlass
::
ceil_div
(
params
.
seqlen_q
,
Kernel_traits
::
kBlockM
);
num_blocks_m
=
cutlass
::
ceil_div
(
num_blocks_m
,
size
<
0
>
(
ClusterShape
{}))
*
size
<
0
>
(
ClusterShape
{});
num_blocks_m
=
cutlass
::
ceil_div
(
num_blocks_m
,
size
<
0
>
(
ClusterShape
{}))
*
size
<
0
>
(
ClusterShape
{});
typename
Scheduler
::
Arguments
scheduler_args
=
{
num_blocks_m
,
params
.
h
,
params
.
b
};
typename
Scheduler
::
Arguments
scheduler_args
=
{
num_blocks_m
,
params
.
h
,
params
.
b
,
params
.
tile_count_semaphore
};
typename
Scheduler
::
Params
scheduler_params
=
Scheduler
::
to_underlying_arguments
(
scheduler_args
);
typename
Scheduler
::
Params
scheduler_params
=
Scheduler
::
to_underlying_arguments
(
scheduler_args
);
// Get the ptr to kernel function.
// Get the ptr to kernel function.
void
*
kernel
;
void
*
kernel
;
kernel
=
(
void
*
)
flash
::
compute_attn_ws
<
Kernel_traits
,
Is_causal
,
Scheduler
>
;
kernel
=
(
void
*
)
flash
::
compute_attn_ws
<
Kernel_traits
,
Is_causal
,
Scheduler
>
;
int
smem_size
=
sizeof
(
typename
Kernel_traits
::
SharedStorage
);
int
smem_size
=
sizeof
(
typename
Kernel_traits
::
SharedStorage
);
int
smem_size_q
=
sizeof
(
decltype
((
typename
Kernel_traits
::
SharedStorage
{}).
smem_q
));
//
int smem_size_q = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_q));
int
smem_size_k
=
sizeof
(
decltype
((
typename
Kernel_traits
::
SharedStorage
{}).
smem_k
));
//
int smem_size_k = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_k));
int
smem_size_v
=
sizeof
(
decltype
((
typename
Kernel_traits
::
SharedStorage
{}).
smem_v
));
//
int smem_size_v = sizeof(decltype((typename Kernel_traits::SharedStorage{}).smem_v));
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
// printf("smem_size = %d, q = %d, k = %d, v = %d\n", smem_size, smem_size_q, smem_size_k, smem_size_v);
if
(
smem_size
>=
48
*
1024
)
{
if
(
smem_size
>=
48
*
1024
)
{
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
C10_CUDA_CHECK
(
cudaFuncSetAttribute
(
kernel
,
cudaFuncAttributeMaxDynamicSharedMemorySize
,
smem_size
));
}
}
int
device
;
cudaGetDevice
(
&
device
);
int
multiprocessor_count
;
cudaError
status_
=
cudaDeviceGetAttribute
(
&
multiprocessor_count
,
cudaDevAttrMultiProcessorCount
,
device
);
if
(
status_
!=
cudaSuccess
)
{
C10_CUDA_CHECK
(
status_
);
}
dim3
grid_dims
=
Scheduler
::
get_grid_dim
(
scheduler_args
,
multiprocessor_count
);
static
constexpr
int
ctaSize
=
Kernel_traits
::
kNWarps
*
32
;
static
constexpr
int
ctaSize
=
Kernel_traits
::
kNWarps
*
32
;
params
.
m_block_divmod
=
cutlass
::
FastDivmod
(
num_blocks_m
);
params
.
total_blocks
=
num_blocks_m
*
params
.
h
*
params
.
b
;
// dim3 grid_dims(num_blocks_m, params.h, params.b);
// dim3 grid_dims(132);
dim3
grid_dims
=
Scheduler
::
get_grid_dim
(
scheduler_args
,
132
);
dim3
block_dims
(
ctaSize
);
dim3
block_dims
(
ctaSize
);
dim3
cluster_dims
(
size
<
0
>
(
ClusterShape
{}),
size
<
1
>
(
ClusterShape
{}),
size
<
2
>
(
ClusterShape
{}));
dim3
cluster_dims
(
size
<
0
>
(
ClusterShape
{}),
size
<
1
>
(
ClusterShape
{}),
size
<
2
>
(
ClusterShape
{}));
cutlass
::
ClusterLaunchParams
launch_params
{
grid_dims
,
block_dims
,
cluster_dims
,
smem_size
,
stream
};
cutlass
::
ClusterLaunchParams
launch_params
{
grid_dims
,
block_dims
,
cluster_dims
,
smem_size
,
stream
};
cutlass
::
launch_kernel_on_cluster
(
launch_params
,
kernel
,
params
,
mainloop_params
,
epilogue_params
,
scheduler_params
);
cutlass
::
launch_kernel_on_cluster
(
launch_params
,
kernel
,
mainloop_params
,
epilogue_params
,
scheduler_params
);
// kernel<<<grid_dims, block_dims, smem_size, stream>>>(params, tma_load_Q, tma_load_K, tma_load_V, tma_store_O);
C10_CUDA_KERNEL_LAUNCH_CHECK
();
C10_CUDA_KERNEL_LAUNCH_CHECK
();
}
}
...
@@ -92,7 +98,10 @@ template<typename T>
...
@@ -92,7 +98,10 @@ template<typename T>
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
128
;
constexpr
static
int
Headdim
=
128
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
Is_causal
?
128
:
176
,
12
,
2
,
false
,
!
Is_causal
?
2
:
1
,
T
>
,
Is_causal
>
(
params
,
stream
);
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH
(
cutlass
::
ceil_div
(
params
.
seqlen_q
,
128
)
%
2
==
0
,
UseCluster
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
Is_causal
?
128
:
176
,
12
,
2
,
false
,
!
Is_causal
&&
UseCluster
?
2
:
1
,
T
>
,
Is_causal
>
(
params
,
stream
);
});
});
});
}
}
...
@@ -100,6 +109,9 @@ template<typename T>
...
@@ -100,6 +109,9 @@ template<typename T>
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
void
run_mha_fwd_hdim256
(
Flash_fwd_params
&
params
,
cudaStream_t
stream
)
{
constexpr
static
int
Headdim
=
256
;
constexpr
static
int
Headdim
=
256
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
80
,
12
,
2
,
false
,
!
Is_causal
?
2
:
1
,
T
>
,
Is_causal
>
(
params
,
stream
);
// Only use Cluster if number of tiles along seqlen_q is even
BOOL_SWITCH
(
cutlass
::
ceil_div
(
params
.
seqlen_q
,
128
)
%
2
==
0
,
UseCluster
,
[
&
]
{
run_flash_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
128
,
80
,
12
,
2
,
false
,
!
Is_causal
&&
UseCluster
?
2
:
1
,
T
>
,
Is_causal
>
(
params
,
stream
);
});
});
});
}
}
hopper/mainloop_fwd_sm90_tma_gmma_ws.hpp
View file @
74b0761f
...
@@ -14,6 +14,7 @@
...
@@ -14,6 +14,7 @@
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "named_barrier.hpp"
#include "utils.h"
#include "utils.h"
namespace
flash
{
namespace
flash
{
...
@@ -108,6 +109,7 @@ struct CollectiveMainloopFwd {
...
@@ -108,6 +109,7 @@ struct CollectiveMainloopFwd {
struct
Params
{
struct
Params
{
ShapeQKV
const
shape_Q
;
ShapeQKV
const
shape_Q
;
ShapeQKV
const
shape_K
;
ShapeQKV
const
shape_K
;
cutlass
::
FastDivmod
qhead_per_khead_divmod
;
TMA_Q
tma_load_Q
;
TMA_Q
tma_load_Q
;
TMA_KV
tma_load_K
,
tma_load_V
;
TMA_KV
tma_load_K
,
tma_load_V
;
float
const
softmax_scale_log2
;
float
const
softmax_scale_log2
;
...
@@ -137,7 +139,10 @@ struct CollectiveMainloopFwd {
...
@@ -137,7 +139,10 @@ struct CollectiveMainloopFwd {
SmemLayoutV
{}(
_
,
_
,
_0
{}),
SmemLayoutV
{}(
_
,
_
,
_0
{}),
select
<
1
,
2
>
(
TileShape_MNK
{}),
select
<
1
,
2
>
(
TileShape_MNK
{}),
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
size
<
0
>
(
ClusterShape
{}));
// mcast along M mode for this N load, if any
return
{
args
.
shape_Q
,
args
.
shape_K
,
tma_load_Q
,
tma_load_K
,
tma_load_V
,
args
.
softmax_scale_log2
};
return
{
args
.
shape_Q
,
args
.
shape_K
,
cutlass
::
FastDivmod
(
cute
::
ceil_div
(
get
<
2
>
(
args
.
shape_Q
),
get
<
2
>
(
args
.
shape_K
))),
tma_load_Q
,
tma_load_K
,
tma_load_V
,
args
.
softmax_scale_log2
};
}
}
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
/// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance
...
@@ -162,46 +167,21 @@ struct CollectiveMainloopFwd {
...
@@ -162,46 +167,21 @@ struct CollectiveMainloopFwd {
return
n_block_max
;
return
n_block_max
;
}
}
template
<
typename
FullParams
,
typename
Scheduler
Params
,
typename
SharedStorage
,
typename
WorkTileInfo
>
template
<
typename
Scheduler
,
typename
SharedStorage
>
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
load
(
FullParams
const
&
params
,
load
(
Params
const
&
mainloop_params
,
Params
const
&
mainloop_params
,
SchedulerParams
const
&
scheduler_params
,
MainloopPipeline
pipeline_k
,
MainloopPipeline
pipeline_k
,
MainloopPipeline
pipeline_v
,
MainloopPipeline
pipeline_v
,
PipelineState
&
smem_pipe_write_k
,
PipelineState
&
smem_pipe_write_k
,
PipelineState
&
smem_pipe_write_v
,
PipelineState
&
smem_pipe_write_v
,
SharedStorage
&
shared_storage
,
SharedStorage
&
shared_storage
,
WorkTileInfo
work_tile_info
,
Scheduler
&
scheduler
,
int
&
work_idx
,
typename
Scheduler
::
Params
const
&
scheduler_params
,
int
&
tile_count_semaphore
typename
Scheduler
::
WorkTileInfo
&
work_tile_info
,
cute
::
tuple
<
int32_t
,
int32_t
,
int32_t
>
block_coord
,
int
work_idx
)
{
)
{
static
constexpr
int
kBlockM
=
get
<
0
>
(
TileShape_MNK
{});
static
constexpr
int
kBlockN
=
get
<
1
>
(
TileShape_MNK
{});
// int const m_block = work_tile_info.M_idx;
// int const bidh = work_tile_info.H_idx;
// int const bidb = work_tile_info.B_idx;
// int m_block;
// int bidh, bidb;
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_count_semaphore));
auto
[
m_block
,
bidh
,
bidb
]
=
work_tile_info
.
get_block_coord
(
scheduler_params
);
// if (threadIdx.x == 0) { printf("producer, blockIdx.x = %d, bidb = %d, bidh = %d, m_block = %d\n", blockIdx.x, bidb, bidh, m_block); }
int
n_block_max
=
get_n_block_max
(
mainloop_params
,
m_block
);
if
(
Is_causal
&&
n_block_max
<=
0
)
{
// Need sync to avoid the case where the producer issues 2 arrives before the consumer can issue 1 wait
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
7
/*id*/
);
// if (threadIdx.x % cutlass::NumThreadsPerWarp == 0) {
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
// shared_storage.tile_count_semaphore = tile_count_semaphore;
// }
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
10
/*id*/
);
return
;
}
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
smem_q
.
data
()),
SmemLayoutQ
{});
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
smem_q
.
data
()),
SmemLayoutQ
{});
Tensor
sK
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
smem_k
.
data
()),
SmemLayoutK
{});
Tensor
sK
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
smem_k
.
data
()),
SmemLayoutK
{});
Tensor
sV
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
smem_v
.
data
()),
SmemLayoutV
{});
Tensor
sV
=
make_tensor
(
make_smem_ptr
(
shared_storage
.
smem_v
.
data
()),
SmemLayoutV
{});
...
@@ -210,13 +190,16 @@ struct CollectiveMainloopFwd {
...
@@ -210,13 +190,16 @@ struct CollectiveMainloopFwd {
Tensor
mK
=
mainloop_params
.
tma_load_K
.
get_tma_tensor
(
mainloop_params
.
shape_K
);
Tensor
mK
=
mainloop_params
.
tma_load_K
.
get_tma_tensor
(
mainloop_params
.
shape_K
);
Tensor
mV
=
mainloop_params
.
tma_load_V
.
get_tma_tensor
(
mainloop_params
.
shape_K
);
Tensor
mV
=
mainloop_params
.
tma_load_V
.
get_tma_tensor
(
mainloop_params
.
shape_K
);
auto
[
m_block
,
bidh
,
bidb
]
=
block_coord
;
int
bidh_kv
=
mainloop_params
.
qhead_per_khead_divmod
.
divide
(
bidh
);
// Prepare the TMA loads
// Prepare the TMA loads
uint32_t
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
uint32_t
block_rank_in_cluster
=
cute
::
block_rank_in_cluster
();
constexpr
uint32_t
cluster_shape_x
=
get
<
0
>
(
ClusterShape
());
constexpr
uint32_t
cluster_shape_x
=
get
<
0
>
(
ClusterShape
());
uint2
cluster_local_block_id
=
{
block_rank_in_cluster
%
cluster_shape_x
,
block_rank_in_cluster
/
cluster_shape_x
};
uint2
cluster_local_block_id
=
{
block_rank_in_cluster
%
cluster_shape_x
,
block_rank_in_cluster
/
cluster_shape_x
};
Tensor
gQ
=
local_tile
(
mQ
(
_
,
_
,
bidh
,
bidb
),
select
<
0
,
2
>
(
TileShape_MNK
{}),
make_coord
(
m_block
,
_0
{}));
// (M, K)
Tensor
gQ
=
local_tile
(
mQ
(
_
,
_
,
bidh
,
bidb
),
select
<
0
,
2
>
(
TileShape_MNK
{}),
make_coord
(
m_block
,
_0
{}));
// (M, K)
Tensor
gK
=
local_tile
(
mK
(
_
,
_
,
bidh
,
bidb
),
select
<
1
,
2
>
(
TileShape_MNK
{}),
make_coord
(
_
,
_0
{}));
// (N, K, _)
Tensor
gK
=
local_tile
(
mK
(
_
,
_
,
bidh
_kv
,
bidb
),
select
<
1
,
2
>
(
TileShape_MNK
{}),
make_coord
(
_
,
_0
{}));
// (N, K, _)
Tensor
gV
=
local_tile
(
mV
(
_
,
_
,
bidh
,
bidb
),
select
<
1
,
2
>
(
TileShape_MNK
{}),
make_coord
(
_
,
_0
{}));
// (N, K, _)
Tensor
gV
=
local_tile
(
mV
(
_
,
_
,
bidh
_kv
,
bidb
),
select
<
1
,
2
>
(
TileShape_MNK
{}),
make_coord
(
_
,
_0
{}));
// (N, K, _)
Tensor
sQ_x
=
make_tensor
(
sQ
.
data
(),
make_layout
(
sQ
.
layout
(),
Layout
<
_1
>
{}));
Tensor
sQ_x
=
make_tensor
(
sQ
.
data
(),
make_layout
(
sQ
.
layout
(),
Layout
<
_1
>
{}));
Tensor
gQ_x
=
make_tensor
(
gQ
.
data
(),
make_layout
(
gQ
.
layout
(),
Layout
<
_1
>
{}));
Tensor
gQ_x
=
make_tensor
(
gQ
.
data
(),
make_layout
(
gQ
.
layout
(),
Layout
<
_1
>
{}));
...
@@ -235,6 +218,7 @@ struct CollectiveMainloopFwd {
...
@@ -235,6 +218,7 @@ struct CollectiveMainloopFwd {
}
}
}
}
int
n_block_max
=
get_n_block_max
(
mainloop_params
,
m_block
);
int
n_block
=
n_block_max
-
1
;
int
n_block
=
n_block_max
-
1
;
int
lane_predicate
=
cute
::
elect_one_sync
();
int
lane_predicate
=
cute
::
elect_one_sync
();
...
@@ -246,7 +230,7 @@ struct CollectiveMainloopFwd {
...
@@ -246,7 +230,7 @@ struct CollectiveMainloopFwd {
}
}
// Wait for the MMA warpgroups to say that smem_q is ready
// Wait for the MMA warpgroups to say that smem_q is ready
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
1
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
QueryEmpty
)
/*id*/
);
if
(
lane_predicate
)
{
if
(
lane_predicate
)
{
shared_storage
.
barrier_Q
.
arrive_and_expect_tx
(
TmaTransactionBytesQ
);
shared_storage
.
barrier_Q
.
arrive_and_expect_tx
(
TmaTransactionBytesQ
);
...
@@ -272,22 +256,14 @@ struct CollectiveMainloopFwd {
...
@@ -272,22 +256,14 @@ struct CollectiveMainloopFwd {
++
smem_pipe_write_v
;
++
smem_pipe_write_v
;
}
}
}
}
if
(
threadIdx
.
x
%
cutlass
::
NumThreadsPerWarp
==
0
)
{
scheduler
.
prefetch_next_work
(
scheduler_params
,
work_tile_info
);
// tile_count_semaphore = atomicAdd(params.tile_count_semaphore, 1);
}
if
(
lane_predicate
)
{
if
(
lane_predicate
)
{
pipeline_v
.
producer_acquire
(
smem_pipe_write_v
);
pipeline_v
.
producer_acquire
(
smem_pipe_write_v
);
copy
(
mainloop_params
.
tma_load_V
.
with
(
*
pipeline_v
.
producer_get_barrier
(
smem_pipe_write_v
),
mcast_mask_kv
),
copy
(
mainloop_params
.
tma_load_V
.
with
(
*
pipeline_v
.
producer_get_barrier
(
smem_pipe_write_v
),
mcast_mask_kv
),
tVgV
(
_
,
n_block
),
tVsV
(
_
,
smem_pipe_write_v
.
index
()));
tVgV
(
_
,
n_block
),
tVsV
(
_
,
smem_pipe_write_v
.
index
()));
++
smem_pipe_write_v
;
++
smem_pipe_write_v
;
}
}
if
(
threadIdx
.
x
%
cutlass
::
NumThreadsPerWarp
==
0
)
{
scheduler
.
broadcast_next_work
(
work_tile_info
);
// printf("blockIdx.x = %d, tile_count_semaphore: %d\n", blockIdx.x, tile_count_semaphore);
// shared_storage.tile_count_semaphore = tile_count_semaphore;
}
// cutlass::arch::NamedBarrier::arrive(NumMmaThreads + 2 * cutlass::NumThreadsPerWarp, 10 /*id*/);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
10
/*id*/
);
++
work_idx
;
}
}
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
/// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster
...
@@ -307,36 +283,36 @@ struct CollectiveMainloopFwd {
...
@@ -307,36 +283,36 @@ struct CollectiveMainloopFwd {
}
}
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
scheduler_barrier_sync
()
{
warp_
scheduler_barrier_sync
()
{
if
constexpr
(
UseSchedulerBarrier
)
{
if
constexpr
(
UseSchedulerBarrier
)
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
,
3
+
cutlass
::
canonical_warp_group_idx
()
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
WarpSchedulerWG1
)
-
1
+
cutlass
::
canonical_warp_group_idx
()
/*id*/
);
}
}
}
}
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
scheduler_barrier_arrive
()
{
warp_
scheduler_barrier_arrive
()
{
if
constexpr
(
!
UseSchedulerBarrier
)
{
return
;
}
if
constexpr
(
!
UseSchedulerBarrier
)
{
return
;
}
static_assert
(
NumMmaThreads
==
2
*
cutlass
::
NumThreadsPerWarpGroup
||
NumMmaThreads
==
3
*
cutlass
::
NumThreadsPerWarpGroup
);
static_assert
(
NumMmaThreads
==
2
*
cutlass
::
NumThreadsPerWarpGroup
||
NumMmaThreads
==
3
*
cutlass
::
NumThreadsPerWarpGroup
);
if
constexpr
(
NumMmaThreads
==
2
*
cutlass
::
NumThreadsPerWarpGroup
)
{
if
constexpr
(
NumMmaThreads
==
2
*
cutlass
::
NumThreadsPerWarpGroup
)
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
3
+
(
3
-
cutlass
::
canonical_warp_group_idx
())
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
WarpSchedulerWG1
)
-
1
+
(
3
-
cutlass
::
canonical_warp_group_idx
())
/*id*/
);
}
else
{
}
else
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
cutlass
::
canonical_warp_group_idx
()
<=
2
?
3
+
cutlass
::
canonical_warp_group_idx
()
+
1
:
3
+
cutlass
::
canonical_warp_group_idx
()
+
1
-
3
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
WarpSchedulerWG1
)
-
1
+
(
cutlass
::
canonical_warp_group_idx
()
<=
2
?
cutlass
::
canonical_warp_group_idx
()
+
1
:
cutlass
::
canonical_warp_group_idx
()
+
1
-
3
)
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
cutlass
::
canonical_warp_group_idx
()
<=
1
?
3
+
cutlass
::
canonical_warp_group_idx
()
+
2
:
3
+
cutlass
::
canonical_warp_group_idx
()
+
2
-
3
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
WarpSchedulerWG1
)
-
1
+
(
cutlass
::
canonical_warp_group_idx
()
<=
1
?
cutlass
::
canonical_warp_group_idx
()
+
2
:
cutlass
::
canonical_warp_group_idx
()
+
2
-
3
)
/*id*/
);
}
}
}
}
CUTLASS_DEVICE
void
CUTLASS_DEVICE
void
mma_init
()
{
mma_init
()
{
// Tell producer (warp 0) that smem_q is ready
// Tell producer (warp 0) that smem_q is ready
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
1
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
QueryEmpty
)
/*id*/
);
if
constexpr
(
!
UseSchedulerBarrier
)
{
return
;
}
if
constexpr
(
!
UseSchedulerBarrier
)
{
return
;
}
static_assert
(
NumMmaThreads
==
2
*
cutlass
::
NumThreadsPerWarpGroup
||
NumMmaThreads
==
3
*
cutlass
::
NumThreadsPerWarpGroup
);
static_assert
(
NumMmaThreads
==
2
*
cutlass
::
NumThreadsPerWarpGroup
||
NumMmaThreads
==
3
*
cutlass
::
NumThreadsPerWarpGroup
);
if
(
cutlass
::
canonical_warp_group_idx
()
>
1
)
{
if
(
cutlass
::
canonical_warp_group_idx
()
>
1
)
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
3
+
1
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
WarpSchedulerWG1
)
-
1
+
1
/*id*/
);
}
}
if
constexpr
(
NumMmaThreads
==
3
*
cutlass
::
NumThreadsPerWarpGroup
)
{
if
constexpr
(
NumMmaThreads
==
3
*
cutlass
::
NumThreadsPerWarpGroup
)
{
if
(
cutlass
::
canonical_warp_group_idx
()
>
2
)
{
if
(
cutlass
::
canonical_warp_group_idx
()
>
2
)
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
3
+
2
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
,
static_cast
<
int
>
(
FwdNamedBarriers
::
WarpSchedulerWG1
)
-
1
+
2
/*id*/
);
}
}
}
}
...
@@ -393,9 +369,9 @@ struct CollectiveMainloopFwd {
...
@@ -393,9 +369,9 @@ struct CollectiveMainloopFwd {
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
scheduler_barrier_sync
();
warp_
scheduler_barrier_sync
();
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
scheduler_barrier_arrive
();
warp_
scheduler_barrier_arrive
();
if
(
work_idx
!=
0
)
{
if
(
work_idx
!=
0
)
{
int
lane_predicate
=
cute
::
elect_one_sync
();
int
lane_predicate
=
cute
::
elect_one_sync
();
if
(
cutlass
::
canonical_warp_idx_sync
()
==
Ktraits
::
kNWarps
-
1
&&
lane_predicate
)
{
if
(
cutlass
::
canonical_warp_idx_sync
()
==
Ktraits
::
kNWarps
-
1
&&
lane_predicate
)
{
...
@@ -443,12 +419,12 @@ struct CollectiveMainloopFwd {
...
@@ -443,12 +419,12 @@ struct CollectiveMainloopFwd {
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
-
1
&&
n_block
>
0
;
++
masking_step
,
--
n_block
)
{
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
-
1
&&
n_block
>
0
;
++
masking_step
,
--
n_block
)
{
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
scheduler_barrier_sync
();
warp_
scheduler_barrier_sync
();
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
if
(
masking_step
>
0
)
{
softmax
.
rescale_o
(
tOrO
,
scores_scale
);
}
if
(
masking_step
>
0
)
{
softmax
.
rescale_o
(
tOrO
,
scores_scale
);
}
consumer_wait
(
pipeline_v
,
smem_pipe_read_v
);
consumer_wait
(
pipeline_v
,
smem_pipe_read_v
);
flash
::
gemm
<
/*zero_init=*/
false
,
/*wg_wait=*/
-
1
>
(
tiled_mma1
,
tOrP
,
tOrV
(
_
,
_
,
_
,
smem_pipe_read_v
.
index
()),
tOrO
);
flash
::
gemm
<
/*zero_init=*/
false
,
/*wg_wait=*/
-
1
>
(
tiled_mma1
,
tOrP
,
tOrV
(
_
,
_
,
_
,
smem_pipe_read_v
.
index
()),
tOrO
);
scheduler_barrier_arrive
();
warp_
scheduler_barrier_arrive
();
warpgroup_wait
<
1
>
();
warpgroup_wait
<
1
>
();
pipeline_k
.
consumer_release
(
smem_pipe_read_k
);
// release K
pipeline_k
.
consumer_release
(
smem_pipe_read_k
);
// release K
Tensor
cS
=
cute
::
make_identity_tensor
(
select
<
0
,
1
>
(
TileShape_MNK
{}));
Tensor
cS
=
cute
::
make_identity_tensor
(
select
<
0
,
1
>
(
TileShape_MNK
{}));
...
@@ -472,12 +448,12 @@ struct CollectiveMainloopFwd {
...
@@ -472,12 +448,12 @@ struct CollectiveMainloopFwd {
for
(;
n_block
>
0
;
--
n_block
)
{
for
(;
n_block
>
0
;
--
n_block
)
{
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
Tensor
tSrS
=
partition_fragment_C
(
tiled_mma0
,
select
<
0
,
1
>
(
TileShape_MNK
{}));
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
consumer_wait
(
pipeline_k
,
smem_pipe_read_k
);
scheduler_barrier_sync
();
warp_
scheduler_barrier_sync
();
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
flash
::
gemm
<
/*zero_init=*/
true
,
/*wg_wait=*/
-
1
>
(
tiled_mma0
,
tSrQ
,
tSrK
(
_
,
_
,
_
,
smem_pipe_read_k
.
index
()),
tSrS
);
softmax
.
rescale_o
(
tOrO
,
scores_scale
);
softmax
.
rescale_o
(
tOrO
,
scores_scale
);
consumer_wait
(
pipeline_v
,
smem_pipe_read_v
);
consumer_wait
(
pipeline_v
,
smem_pipe_read_v
);
flash
::
gemm
<
/*zero_init=*/
false
,
/*wg_wait=*/
-
1
>
(
tiled_mma1
,
tOrP
,
tOrV
(
_
,
_
,
_
,
smem_pipe_read_v
.
index
()),
tOrO
);
flash
::
gemm
<
/*zero_init=*/
false
,
/*wg_wait=*/
-
1
>
(
tiled_mma1
,
tOrP
,
tOrV
(
_
,
_
,
_
,
smem_pipe_read_v
.
index
()),
tOrO
);
scheduler_barrier_arrive
();
warp_
scheduler_barrier_arrive
();
warpgroup_wait
<
1
>
();
warpgroup_wait
<
1
>
();
pipeline_k
.
consumer_release
(
smem_pipe_read_k
);
// release K
pipeline_k
.
consumer_release
(
smem_pipe_read_k
);
// release K
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
// auto scores_scale = softmax.template max</*Is_first=*/false>(tSrS);
...
@@ -491,7 +467,7 @@ struct CollectiveMainloopFwd {
...
@@ -491,7 +467,7 @@ struct CollectiveMainloopFwd {
cute
::
copy
(
make_tensor
(
convert_type
<
Element
>
(
tSrS
).
data
(),
convert_layout_acc_Aregs
<
typename
Ktraits
::
TiledMma1
>
(
tSrS
.
layout
())),
tOrP
);
cute
::
copy
(
make_tensor
(
convert_type
<
Element
>
(
tSrS
).
data
(),
convert_layout_acc_Aregs
<
typename
Ktraits
::
TiledMma1
>
(
tSrS
.
layout
())),
tOrP
);
}
}
// Tell warp 0 that smem_q is ready
// Tell warp 0 that smem_q is ready
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
1
/*id*/
);
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
QueryEmpty
)
/*id*/
);
softmax
.
rescale_o
(
tOrO
,
scores_scale
);
softmax
.
rescale_o
(
tOrO
,
scores_scale
);
consumer_wait
(
pipeline_v
,
smem_pipe_read_v
);
consumer_wait
(
pipeline_v
,
smem_pipe_read_v
);
flash
::
gemm
<
/*zero_init=*/
false
,
/*wg_wait=*/
-
1
>
(
tiled_mma1
,
tOrP
,
tOrV
(
_
,
_
,
_
,
smem_pipe_read_v
.
index
()),
tOrO
);
flash
::
gemm
<
/*zero_init=*/
false
,
/*wg_wait=*/
-
1
>
(
tiled_mma1
,
tOrP
,
tOrV
(
_
,
_
,
_
,
smem_pipe_read_v
.
index
()),
tOrO
);
...
...
hopper/named_barrier.hpp
0 → 100644
View file @
74b0761f
/******************************************************************************
* Copyright (c) 2024, Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani, Tri Dao.
******************************************************************************/
#pragma once
#include "cutlass/arch/barrier.h"
namespace
flash
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// Enumerates the reserved named barriers to avoid potential conflicts
enum
class
FwdNamedBarriers
{
QueryEmpty
=
0
,
ValueEmpty
=
1
,
TileCountSmemEmpty
=
2
,
TileCountSmemFull
=
3
,
WarpSchedulerWG1
=
4
,
WarpSchedulerWG2
=
5
,
WarpSchedulerWG3
=
6
,
};
}
// flash
\ No newline at end of file
hopper/setup.py
View file @
74b0761f
...
@@ -111,8 +111,11 @@ if not SKIP_CUDA_BUILD:
...
@@ -111,8 +111,11 @@ if not SKIP_CUDA_BUILD:
sources
=
[
sources
=
[
"flash_api.cpp"
,
"flash_api.cpp"
,
"flash_fwd_hdim64_fp16_sm90.cu"
,
"flash_fwd_hdim64_fp16_sm90.cu"
,
"flash_fwd_hdim64_bf16_sm90.cu"
,
"flash_fwd_hdim128_fp16_sm90.cu"
,
"flash_fwd_hdim128_fp16_sm90.cu"
,
"flash_fwd_hdim128_bf16_sm90.cu"
,
"flash_fwd_hdim256_fp16_sm90.cu"
,
"flash_fwd_hdim256_fp16_sm90.cu"
,
"flash_fwd_hdim256_bf16_sm90.cu"
,
"flash_bwd_hdim64_fp16_sm90.cu"
,
"flash_bwd_hdim64_fp16_sm90.cu"
,
"flash_bwd_hdim128_fp16_sm90.cu"
,
"flash_bwd_hdim128_fp16_sm90.cu"
,
"flash_bwd_hdim256_fp16_sm90.cu"
,
"flash_bwd_hdim256_fp16_sm90.cu"
,
...
...
hopper/test_flash_attn.py
View file @
74b0761f
...
@@ -131,15 +131,18 @@ def attention_ref(
...
@@ -131,15 +131,18 @@ def attention_ref(
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
])
@
pytest
.
mark
.
parametrize
(
"dtype"
,
[
torch
.
float16
,
torch
.
bfloat16
])
# @pytest.mark.parametrize("dtype", [torch.bfloat16])
@
pytest
.
mark
.
parametrize
(
"mha_type"
,
[
"mha"
,
"mqa"
,
"gqa"
])
# @pytest.mark.parametrize("mha_type", ["gqa"])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"causal"
,
[
False
,
True
])
# @pytest.mark.parametrize("causal", [
Fals
e])
# @pytest.mark.parametrize("causal", [
Tru
e])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize("d", [32, 64, 96, 128, 160, 192, 224, 256])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 40, 64, 80, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [32, 64, 96, 128, 160, 192])
# @pytest.mark.parametrize('d', [56, 80])
# @pytest.mark.parametrize('d', [56, 80])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
64
,
128
,
256
])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
64
,
128
,
256
])
# @pytest.mark.parametrize("d", [
128
])
# @pytest.mark.parametrize("d", [
256
])
@
pytest
.
mark
.
parametrize
(
@
pytest
.
mark
.
parametrize
(
"seqlen_q,seqlen_k"
,
"seqlen_q,seqlen_k"
,
[
[
...
@@ -151,6 +154,8 @@ def attention_ref(
...
@@ -151,6 +154,8 @@ def attention_ref(
(
113
,
211
),
(
113
,
211
),
(
108
,
256
),
(
108
,
256
),
(
256
,
512
),
(
256
,
512
),
(
384
,
256
),
(
640
,
128
),
(
512
,
256
),
(
512
,
256
),
(
1024
,
1024
),
(
1024
,
1024
),
(
1023
,
1024
),
(
1023
,
1024
),
...
@@ -160,7 +165,7 @@ def attention_ref(
...
@@ -160,7 +165,7 @@ def attention_ref(
)
)
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
# @pytest.mark.parametrize('seqlen_q,seqlen_k', [(128, 128)])
def
test_flash_attn_output
(
def
test_flash_attn_output
(
seqlen_q
,
seqlen_k
,
d
,
causal
,
dtype
seqlen_q
,
seqlen_k
,
d
,
causal
,
mha_type
,
dtype
):
):
device
=
"cuda"
device
=
"cuda"
# set seed
# set seed
...
@@ -168,16 +173,13 @@ def test_flash_attn_output(
...
@@ -168,16 +173,13 @@ def test_flash_attn_output(
# batch_size = 40
# batch_size = 40
# nheads = 16
# nheads = 16
batch_size
=
9
batch_size
=
9
nheads
=
4
nheads
=
6
nheads_kv
=
6
if
mha_type
==
"mha"
else
(
2
if
mha_type
==
"gqa"
else
1
)
# batch_size = 1
# batch_size = 1
# nheads = 1
# nheads = 1
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
q
=
torch
.
randn
(
batch_size
,
seqlen_q
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
k
=
torch
.
randn
(
k
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads_kv
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
)
v
=
torch
.
randn
(
batch_size
,
seqlen_k
,
nheads
,
d
,
device
=
device
,
dtype
=
dtype
,
requires_grad
=
True
)
out
,
lse
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
)
out
,
lse
=
flash_attn_func
(
q
,
k
,
v
,
causal
=
causal
)
out_ref
,
attn_ref
=
attention_ref
(
out_ref
,
attn_ref
=
attention_ref
(
q
,
q
,
...
@@ -202,15 +204,15 @@ def test_flash_attn_output(
...
@@ -202,15 +204,15 @@ def test_flash_attn_output(
# m = qk.amax(-1, keepdim=True)
# m = qk.amax(-1, keepdim=True)
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# s_tmp = torch.exp((qk - m) / math.sqrt(d))
# exp_sum = s_tmp.sum(-1)
# exp_sum = s_tmp.sum(-1)
qk
=
torch
.
einsum
(
'bthd,bshd->bhts'
,
q
.
float
()
/
math
.
sqrt
(
d
),
k
.
float
())
#
qk = torch.einsum('bthd,bshd->bhts', q.float() / math.sqrt(d), k.float())
lse_ref
=
torch
.
logsumexp
(
qk
,
dim
=-
1
)
#
lse_ref = torch.logsumexp(qk, dim=-1)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output max diff:
{
(
out
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Output mean diff:
{
(
out
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch max diff:
{
(
out_pt
-
out_ref
).
abs
().
max
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
print
(
f
"Pytorch mean diff:
{
(
out_pt
-
out_ref
).
abs
().
mean
().
item
()
}
"
)
if
not
causal
:
#
if not causal:
print
(
f
"LSE max diff:
{
(
lse
-
lse_ref
).
abs
().
max
().
item
()
}
"
)
#
print(f"LSE max diff: {(lse - lse_ref).abs().max().item()}")
# breakpoint()
# breakpoint()
# if d <= 128:
# if d <= 128:
...
@@ -248,5 +250,3 @@ def test_flash_attn_output(
...
@@ -248,5 +250,3 @@ def test_flash_attn_output(
# assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
# assert (dq - dq_ref).abs().max().item() <= 2 * (dq_pt - dq_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
# assert (dk - dk_ref).abs().max().item() <= 2 * (dk_pt - dk_ref).abs().max().item()
# assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
# assert (dv - dv_ref).abs().max().item() <= 2 * (dv_pt - dv_ref).abs().max().item()
hopper/tile_scheduler.hpp
View file @
74b0761f
/******************************************************************************
/******************************************************************************
* Copyright (c) 2024, Tri Dao.
* Copyright (c) 2024,
Jay Shah, Ganesh Bikshandi, Ying Zhang, Vijay Thakkar, Pradeep Ramani,
Tri Dao.
******************************************************************************/
******************************************************************************/
#pragma once
#pragma once
#include "cutlass/fast_math.h"
#include "cutlass/fast_math.h"
#include "cutlass/arch/barrier.h"
namespace
flash
{
#include "named_barrier.hpp"
///////////////////////////////////////////////////////////////////////////////
class
StaticPersistentTileSchedulerOld
{
//
// Data members
//
private:
int
current_work_linear_idx_
;
cutlass
::
FastDivmod
const
&
m_block_divmod
,
&
head_divmod
;
int
const
total_blocks
;
public:
namespace
flash
{
struct
WorkTileInfo
{
int
M_idx
=
0
;
int
H_idx
=
0
;
int
B_idx
=
0
;
bool
is_valid_tile
=
false
;
CUTLASS_HOST_DEVICE
bool
is_valid
()
const
{
return
is_valid_tile
;
}
CUTLASS_HOST_DEVICE
static
WorkTileInfo
invalid_work_tile
()
{
return
{
-
1
,
-
1
,
-
1
,
false
};
}
};
public:
CUTLASS_DEVICE
explicit
StaticPersistentTileSchedulerOld
(
cutlass
::
FastDivmod
const
&
m_block_divmod_
,
cutlass
::
FastDivmod
const
&
head_divmod_
,
int
const
total_blocks_
)
:
m_block_divmod
(
m_block_divmod_
),
head_divmod
(
head_divmod_
),
total_blocks
(
total_blocks_
)
{
// MSVC requires protecting use of CUDA-specific nonstandard syntax,
// like blockIdx and gridDim, with __CUDA_ARCH__.
#if defined(__CUDA_ARCH__)
// current_work_linear_idx_ = blockIdx.x + blockIdx.y * gridDim.x + blockIdx.z * gridDim.x * gridDim.y;
current_work_linear_idx_
=
blockIdx
.
x
;
#else
CUTLASS_ASSERT
(
false
&&
"This line should never be reached"
);
#endif
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work
()
const
{
return
get_current_work_for_linear_idx
(
current_work_linear_idx_
);
}
CUTLASS_DEVICE
WorkTileInfo
get_current_work_for_linear_idx
(
int
linear_idx
)
const
{
if
(
linear_idx
>=
total_blocks
)
{
return
WorkTileInfo
::
invalid_work_tile
();
}
// Map worker's linear index into the CTA tiled problem shape to the corresponding MHB indices
int
M_idx
,
H_idx
,
B_idx
;
int
quotient
=
m_block_divmod
.
divmod
(
M_idx
,
linear_idx
);
B_idx
=
head_divmod
.
divmod
(
H_idx
,
quotient
);
return
{
M_idx
,
H_idx
,
B_idx
,
true
};
}
CUTLASS_DEVICE
void
// advance_to_next_work(int advance_count = 1) {
advance_to_next_work
()
{
// current_work_linear_idx_ += int(gridDim.x * gridDim.y * gridDim.z);
current_work_linear_idx_
+=
int
(
gridDim
.
x
);
}
CUTLASS_DEVICE
WorkTileInfo
fetch_next_work
()
{
WorkTileInfo
new_work_tile_info
;
advance_to_next_work
();
new_work_tile_info
=
get_current_work
();
return
new_work_tile_info
;
}
};
///////////////////////////////////////////////////////////////////////////////
///////////////////////////////////////////////////////////////////////////////
class
SingleTileScheduler
{
struct
SingleTileScheduler
{
public:
public:
// Host side kernel arguments
// Host side kernel arguments
struct
Arguments
{
struct
Arguments
{
int
const
num_blocks_m
,
num_head
,
num_batch
;
int
const
num_blocks_m
,
num_head
,
num_batch
;
int
const
*
tile_count_semaphore
=
nullptr
;
int
*
const
tile_count_semaphore
=
nullptr
;
};
};
// Device side kernel params
// Device side kernel params
...
@@ -140,20 +54,30 @@ public:
...
@@ -140,20 +54,30 @@ public:
return
{
M_idx
,
H_idx
,
B_idx
};
return
{
M_idx
,
H_idx
,
B_idx
};
}
}
CUTLASS_DEVICE
WorkTileInfo
get_next_work
(
Params
const
&
params
)
const
{
return
{
-
1
,
-
1
,
-
1
,
false
};
}
};
};
CUTLASS_DEVICE
SingleTileScheduler
(
int
*
tile_count_smem_
)
{
}
CUTLASS_DEVICE
CUTLASS_DEVICE
WorkTileInfo
WorkTileInfo
get_initial_work
()
const
{
get_initial_work
()
const
{
return
{
int
(
blockIdx
.
x
),
int
(
blockIdx
.
y
),
int
(
blockIdx
.
z
),
true
};
return
{
int
(
blockIdx
.
x
),
int
(
blockIdx
.
y
),
int
(
blockIdx
.
z
),
true
};
}
}
CUTLASS_DEVICE
void
init_consumer
()
const
{}
CUTLASS_DEVICE
void
prefetch_next_work
(
Params
const
&
params
,
WorkTileInfo
&
current_work
)
const
{}
CUTLASS_DEVICE
void
broadcast_next_work
(
WorkTileInfo
&
current_work
)
const
{}
template
<
bool
IsProducer
=
false
>
CUTLASS_DEVICE
CUTLASS_DEVICE
WorkTileInfo
WorkTileInfo
get_next_work
(
Params
const
&
params
,
WorkTileInfo
const
&
current_work
)
const
{
get_next_work
(
Params
const
&
params
,
WorkTileInfo
const
&
current_work
)
const
{
...
@@ -171,7 +95,7 @@ public:
...
@@ -171,7 +95,7 @@ public:
// Host side kernel arguments
// Host side kernel arguments
struct
Arguments
{
struct
Arguments
{
int
const
num_blocks_m
,
num_head
,
num_batch
;
int
const
num_blocks_m
,
num_head
,
num_batch
;
int
const
*
tile_count_semaphore
=
nullptr
;
int
*
const
tile_count_semaphore
=
nullptr
;
};
};
// Device side kernel params
// Device side kernel params
...
@@ -210,12 +134,28 @@ public:
...
@@ -210,12 +134,28 @@ public:
};
};
CUTLASS_DEVICE
StaticPersistentTileScheduler
(
int
*
tile_count_smem_
)
{};
CUTLASS_DEVICE
CUTLASS_DEVICE
WorkTileInfo
WorkTileInfo
get_initial_work
()
const
{
get_initial_work
()
const
{
return
{
int
(
blockIdx
.
x
)};
return
{
int
(
blockIdx
.
x
)};
}
}
CUTLASS_DEVICE
void
init_consumer
()
const
{}
CUTLASS_DEVICE
void
prefetch_next_work
(
Params
const
&
params
,
WorkTileInfo
&
current_work
)
const
{}
CUTLASS_DEVICE
void
broadcast_next_work
(
WorkTileInfo
&
current_work
)
const
{}
template
<
bool
IsProducer
=
false
>
CUTLASS_DEVICE
CUTLASS_DEVICE
WorkTileInfo
WorkTileInfo
get_next_work
(
Params
const
&
params
,
WorkTileInfo
const
&
current_work
)
const
{
get_next_work
(
Params
const
&
params
,
WorkTileInfo
const
&
current_work
)
const
{
...
@@ -224,21 +164,25 @@ public:
...
@@ -224,21 +164,25 @@ public:
};
};
template
<
int
NumMmaThreads
=
2
*
cutlass
::
NumThreadsPerWarpGroup
>
class
DynamicPersistentTileScheduler
{
class
DynamicPersistentTileScheduler
{
protected:
int
*
const
tile_count_smem
;
public:
public:
// Host side kernel arguments
// Host side kernel arguments
struct
Arguments
{
struct
Arguments
{
int
const
num_blocks_m
,
num_head
,
num_batch
;
int
const
num_blocks_m
,
num_head
,
num_batch
;
int
const
*
tile_count_semaphore
;
int
*
const
tile_count_semaphore
;
};
};
// Device side kernel params
// Device side kernel params
struct
Params
{
struct
Params
{
int
const
total_blocks
;
int
const
total_blocks
;
cutlass
::
FastDivmod
const
m_block_divmod
,
head_divmod
;
cutlass
::
FastDivmod
const
m_block_divmod
,
head_divmod
;
int
const
*
tile_count_semaphore
;
int
*
const
tile_count_semaphore
;
};
};
static
Params
static
Params
...
@@ -253,25 +197,27 @@ public:
...
@@ -253,25 +197,27 @@ public:
return
{
uint32_t
(
num_sm
)};
return
{
uint32_t
(
num_sm
)};
}
}
using
WorkTileInfo
=
StaticPersistentTileScheduler
::
WorkTileInfo
;
struct
WorkTileInfo
{
// struct WorkTileInfo {
int
tile_idx
;
// int tile_idx;
// CUTLASS_DEVICE
CUTLASS_DEVICE
// bool
bool
// is_valid(Params const& params) const {
is_valid
(
Params
const
&
params
)
const
{
// return tile_idx < params.total_blocks;
return
tile_idx
<
params
.
total_blocks
;
// }
}
CUTLASS_DEVICE
cute
::
tuple
<
int32_t
,
int32_t
,
int32_t
>
get_block_coord
(
Params
const
&
params
)
const
{
int
m_block
,
bidh
,
bidb
;
bidb
=
params
.
head_divmod
.
divmod
(
bidh
,
params
.
m_block_divmod
.
divmod
(
m_block
,
tile_idx
));
return
{
m_block
,
bidh
,
bidb
};
}
// CUTLASS_DEVICE
};
// cute::tuple<int32_t, int32_t, int32_t>
// get_block_coord(Params const& params) const {
// int m_block, bidh, bidb;
// bidb = params.head_divmod.divmod(bidh, params.m_block_divmod.divmod(m_block, tile_idx));
// return {m_block, bidh, bidb};
// }
// };
CUTLASS_DEVICE
DynamicPersistentTileScheduler
(
int
*
tile_count_smem_
)
:
tile_count_smem
(
tile_count_smem_
)
{};
CUTLASS_DEVICE
CUTLASS_DEVICE
WorkTileInfo
WorkTileInfo
...
@@ -279,12 +225,45 @@ public:
...
@@ -279,12 +225,45 @@ public:
return
{
int
(
blockIdx
.
x
)};
return
{
int
(
blockIdx
.
x
)};
}
}
CUTLASS_DEVICE
void
init_consumer
()
const
{
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
TileCountSmemEmpty
)
/*id*/
);
}
CUTLASS_DEVICE
void
prefetch_next_work
(
Params
const
&
params
,
WorkTileInfo
&
current_work
)
const
{
if
(
threadIdx
.
x
%
cutlass
::
NumThreadsPerWarp
==
0
)
{
current_work
.
tile_idx
=
atomicAdd
(
params
.
tile_count_semaphore
,
1
)
+
int
(
gridDim
.
x
);
}
}
CUTLASS_DEVICE
void
broadcast_next_work
(
WorkTileInfo
&
current_work
)
const
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
TileCountSmemEmpty
)
/*id*/
);
if
(
threadIdx
.
x
%
cutlass
::
NumThreadsPerWarp
==
0
)
{
*
tile_count_smem
=
current_work
.
tile_idx
;
}
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
TileCountSmemFull
)
/*id*/
);
}
template
<
bool
IsProducer
=
false
>
CUTLASS_DEVICE
CUTLASS_DEVICE
WorkTileInfo
WorkTileInfo
get_next_work
(
Params
const
&
params
,
WorkTileInfo
const
&
current_work
)
const
{
get_next_work
(
Params
const
&
params
,
WorkTileInfo
const
&
current_work
)
const
{
return
{
current_work
.
tile_idx
+
int
(
gridDim
.
x
)};
if
constexpr
(
IsProducer
)
{
// thread 0 already has the right tile_idx, just need to broadcast to the rest of warp 0
return
{
__shfl_sync
(
0xffffffff
,
current_work
.
tile_idx
,
0
/*lane*/
)};
}
else
{
cutlass
::
arch
::
NamedBarrier
::
sync
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
TileCountSmemFull
)
/*id*/
);
int
tile_idx
=
*
tile_count_smem
;
cutlass
::
arch
::
NamedBarrier
::
arrive
(
NumMmaThreads
+
cutlass
::
NumThreadsPerWarp
,
static_cast
<
int
>
(
FwdNamedBarriers
::
TileCountSmemEmpty
)
/*id*/
);
return
{
tile_idx
};
}
}
}
};
};
}
// flash
}
// flash
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment