Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
8efeb7f5
Commit
8efeb7f5
authored
Feb 08, 2024
by
skrider
Browse files
add print statements for debugging
parent
36587c01
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
103 additions
and
0 deletions
+103
-0
csrc/flash_attn/src/debug.h
csrc/flash_attn/src/debug.h
+39
-0
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+64
-0
No files found.
csrc/flash_attn/src/debug.h
0 → 100644
View file @
8efeb7f5
#include <cute/util/debug.hpp>
#define KIN_PRINT(tag, statement) \
if (cute::thread0()) { \
printf("[kin:start:%s]\n", tag); \
statement; \
printf("\n[kin:end:%s]\n", tag); \
}
template
<
typename
Kernel_traits
>
void
print_traits
()
{
// bool
printf
(
"Kernel_traits::Share_Q_K_smem : %s
\n
"
,
Kernel_traits
::
Share_Q_K_smem
);
printf
(
"Kernel_traits::Is_Q_in_regs : %s
\n
"
,
Kernel_traits
::
Is_Q_in_regs
);
// int
printf
(
"Kernel_traits::kNWarps : %s
\n
"
,
Kernel_traits
::
kNWarps
);
printf
(
"Kernel_traits::kNThreads : %s
\n
"
,
Kernel_traits
::
kNThreads
);
printf
(
"Kernel_traits::kBlockM : %s
\n
"
,
Kernel_traits
::
kBlockM
);
printf
(
"Kernel_traits::kBlockN : %s
\n
"
,
Kernel_traits
::
kBlockN
);
printf
(
"Kernel_traits::kHeadDim : %s
\n
"
,
Kernel_traits
::
kHeadDim
);
printf
(
"Kernel_traits::kBlockKSmem : %s
\n
"
,
Kernel_traits
::
kBlockKSmem
);
printf
(
"Kernel_traits::kBlockKGmem : %s
\n
"
,
Kernel_traits
::
kBlockKGmem
);
printf
(
"Kernel_traits::kSwizzle : %s
\n
"
,
Kernel_traits
::
kSwizzle
);
printf
(
"Kernel_traits::kSmemQSize : %s
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemKVSize : %s
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemSize : %s
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %s
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
// cute object
printf
(
"Kernel_traits::GmemLayoutAtom : "
);
print
(
Kernel_traits
::
GmemLayoutAtom
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::GmemTiledCopyQKV : "
);
print
(
Kernel_traits
::
GmemTiledCopyQKV
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::GmemTiledCopyO : "
);
print
(
Kernel_traits
::
GmemTiledCopyO
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::SmemCopyAtom : "
);
print
(
Kernel_traits
::
SmemCopyAtom
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::SmemCopyAtomTransposed : "
);
print
(
Kernel_traits
::
SmemCopyAtomTransposed
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::MMA_Atom_Arch : "
);
print
(
Kernel_traits
::
MMA_Atom_Arch
);
printf
(
"
\n
"
);
}
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
8efeb7f5
...
...
@@ -18,6 +18,8 @@
#include "dropout.h"
#include "rotary.h"
#include "debug.h"
namespace
flash
{
using
namespace
cute
;
...
...
@@ -41,6 +43,9 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 0
KIN_PRINT("Kernel_traits", print_traits<Kernel_traits>());
#endif
auto
seed_offset
=
at
::
cuda
::
philox
::
unpack
(
params
.
philox_args
);
flash
::
Dropout
dropout
(
std
::
get
<
0
>
(
seed_offset
),
std
::
get
<
1
>
(
seed_offset
),
params
.
p_dropout_in_uint8_t
,
...
...
@@ -55,6 +60,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 0
// const int sum_s_q;
// const int sum_s_k;
// const int actual_seqlen_q;
// const int seqlen_k_cache;
// const int actual_seqlen_k;
KIN_PRINT("binfo.sum_s_q", printf("%d", binfo.sum_s_q))
KIN_PRINT("binfo.sum_s_k", printf("%d", binfo.sum_s_k))
KIN_PRINT("binfo.actual_seqlen_q", printf("%d", binfo.actual_seqlen_q))
KIN_PRINT("binfo.seqlen_k_cache", printf("%d", binfo.seqlen_k_cache))
KIN_PRINT("binfo.actual_seqlen_k", printf("%d", binfo.actual_seqlen_k))
#endif
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
cute
::
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
...
...
@@ -136,10 +153,24 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
#if 1
KIN_PRINT
(
"sK.layout()"
,
print
(
sK
.
layout
()))
KIN_PRINT
(
"gK.layout()"
,
print
(
gK
.
layout
()))
KIN_PRINT
(
"Share_Q_K_smem"
,
printf
(
"%d"
,
Kernel_traits
::
Share_Q_K_smem
))
#endif
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
KIN_PRINT
(
"sV.layout()"
,
print
(
sV
.
layout
()))
KIN_PRINT
(
"sVt.layout()"
,
print
(
sVt
.
layout
()))
KIN_PRINT
(
"sVtNoSwizzle.layout()"
,
print
(
sVtNoSwizzle
.
layout
()))
KIN_PRINT
(
"Share_Q_K_smem"
,
printf
(
"%d"
,
Kernel_traits
::
Share_Q_K_smem
))
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
...
@@ -150,16 +181,30 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
#if 1
KIN_PRINT
(
"tKgK.layout()"
,
print
(
tKgK
.
layout
()))
KIN_PRINT
(
"tKsK.layout()"
,
print
(
tKsK
.
layout
()))
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT
(
"tSrQ.layout()"
,
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
"tSrK.layout()"
,
print
(
tSrK
.
layout
()))
#endif
Tensor
tSgS
=
thr_mma
.
partition_C
(
gP
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
KIN_PRINT
(
"acc_o.layout()"
,
print
(
acc_o
.
layout
()))
#endif
//
// Copy Atom retiling
//
...
...
@@ -168,11 +213,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
KIN_PRINT
(
"smem_thr_copy_Q.print_all()"
,
smem_thr_copy_Q
.
print_all
())
KIN_PRINT
(
"tSsQ.layout()"
,
print
(
tSsQ
.
layout
()))
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
# if 1
KIN_PRINT
(
"tSsK.layout()"
,
print
(
tSsK
.
layout
()))
#endif
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
...
...
@@ -189,6 +241,10 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT
(
"cQ.layout()"
,
print
(
cQ
.
layout
()))
KIN_PRINT
(
"cKV.layout()"
,
print
(
cKV
.
layout
()))
#endif
// Tensor tScQ = thr_mma.partition_A(cQ); // (MMA,MMA_M,MMA_K)
// if (cute::thread0()) {
// print(tScQ.layout()); printf("\n");
...
...
@@ -205,10 +261,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_QKV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_QKV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT
(
"tQcQ.layout()"
,
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
"tKVcKV.layout()"
,
print
(
tKVcKV
.
layout
()))
#endif
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
KIN_PRINT
(
"tQpQ.layout()"
,
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
"tKVpKV.layout()"
,
print
(
tKVpKV
.
layout
()))
#endif
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment