Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
3f2484ee
Commit
3f2484ee
authored
Feb 11, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
remove print statements
parent
e524c2ca
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
0 additions
and
164 deletions
+0
-164
csrc/flash_attn/src/debug.h
csrc/flash_attn/src/debug.h
+0
-141
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+0
-21
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+0
-2
No files found.
csrc/flash_attn/src/debug.h
deleted
100644 → 0
View file @
e524c2ca
#include <cute/util/debug.hpp>
#include "block_info.h"
#pragma once
#define KIN_PRINT(statement) \
if (thread0()) { \
printf("\n[kin:start:%s]\n", #statement); \
statement; \
printf("\n[kin:end:%s]\n", #statement); \
}
#define KIN_PRINT_BOOL(BOOL) \
if (thread0()) { \
printf("\n[kin:start:%s]\n", #BOOL); \
printf("%s", BOOL ? "true" : "false"); \
printf("\n[kin:end:%s]\n", #BOOL); \
}
__forceinline__
__device__
void
print_qkv_params
(
const
Qkv_params
&
params
)
{
// LLM generated
printf
(
"Qkv_params:
\n
"
);
printf
(
"q_ptr: %p
\n
"
,
params
.
q_ptr
);
printf
(
"k_ptr: %p
\n
"
,
params
.
k_ptr
);
printf
(
"v_ptr: %p
\n
"
,
params
.
v_ptr
);
printf
(
"q_batch_stride: %"
PRId64
"
\n
"
,
params
.
q_batch_stride
);
printf
(
"k_batch_stride: %"
PRId64
"
\n
"
,
params
.
k_batch_stride
);
printf
(
"v_batch_stride: %"
PRId64
"
\n
"
,
params
.
v_batch_stride
);
printf
(
"q_row_stride: %"
PRId64
"
\n
"
,
params
.
q_row_stride
);
printf
(
"k_row_stride: %"
PRId64
"
\n
"
,
params
.
k_row_stride
);
printf
(
"v_row_stride: %"
PRId64
"
\n
"
,
params
.
v_row_stride
);
printf
(
"q_head_stride: %"
PRId64
"
\n
"
,
params
.
q_head_stride
);
printf
(
"k_head_stride: %"
PRId64
"
\n
"
,
params
.
k_head_stride
);
printf
(
"v_head_stride: %"
PRId64
"
\n
"
,
params
.
v_head_stride
);
printf
(
"h: %d
\n
"
,
params
.
h
);
printf
(
"h_k: %d
\n
"
,
params
.
h_k
);
printf
(
"h_h_k_ratio: %d
\n
"
,
params
.
h_h_k_ratio
);
}
__forceinline__
__device__
void
print_flash_fwd_params
(
const
Flash_fwd_params
&
params
)
{
print_qkv_params
(
params
);
// LLM generated
printf
(
"struct Flash_fwd_params:
\n
"
);
printf
(
"o_ptr: %p
\n
"
,
params
.
o_ptr
);
printf
(
"oaccum_ptr: %p
\n
"
,
params
.
oaccum_ptr
);
printf
(
"o_batch_stride: %ld
\n
"
,
params
.
o_batch_stride
);
printf
(
"o_row_stride: %ld
\n
"
,
params
.
o_row_stride
);
printf
(
"o_head_stride: %ld
\n
"
,
params
.
o_head_stride
);
printf
(
"p_ptr: %p
\n
"
,
params
.
p_ptr
);
printf
(
"softmax_lse_ptr: %p
\n
"
,
params
.
softmax_lse_ptr
);
printf
(
"softmax_lseaccum_ptr: %p
\n
"
,
params
.
softmax_lseaccum_ptr
);
printf
(
"b: %d
\n
"
,
params
.
b
);
printf
(
"seqlen_q: %d
\n
"
,
params
.
seqlen_q
);
printf
(
"seqlen_k: %d
\n
"
,
params
.
seqlen_k
);
printf
(
"seqlen_knew: %d
\n
"
,
params
.
seqlen_knew
);
printf
(
"d: %d
\n
"
,
params
.
d
);
printf
(
"seqlen_q_rounded: %d
\n
"
,
params
.
seqlen_q_rounded
);
printf
(
"seqlen_k_rounded: %d
\n
"
,
params
.
seqlen_k_rounded
);
printf
(
"d_rounded: %d
\n
"
,
params
.
d_rounded
);
printf
(
"rotary_dim: %d
\n
"
,
params
.
rotary_dim
);
printf
(
"scale_softmax: %f
\n
"
,
params
.
scale_softmax
);
printf
(
"scale_softmax_log2: %f
\n
"
,
params
.
scale_softmax_log2
);
printf
(
"cu_seqlens_q: %p
\n
"
,
params
.
cu_seqlens_q
);
printf
(
"cu_seqlens_k: %p
\n
"
,
params
.
cu_seqlens_k
);
printf
(
"seqused_k: %p
\n
"
,
params
.
seqused_k
);
printf
(
"blockmask: %p
\n
"
,
params
.
blockmask
);
printf
(
"knew_ptr: %p
\n
"
,
params
.
knew_ptr
);
printf
(
"vnew_ptr: %p
\n
"
,
params
.
vnew_ptr
);
printf
(
"knew_batch_stride: %ld
\n
"
,
params
.
knew_batch_stride
);
printf
(
"vnew_batch_stride: %ld
\n
"
,
params
.
vnew_batch_stride
);
printf
(
"knew_row_stride: %ld
\n
"
,
params
.
knew_row_stride
);
printf
(
"vnew_row_stride: %ld
\n
"
,
params
.
vnew_row_stride
);
printf
(
"knew_head_stride: %ld
\n
"
,
params
.
knew_head_stride
);
printf
(
"vnew_head_stride: %ld
\n
"
,
params
.
vnew_head_stride
);
printf
(
"rotary_cos_ptr: %p
\n
"
,
params
.
rotary_cos_ptr
);
printf
(
"rotary_sin_ptr: %p
\n
"
,
params
.
rotary_sin_ptr
);
printf
(
"cache_batch_idx: %p
\n
"
,
params
.
cache_batch_idx
);
printf
(
"block_table: %p
\n
"
,
params
.
block_table
);
printf
(
"block_table_batch_stride: %ld
\n
"
,
params
.
block_table_batch_stride
);
printf
(
"page_block_size: %d
\n
"
,
params
.
page_block_size
);
printf
(
"p_dropout: %f
\n
"
,
params
.
p_dropout
);
printf
(
"p_dropout_in_uint8_t: %u
\n
"
,
params
.
p_dropout_in_uint8_t
);
printf
(
"rp_dropout: %f
\n
"
,
params
.
rp_dropout
);
printf
(
"scale_softmax_rp_dropout: %f
\n
"
,
params
.
scale_softmax_rp_dropout
);
printf
(
"window_size_left: %d
\n
"
,
params
.
window_size_left
);
printf
(
"window_size_right: %d
\n
"
,
params
.
window_size_right
);
printf
(
"philox_args: %p
\n
"
,
&
(
params
.
philox_args
));
printf
(
"rng_state: %p
\n
"
,
params
.
rng_state
);
printf
(
"is_bf16: %d
\n
"
,
params
.
is_bf16
);
printf
(
"is_causal: %d
\n
"
,
params
.
is_causal
);
printf
(
"is_seqlens_k_cumulative: %d
\n
"
,
params
.
is_seqlens_k_cumulative
);
printf
(
"is_rotary_interleaved: %d
\n
"
,
params
.
is_rotary_interleaved
);
printf
(
"num_splits: %d
\n
"
,
params
.
num_splits
);
printf
(
"alibi_slopes_ptr: %p
\n
"
,
params
.
alibi_slopes_ptr
);
printf
(
"alibi_slopes_batch_stride: %ld
\n
"
,
params
.
alibi_slopes_batch_stride
);
}
template
<
typename
Kernel_traits
>
__forceinline__
__device__
void
print_traits
()
{
// bool
printf
(
"Kernel_traits::Share_Q_K_smem : %s
\n
"
,
Kernel_traits
::
Share_Q_K_smem
?
"true"
:
"false"
);
printf
(
"Kernel_traits::Is_Q_in_regs : %s
\n
"
,
Kernel_traits
::
Is_Q_in_regs
?
"true"
:
"false"
);
// int
printf
(
"Kernel_traits::kNWarps : %d
\n
"
,
Kernel_traits
::
kNWarps
);
printf
(
"Kernel_traits::kNThreads : %d
\n
"
,
Kernel_traits
::
kNThreads
);
printf
(
"Kernel_traits::kBlockM : %d
\n
"
,
Kernel_traits
::
kBlockM
);
printf
(
"Kernel_traits::kBlockN : %d
\n
"
,
Kernel_traits
::
kBlockN
);
printf
(
"Kernel_traits::kHeadDim : %d
\n
"
,
Kernel_traits
::
kHeadDim
);
printf
(
"Kernel_traits::kBlockKSmem : %d
\n
"
,
Kernel_traits
::
kBlockKSmem
);
printf
(
"Kernel_traits::kBlockKGmem : %d
\n
"
,
Kernel_traits
::
kBlockKGmem
);
printf
(
"Kernel_traits::kSwizzle : %d
\n
"
,
Kernel_traits
::
kSwizzle
);
printf
(
"Kernel_traits::kSmemQSize : %d
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemKVSize : %d
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemSize : %d
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kGmemRowsPerThread: %d
\n
"
,
Kernel_traits
::
kGmemRowsPerThread
);
printf
(
"Kernel_traits::kGmemThreadsPerRow: %d
\n
"
,
Kernel_traits
::
kGmemThreadsPerRow
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %d
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
// cute object
printf
(
"Kernel_traits::GmemLayoutAtom : "
);
cute
::
print
(
Kernel_traits
::
GmemLayoutAtom
());
printf
(
"
\n
"
);
printf
(
"Kernel_traits::GmemTiledCopyQKV :
\n
"
);
cute
::
print
(
Kernel_traits
::
GmemTiledCopyQKV
());
printf
(
"
\n
"
);
}
template
<
typename
BlockInfo
>
__forceinline__
__device__
void
print_binfo
(
const
BlockInfo
&
binfo
)
{
printf
(
"binfo.sum_s_q : %d
\n
"
,
binfo
.
sum_s_q
);
printf
(
"binfo.sum_s_k : %d
\n
"
,
binfo
.
sum_s_k
);
printf
(
"binfo.actual_seqlen_q : %d
\n
"
,
binfo
.
actual_seqlen_q
);
printf
(
"binfo.seqlen_k_cache : %d
\n
"
,
binfo
.
seqlen_k_cache
);
printf
(
"binfo.actual_seqlen_k : %d
\n
"
,
binfo
.
actual_seqlen_k
);
}
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
3f2484ee
...
...
@@ -18,8 +18,6 @@
#include "dropout.h"
#include "rotary.h"
#include "debug.h"
namespace
flash
{
using
namespace
cute
;
...
...
@@ -491,10 +489,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 1
KIN_PRINT
(
print_traits
<
Kernel_traits
>
())
KIN_PRINT
(
print_flash_fwd_params
(
params
))
#endif
using
GmemTiledCopyO
=
std
::
conditional_t
<
!
Split
,
...
...
@@ -619,21 +613,6 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
tVgV
.
data
()
=
gV
.
data
()
+
flash
::
init_thread_kv_page_slice_offset
<
Kernel_traits
>
(
tidx
,
n_block_max
,
params
.
page_block_size
,
block_table
,
params
.
v_batch_stride
,
params
.
v_row_stride
);
}
#if 1
// KIN_PRINT([&]() {
// for (int i = 0; i < n_block_max; i++) {
// printf("%d ", block_table[i]);
// }
// }())
// if (tidx == 8) fill(tKgK, 1.f * tidx);
// if (thread0()) {
// gK.data() = tKgK.data();
// }
// KIN_PRINT(print_tensor(tKgK))
// KIN_PRINT(print_tensor(gK))
// KIN_PRINT(print_tensor(tKgK__shadow))
// KIN_PRINT(print_tensor(gK__shadow))
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
...
...
csrc/flash_attn/src/utils.h
View file @
3f2484ee
...
...
@@ -4,8 +4,6 @@
#pragma once
#include "debug.h"
#include <assert.h>
#include <stdint.h>
#include <stdlib.h>
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment