Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
14b190bc
Commit
14b190bc
authored
Feb 11, 2024
by
skrider
Browse files
reshape gmem copy
parent
ac5e78a6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
103 additions
and
73 deletions
+103
-73
csrc/flash_attn/src/debug.h
csrc/flash_attn/src/debug.h
+16
-6
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+76
-67
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+11
-0
No files found.
csrc/flash_attn/src/debug.h
View file @
14b190bc
...
@@ -3,18 +3,18 @@
...
@@ -3,18 +3,18 @@
#pragma once
#pragma once
#define KIN_PRINT(
tag,
statement) \
#define KIN_PRINT(statement) \
if (thread0()) { \
if (thread0()) { \
printf("\n[kin:start:%s]\n",
tag
); \
printf("\n[kin:start:%s]\n",
#statement
); \
statement; \
statement; \
printf("\n[kin:end:%s]\n",
tag
); \
printf("\n[kin:end:%s]\n",
#statement
); \
}
}
#define KIN_PRINT_BOOL(
tag,
BOOL) \
#define KIN_PRINT_BOOL(BOOL) \
if (thread0()) { \
if (thread0()) { \
printf("\n[kin:start:%s]\n",
tag
); \
printf("\n[kin:start:%s]\n",
#BOOL
); \
printf("%s", BOOL ? "true" : "false"); \
printf("%s", BOOL ? "true" : "false"); \
printf("\n[kin:end:%s]\n",
tag
); \
printf("\n[kin:end:%s]\n",
#BOOL
); \
}
}
template
<
typename
Kernel_traits
>
template
<
typename
Kernel_traits
>
...
@@ -36,7 +36,17 @@ print_traits() {
...
@@ -36,7 +36,17 @@ print_traits() {
printf
(
"Kernel_traits::kSmemQSize : %d
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemQSize : %d
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemKVSize : %d
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemKVSize : %d
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemSize : %d
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kSmemSize : %d
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kGmemRowsPerThread: %d
\n
"
,
Kernel_traits
::
kGmemRowsPerThread
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %d
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %d
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
// cute object
printf
(
"Kernel_traits::GmemLayoutAtom : "
);
cute
::
print
(
Kernel_traits
::
GmemLayoutAtom
());
printf
(
"
\n
"
);
printf
(
"Kernel_traits::GmemTiledCopyQKV :
\n
"
);
cute
::
print
(
Kernel_traits
::
GmemTiledCopyQKV
());
printf
(
"
\n
"
);
}
}
template
<
typename
BlockInfo
>
template
<
typename
BlockInfo
>
...
...
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
14b190bc
...
@@ -44,7 +44,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -44,7 +44,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 1
#if 1
KIN_PRINT
(
"Kernel_traits"
,
print_traits
<
Kernel_traits
>
());
KIN_PRINT
(
print_traits
<
Kernel_traits
>
());
#endif
#endif
auto
seed_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
auto
seed_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
...
@@ -61,7 +61,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -61,7 +61,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 1
#if 1
KIN_PRINT
(
"binfo"
,
print_binfo
(
binfo
))
KIN_PRINT
(
print_binfo
(
binfo
))
#endif
#endif
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
...
@@ -145,17 +145,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -145,17 +145,17 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
typename
Kernel_traits
::
SmemLayoutKV
{});
#if 1
#if 1
KIN_PRINT
(
"sK.layout()"
,
print
(
sK
.
layout
()))
KIN_PRINT
(
print
(
sK
.
layout
()))
KIN_PRINT
(
"gK.layout()"
,
print
(
gK
.
layout
()))
KIN_PRINT
(
print
(
gK
.
layout
()))
#endif
#endif
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
#if 1
KIN_PRINT
(
"sV.layout()"
,
print
(
sV
.
layout
()))
KIN_PRINT
(
print
(
sV
.
layout
()))
KIN_PRINT
(
"sVt.layout()"
,
print
(
sVt
.
layout
()))
KIN_PRINT
(
print
(
sVt
.
layout
()))
KIN_PRINT
(
"sVtNoSwizzle.layout()"
,
print
(
sVtNoSwizzle
.
layout
()))
KIN_PRINT
(
print
(
sVtNoSwizzle
.
layout
()))
#endif
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
...
@@ -168,8 +168,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -168,8 +168,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
#if 1
#if 1
KIN_PRINT
(
"tKgK.layout()"
,
print
(
tKgK
.
layout
()))
KIN_PRINT
(
print
(
tKgK
.
layout
()))
KIN_PRINT
(
"tKsK.layout()"
,
print
(
tKsK
.
layout
()))
KIN_PRINT
(
print
(
tKsK
.
layout
()))
#endif
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
@@ -178,15 +178,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -178,15 +178,15 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
#if 1
KIN_PRINT
(
"tSrQ.layout()"
,
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
"tSrK.layout()"
,
print
(
tSrK
.
layout
()))
KIN_PRINT
(
print
(
tSrK
.
layout
()))
#endif
#endif
Tensor
tSgS
=
thr_mma
.
partition_C
(
gP
);
Tensor
tSgS
=
thr_mma
.
partition_C
(
gP
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
#if 1
KIN_PRINT
(
"acc_o.layout()"
,
print
(
acc_o
.
layout
()))
KIN_PRINT
(
print
(
acc_o
.
layout
()))
#endif
#endif
//
//
...
@@ -196,12 +196,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -196,12 +196,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
#if 0
#if 0
KIN_PRINT(
"fail",
smem_thr_copy_Q.print_all());
KIN_PRINT(smem_thr_copy_Q.print_all());
#endif
#endif
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
#if 1
KIN_PRINT
(
"tSsQ.layout()"
,
print
(
tSsQ
.
layout
()))
KIN_PRINT
(
print
(
tSsQ
.
layout
()))
#endif
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
...
@@ -209,7 +209,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -209,7 +209,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
#if 1
#if 1
KIN_PRINT
(
"tSsK.layout()"
,
print
(
tSsK
.
layout
()))
KIN_PRINT
(
print
(
tSsK
.
layout
()))
#endif
#endif
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
...
@@ -228,8 +228,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -228,8 +228,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
#if 1
KIN_PRINT
(
"cQ.layout()"
,
print
(
cQ
.
layout
()))
KIN_PRINT
(
print
(
cQ
.
layout
()))
KIN_PRINT
(
"cKV.layout()"
,
print
(
cKV
.
layout
()))
KIN_PRINT
(
print
(
cKV
.
layout
()))
#endif
#endif
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// if (cute::thread0()) {
...
@@ -252,10 +252,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
...
@@ -252,10 +252,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
#if 1
KIN_PRINT
(
"tQcQ.layout()"
,
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
"tKVcKV.layout()"
,
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
"tQpQ.layout()"
,
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
"tKVpKV.layout()"
,
print
(
tKVpKV
.
layout
()))
KIN_PRINT
(
print
(
tKVpKV
.
layout
()))
#endif
#endif
// Set predicates for k bounds
// Set predicates for k bounds
...
@@ -537,14 +537,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -537,14 +537,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 1
#if 1
KIN_PRINT
(
"Kernel_traits"
,
print_traits
<
Kernel_traits
>
())
KIN_PRINT
(
print_traits
<
Kernel_traits
>
())
KIN_PRINT_BOOL
(
"Is_causal"
,
Is_causal
)
KIN_PRINT_BOOL
(
Is_causal
)
KIN_PRINT_BOOL
(
"Is_local"
,
Is_local
)
KIN_PRINT_BOOL
(
Is_local
)
KIN_PRINT_BOOL
(
"Has_alibi"
,
Has_alibi
)
KIN_PRINT_BOOL
(
Has_alibi
)
KIN_PRINT_BOOL
(
"Is_even_MN"
,
Is_even_MN
)
KIN_PRINT_BOOL
(
Is_even_MN
)
KIN_PRINT_BOOL
(
"Is_even_K"
,
Is_even_K
)
KIN_PRINT_BOOL
(
Is_even_K
)
KIN_PRINT_BOOL
(
"Split"
,
Split
)
KIN_PRINT_BOOL
(
Split
)
KIN_PRINT_BOOL
(
"Append_KV"
,
Append_KV
)
KIN_PRINT_BOOL
(
Append_KV
)
#endif
#endif
using
GmemTiledCopyO
=
std
::
conditional_t
<
using
GmemTiledCopyO
=
std
::
conditional_t
<
...
@@ -559,7 +559,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -559,7 +559,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 1
#if 1
KIN_PRINT
(
"binfo"
,
print_binfo
(
binfo
))
KIN_PRINT
(
print_binfo
(
binfo
))
#endif
#endif
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
...
@@ -649,25 +649,34 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -649,25 +649,34 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
#if 1
KIN_PRINT
(
"sK.layout()"
,
print
(
sK
.
layout
()))
KIN_PRINT
(
print
(
sK
.
layout
()))
KIN_PRINT
(
"gK.layout()"
,
print
(
gK
.
layout
()))
KIN_PRINT
(
print
(
gK
.
layout
()))
KIN_PRINT
(
"sV.layout()"
,
print
(
sV
.
layout
()))
KIN_PRINT
(
print
(
sV
.
layout
()))
KIN_PRINT
(
"sVt.layout()"
,
print
(
sVt
.
layout
()))
KIN_PRINT
(
print
(
sVt
.
layout
()))
KIN_PRINT
(
"sVtNoSwizzle.layout()"
,
print
(
sVtNoSwizzle
.
layout
()))
KIN_PRINT
(
print
(
sVtNoSwizzle
.
layout
()))
#endif
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyQKVPaged
gmem_tiled_copy_KV
;
auto
gmem_thr_copy_KV
=
gmem_tiled_copy_KV
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_KV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_KV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_KV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_KV
.
partition_D
(
sV
);
#if 1
KIN_PRINT
(
print
(
tKgK
.
layout
()))
KIN_PRINT
(
print
(
tKsK
.
layout
()))
#endif
Tensor
tQgQ
=
gmem_thr_copy_QKV
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_QKV
.
partition_D
(
sQ
);
Tensor
tKgK
=
gmem_thr_copy_QKV
.
partition_S
(
gK
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
#if 1
#if 1
KIN_PRINT
(
"tKgK.layout()"
,
print
(
tKgK
.
layout
()))
fill
(
tVgV
,
1.
f
*
((
Element
)
tidx
));
KIN_PRINT
(
"tKsK.layout()"
,
print
(
tKsK
.
layout
()))
__syncthreads
();
KIN_PRINT
(
print_tensor
(
gV
))
#endif
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
typename
Kernel_traits
::
TiledMma
tiled_mma
;
...
@@ -676,13 +685,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -676,13 +685,13 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
#if 1
KIN_PRINT
(
"tSrQ.layout()"
,
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
"tSrK.layout()"
,
print
(
tSrK
.
layout
()))
KIN_PRINT
(
print
(
tSrK
.
layout
()))
#endif
#endif
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
#if 1
KIN_PRINT
(
"acc_o.layout()"
,
print
(
acc_o
.
layout
()))
KIN_PRINT
(
print
(
acc_o
.
layout
()))
#endif
#endif
//
//
...
@@ -693,14 +702,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -693,14 +702,14 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
#if 1
KIN_PRINT
(
"tSsQ.layout()"
,
print
(
tSsQ
.
layout
()))
KIN_PRINT
(
print
(
tSsQ
.
layout
()))
#endif
#endif
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
#if 1
#if 1
KIN_PRINT
(
"tSsK.layout()"
,
print
(
tSsK
.
layout
()))
KIN_PRINT
(
print
(
tSsK
.
layout
()))
#endif
#endif
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
...
@@ -718,22 +727,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -718,22 +727,22 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
#if 1
KIN_PRINT
(
"cQ.layout()"
,
print
(
cQ
.
layout
()))
KIN_PRINT
(
print
(
cQ
.
layout
()))
KIN_PRINT
(
"cKV.layout()"
,
print
(
cKV
.
layout
()))
KIN_PRINT
(
print
(
cKV
.
layout
()))
#endif
#endif
// Repeat the partitioning with identity layouts
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_Q
KV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_
Q
KV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_KV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
// Allocate predicate tensors for k
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
#if 1
KIN_PRINT
(
"tQcQ.layout()"
,
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
"tKVcKV.layout()"
,
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
"tQpQ.layout()"
,
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
"tKVpKV.layout()"
,
print
(
tKVpKV
.
layout
()))
KIN_PRINT
(
print
(
tKVpKV
.
layout
()))
#endif
#endif
// Set predicates for k bounds
// Set predicates for k bounds
...
@@ -792,8 +801,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -792,8 +801,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
+
row_offset_vnew
-
binfo
.
seqlen_k_cache
*
params
.
vnew_row_stride
),
+
row_offset_vnew
-
binfo
.
seqlen_k_cache
*
params
.
vnew_row_stride
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
vnew_row_stride
,
_1
{}));
make_stride
(
params
.
vnew_row_stride
,
_1
{}));
Tensor
tKgKnew
=
gmem_thr_copy_
Q
KV
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tKgKnew
=
gmem_thr_copy_KV
.
partition_S
(
gKnew
);
// (KCPY, KCPY_N, KCPY_K)
Tensor
tVgVnew
=
gmem_thr_copy_
Q
KV
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVgVnew
=
gmem_thr_copy_KV
.
partition_S
(
gVnew
);
// (VCPY, VCPY_N, VCPY_K)
const
int
n_block_copy_min
=
std
::
max
(
n_block_min
,
binfo
.
seqlen_k_cache
/
kBlockN
);
const
int
n_block_copy_min
=
std
::
max
(
n_block_min
,
binfo
.
seqlen_k_cache
/
kBlockN
);
auto
tKgK_data
=
tKgK
.
data
();
auto
tKgK_data
=
tKgK
.
data
();
...
@@ -853,7 +862,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -853,7 +862,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Read Q from gmem to smem, optionally apply rotary embedding.
// Read Q from gmem to smem, optionally apply rotary embedding.
if
(
!
Append_KV
||
params
.
rotary_dim
==
0
)
{
if
(
!
Append_KV
||
params
.
rotary_dim
==
0
)
{
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
// We don't need to clear the sQ smem tiles since we'll only write out the valid outputs
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_Q
KV
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_Q
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
}
else
{
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
...
@@ -890,7 +899,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -890,7 +899,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
int
n_block
=
n_block_max
-
1
;
int
n_block
=
n_block_max
-
1
;
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
// We don't need to clear the sK smem tiles since we'll mask out the scores anyway.
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_
Q
KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -935,11 +944,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -935,11 +944,11 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_
Q
KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
}
else
{
}
else
{
// Clear the smem tiles to account for predicated off loads
// Clear the smem tiles to account for predicated off loads
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
flash
::
copy
<
Is_even_MN
,
Is_even_K
,
/*Clear_OOB_MN=*/
true
>
(
gmem_tiled_copy_
Q
KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN
);
);
}
}
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -970,7 +979,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -970,7 +979,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tKgK
.
data
()
=
tKgK
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
k_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
k_row_stride
;
tKgK
.
data
()
=
tKgK
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
k_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
k_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_
Q
KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
@@ -1013,7 +1022,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1013,7 +1022,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
block_table_offset_next
=
n_block
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
tVgV
.
data
()
=
tVgV
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
v_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
v_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_
Q
KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tVgV
,
tVsV
,
tKVcKV
,
tKVpKV
);
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
flash
::
gemm
(
flash
::
gemm
(
...
@@ -1034,7 +1043,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
...
@@ -1034,7 +1043,7 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
block_table_offset_next
=
(
n_block
-
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
tKgK
.
data
()
=
tKgK
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
k_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
k_row_stride
;
tKgK
.
data
()
=
tKgK
.
data
()
+
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
])
*
params
.
k_batch_stride
+
(
block_table_offset_next
-
block_table_offset_cur
)
*
params
.
k_row_stride
;
}
}
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_
Q
KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
flash
::
copy
<
/*Is_even_MN=*/
true
,
Is_even_K
>
(
gmem_tiled_copy_KV
,
tKgK
,
tKsK
,
tKVcKV
,
tKVpKV
);
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// This cp_async_fence needs to be in the if block, otherwise the synchronization
// isn't right and we get race conditions.
// isn't right and we get race conditions.
cute
::
cp_async_fence
();
cute
::
cp_async_fence
();
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
14b190bc
...
@@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base {
...
@@ -131,6 +131,17 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
Element
>
{},
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
Element
>
{},
GmemLayoutAtom
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per read
// from how many rows does each thread have to fetch
static
constexpr
int
kGmemRowsPerThread
=
kBlockN
/
(
kNThreads
/
kGmemThreadsPerRow
);
// Here we assign a contiguous tile to each thread, rather than a 1x8 row every
// (kNThreads / kGmemThreadsPerRow) rows, ensuring that the elements assigned to each thread
// do not cross a page boundary. This way, each thread need only fetch 1 page index per
// mainloop iteration. R>udimentary testing shows no slowdown.
using
GmemTiledCopyQKVPaged
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
Gmem_copy_struct
,
Element
>
{},
GmemLayoutAtom
{},
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_8
>
,
Stride
<
_8
,
_1
>>
{}));
using
GmemTiledCopyO
=
decltype
(
using
GmemTiledCopyO
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtom
{},
GmemLayoutAtom
{},
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment