Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
446204c7
Commit
446204c7
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
tests passing for single page k
parent
a3e06cd5
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
119 additions
and
108 deletions
+119
-108
csrc/flash_attn/src/debug.h
csrc/flash_attn/src/debug.h
+84
-3
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+24
-97
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+11
-8
No files found.
csrc/flash_attn/src/debug.h
View file @
446204c7
...
...
@@ -17,9 +17,89 @@
printf("\n[kin:end:%s]\n", #BOOL); \
}
__forceinline__
__device__
void
print_qkv_params
(
const
Qkv_params
&
params
)
{
// LLM generated
printf
(
"Qkv_params:
\n
"
);
printf
(
"q_ptr: %p
\n
"
,
params
.
q_ptr
);
printf
(
"k_ptr: %p
\n
"
,
params
.
k_ptr
);
printf
(
"v_ptr: %p
\n
"
,
params
.
v_ptr
);
printf
(
"q_batch_stride: %"
PRId64
"
\n
"
,
params
.
q_batch_stride
);
printf
(
"k_batch_stride: %"
PRId64
"
\n
"
,
params
.
k_batch_stride
);
printf
(
"v_batch_stride: %"
PRId64
"
\n
"
,
params
.
v_batch_stride
);
printf
(
"q_row_stride: %"
PRId64
"
\n
"
,
params
.
q_row_stride
);
printf
(
"k_row_stride: %"
PRId64
"
\n
"
,
params
.
k_row_stride
);
printf
(
"v_row_stride: %"
PRId64
"
\n
"
,
params
.
v_row_stride
);
printf
(
"q_head_stride: %"
PRId64
"
\n
"
,
params
.
q_head_stride
);
printf
(
"k_head_stride: %"
PRId64
"
\n
"
,
params
.
k_head_stride
);
printf
(
"v_head_stride: %"
PRId64
"
\n
"
,
params
.
v_head_stride
);
printf
(
"h: %d
\n
"
,
params
.
h
);
printf
(
"h_k: %d
\n
"
,
params
.
h_k
);
printf
(
"h_h_k_ratio: %d
\n
"
,
params
.
h_h_k_ratio
);
}
__forceinline__
__device__
void
print_flash_fwd_params
(
const
Flash_fwd_params
&
params
)
{
print_qkv_params
(
params
);
// LLM generated
printf
(
"struct Flash_fwd_params:
\n
"
);
printf
(
"o_ptr: %p
\n
"
,
params
.
o_ptr
);
printf
(
"oaccum_ptr: %p
\n
"
,
params
.
oaccum_ptr
);
printf
(
"o_batch_stride: %ld
\n
"
,
params
.
o_batch_stride
);
printf
(
"o_row_stride: %ld
\n
"
,
params
.
o_row_stride
);
printf
(
"o_head_stride: %ld
\n
"
,
params
.
o_head_stride
);
printf
(
"p_ptr: %p
\n
"
,
params
.
p_ptr
);
printf
(
"softmax_lse_ptr: %p
\n
"
,
params
.
softmax_lse_ptr
);
printf
(
"softmax_lseaccum_ptr: %p
\n
"
,
params
.
softmax_lseaccum_ptr
);
printf
(
"b: %d
\n
"
,
params
.
b
);
printf
(
"seqlen_q: %d
\n
"
,
params
.
seqlen_q
);
printf
(
"seqlen_k: %d
\n
"
,
params
.
seqlen_k
);
printf
(
"seqlen_knew: %d
\n
"
,
params
.
seqlen_knew
);
printf
(
"d: %d
\n
"
,
params
.
d
);
printf
(
"seqlen_q_rounded: %d
\n
"
,
params
.
seqlen_q_rounded
);
printf
(
"seqlen_k_rounded: %d
\n
"
,
params
.
seqlen_k_rounded
);
printf
(
"d_rounded: %d
\n
"
,
params
.
d_rounded
);
printf
(
"rotary_dim: %d
\n
"
,
params
.
rotary_dim
);
printf
(
"scale_softmax: %f
\n
"
,
params
.
scale_softmax
);
printf
(
"scale_softmax_log2: %f
\n
"
,
params
.
scale_softmax_log2
);
printf
(
"cu_seqlens_q: %p
\n
"
,
params
.
cu_seqlens_q
);
printf
(
"cu_seqlens_k: %p
\n
"
,
params
.
cu_seqlens_k
);
printf
(
"seqused_k: %p
\n
"
,
params
.
seqused_k
);
printf
(
"blockmask: %p
\n
"
,
params
.
blockmask
);
printf
(
"knew_ptr: %p
\n
"
,
params
.
knew_ptr
);
printf
(
"vnew_ptr: %p
\n
"
,
params
.
vnew_ptr
);
printf
(
"knew_batch_stride: %ld
\n
"
,
params
.
knew_batch_stride
);
printf
(
"vnew_batch_stride: %ld
\n
"
,
params
.
vnew_batch_stride
);
printf
(
"knew_row_stride: %ld
\n
"
,
params
.
knew_row_stride
);
printf
(
"vnew_row_stride: %ld
\n
"
,
params
.
vnew_row_stride
);
printf
(
"knew_head_stride: %ld
\n
"
,
params
.
knew_head_stride
);
printf
(
"vnew_head_stride: %ld
\n
"
,
params
.
vnew_head_stride
);
printf
(
"rotary_cos_ptr: %p
\n
"
,
params
.
rotary_cos_ptr
);
printf
(
"rotary_sin_ptr: %p
\n
"
,
params
.
rotary_sin_ptr
);
printf
(
"cache_batch_idx: %p
\n
"
,
params
.
cache_batch_idx
);
printf
(
"block_table: %p
\n
"
,
params
.
block_table
);
printf
(
"block_table_batch_stride: %ld
\n
"
,
params
.
block_table_batch_stride
);
printf
(
"page_block_size: %d
\n
"
,
params
.
page_block_size
);
printf
(
"p_dropout: %f
\n
"
,
params
.
p_dropout
);
printf
(
"p_dropout_in_uint8_t: %u
\n
"
,
params
.
p_dropout_in_uint8_t
);
printf
(
"rp_dropout: %f
\n
"
,
params
.
rp_dropout
);
printf
(
"scale_softmax_rp_dropout: %f
\n
"
,
params
.
scale_softmax_rp_dropout
);
printf
(
"window_size_left: %d
\n
"
,
params
.
window_size_left
);
printf
(
"window_size_right: %d
\n
"
,
params
.
window_size_right
);
printf
(
"philox_args: %p
\n
"
,
&
(
params
.
philox_args
));
printf
(
"rng_state: %p
\n
"
,
params
.
rng_state
);
printf
(
"is_bf16: %d
\n
"
,
params
.
is_bf16
);
printf
(
"is_causal: %d
\n
"
,
params
.
is_causal
);
printf
(
"is_seqlens_k_cumulative: %d
\n
"
,
params
.
is_seqlens_k_cumulative
);
printf
(
"is_rotary_interleaved: %d
\n
"
,
params
.
is_rotary_interleaved
);
printf
(
"num_splits: %d
\n
"
,
params
.
num_splits
);
printf
(
"alibi_slopes_ptr: %p
\n
"
,
params
.
alibi_slopes_ptr
);
printf
(
"alibi_slopes_batch_stride: %ld
\n
"
,
params
.
alibi_slopes_batch_stride
);
}
template
<
typename
Kernel_traits
>
__forceinline__
__device__
void
print_traits
()
{
__forceinline__
__device__
void
print_traits
()
{
// bool
printf
(
"Kernel_traits::Share_Q_K_smem : %s
\n
"
,
Kernel_traits
::
Share_Q_K_smem
?
"true"
:
"false"
);
printf
(
"Kernel_traits::Is_Q_in_regs : %s
\n
"
,
Kernel_traits
::
Is_Q_in_regs
?
"true"
:
"false"
);
...
...
@@ -36,7 +116,8 @@ print_traits() {
printf
(
"Kernel_traits::kSmemQSize : %d
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemKVSize : %d
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemSize : %d
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kGmemRowsPerThread: %d
\n
"
,
Kernel_traits
::
kGmemRowsPerThread
);
printf
(
"Kernel_traits::kGmemRowsPerThread: %d
\n
"
,
Kernel_traits
::
kGmemRowsPerThread
);
printf
(
"Kernel_traits::kGmemThreadsPerRow: %d
\n
"
,
Kernel_traits
::
kGmemThreadsPerRow
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %d
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
// cute object
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
446204c7
...
...
@@ -43,9 +43,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 1
KIN_PRINT
(
print_traits
<
Kernel_traits
>
());
#endif
auto
seed_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
flash
::
Dropout
dropout
(
std
::
get
<
0
>
(
seed_offset
),
std
::
get
<
1
>
(
seed_offset
),
params
.
p_dropout_in_uint8_t
,
...
...
@@ -60,9 +57,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 1
KIN_PRINT
(
print_binfo
(
binfo
))
#endif
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
...
...
@@ -144,19 +138,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
#if 1
KIN_PRINT
(
print
(
sK
.
layout
()))
KIN_PRINT
(
print
(
gK
.
layout
()))
#endif
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
KIN_PRINT
(
print
(
sV
.
layout
()))
KIN_PRINT
(
print
(
sVt
.
layout
()))
KIN_PRINT
(
print
(
sVtNoSwizzle
.
layout
()))
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
...
@@ -167,27 +152,16 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
#if 1
KIN_PRINT
(
print
(
tKgK
.
layout
()))
KIN_PRINT
(
print
(
tKsK
.
layout
()))
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT
(
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
print
(
tSrK
.
layout
()))
#endif
Tensor
tSgS
=
thr_mma
.
partition_C
(
gP
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
KIN_PRINT
(
print
(
acc_o
.
layout
()))
#endif
//
// Copy Atom retiling
...
...
@@ -195,22 +169,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
#if 0
KIN_PRINT(smem_thr_copy_Q.print_all());
#endif
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
KIN_PRINT
(
print
(
tSsQ
.
layout
()))
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
#if 1
KIN_PRINT
(
print
(
tSsK
.
layout
()))
#endif
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
...
...
@@ -227,10 +192,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT
(
print
(
cQ
.
layout
()))
KIN_PRINT
(
print
(
cKV
.
layout
()))
#endif
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
...
...
@@ -251,12 +212,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
KIN_PRINT
(
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
print
(
tKVpKV
.
layout
()))
#endif
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
...
...
@@ -538,13 +493,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 1
KIN_PRINT
(
print_traits
<
Kernel_traits
>
())
KIN_PRINT_BOOL
(
Is_causal
)
KIN_PRINT_BOOL
(
Is_local
)
KIN_PRINT_BOOL
(
Has_alibi
)
KIN_PRINT_BOOL
(
Is_even_MN
)
KIN_PRINT_BOOL
(
Is_even_K
)
KIN_PRINT_BOOL
(
Split
)
KIN_PRINT_BOOL
(
Append_KV
)
KIN_PRINT
(
print_flash_fwd_params
(
params
))
#endif
using
GmemTiledCopyO
=
std
::
conditional_t
<
...
...
@@ -558,9 +507,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 1
KIN_PRINT
(
print_binfo
(
binfo
))
#endif
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_block_min
=
!
Is_local
...
...
@@ -625,17 +571,24 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
?
binfo
.
k_offset
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
:
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
// block addresses are later resolved per-thread
const
index_t
row_offset_k__shadow
=
block_table
[(
n_block_max
-
1
)
*
kBlockN
/
params
.
page_block_size
]
*
params
.
k_batch_stride
+
(((
n_block_max
-
1
)
*
kBlockN
)
%
params
.
page_block_size
)
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
index_t
row_offset_v
=
block_table
==
nullptr
?
binfo
.
k_offset
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb_cache
)
+
(
n_block_max
-
1
)
*
kBlockN
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
:
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gK__shadow
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k__shadow
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("k_ptr = %p, row_offset_k = %d, gK_ptr = %p\n", params.k_ptr, row_offset_k, gK.data()); }
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
...
...
@@ -646,13 +599,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
KIN_PRINT
(
print
(
sK
.
layout
()))
KIN_PRINT
(
print
(
gK
.
layout
()))
KIN_PRINT
(
print
(
sV
.
layout
()))
KIN_PRINT
(
print
(
sVt
.
layout
()))
KIN_PRINT
(
print
(
sVtNoSwizzle
.
layout
()))
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
...
...
@@ -662,27 +608,31 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_KV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgK__shadow
=
gmem_thr_copy_KV
.
partition_S
(
gK__shadow
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
if
(
block_table
!=
nullptr
)
{
tKgK
.
data
()
=
g
V
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tKgK
.
data
()
=
g
K
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
k_batch_stride
,
params
.
k_row_stride
);
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
#if 1
KIN_PRINT
(
print
(
tKgK
.
layout
()))
KIN_PRINT
(
print
(
tKsK
.
layout
()))
#endif
#if 1
fill
(
tVgV
,
1.
f
*
((
Element
)
tidx
));
__syncthreads
();
KIN_PRINT
(
print_tensor
(
gV
))
KIN_PRINT
([
&
]()
{
for
(
int
i
=
0
;
i
<
n_block_max
;
i
++
)
{
printf
(
"%d "
,
block_table
[
i
]);
}
}())
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) {
// gK.data() = tKgK.data();
// }
KIN_PRINT
(
print_tensor
(
tKgK
))
KIN_PRINT
(
print_tensor
(
gK
))
KIN_PRINT
(
print_tensor
(
tKgK__shadow
))
KIN_PRINT
(
print_tensor
(
gK__shadow
))
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
...
@@ -690,15 +640,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT
(
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
print
(
tSrK
.
layout
()))
#endif
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
KIN_PRINT
(
print
(
acc_o
.
layout
()))
#endif
//
// Copy Atom retiling
...
...
@@ -707,16 +650,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
KIN_PRINT
(
print
(
tSsQ
.
layout
()))
#endif
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
#if 1
KIN_PRINT
(
print
(
tSsK
.
layout
()))
#endif
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
...
...
@@ -732,10 +669,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT
(
print
(
cQ
.
layout
()))
KIN_PRINT
(
print
(
cKV
.
layout
()))
#endif
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
...
...
@@ -744,12 +677,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
KIN_PRINT
(
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
print
(
tKVpKV
.
layout
()))
#endif
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
...
...
csrc/flash_attn/src/utils.h
View file @
446204c7
...
...
@@ -4,6 +4,8 @@
#pragma once
#include "debug.h"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
...
...
@@ -298,16 +300,17 @@ template <typename Kernel_traits>
__forceinline__
__device__
int
init_thread_kv_page_slice_offset
(
const
int
tidx
,
const
int
n_block_max
,
const
int
page_block_size
,
const
int
*
block_table
,
const
int
page_stride
,
const
int
row_stride
)
{
// base col of thread's slice relative to the block
const
int
col_offset
=
tidx
%
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemElemsPerLoad
;
// base row of thread's slice relative to the block
const
int
block_row_offset
=
tidx
/
Kernel_traits
::
kGmemThreadsPerRow
*
Kernel_traits
::
kGmemRowsPerThread
;
// base col of thread's slice relative to the entire tensor
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
Kernel_traits
::
kBlockN
;
// base row of thread's slice relative to the page
constexpr
int
kGmemThreadsPerRow
=
Kernel_traits
::
kGmemThreadsPerRow
;
constexpr
int
kGmemRowsPerThread
=
Kernel_traits
::
kGmemRowsPerThread
;
constexpr
int
kGmemElemsPerLoad
=
Kernel_traits
::
kGmemElemsPerLoad
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
const
int
col_offset
=
tidx
%
kGmemThreadsPerRow
*
kGmemElemsPerLoad
;
const
int
block_row_offset
=
tidx
/
kGmemThreadsPerRow
*
kGmemRowsPerThread
;
const
int
global_row_offset
=
block_row_offset
+
(
n_block_max
-
1
)
*
kBlockN
;
const
int
page_offset
=
global_row_offset
%
page_block_size
;
const
int
virtual_page_idx
=
global_row_offset
/
page_block_size
;
KIN_PRINT
(
printf
(
"%d"
,
virtual_page_idx
))
return
block_table
[
virtual_page_idx
]
*
page_stride
+
page_offset
*
row_stride
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment