Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2033d805
Commit
2033d805
authored
Feb 03, 2026
by
zhanghj2
Browse files
支持纯bf16
parent
58b43d4a
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
867 additions
and
68 deletions
+867
-68
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+1
-1
csrc/sm90/decode/dense/config.h
csrc/sm90/decode/dense/config.h
+1
-1
csrc/sm90/decode/dense/splitkv_mla.cuh
csrc/sm90/decode/dense/splitkv_mla.cuh
+591
-11
csrc/sm90/decode/dense/traits.h
csrc/sm90/decode/dense/traits.h
+95
-54
csrc/utils.h
csrc/utils.h
+177
-0
tests/test_flash_mla_dense_decoding.py
tests/test_flash_mla_dense_decoding.py
+2
-1
No files found.
csrc/api/dense_decode.h
View file @
2033d805
...
...
@@ -75,7 +75,7 @@ dense_attn_decode_interface(
const
int
num_heads
=
num_heads_k
;
q
=
q
.
view
({
batch_size
,
seqlen_q_ori
,
num_heads_k
,
num_q_heads_per_hk
,
head_size_k
}).
transpose
(
2
,
3
)
.
reshape
({
batch_size
,
q_seq_per_hk
,
num_heads
,
head_size_k
});
int
num_sm_parts
=
std
::
max
(
arch
.
num_sms
/
num_heads_k
/
cutlass
::
ceil_div
(
seqlen_q_ori
*
num_heads_q
/
num_heads_k
,
6
4
),
1
);
int
num_sm_parts
=
std
::
max
(
arch
.
num_sms
/
num_heads_k
/
cutlass
::
ceil_div
(
seqlen_q_ori
*
num_heads_q
/
num_heads_k
,
1
6
),
1
);
KU_CHECK_SHAPE
(
q
,
batch_size
,
q_seq_per_hk
,
num_heads
,
head_size_k
);
KU_CHECK_SHAPE
(
kcache
,
num_blocks
,
page_block_size
,
num_heads_k
,
head_size_k
);
...
...
csrc/sm90/decode/dense/config.h
View file @
2033d805
...
...
@@ -2,7 +2,7 @@
namespace
Config
{
static
constexpr
int
BLOCK_SIZE_M
=
6
4
;
static
constexpr
int
BLOCK_SIZE_M
=
1
6
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
64
;
static
constexpr
int
HEAD_DIM_K
=
576
;
...
...
csrc/sm90/decode/dense/splitkv_mla.cuh
View file @
2033d805
...
...
@@ -5,22 +5,598 @@
#include "params.h"
#include "config.h"
#include "traits.h"
#include "softmax.h"
using
namespace
cute
;
namespace
sm90
{
// Here we use MAX_INIT_VAL_SM to initialize sM, and MAX_INIT_VAL for masking
// The reason is that, we need to calculate new_max = max(sM(row_idx), cur_max*scale_softmax_log2)
// so we must guarantee that MAX_INIT_VAL*scale_softmax_log2 < MAX_INIT_VAL_SM
static
constexpr
float
MAX_INIT_VAL_SM
=
-
1e30
f
;
static
constexpr
float
MAX_INIT_VAL
=
-
1e33
f
;
template
<
typename
T
>
__device__
void
compute_attn_1rowblock_splitkv_mla_gfx936
(
const
DenseAttnDecodeParams
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
)
{
extern
__shared__
char
shared_memory
[];
using
SharedMemoryPlan
=
typename
T
::
SharedMemoryPlan
;
SharedMemoryPlan
&
plan
=
*
reinterpret_cast
<
SharedMemoryPlan
*>
(
shared_memory
);
const
int
tidx
=
threadIdx
.
x
;
constexpr
int
kBlockM
=
T
::
BLOCK_SIZE_M
;
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
constexpr
int
kHeadDim
=
T
::
HEAD_DIM_K
;
constexpr
int
kHeadDimV
=
T
::
HEAD_DIM_V
;
using
Element
=
T
::
InputT
;
using
index_t
=
int64_t
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
const
index_t
row_offset_k
=
(
bidh
)
*
params
.
k_head_stride
;
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_q
.
data
()),
typename
T
::
SmemLayoutQ
{});
Tensor
sV
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
typename
T
::
SmemLayoutV
{});
Tensor
sK
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
typename
T
::
SmemLayoutK
{});
Tensor
sP
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_p
.
data
()),
typename
T
::
SmemLayoutP
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
T
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
T
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sRow_max_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_max
.
data
()),
typename
T
::
SmemLayoutRow
{});
Tensor
sRow_sum_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_sum
.
data
()),
typename
T
::
SmemLayoutRow
{});
typename
T
::
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
typename
T
::
TiledMma_O
tiled_mma_o
;
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
#if 1
typename
T
::
GmemTiledCopyQ
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gQ
),
size
<
1
>
(
gQ
)));
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQgQ
)));
if
(
tidx
<
128
)
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
true
,
false
>
(
gmem_tiled_copy_Q
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
params
.
q_seq_per_hk
-
m_block
*
kBlockM
);
__syncthreads
();
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
#else
auto
gmem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
gQ
);
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gQ
),
size
<
1
>
(
gQ
)));
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tSgQ
)));
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
true
>
(
gmem_tiled_copy_Q
,
tSgQ
,
tSrQ
,
tQcQ
,
tQpQ
,
params
.
q_seq_per_hk
-
m_block
*
kBlockM
);
__syncthreads
();
#endif
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(tSrQ(0)), float(tSrQ(1)));
// }
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSgK
=
smem_thr_copy_K
.
partition_S
(
gK
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
Copy_Atom
<
GFX928_DS_READ_DS_M32x16_B16
,
Element
>
{},
tiled_mma_o
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
Tensor
tOrVt
=
thr_mma_o
.
partition_fragment_B
(
sVtNoSwizzle
);
constexpr
int
n_masking_steps
=
!
T
::
Is_causal
?
1
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
;
const
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
int
n_block
=
n_block_max
-
1
;
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
clear
(
acc_o
);
flash
::
Softmax
<
size
<
1
>
(
acc_o
)
>
softmax
;
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// Tensor tOrVt_copy_view = smem_thr_copy_V.retile_D(tOrVt);
// Tensor tSrK_copy_view = smem_thr_copy_K.retile_D(tSrK);
// Tensor tKcK_smem = smem_thr_copy_K.partition_S(cK);
Tensor
tKpK_smem
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tSgK
)));
Tensor
tSrK_smem
=
thr_mma
.
partition_fragment_B
(
gK
);
constexpr
static
int
k0_lds_loops
=
15
;
constexpr
static
int
k0_loops
=
size
<
2
>
(
tSrK_smem
);
constexpr
static
int
k1_loops
=
size
<
2
>
(
tOrVt
);
constexpr
static
int
STAGE
=
15
;
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
&&
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
clear
(
acc_s
);
// asm volatile("s_barrier\n\t");
// 这个也做过循环2类似的修改,但是性能不如现在的好,所以保持不变
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
// cur_block_table = block_table[n_block - 1];
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
#pragma unroll
for
(
int
i
=
0
;
i
<
STAGE
;
i
++
)
{
flash
::
lds_direct_copy
<
false
,
true
>
(
gK
,
sK
,
i
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
}
constexpr
static
int
BUFFER_SIZE
=
3
;
uint128_t
buffer
[
BUFFER_SIZE
];
flash
::
buffer_load_copy
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
15
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
1
],
16
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
2
],
17
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
// if constexpr (STAGE == 15)
{
int
k_idx
=
0
;
// k_idx++;
asm
volatile
(
"s_waitcnt vmcnt(14 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(13 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(12 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(11 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(10 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(9+ 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(8+ 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(7+ 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(6+ 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(5 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(4 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(3 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(2 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(1 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(0 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
\n\t
"
);
flash
::
buffer_to_tensor
(
buffer
[
0
],
tSrK_smem
,
15
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
15
),
tSrK_smem
(
_
,
_
,
15
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
\n\t
"
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
16
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
\n\t
"
);
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
// asm volatile("s_barrier\n\t");
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", acc_s(0), acc_s(1));
// }
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
T
::
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
n_block
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
n_block
*
kBlockN
-
(
params
.
q_seq_per_hk
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
q_head_per_hk
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
// We have key_padding_mask so we'll need to Check_inf
if
constexpr
(
n_masking_steps
==
1
)
{
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
);
}
else
{
const
bool
is_first_masking_step
=
masking_step
==
0
;
is_first_masking_step
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
)
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
);
}
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_s
);
// Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP);
Tensor
tOrP
=
flash
::
convert_layout_acc_Aregs_dense
(
tiled_mma
,
tiled_mma_o
,
rP
,
sP
);
__syncthreads
();
flash
::
lds_direct_copy
<
false
,
true
>
(
gK
,
sK
,
15
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
// asm_ds_write(buffer[0], tVsV, 15);
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
#pragma unroll
for
(
int
i
=
0
;
i
<
k1_loops
;
i
++
)
{
cute
::
copy
(
smem_tiled_copy_V
,
tOsVt
(
_
,
_
,
i
),
tOrVt_copy_view
(
_
,
_
,
i
));
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
i
),
tOrVt
(
_
,
_
,
i
),
acc_o
);
}
// asm volatile("s_barrier\n\t");
}
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
clear
(
acc_s
);
// asm volatile("s_barrier\n\t");
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
// cur_block_table = block_table[n_block - 1];
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
#pragma unroll
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
flash
::
lds_direct_copy
<
true
,
true
>
(
gK
,
sK
,
i
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
}
constexpr
static
int
BUFFER_SIZE
=
2
;
uint128_t
buffer
[
BUFFER_SIZE
];
// buffer_load_copy<true, true, true, true>(gK, buffer[0], 15, params.k_row_stride, offset_k, seqlen_k - n_block * kBlockN);
flash
::
buffer_load_copy
<
true
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
16
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
true
,
true
,
true
,
true
>
(
gK
,
buffer
[
1
],
17
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
// if constexpr (STAGE == 15)
{
int
k_idx
=
0
;
// k_idx++;
asm
volatile
(
"s_waitcnt vmcnt(14 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(13 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(12 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(11 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(10 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(9+ 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(8+ 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(7+ 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(6+ 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(5 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(4 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
asm
volatile
(
"s_waitcnt vmcnt(3 + 3)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(2 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(1 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt vmcnt(0 + 3)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 2)
\n\t
s_barrier
\n\t
"
);
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
__ds_read_m32x16_row_col
<
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
3
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
3
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
\n\t
"
);
flash
::
buffer_to_tensor
(
buffer
[
0
],
tSrK_smem
,
16
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
\n\t
"
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
// asm volatile("s_barrier\n\t");
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
// We have key_padding_mask so we'll need to Check_inf
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*//*Is_local=*/
false
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_s
);
// Tensor tOrP = convert_layout_acc_Aregs(tiled_mma_o, rP, sP);
Tensor
tOrP
=
flash
::
convert_layout_acc_Aregs_dense
(
tiled_mma
,
tiled_mma_o
,
rP
,
sP
);
flash
::
__ds_read_m32x16_row_col
<
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
0
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
2
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
0
),
tOrVt
(
_
,
_
,
0
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
1
),
tOrVt
(
_
,
_
,
1
),
acc_o
);
flash
::
__ds_read_m32x16_row_col
<
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
0
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
1
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col
<
2
,
3
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
2
),
tOrVt
(
_
,
_
,
2
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
3
),
tOrVt
(
_
,
_
,
3
),
acc_o
);
// asm volatile("s_barrier\n\t");
}
using
ElementAccum
=
float
;
if
(
NoSplit
)
{
using
ElementO
=
Element
;
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
scale_softmax
);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lse
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
Tensor
rO
=
flash
::
convert_type
<
ElementO
>
(
acc_o
);
if
(
get
<
1
>
(
taccOcO
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
mi
,
0
));
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
acc_o
);
n
++
)
{
col
=
(
tidx
%
64
/
16
)
+
warpid
*
32
+
n
*
128
;
for
(
int
ei
=
0
;
ei
<
8
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
rO
(
ei
,
m
,
n
);
col
+=
4
;
}
}
}
}
}
}
else
{
using
ElementO
=
float
;
int
split_idx
=
params
.
num_splits_ptr
[
bidb
]
+
n_split_idx
;
constexpr
bool
Split
=
true
;
const
index_t
row_offset_oaccum
=
((
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
)
*
T
::
HEAD_DIM_V
;
// (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const
index_t
row_offset_lseaccum
=
(
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_oaccum
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lseaccum
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
scale_softmax
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
if
(
get
<
1
>
(
taccOcO
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
mi
,
0
));
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
acc_o
);
n
++
)
{
col
=
(
tidx
%
64
/
16
)
+
warpid
*
32
+
n
*
128
;
for
(
int
ei
=
0
;
ei
<
8
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
acc_o
(
ei
,
m
,
n
);
col
+=
4
;
}
}
}
}
}
}
}
template
<
typename
T
>
__global__
void
__launch_bounds__
(
T
::
NUM_THREADS
,
1
)
flash_fwd_splitkv_mla_kernel
(
const
DenseAttnDecodeParams
params
)
{
const
int
m_block
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
DecodingSchedMeta
sched_meta
=
params
.
tile_scheduler_metadata_ptr
[
partition_idx
];
if
(
sched_meta
.
begin_req_idx
>=
params
.
b
)
return
;
for
(
int
batch_idx
=
sched_meta
.
begin_req_idx
;
batch_idx
<=
sched_meta
.
end_req_idx
;
++
batch_idx
)
{
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
const
int
n_split_idx
=
batch_idx
==
sched_meta
.
begin_req_idx
?
sched_meta
.
begin_split_idx
:
0
;
int
seqlen_k
=
__ldg
(
params
.
seqlens_k_ptr
+
batch_idx
);
const
int
start_block_idx
=
batch_idx
==
sched_meta
.
begin_req_idx
?
sched_meta
.
begin_block_idx
:
0
;
int
end_block_idx
=
batch_idx
==
sched_meta
.
end_req_idx
?
sched_meta
.
end_block_idx
:
cute
::
ceil_div
(
seqlen_k
,
kBlockN
);
const
bool
is_no_split
=
batch_idx
==
sched_meta
.
begin_req_idx
?
!
sched_meta
.
is_first_req_splitted
:
(
batch_idx
==
sched_meta
.
end_req_idx
?
!
sched_meta
.
is_last_req_splitted
:
true
);
if
(
batch_idx
>
sched_meta
.
begin_req_idx
)
{
__syncthreads
();
}
compute_attn_1rowblock_splitkv_mla_gfx936
<
T
>
(
params
,
batch_idx
,
bidh
,
m_block
,
n_split_idx
,
seqlen_k
,
start_block_idx
,
end_block_idx
,
is_no_split
);
}
}
...
...
@@ -29,15 +605,19 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) {
FLASH_ASSERT
(
params
.
d
==
Config
::
HEAD_DIM_K
);
FLASH_ASSERT
(
params
.
d_v
==
Config
::
HEAD_DIM_V
);
using
T
=
Traits
<
InputT
>
;
auto
shape_Q
=
make_shape
(
params
.
q_seq_per_hk
,
params
.
d
,
params
.
h_k
,
params
.
b
);
auto
mla_kernel
=
&
flash_fwd_splitkv_mla_kernel
<
T
>
;
constexpr
size_t
smem_size
=
sizeof
(
typename
T
::
SharedMemoryPlan
);
constexpr
size_t
smem_size
=
65536
;
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
const
int
num_m_block
=
cute
::
ceil_div
(
params
.
q_seq_per_hk
,
T
::
BLOCK_SIZE_M
);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
using
T
=
Traits
<
InputT
,
Is_causal
>
;
const
int
num_m_block
=
cute
::
ceil_div
(
params
.
q_seq_per_hk
,
T
::
BLOCK_SIZE_M
);
auto
mla_kernel
=
&
flash_fwd_splitkv_mla_kernel
<
T
>
;
mla_kernel
<<<
dim3
(
num_m_block
,
params
.
h_k
,
params
.
num_sm_parts
),
T
::
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
// cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(T::NUM_THREADS, 1, 1),
...
...
csrc/sm90/decode/dense/traits.h
View file @
2033d805
...
...
@@ -7,13 +7,12 @@
#include "config.h"
using
TMABarrier
=
cutlass
::
arch
::
ClusterTransactionBarrier
;
using
namespace
cute
;
template
<
typename
InputT_
>
template
<
typename
InputT_
,
bool
Is_causal_
>
struct
Traits
{
using
InputT
=
InputT_
;
static
constexpr
bool
Is_causal
=
Is_causal_
;
static
constexpr
int
BLOCK_SIZE_M
=
Config
::
BLOCK_SIZE_M
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
Config
::
PAGE_BLOCK_SIZE
;
static
constexpr
int
HEAD_DIM_K
=
Config
::
HEAD_DIM_K
;
...
...
@@ -23,63 +22,105 @@ struct Traits {
static_assert
(
std
::
is_same_v
<
InputT
,
cutlass
::
bfloat16_t
>
||
std
::
is_same_v
<
InputT
,
cutlass
::
half_t
>
);
using
TiledMMA_QK_sQ
=
decltype
(
make_tiled_mma
(
GMMA
::
ss_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>
,
Int
<
HEAD_DIM_K
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
K
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
TiledMMA_QK_rQ
=
decltype
(
make_tiled_mma
(
GMMA
::
rs_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>
,
Int
<
HEAD_DIM_K
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
K
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
TiledMMA_PV_LocalP
=
decltype
(
make_tiled_mma
(
GMMA
::
rs_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
HEAD_DIM_V
/
2
>
,
Int
<
PAGE_BLOCK_SIZE
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
MN
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
TiledMMA_PV_RemoteP
=
decltype
(
make_tiled_mma
(
GMMA
::
ss_op_selector
<
InputT
,
InputT
,
float
,
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
HEAD_DIM_V
/
2
>
,
Int
<
PAGE_BLOCK_SIZE
>>
,
GMMA
::
Major
::
K
,
GMMA
::
Major
::
MN
>
(),
Layout
<
Shape
<
_1
,
_1
,
_1
>>
{}
));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
InputT
>
{},
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
HEAD_DIM_K
>>
{}
));
static
constexpr
int
kBlockM
=
BLOCK_SIZE_M
;
static
constexpr
int
kBlockN
=
PAGE_BLOCK_SIZE
;
static
constexpr
int
kHeadDim
=
HEAD_DIM_K
;
static
constexpr
int
kHeadDimV
=
HEAD_DIM_V
;
static
constexpr
int
kNWarps
=
4
;
using
Element
=
InputT
;
using
elem_type
=
Element
;
using
ElementAccum
=
float
;
using
SmemLayoutRow
=
Layout
<
Shape
<
_128
>
,
Stride
<
_1
>>
;
using
SmemLayoutAtomK
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
32
>>
,
Stride
<
Int
<
32
>
,
_1
>>
{}));
using
SmemLayoutK
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
InputT
>
{},
Shape
<
Int
<
PAGE_BLOCK_SIZE
>
,
Int
<
HEAD_DIM_K
>>
{}
));
using
SmemLayoutV
=
decltype
(
composition
(
SmemLayoutK
{},
make_layout
(
Shape
<
Int
<
HEAD_DIM_V
>
,
Int
<
PAGE_BLOCK_SIZE
>>
{},
GenRowMajor
{})
));
// A transposed version of SmemLayoutK
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
16
*
32
>>
{}));
using
SmemLayoutK_place_holder
=
decltype
(
tile_to_shape
(
SmemLayoutAtomK
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
15
*
32
>>
{}));
using
SmemLayoutAtomV
=
SmemLayoutAtomK
;
using
SmemLayoutV
=
decltype
(
tile_to_shape
(
SmemLayoutAtomV
{},
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{}));
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
4
*
16
*
16
>>
,
Stride
<
Int
<
1
>>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
4
*
16
*
16
>>
{}));
using
SmemLayoutVtransposed
=
decltype
(
composition
(
SmemLayoutV
{},
make_layout
(
Shape
<
Int
<
kHeadDimV
>
,
Int
<
kBlockN
>>
{},
GenRowMajor
{})));
using
SmemLayoutVtransposedNoSwizzle
=
decltype
(
get_nonswizzle_portion
(
SmemLayoutVtransposed
{}));
using
SmemLayoutAtomQ
=
decltype
(
composition
(
Swizzle
<
3
,
3
,
3
>
{},
Layout
<
Shape
<
Int
<
8
>
,
Int
<
64
>>
,
Stride
<
Int
<
64
>
,
_1
>>
{}));
using
SmemLayoutQ
=
decltype
(
tile_to_shape
(
SmemLayoutAtomQ
{},
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{}));
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
// #if defined(__gfx936__) || defined(__gfx938__)
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NT
>
>
;
using
TiledMma
=
TiledMMA
<
MMA_Atom_Arch
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
// #elif defined(__gfx928__)
// using MMA_Atom_Arch = std::conditional_t<
// std::is_same_v<elem_type, cutlass::half_t>,
// MMA_Atom<GFX928_16x16x32_F32F16F16F32_NT>,
// MMA_Atom<GFX928_16x16x32_F32BF16BF16F32_NT>
// >;
// using TiledMma = TiledMMA<
// MMA_Atom_Arch,
// Layout<Shape<_1, Int<kNWarps>, _1>>, // 1x4x1 or 1x8x1 thread group
// ValLayoutMNK>;
// #endif
using
MMA_Atom_Arch_16x32
=
std
::
conditional_t
<
std
::
is_same_v
<
elem_type
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x32x16_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x32x16_F32BF16BF16F32_NT
>
>
;
using
TiledMma_O
=
TiledMMA
<
MMA_Atom_Arch_16x32
,
Layout
<
Shape
<
_1
,
Int
<
kNWarps
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
using
GmemLayoutAtomQ
=
Layout
<
Shape
<
_32
,
_8
>
,
Stride
<
_8
,
_1
>>
;
using
GmemTiledCopyQ
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomQ
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
using
SmemLayoutP0
=
decltype
(
tile_to_shape
(
GMMA
::
Layout_K_SW128_Atom
<
InputT
>
{},
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>>
{}
));
using
rP0Layout
=
decltype
(
layout
(
partition_fragment_C
(
TiledMMA_QK_sQ
{},
Shape
<
Int
<
BLOCK_SIZE_M
>
,
Int
<
PAGE_BLOCK_SIZE
>>
{}
)));
struct
SharedMemoryPlan
{
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutQ
>>
smem_sQ
;
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutK
>>
smem_sK0
;
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutK
>>
smem_sK1
;
cute
::
array_aligned
<
InputT
,
cosize_v
<
SmemLayoutP0
>>
smem_sP0
;
cute
::
array_aligned
<
float
,
BLOCK_SIZE_M
>
smem_sM
;
cute
::
array_aligned
<
float
,
2
*
BLOCK_SIZE_M
>
sL_reduction_wksp
;
cute
::
array_aligned
<
float
,
BLOCK_SIZE_M
>
smem_sScale0
;
cute
::
array_aligned
<
float
,
BLOCK_SIZE_M
>
smem_sScale1
;
TMABarrier
barriers_K0
[
HEAD_DIM_K
/
64
];
TMABarrier
barriers_K1
[
HEAD_DIM_K
/
64
];
TMABarrier
barrier_Q
;
union
{
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutV
>>
smem_v
;
// Double buffer
};
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutK_place_holder
>>
smem_temp
;
// Double buffer
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutP
>>
smem_p
;
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_sum
;
cute
::
array_aligned
<
ElementAccum
,
cute
::
cosize_v
<
SmemLayoutRow
>>
smem_row_max
;
};
struct
{
cute
::
array_aligned
<
Element
,
cute
::
cosize_v
<
SmemLayoutQ
>>
smem_q
;
};
};
};
};
...
...
csrc/utils.h
View file @
2033d805
...
...
@@ -88,6 +88,18 @@ struct RingBufferState {
}
};
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
[&] { \
if (COND) { \
constexpr static bool CONST_NAME = true; \
return __VA_ARGS__(); \
} else { \
constexpr static bool CONST_NAME = false; \
return __VA_ARGS__(); \
} \
}()
namespace
flash
{
using
namespace
cute
;
...
...
@@ -559,5 +571,170 @@ lds_direct_copy_for_prefill_sparse_mla(
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
mma_layout
=
false
,
bool
use_asm
=
false
,
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_load_copy
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
uint128_t
&
dst
,
int
k_idx_
,
const
int
row_stride
,
int
offset_k
,
const
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
constexpr
int
elements_per_thread
=
8
;
if
constexpr
(
mma_layout
)
{
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
row
=
tidx
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
constexpr
(
use_asm
)
{
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
}
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
dst
=
*
reinterpret_cast
<
uint128_t
*>
(
&
res
);
}
}
else
{
uint32x4_t
global_addr
=
{
0
};
*
(
uint64_t
*
)
&
global_addr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// global_addr[1] += 0x41000000; // 62 bit: cache swizzle; 48~61: Stride
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
row
=
tidx
/
4
;
int
col
=
lane
%
4
;
int
row_offset
=
row
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
32
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
constexpr
(
use_asm
)
{
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
}
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
dst
=
*
reinterpret_cast
<
uint128_t
*>
(
&
res
);
}
}
}
template
<
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
buffer_to_tensor
(
const
uint128_t
&
src
,
Tensor
<
SrcEngine
,
SrcLayout
>
&
dst
,
int
k_idx
)
{
uint128_t
*
d
=
reinterpret_cast
<
uint128_t
*>
(
&
dst
(
0
,
0
,
k_idx
));
d
[
0
]
=
src
;
}
template
<
class
TiledMma
,
class
TiledMma_O
,
typename
Engine0
,
typename
Layout0
,
typename
Engine1
,
typename
Layout1
>
__forceinline__
__device__
auto
convert_layout_acc_Aregs_dense
(
const
TiledMma
&
tiled_mma
,
const
TiledMma_O
&
tiled_mma_o
,
Tensor
<
Engine0
,
Layout0
>
const
&
tOrP
,
Tensor
<
Engine1
,
Layout1
>
const
&
sAcc
)
{
using
Value_type
=
typename
Engine0
::
value_type
;
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
// __fp16 *smem_ptr =
// sAcc((tid % 16 ) * 4 + (tid / 16) + warp_id * 16 * 16) = tOrP(0, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 16 * 4 + warp_id * 16 * 16) = tOrP(1, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 2 * 16 * 4 + warp_id * 16 * 16) = tOrP(2, 0, 0);
// sAcc((tid % 16 ) * 4 + (tid / 16) + 3 * 16 * 4 + warp_id * 16 * 16) = tOrP(3, 0, 0);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
0
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
1
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
1
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
2
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
2
,
0
,
0
);
sAcc
((
tid
%
16
)
*
8
+
(
tid
/
16
)
+
3
*
16
*
8
+
(
warp_id
%
2
)
*
4
+
(
warp_id
/
2
)
*
16
*
32
)
=
tOrP
(
3
,
0
,
0
);
__syncthreads
();
using
SmemLayoutAtomP
=
Layout
<
Shape
<
Int
<
16
>
,
Int
<
64
>>
,
Stride
<
Int
<
64
>
,
_1
>>
;
using
SmemLayoutP
=
decltype
(
tile_to_shape
(
SmemLayoutAtomP
{},
Shape
<
Int
<
16
>
,
Int
<
64
>>
{}));
Tensor
sP_tmp
=
make_tensor
(
sAcc
.
data
(),
SmemLayoutP
{});
auto
thr_mma
=
tiled_mma_o
.
get_thread_slice
(
tid
);
Tensor
tSrACC
=
thr_mma
.
partition_fragment_A
(
sP_tmp
);
tSrACC
(
0
,
0
,
0
)
=
sAcc
(
tid
*
8
+
0
);
tSrACC
(
1
,
0
,
0
)
=
sAcc
(
tid
*
8
+
1
);
tSrACC
(
2
,
0
,
0
)
=
sAcc
(
tid
*
8
+
2
);
tSrACC
(
3
,
0
,
0
)
=
sAcc
(
tid
*
8
+
3
);
tSrACC
(
0
,
0
,
1
)
=
sAcc
(
tid
*
8
+
0
+
4
);
tSrACC
(
1
,
0
,
1
)
=
sAcc
(
tid
*
8
+
1
+
4
);
tSrACC
(
2
,
0
,
1
)
=
sAcc
(
tid
*
8
+
2
+
4
);
tSrACC
(
3
,
0
,
1
)
=
sAcc
(
tid
*
8
+
3
+
4
);
tSrACC
(
0
,
0
,
2
)
=
sAcc
(
tid
*
8
+
0
+
16
*
32
);
tSrACC
(
1
,
0
,
2
)
=
sAcc
(
tid
*
8
+
1
+
16
*
32
);
tSrACC
(
2
,
0
,
2
)
=
sAcc
(
tid
*
8
+
2
+
16
*
32
);
tSrACC
(
3
,
0
,
2
)
=
sAcc
(
tid
*
8
+
3
+
16
*
32
);
tSrACC
(
0
,
0
,
3
)
=
sAcc
(
tid
*
8
+
0
+
4
+
16
*
32
);
tSrACC
(
1
,
0
,
3
)
=
sAcc
(
tid
*
8
+
1
+
4
+
16
*
32
);
tSrACC
(
2
,
0
,
3
)
=
sAcc
(
tid
*
8
+
2
+
4
+
16
*
32
);
tSrACC
(
3
,
0
,
3
)
=
sAcc
(
tid
*
8
+
3
+
4
+
16
*
32
);
return
tSrACC
;
}
}
\ No newline at end of file
tests/test_flash_mla_dense_decoding.py
View file @
2033d805
...
...
@@ -223,9 +223,10 @@ def main(torch_dtype):
]
performance_cases
=
[
TestParam
(
128
,
s_q
,
s_k
,
is_varlen
=
True
,
is_causal
=
is_causal
,
test_performance
=
True
)
TestParam
(
128
,
s_q
,
s_k
,
is_varlen
=
True
,
is_causal
=
is_causal
,
h_q
=
h_q
,
test_performance
=
True
)
for
is_causal
in
[
False
,
True
]
for
s_q
in
[
1
,
2
]
for
h_q
in
[
16
,
128
]
for
s_k
in
[
4096
,
8192
,
16384
,
32768
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment