Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ac5e78a6
Commit
ac5e78a6
authored
Feb 09, 2024
by
skrider
Browse files
add print statements for debugging
parent
8efeb7f5
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
92 additions
and
51 deletions
+92
-51
csrc/flash_attn/src/debug.h
csrc/flash_attn/src/debug.h
+36
-25
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+56
-26
No files found.
csrc/flash_attn/src/debug.h
View file @
ac5e78a6
#include <cute/util/debug.hpp>
#include "block_info.h"
#pragma once
#define KIN_PRINT(tag, statement) \
if (
cute::
thread0()) { \
printf("[kin:start:%s]\n", tag); \
if (thread0()) { \
printf("
\n
[kin:start:%s]\n", tag); \
statement; \
printf("\n[kin:end:%s]\n", tag); \
}
#define KIN_PRINT_BOOL(tag, BOOL) \
if (thread0()) { \
printf("\n[kin:start:%s]\n", tag); \
printf("%s", BOOL ? "true" : "false"); \
printf("\n[kin:end:%s]\n", tag); \
}
template
<
typename
Kernel_traits
>
void
__forceinline__
__device__
void
print_traits
()
{
// bool
printf
(
"Kernel_traits::Share_Q_K_smem : %s
\n
"
,
Kernel_traits
::
Share_Q_K_smem
);
printf
(
"Kernel_traits::Is_Q_in_regs : %s
\n
"
,
Kernel_traits
::
Is_Q_in_regs
);
printf
(
"Kernel_traits::Share_Q_K_smem
: %s
\n
"
,
Kernel_traits
::
Share_Q_K_smem
?
"true"
:
"false"
);
printf
(
"Kernel_traits::Is_Q_in_regs
: %s
\n
"
,
Kernel_traits
::
Is_Q_in_regs
?
"true"
:
"false"
);
// int
printf
(
"Kernel_traits::kNWarps : %s
\n
"
,
Kernel_traits
::
kNWarps
);
printf
(
"Kernel_traits::kNThreads : %s
\n
"
,
Kernel_traits
::
kNThreads
);
printf
(
"Kernel_traits::kBlockM : %s
\n
"
,
Kernel_traits
::
kBlockM
);
printf
(
"Kernel_traits::kBlockN : %s
\n
"
,
Kernel_traits
::
kBlockN
);
printf
(
"Kernel_traits::kHeadDim : %s
\n
"
,
Kernel_traits
::
kHeadDim
);
printf
(
"Kernel_traits::kBlockKSmem : %s
\n
"
,
Kernel_traits
::
kBlockKSmem
);
printf
(
"Kernel_traits::kBlockKGmem : %s
\n
"
,
Kernel_traits
::
kBlockKGmem
);
printf
(
"Kernel_traits::kSwizzle : %s
\n
"
,
Kernel_traits
::
kSwizzle
);
printf
(
"Kernel_traits::kSmemQSize : %s
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemKVSize : %s
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemSize : %s
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %s
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
// cute object
printf
(
"Kernel_traits::GmemLayoutAtom : "
);
print
(
Kernel_traits
::
GmemLayoutAtom
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::GmemTiledCopyQKV : "
);
print
(
Kernel_traits
::
GmemTiledCopyQKV
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::GmemTiledCopyO : "
);
print
(
Kernel_traits
::
GmemTiledCopyO
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::SmemCopyAtom : "
);
print
(
Kernel_traits
::
SmemCopyAtom
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::SmemCopyAtomTransposed : "
);
print
(
Kernel_traits
::
SmemCopyAtomTransposed
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::MMA_Atom_Arch : "
);
print
(
Kernel_traits
::
MMA_Atom_Arch
);
printf
(
"
\n
"
);
printf
(
"Kernel_traits::kNWarps : %d
\n
"
,
Kernel_traits
::
kNWarps
);
printf
(
"Kernel_traits::kNThreads : %d
\n
"
,
Kernel_traits
::
kNThreads
);
printf
(
"Kernel_traits::kBlockM : %d
\n
"
,
Kernel_traits
::
kBlockM
);
printf
(
"Kernel_traits::kBlockN : %d
\n
"
,
Kernel_traits
::
kBlockN
);
printf
(
"Kernel_traits::kHeadDim : %d
\n
"
,
Kernel_traits
::
kHeadDim
);
printf
(
"Kernel_traits::kBlockKSmem : %d
\n
"
,
Kernel_traits
::
kBlockKSmem
);
printf
(
"Kernel_traits::kBlockKGmem : %d
\n
"
,
Kernel_traits
::
kBlockKGmem
);
printf
(
"Kernel_traits::kSwizzle : %d
\n
"
,
Kernel_traits
::
kSwizzle
);
printf
(
"Kernel_traits::kSmemQSize : %d
\n
"
,
Kernel_traits
::
kSmemQSize
);
printf
(
"Kernel_traits::kSmemKVSize : %d
\n
"
,
Kernel_traits
::
kSmemKVSize
);
printf
(
"Kernel_traits::kSmemSize : %d
\n
"
,
Kernel_traits
::
kSmemSize
);
printf
(
"Kernel_traits::kGmemElemsPerLoad : %d
\n
"
,
Kernel_traits
::
kGmemElemsPerLoad
);
}
template
<
typename
BlockInfo
>
__forceinline__
__device__
void
print_binfo
(
const
BlockInfo
&
binfo
)
{
printf
(
"binfo.sum_s_q : %d
\n
"
,
binfo
.
sum_s_q
);
printf
(
"binfo.sum_s_k : %d
\n
"
,
binfo
.
sum_s_k
);
printf
(
"binfo.actual_seqlen_q : %d
\n
"
,
binfo
.
actual_seqlen_q
);
printf
(
"binfo.seqlen_k_cache : %d
\n
"
,
binfo
.
seqlen_k_cache
);
printf
(
"binfo.actual_seqlen_k : %d
\n
"
,
binfo
.
actual_seqlen_k
);
}
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
ac5e78a6
...
...
@@ -43,7 +43,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if
0
#if
1
KIN_PRINT
(
"Kernel_traits"
,
print_traits
<
Kernel_traits
>
());
#endif
...
...
@@ -60,17 +60,8 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
const
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 0
// const int sum_s_q;
// const int sum_s_k;
// const int actual_seqlen_q;
// const int seqlen_k_cache;
// const int actual_seqlen_k;
KIN_PRINT("binfo.sum_s_q", printf("%d", binfo.sum_s_q))
KIN_PRINT("binfo.sum_s_k", printf("%d", binfo.sum_s_k))
KIN_PRINT("binfo.actual_seqlen_q", printf("%d", binfo.actual_seqlen_q))
KIN_PRINT("binfo.seqlen_k_cache", printf("%d", binfo.seqlen_k_cache))
KIN_PRINT("binfo.actual_seqlen_k", printf("%d", binfo.actual_seqlen_k))
#if 1
KIN_PRINT
(
"binfo"
,
print_binfo
(
binfo
))
#endif
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
...
...
@@ -153,22 +144,18 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Careful we're using the same smem for sQ and sK | sV if Share_Q_K_smem;
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
(
Kernel_traits
::
Share_Q_K_smem
?
0
:
size
(
sQ
)),
typename
Kernel_traits
::
SmemLayoutKV
{});
#if 1
KIN_PRINT
(
"sK.layout()"
,
print
(
sK
.
layout
()))
KIN_PRINT
(
"gK.layout()"
,
print
(
gK
.
layout
()))
KIN_PRINT
(
"Share_Q_K_smem"
,
printf
(
"%d"
,
Kernel_traits
::
Share_Q_K_smem
))
#endif
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
KIN_PRINT
(
"sV.layout()"
,
print
(
sV
.
layout
()))
KIN_PRINT
(
"sVt.layout()"
,
print
(
sVt
.
layout
()))
KIN_PRINT
(
"sVtNoSwizzle.layout()"
,
print
(
sVtNoSwizzle
.
layout
()))
KIN_PRINT
(
"Share_Q_K_smem"
,
printf
(
"%d"
,
Kernel_traits
::
Share_Q_K_smem
))
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
...
...
@@ -180,7 +167,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
#if 1
KIN_PRINT
(
"tKgK.layout()"
,
print
(
tKgK
.
layout
()))
KIN_PRINT
(
"tKsK.layout()"
,
print
(
tKsK
.
layout
()))
...
...
@@ -191,7 +177,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT
(
"tSrQ.layout()"
,
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
"tSrK.layout()"
,
print
(
tSrK
.
layout
()))
...
...
@@ -200,7 +185,6 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
Tensor
tSgS
=
thr_mma
.
partition_C
(
gP
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
KIN_PRINT
(
"acc_o.layout()"
,
print
(
acc_o
.
layout
()))
#endif
...
...
@@ -211,10 +195,12 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
#if 0
KIN_PRINT("fail", smem_thr_copy_Q.print_all());
#endif
// if (cute::thread0()) {smem_thr_copy_Q.print_all();}
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
KIN_PRINT
(
"smem_thr_copy_Q.print_all()"
,
smem_thr_copy_Q
.
print_all
())
KIN_PRINT
(
"tSsQ.layout()"
,
print
(
tSsQ
.
layout
()))
#endif
// if (cute::thread0()) {print(tSsQ.layout()); printf("\n");}
...
...
@@ -222,7 +208,7 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
#
if 1
#if 1
KIN_PRINT
(
"tSsK.layout()"
,
print
(
tSsK
.
layout
()))
#endif
...
...
@@ -261,15 +247,13 @@ inline __device__ void compute_attn_1rowblock(const Params ¶ms, const int bi
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_QKV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_QKV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT
(
"tQcQ.layout()"
,
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
"tKVcKV.layout()"
,
print
(
tKVcKV
.
layout
()))
#endif
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
KIN_PRINT
(
"tQcQ.layout()"
,
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
"tKVcKV.layout()"
,
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
"tQpQ.layout()"
,
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
"tKVpKV.layout()"
,
print
(
tKVpKV
.
layout
()))
#endif
...
...
@@ -552,6 +536,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
#if 1
KIN_PRINT
(
"Kernel_traits"
,
print_traits
<
Kernel_traits
>
())
KIN_PRINT_BOOL
(
"Is_causal"
,
Is_causal
)
KIN_PRINT_BOOL
(
"Is_local"
,
Is_local
)
KIN_PRINT_BOOL
(
"Has_alibi"
,
Has_alibi
)
KIN_PRINT_BOOL
(
"Is_even_MN"
,
Is_even_MN
)
KIN_PRINT_BOOL
(
"Is_even_K"
,
Is_even_K
)
KIN_PRINT_BOOL
(
"Split"
,
Split
)
KIN_PRINT_BOOL
(
"Append_KV"
,
Append_KV
)
#endif
using
GmemTiledCopyO
=
std
::
conditional_t
<
!
Split
,
...
...
@@ -564,6 +558,9 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// if (threadIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0) { printf("Is_even_MN = %d, is_cumulativ = %d, seqlen_k_cache = %d, actual_seqlen_k = %d\n", Is_even_MN, params.is_seqlens_k_cumulative, binfo.seqlen_k_cache, binfo.actual_seqlen_k); }
// if (threadIdx.x == 0 && blockIdx.y == 1 && blockIdx.z == 0) { printf("params.knew_ptr = %p, seqlen_k_cache + seqlen_knew = %d\n", params.knew_ptr, binfo.seqlen_k_cache + (params.knew_ptr == nullptr ? 0 : params.seqlen_knew)); }
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
#if 1
KIN_PRINT
(
"binfo"
,
print_binfo
(
binfo
))
#endif
const
int
n_blocks_per_split
=
((
params
.
seqlen_k
+
kBlockN
-
1
)
/
kBlockN
+
num_n_splits
-
1
)
/
num_n_splits
;
const
int
n_block_min
=
!
Is_local
...
...
@@ -645,13 +642,19 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
v_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
reinterpret_cast
<
Element
*>
(
smem_
)),
typename
Kernel_traits
::
SmemLayoutQ
{});
Tensor
sK
=
make_tensor
(
sQ
.
data
()
+
size
(
sQ
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sV
=
make_tensor
(
sK
.
data
()
+
size
(
sK
),
typename
Kernel_traits
::
SmemLayoutKV
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
Kernel_traits
::
SmemLayoutVtransposedNoSwizzle
{});
#if 1
KIN_PRINT
(
"sK.layout()"
,
print
(
sK
.
layout
()))
KIN_PRINT
(
"gK.layout()"
,
print
(
gK
.
layout
()))
KIN_PRINT
(
"sV.layout()"
,
print
(
sV
.
layout
()))
KIN_PRINT
(
"sVt.layout()"
,
print
(
sVt
.
layout
()))
KIN_PRINT
(
"sVtNoSwizzle.layout()"
,
print
(
sVtNoSwizzle
.
layout
()))
#endif
typename
Kernel_traits
::
GmemTiledCopyQKV
gmem_tiled_copy_QKV
;
auto
gmem_thr_copy_QKV
=
gmem_tiled_copy_QKV
.
get_thread_slice
(
tidx
);
...
...
@@ -662,14 +665,25 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
tKsK
=
gmem_thr_copy_QKV
.
partition_D
(
sK
);
Tensor
tVgV
=
gmem_thr_copy_QKV
.
partition_S
(
gV
);
// (VCPY, VCPY_N, VCPY_K)
Tensor
tVsV
=
gmem_thr_copy_QKV
.
partition_D
(
sV
);
#if 1
KIN_PRINT
(
"tKgK.layout()"
,
print
(
tKgK
.
layout
()))
KIN_PRINT
(
"tKsK.layout()"
,
print
(
tKsK
.
layout
()))
#endif
typename
Kernel_traits
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
// (MMA,MMA_M,MMA_K)
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
// (MMA,MMA_N,MMA_K)
Tensor
tOrVt
=
thr_mma
.
partition_fragment_B
(
sVtNoSwizzle
);
// (MMA, MMA_K,MMA_N)
#if 1
KIN_PRINT
(
"tSrQ.layout()"
,
print
(
tSrQ
.
layout
()))
KIN_PRINT
(
"tSrK.layout()"
,
print
(
tSrK
.
layout
()))
#endif
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{});
// MMA, MMA_M, MMA_K
#if 1
KIN_PRINT
(
"acc_o.layout()"
,
print
(
acc_o
.
layout
()))
#endif
//
// Copy Atom retiling
...
...
@@ -678,10 +692,16 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
#if 1
KIN_PRINT
(
"tSsQ.layout()"
,
print
(
tSsQ
.
layout
()))
#endif
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtom
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
#if 1
KIN_PRINT
(
"tSsK.layout()"
,
print
(
tSsK
.
layout
()))
#endif
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
typename
Kernel_traits
::
SmemCopyAtomTransposed
{},
tiled_mma
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
...
...
@@ -697,6 +717,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Construct identity layout for sQ and sK
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sQ
),
size
<
1
>
(
sQ
)));
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
cKV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
sK
),
size
<
1
>
(
sK
)));
// (BLK_N,BLK_K) -> (blk_n,blk_k)
#if 1
KIN_PRINT
(
"cQ.layout()"
,
print
(
cQ
.
layout
()))
KIN_PRINT
(
"cKV.layout()"
,
print
(
cKV
.
layout
()))
#endif
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_QKV
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
...
...
@@ -705,6 +729,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
Tensor
tKVpKV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKsK
)));
#if 1
KIN_PRINT
(
"tQcQ.layout()"
,
print
(
tQcQ
.
layout
()))
KIN_PRINT
(
"tKVcKV.layout()"
,
print
(
tKVcKV
.
layout
()))
KIN_PRINT
(
"tQpQ.layout()"
,
print
(
tQpQ
.
layout
()))
KIN_PRINT
(
"tKVpKV.layout()"
,
print
(
tKVpKV
.
layout
()))
#endif
// Set predicates for k bounds
if
(
!
Is_even_K
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment