Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
aec17474
Commit
aec17474
authored
Apr 20, 2026
by
zhanghj2
Browse files
Feature/kimi nhead64 dense
parent
a45f646b
Changes
9
Show whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1590 additions
and
159 deletions
+1590
-159
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+88
-38
csrc/gfx9/decode/combine/combine.cu
csrc/gfx9/decode/combine/combine.cu
+61
-30
csrc/gfx93/decode/dense/splitkv_mla.cuh
csrc/gfx93/decode/dense/splitkv_mla.cuh
+1335
-20
csrc/gfx93/decode/dense/traits.h
csrc/gfx93/decode/dense/traits.h
+23
-0
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+65
-63
csrc/params.h
csrc/params.h
+10
-0
csrc/softmax.h
csrc/softmax.h
+1
-1
csrc/utils.h
csrc/utils.h
+5
-5
tests/test_flash_mla_dense_decoding.py
tests/test_flash_mla_dense_decoding.py
+2
-2
No files found.
csrc/api/dense_decode.h
View file @
aec17474
...
...
@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS
(
out
);
KU_CHECK_CONTIGUOUS
(
lse
);
if
(
!
tile_scheduler_metadata
.
has_value
())
{
if
(
!
tile_scheduler_metadata
.
has_value
()
&&
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
)
)
{
tile_scheduler_metadata
=
torch
::
empty
({
num_sm_parts
,
sizeof
(
DecodingSchedMeta
)
/
4
},
opts
.
dtype
(
torch
::
kInt32
));
num_splits
=
torch
::
empty
({
batch_size
+
1
},
opts
.
dtype
(
torch
::
kInt32
));
KU_CHECK_CONTIGUOUS
(
tile_scheduler_metadata
);
...
...
@@ -125,20 +125,6 @@ dense_attn_decode_interface(
if
(
const
char
*
val
=
std
::
getenv
(
"FLASH_MLA_PRINT_PARAM"
))
{
print_param
=
(
std
::
string
(
val
)
==
"1"
);
}
if
(
print_param
)
{
fprintf
(
stderr
,
"[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d
\n
"
,
arch
.
archName
.
c_str
(),
batch_size
,
seqlen_q_ori
,
num_heads_q
,
head_size_k
,
max_num_blocks_per_seq
,
num_blocks
,
page_block_size
,
num_heads_k
);
}
// Set the sizes
DenseAttnDecodeParams
params
;
params
.
b
=
batch_size
;
...
...
@@ -174,10 +160,10 @@ dense_attn_decode_interface(
params
.
block_table_batch_stride
=
block_table
.
stride
(
0
);
params
.
page_block_size
=
page_block_size
;
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
)
{
params
.
tile_scheduler_metadata_ptr
=
(
DecodingSchedMeta
*
)
tile_scheduler_metadata
->
data_ptr
();
params
.
num_sm_parts
=
num_sm_parts
;
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
const
int
total_num_splits
=
batch_size
+
params
.
num_sm_parts
;
at
::
Tensor
lse_accum
=
torch
::
empty
({
total_num_splits
,
num_heads
,
q_seq_per_hk
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
total_num_splits
,
num_heads
,
q_seq_per_hk
,
head_size_v
},
opts
.
dtype
(
at
::
kFloat
));
...
...
@@ -186,7 +172,65 @@ dense_attn_decode_interface(
params
.
total_num_splits
=
total_num_splits
;
params
.
softmax_lseaccum_ptr
=
lse_accum
.
data_ptr
<
float
>
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
<
float
>
();
params
.
use_split_kv
=
false
;
}
else
{
bool
use_split_kv
=
true
;
int
num_m_blocks
=
(
params
.
q_seq_per_hk
+
64
-
1
)
/
64
;
int
num_sms
=
arch
.
num_sms
;
int
num_splits
=
num_sms
*
3
/
(
num_m_blocks
*
params
.
b
);
if
(
max_num_blocks_per_seq
>=
32768
/
64
)
{
num_splits
=
32
;
}
else
if
(
max_num_blocks_per_seq
>=
16384
/
64
)
{
num_splits
=
32
;
}
else
if
(
max_num_blocks_per_seq
>=
8192
/
64
)
{
num_splits
=
16
;
}
else
if
(
max_num_blocks_per_seq
>=
4096
/
64
)
{
num_splits
=
8
;
}
else
if
(
max_num_blocks_per_seq
>=
2048
/
64
)
{
num_splits
=
4
;
}
else
{
num_splits
=
1
;
}
if
(
params
.
b
>=
128
)
{
num_splits
=
1
;
}
if
(
num_splits
<=
1
)
{
use_split_kv
=
false
;
}
else
{
num_splits
=
std
::
min
(
num_splits
,
240
);
params
.
partition_block_nums
=
max_num_blocks_per_seq
/
num_splits
;
}
if
(
params
.
partition_block_nums
<=
4
)
{
use_split_kv
=
false
;
}
params
.
use_split_kv
=
use_split_kv
;
params
.
total_num_splits
=
params
.
b
*
num_splits
;
at
::
Tensor
lse_accum
=
torch
::
empty
({
params
.
total_num_splits
,
num_heads
,
q_seq_per_hk
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
total_num_splits
,
num_heads
,
q_seq_per_hk
,
head_size_v
},
opts
.
dtype
(
at
::
kFloat
));
KU_CHECK_CONTIGUOUS
(
lse_accum
);
KU_CHECK_CONTIGUOUS
(
out_accum
);
params
.
softmax_lseaccum_ptr
=
lse_accum
.
data_ptr
<
float
>
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
<
float
>
();
}
if
(
print_param
)
{
fprintf
(
stderr
,
"[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d use_split_kv = %d num_splits %d params.partition_block_nums = %d
\n
"
,
arch
.
archName
.
c_str
(),
batch_size
,
seqlen_q_ori
,
num_heads_q
,
head_size_k
,
max_num_blocks_per_seq
,
num_blocks
,
page_block_size
,
num_heads_k
,
params
.
use_split_kv
,
params
.
total_num_splits
/
params
.
b
,
params
.
partition_block_nums
);
}
params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
q_dtype
==
torch
::
kBFloat16
)
{
...
...
@@ -220,18 +264,24 @@ dense_attn_decode_interface(
params
.
num_sm_parts
,
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
(),
params
.
use_split_kv
,
params
.
total_num_splits
/
params
.
b
,
params
.
seqlens_k_ptr
,
params
.
partition_block_nums
};
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
params
.
use_split_kv
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
#ifndef FLASH_MLA_DISABLE_FP16
#ifndef FLASH_MLA_DISABLE_FP16
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
half_t
>
(
combine_params
);
#endif
#endif
}
else
{
TORCH_CHECK
(
false
,
"Unsupported tensor dtype for query"
);
}
}
out
=
out
.
view
({
batch_size
,
num_heads_k
,
seqlen_q_ori
,
num_q_heads_per_hk
,
head_size_v
}).
transpose
(
1
,
2
)
.
reshape
({
batch_size
,
seqlen_q_ori
,
num_heads_q
,
head_size_v
});
...
...
csrc/gfx9/decode/combine/combine.cu
View file @
aec17474
...
...
@@ -16,7 +16,7 @@ using namespace cute;
namespace
gfx9
::
decode
{
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
,
bool
USE_SPLIT_KV
=
false
>
__global__
void
__launch_bounds__
(
NUM_THREADS
,
1
)
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
...
...
@@ -33,12 +33,36 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
return
;
}
const
int
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
const
int
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
const
int
my_num_splits
=
end_split_idx
-
start_split_idx
;
int
start_split_idx
;
int
end_split_idx
;
int
my_num_splits
;
if
constexpr
(
USE_SPLIT_KV
)
{
start_split_idx
=
batch_idx
*
params
.
num_splits
;
end_split_idx
=
(
batch_idx
+
1
)
*
params
.
num_splits
;
int
seqlen_k
=
__ldg
(
params
.
seqlens_k_ptr
+
batch_idx
);
end_split_idx
=
std
::
min
(
cute
::
ceil_div
(
cute
::
ceil_div
(
seqlen_k
,
64
),
params
.
partition_block_nums
),
params
.
num_splits
)
+
start_split_idx
;
// if (lane_idx == 0 && batch_idx == 61)
// {
// printf(" batch_idx = %d start_split_idx = %d end_split_idx = %d seqlen_k = %d \n",batch_idx, start_split_idx, end_split_idx, seqlen_k);
// }
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
}
}
else
{
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
}
}
// FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
...
...
@@ -245,6 +269,9 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 240) { \
constexpr static int NAME = 240; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
...
...
@@ -255,29 +282,33 @@ template<typename ElementT>
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
)
{
static
constexpr
int
HEAD_DIM_V
=
512
;
// Since only this head dimension is supported by Flash MLA
FLASH_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
if
(
params
.
use_split_kv
)
{
MLA_NUM_SPLITS_SWITCH
(
params
.
num_splits
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
,
true
>
;
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
}
else
{
MLA_NUM_SPLITS_SWITCH
(
params
.
num_sm_parts
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
>
;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// 0,
// params.stream,
// attribute,
// 1
// };
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
}
CHECK_CUDA_KERNEL_LAUNCH
();
}
...
...
csrc/gfx93/decode/dense/splitkv_mla.cuh
View file @
aec17474
...
...
@@ -10,6 +10,613 @@ using namespace cute;
namespace
gfx93
{
// template<typename T, bool use_split_kv=false>
// __device__ void
// compute_attn_1rowblock_splitkv_mla_block_m_64_gfx936(const DenseAttnDecodeParams& params,
// const int bidb, const int bidh, const int m_block,
// const int n_split_idx, const int seqlen_k,
// const int n_block_min, const int n_block_max, const bool NoSplit)
// {
// extern __shared__ char shared_memory[];
// const int tidx = threadIdx.x;
// constexpr int kBlockM = T::BLOCK_SIZE_M;
// constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
// constexpr int kHeadDim = T::HEAD_DIM_K;
// constexpr int kHeadDimV = T::HEAD_DIM_V;
// using Element = T::InputT;
// using index_t = int64_t;
// const int warp_idx = __builtin_amdgcn_readfirstlane(tidx / 64);
// const int lane_idx = tidx % 64;
// Element* q_lds = (Element*)&(shared_memory);
// Element* k_lds = q_lds;
// Element* v_lds = q_lds;
// const index_t row_offset_q = bidb * params.q_batch_stride + m_block * kBlockM * params.q_row_stride + bidh * params.q_head_stride;
// Tensor gQ = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.q_ptr) + row_offset_q),
// Shape<Int<kBlockM>, Int<kHeadDim>>{},
// make_stride(params.q_row_stride, _1{}));
// // const index_t row_offset_k = (0) * params.k_head_stride;
// // Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
// // Shape<Int<kBlockN>, Int<kHeadDim>>{},
// // make_stride(params.k_row_stride, _1{}));
// typedef __bf16 __fp16x8_t __attribute__((ext_vector_type(8)));
// typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
// typedef __bf16 __fp16x2_t __attribute__((ext_vector_type(2)));
// union Bf16_storage {
// __fp16x8_t data_128;
// __fp16x4_t data_64[2];
// __fp16x2_t data_32[4];
// uint16_t data_array[8];
// };
// union Bf16_storage_x4 {
// __fp16x4_t data_64;
// __fp16x2_t data_32[2];
// uint16_t data[4];
// };
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr_q;
// *(uint64_t*)&glob_ptr_q = reinterpret_cast<uint64_t>(gQ.data().get());
// glob_ptr_q.latter |= ((params.q_row_stride * 2) << 16);
// glob_ptr_q.latter |= 0x40000000;
// uint32x4_t global_addr_q = {0};
// global_addr_q[0] = (glob_ptr_q.former);
// global_addr_q[1] = (glob_ptr_q.latter);
// global_addr_q[2] = params.q_seq_per_hk - m_block * kBlockM;
// global_addr_q[3] = 0x00020000;
// int virtual_row_ = lane_idx / 8;//0
// int virtual_col_ = lane_idx % 8;//0
// int swizzle_col_ = virtual_row_ ^ virtual_col_;
// int row_ = lane_idx / 4;//0
// // 8->9 9->8
// // row_ = (row_ >= 8 ) ^ row_;
// int col_ = swizzle_col_ % 4;
// auto calc_row_and_col_k = [&]() -> std::tuple<int, int> {
// constexpr int elements_per_thread = 8;
// #if defined(__gfx938__)
// int row_offset = row_ + warp_idx * 16;
// int col_offset = col_ * 8;
// #else
// int row_offset = row_ * 4 + warp_idx;
// int col_offset = col_ * 8;
// #endif
// return {row_offset, col_offset};
// };
// auto buffer_load_lds_k = [&](int row_offset, int col, int k_idx, int block_idx, index_t offset_k) {
// constexpr int element_size = 2;
// PtrWrapper glob_ptr_k;
// *(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(params.k_ptr) + offset_k * 2;
// glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16);
// glob_ptr_k.latter |= 0x40000000;
// uint32x4_t global_addr_k = {0};
// global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former);
// global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter);
// global_addr_k[2] = seqlen_k - block_idx * kBlockN;
// global_addr_k[3] = 0x00020000;
// constexpr int elements_per_thread = 8;
// int col_offset = col;
// int offset_v = col_offset * 2;
// int ldsAddrPerWave = reinterpret_cast<size_t>(k_lds) + warp_idx * 16 * 32 * 2 + (k_idx % 4) * 64 * 32 * 2;
// typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = offset_v;
// const int offset_s = k_idx * 32 * 2;
// __builtin_amdgcn_sched_barrier(0);
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "s_nop 0 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
// "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
// :);
// __builtin_amdgcn_sched_barrier(0);
// };
// auto k_lds_read_offset = [&] () -> int {
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// col = (row / 2) ^ col;
// col = col % 4;
// // row = (row >= 8) ^ row;
// const auto lds_offset = row * 32 + col * 8;
// // #endif
// return lds_offset;
// };
// auto calc_row_and_col_v = [&](int i) -> int {
// int row = lane_idx / 4;
// // int col = lane_idx % 4;
// int row_offset = row + i * 16;
// // int col_offset = col * 8 + warp_idx * 32;
// return row_offset;
// };
// const int v_lds_read_ptr = reinterpret_cast<size_t>(v_lds + lane_idx * 8);
// Element* k_lds_read_ptr = (k_lds + k_lds_read_offset());
// int col_offset_v = (lane_idx % 4) * 8 + warp_idx * 32;
// auto buffer_load_lds_v = [&](int row_offset, int col, int k_idx, int n_idx, int block_idx, index_t offset_k) {
// constexpr int element_size = 2;
// PtrWrapper glob_ptr_k;
// *(uint64_t*)&glob_ptr_k = reinterpret_cast<uint64_t>(params.k_ptr) + offset_k * 2;
// glob_ptr_k.latter |= ((params.k_row_stride * 2) << 16);
// glob_ptr_k.latter |= 0x40000000;
// uint32x4_t global_addr_k = {0};
// global_addr_k[0] = __builtin_amdgcn_readfirstlane(glob_ptr_k.former);
// global_addr_k[1] = __builtin_amdgcn_readfirstlane(glob_ptr_k.latter);
// global_addr_k[2] = seqlen_k - block_idx * kBlockN;
// global_addr_k[3] = 0x00020000;
// constexpr int elements_per_thread = 8;
// int col_offset = col;
// // int v_idx = row_offset;
// int offset_v = col_offset * 2;
// int ldsAddrPerWave = reinterpret_cast<size_t>(v_lds) + warp_idx * 16 * 32 * 2 + (k_idx) * 128 * 16 * 2;
// typedef uint32_t uint32x2_t __attribute__((ext_vector_type(2)));
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = offset_v;
// const int offset_s = n_idx * 128 * 2;
// __builtin_amdgcn_sched_barrier(0);
// asm volatile(
// "s_mov_b32 m0, %1 \n\t"
// "s_nop 0 \n\t"
// "buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds \n" ::"v"(index_offset),
// "s"(ldsAddrPerWave), "s"(global_addr_k), "s"(offset_s)
// :);
// __builtin_amdgcn_sched_barrier(0);
// };
// Bf16_storage q_reg[18];
// for (int i = 0; i < 18; i++)
// {
// constexpr int elements_per_thread = 8;
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// int row_offset = row + warp_idx * 16;
// int col_offset = col * 8;
// int offset_v = col_offset * 2 + i * 32 * 2;
// q_reg[i].data_128 = __builtin_amdgcn_buffer_load_dwordx4(global_addr_q, row_offset, offset_v, false, false);
// }
// __syncthreads();
// v4f acco_f32[32];
// for (int i = 0; i < 32; i++)
// {
// acco_f32[i].x = 0.0f;
// acco_f32[i].y = 0.0f;
// acco_f32[i].z = 0.0f;
// acco_f32[i].w = 0.0f;
// }
// const int *block_table = params.block_table + bidb * params.block_table_batch_stride;
// auto float2bf16 = [] (float s) -> uint16_t {
// uint32_t x32 = reinterpret_cast<uint32_t const &>(s);
// #ifndef FLASH_MLA_BF16_TYPE
// #define FLASH_MLA_BF16_TYPE 0
// #endif
// #if FLASH_MLA_BF16_TYPE == 1
// x32 += 0x8000u;
// #endif
// return uint16_t(x32 >> 16);
// };
// // int block_idx = 0;
// // int cur_block_table = block_table[block_idx];
// // index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
// // auto [row_offset, col] = calc_row_and_col_k(block_idx);
// // buffer_load_lds_k(row_offset, col, 0, offset_k);
// // __syncthreads();
// {
// // if (thread0())
// // {
// // int k = 0;
// // for (int i = 0; i < 64; i++)
// // {
// // for (int j = 0; j < 32; j++)
// // {
// // printf(" %.3f ", float(k_lds[k]));
// // k++;
// // }
// // printf("\n");
// // }
// // }
// // if (block0() && threadIdx.x < 64)
// // {
// // cutlass::bfloat16_t q[8];
// // for (int i = 0; i < 8; i++)
// // {
// // q[i].storage = v_reg[0].data_array[i];
// // // q[i].storage = q_reg[0].data_array[i];
// // }
// // printf("tidx %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %d\n ", threadIdx.x,
// // float(q[0]),
// // float(q[1]),
// // float(q[2]),
// // float(q[3]),
// // float(q[4]),
// // float(q[5]),
// // float(q[6]),
// // float(q[7]),
// // v_lds_read_ptr
// // );
// // }
// }
// struct IsMaskBlock {};
// struct IsFirstMaskBlock {};
// struct IsNoMaskBlock {};
// flash::Softmax<1> softmax;
// auto process_one_block = [&] (int block_idx, auto is_mask_block_t) {
// static constexpr bool IS_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
// static constexpr bool IS_FIRST_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsFirstMaskBlock>;
// static constexpr bool IS_NO_MASK_BLOCK = std::is_same_v<decltype(is_mask_block_t), IsNoMaskBlock>;
// int cur_block_table = block_table[block_idx];
// v4f accs_f32[4];
// for (int i = 0; i < 4; i++)
// {
// accs_f32[i].x = 0.0f;
// accs_f32[i].y = 0.0f;
// accs_f32[i].z = 0.0f;
// accs_f32[i].w = 0.0f;
// }
// index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
// auto [row_offset, col] = calc_row_and_col_k();
// #define LOAD_K_AND_QK_GEMM(k) \
// { \
// constexpr int k_val = (k); \
// buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k); \
// flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// }
// {
// constexpr int k_val = (17);
// buffer_load_lds_k(row_offset, col, k_val, block_idx, offset_k);
// buffer_load_lds_k(row_offset, col, k_val - 1, block_idx, offset_k);
// buffer_load_lds_k(row_offset, col, k_val - 2, block_idx, offset_k);
// buffer_load_lds_k(row_offset, col, k_val - 3, block_idx, offset_k);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// LOAD_K_AND_QK_GEMM(16);
// LOAD_K_AND_QK_GEMM(15);
// LOAD_K_AND_QK_GEMM(14);
// LOAD_K_AND_QK_GEMM(13);
// LOAD_K_AND_QK_GEMM(12);
// LOAD_K_AND_QK_GEMM(11);
// LOAD_K_AND_QK_GEMM(10);
// LOAD_K_AND_QK_GEMM(9);
// LOAD_K_AND_QK_GEMM(8);
// LOAD_K_AND_QK_GEMM(7);
// LOAD_K_AND_QK_GEMM(6);
// LOAD_K_AND_QK_GEMM(5);
// LOAD_K_AND_QK_GEMM(4);
// LOAD_K_AND_QK_GEMM(3);
// flash::qk_gemm<Element, k_val - 15>(q_reg[k_val - 15].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::qk_gemm<Element, k_val - 16>(q_reg[k_val - 16].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::qk_gemm<Element, k_val - 17>(q_reg[k_val - 17].data_128, k_lds_read_ptr, accs_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// }
// // if (block0() && tidx < 64)
// // {
// // printf(" %.3f %.3f \n", accs_f32[0][0], accs_f32[0][1]);
// // }
// if constexpr (!IS_NO_MASK_BLOCK) {
// for (int i = 0; i < 16; ++i) {
// int idx = i;
// if constexpr (!T::Is_causal) {
// if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 >= int(seqlen_k - block_idx * kBlockN))
// {
// #if defined(__gfx938__)
// accs_f32[i/4][i%4] = -INFINITY;
// #else
// accs_f32[i%4][i/4] = -INFINITY;
// #endif
// }
// } else {
// // Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// // col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// int row = (lane_idx % 16) + warp_idx * 16;;
// int col_limit_right = seqlen_k - 1 - block_idx * kBlockN - (params.q_seq_per_hk - 1 - (m_block * kBlockM + row)) / params.q_head_per_hk;
// if ((lane_idx / 16) * 4 + (idx % 4) + (idx / 4) * 16 > col_limit_right) {
// #if defined(__gfx938__)
// accs_f32[i/4][i%4] = -INFINITY;
// #else
// accs_f32[i%4][i/4] = -INFINITY;
// #endif
// }
// }
// }
// }
// Tensor scores = make_tensor<float>(Shape<_1, _16>{});
// for (int i = 0; i < 16; i++) {
// #if defined(__gfx938__)
// scores(0, i) = accs_f32[i/4][i%4];
// #else
// scores(0, i) = accs_f32[i%4][i/4];
// #endif
// }
// softmax.template softmax_rescale_o_prefill_4x1</*Is_first=*/IS_FIRST_MASK_BLOCK, /*Check_inf=*//*Is_local=*/T::Is_causal>(scores, acco_f32, params.scale_softmax_log2);
// Bf16_storage_x4 p[4];
// for (int i = 0; i < 4; i++)
// {
// #if defined(__gfx938__)
// p[i].data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4), 0, scores(0, i * 4 + 1), 0);
// p[i].data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, scores(0, i * 4 + 2), 0, scores(0, i * 4 + 3), 0);
// #else
// p[i].data[0] = float2bf16(scores(0, i * 4));
// p[i].data[1] = float2bf16(scores(0, i * 4 + 1));
// p[i].data[2] = float2bf16(scores(0, i * 4 + 2));
// p[i].data[3] = float2bf16(scores(0, i * 4 + 3));
// #endif
// }
// int row_offset_v[4];
// for (int i = 0; i < 4; i++)
// {
// row_offset_v[i] = calc_row_and_col_v(i);
// }
// // __syncthreads();
// // #if 1
// {
// constexpr int k_val = (0);
// buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0, block_idx, offset_k);
// buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0, block_idx, offset_k);
// buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 0, block_idx, offset_k);
// buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 0, block_idx, offset_k);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1, block_idx, offset_k);
// flash::pv_gemm<k_val + 1, 0>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 1, 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 1, 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 1, 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1, block_idx, offset_k);
// flash::pv_gemm<k_val + 2, 0>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 2, 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 2, 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 2, 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, 1, block_idx, offset_k);
// flash::pv_gemm<k_val + 3, 0>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 3, 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 3, 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<k_val + 3, 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, 1, block_idx, offset_k);
// }
// #define LOAD_V_AND_PV_GEMM(n) \
// { \
// constexpr int k_val = (0); \
// constexpr int n_val = (n); \
// flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1, block_idx, offset_k); \
// flash::pv_gemm<k_val + 1, n_val * 4>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 1, n_val * 4 + 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 1, n_val * 4 + 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 1, n_val * 4 + 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1, block_idx, offset_k); \
// flash::pv_gemm<k_val + 2, n_val * 4>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 2, n_val * 4 + 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 2, n_val * 4 + 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 2, n_val * 4 + 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1, block_idx, offset_k); \
// flash::pv_gemm<k_val + 3, n_val * 4>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 3, n_val * 4 + 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 3, n_val * 4 + 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// flash::pv_gemm<k_val + 3, n_val * 4 + 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
// __builtin_amdgcn_sched_barrier(0); \
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
// __builtin_amdgcn_sched_barrier(0); \
// buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1, block_idx, offset_k); \
// }
// LOAD_V_AND_PV_GEMM(1);
// LOAD_V_AND_PV_GEMM(2);
// {
// constexpr int n_val = (3);
// flash::pv_gemm<0, 12>(p[0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<0, 13>(p[0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<0, 14>(p[0].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<0, 15>(p[0].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<1, 12>(p[1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<1, 13>(p[1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<1, 14>(p[1].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<1, 15>(p[1].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<2, 12>(p[2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<2, 13>(p[2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<2, 14>(p[2].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<2, 15>(p[2].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// flash::pv_gemm<3, 12>(p[3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<3, 13>(p[3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<3, 14>(p[3].data_64, v_lds_read_ptr, acco_f32);
// flash::pv_gemm<3, 15>(p[3].data_64, v_lds_read_ptr, acco_f32);
// __builtin_amdgcn_sched_barrier(0);
// asm volatile("s_barrier\n\t");
// __builtin_amdgcn_sched_barrier(0);
// }
// };
// constexpr int n_masking_steps = !T::Is_causal ? 1 : cute::ceil_div(kBlockM, kBlockN) + 1;
// int n_block = n_block_max - 1;
// if constexpr (n_masking_steps == 1) {
// if (n_block >= n_block_min) {
// process_one_block(n_block, IsFirstMaskBlock{});
// }
// n_block--;
// } else {
// int masking_step = 1;
// if (n_block >= n_block_min) {
// process_one_block(n_block, IsFirstMaskBlock{});
// }
// n_block--;
// for (; n_block >= n_block_min && masking_step < n_masking_steps; ++masking_step, --n_block) {
// process_one_block(n_block, IsMaskBlock{});
// }
// }
// for(; n_block >= n_block_min; --n_block) {
// process_one_block(n_block, IsNoMaskBlock{});
// }
// using ElementAccum = float;
// if constexpr (true)
// {
// using ElementO = Element;
// const index_t row_offset_o = bidb * params.o_batch_stride + m_block * kBlockM * params.o_row_stride + bidh * params.o_head_stride;
// const index_t row_offset_lse = (bidb * params.h_k + bidh) * params.q_seq_per_hk + m_block * kBlockM;
// constexpr bool Split = false;
// Tensor gOaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementO *>(Split ? params.oaccum_ptr : params.o_ptr) + ( row_offset_o)),
// Shape<Int<kBlockM>, Int<kHeadDimV>>{},
// make_stride(Split ? kHeadDimV : params.o_row_stride, _1{}));
// Tensor lse = softmax.template normalize_softmax_lse_prefill_4x1</*Is_dropout=*/false, Split>(acco_f32, params.scale_softmax);
// // if (block0() && tidx < 64)
// // {
// // printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// // }
// Tensor gLSEaccum = make_tensor(make_gmem_ptr(reinterpret_cast<ElementAccum *>(Split ? params.softmax_lseaccum_ptr : params.softmax_lse_ptr) + (row_offset_lse)),
// Shape<Int<kBlockM>>{}, Stride<_1>{});
// {
// // using result_type = cutlass::Array<bfloat16_t, 2>;
// // int tidx = threadIdx.x;
// int row, col;
// // int warpid = tidx / 64;
// for (int mi = 0; mi < 1; mi++) {
// row = mi * kBlockM + lane_idx % 16 + warp_idx * 16;
// if (row < params.q_seq_per_hk - m_block * kBlockM) {
// for (int ni = 0; ni < 16; ++ni) {
// #if defined(__gfx938__)
// Bf16_storage res;
// col = (lane_idx / 16) * 8 + ni * 32 ;
// res.data_32[0] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][0], 0, acco_f32[ni * 2 + 1][0], 0);
// res.data_32[1] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][1], 0, acco_f32[ni * 2 + 1][1], 0);
// res.data_32[2] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][2], 0, acco_f32[ni * 2 + 1][2], 0);
// res.data_32[3] = __builtin_hcu_cvt_pk_bf16_f32(0, acco_f32[ni * 2][3], 0, acco_f32[ni * 2 + 1][3], 0);
// *(__fp16x8_t*)(&gOaccum(row, col)) = res.data_128;
// #else
// col = (lane_idx / 16) * 2 + ni * 32 ;
// using result_type = cutlass::Array<Element, 2>;
// for (int ei = 0; ei < 4; ei++)
// {
// result_type res;
// Element e0, e1;
// e0.storage = float2bf16(acco_f32[ni * 2][ei]);
// e1.storage = float2bf16(acco_f32[ni * 2 + 1][ei]);
// res[0] = e0;
// res[1] = e1;
// // gO(row, col) = res[0];
// // gO(row, col + 1) = res[1];
// *(result_type*)(&gOaccum(row, col)) = res;
// col += 8;
// }
// #endif
// }
// // for (int n = 0; n < 1; n++) {
// // col = (tidx % 64 / 16) + warpid * 32 + n * 128;
// // for (int ei = 0; ei < 8; ei ++) {
// // gOaccum(row, col) = rO(ei, m, n);
// // col += 4;
// // }
// // }
// gLSEaccum(row) = lse(mi);
// }
// }
// }
// }
// }
template
<
typename
T
>
__device__
void
compute_attn_1rowblock_splitkv_mla_gfx936
(
const
DenseAttnDecodeParams
&
params
,
...
...
@@ -599,34 +1206,742 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
}
}
template
<
typename
T
,
bool
use_split_kv
=
false
>
__global__
void
__launch_bounds__
(
T
::
NUM_THREADS
,
1
)
flash_fwd_splitkv_mla_block_m_64_kernel
(
const
DenseAttnDecodeParams
params
)
{
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
const
int
m_block
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
int
bidb
;
int
seqlen_k
;
int
n_block_min
;
int
n_block_max
;
const
int
tidx
=
threadIdx
.
x
;
const
int
lane_idx
=
tidx
%
64
;
bool
is_split
=
use_split_kv
;
if
constexpr
(
use_split_kv
)
{
int
num_splits
=
params
.
total_num_splits
/
params
.
b
;
bidb
=
blockIdx
.
z
%
params
.
b
;
// bidb = blockIdx.z / num_splits;
seqlen_k
=
__ldg
(
params
.
seqlens_k_ptr
+
bidb
);
int
split_id
=
blockIdx
.
z
/
params
.
b
;
n_block_min
=
split_id
*
params
.
partition_block_nums
;
n_block_max
=
split_id
==
(
num_splits
-
1
)
?
cute
::
ceil_div
(
seqlen_k
,
kBlockN
)
:
std
::
min
((
split_id
+
1
)
*
params
.
partition_block_nums
,
cute
::
ceil_div
(
seqlen_k
,
kBlockN
));
if
(
split_id
==
0
&&
n_block_max
==
cute
::
ceil_div
(
seqlen_k
,
kBlockN
))
{
is_split
=
false
;
}
// if (tidx == 0 && bidb == 61)
// {
// printf("bidb = %d split_id = %d n_block_min = %d n_block_max = %d num_splits = %d params.partition_block_nums %d is_split = %d \n", bidb, split_id, n_block_min, n_block_max, num_splits, params.partition_block_nums, is_split);
// }
if
(
n_block_max
<=
n_block_min
)
return
;
template
<
typename
InputT
>
void
run_flash_splitkv_mla_kernel
(
DenseAttnDecodeParams
&
params
)
{
FLASH_ASSERT
(
params
.
d
==
Config
::
HEAD_DIM_K
);
FLASH_ASSERT
(
params
.
d_v
==
Config
::
HEAD_DIM_V
);
}
else
{
bidb
=
blockIdx
.
z
;
seqlen_k
=
__ldg
(
params
.
seqlens_k_ptr
+
bidb
);
n_block_min
=
0
;
n_block_max
=
cute
::
ceil_div
(
seqlen_k
,
kBlockN
);
}
constexpr
size_t
smem_size
=
65536
;
extern
__shared__
char
shared_memory
[];
constexpr
int
kBlockM
=
T
::
BLOCK_SIZE_M
;
// constexpr int kBlockN = T::PAGE_BLOCK_SIZE;
constexpr
int
kHeadDim
=
T
::
HEAD_DIM_K
;
constexpr
int
kHeadDimV
=
T
::
HEAD_DIM_V
;
using
Element
=
T
::
InputT
;
using
index_t
=
int64_t
;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
);
Element
*
q_lds
=
(
Element
*
)
&
(
shared_memory
);
Element
*
k_lds
=
q_lds
;
Element
*
v_lds
=
q_lds
;
// Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
// const index_t row_offset_k = (0) * params.k_head_stride;
// Tensor gK = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.k_ptr) + row_offset_k),
// Shape<Int<kBlockN>, Int<kHeadDim>>{},
// make_stride(params.k_row_stride, _1{}));
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
__bf16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
__bf16
__fp16x2_t
__attribute__
((
ext_vector_type
(
2
)));
union
Bf16_storage
{
__fp16x8_t
data_128
;
__fp16x4_t
data_64
[
2
];
__fp16x2_t
data_32
[
4
];
uint16_t
data_array
[
8
];
};
union
Bf16_storage_x4
{
__fp16x4_t
data_64
;
__fp16x2_t
data_32
[
2
];
uint16_t
data
[
4
];
};
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr_q
;
*
(
uint64_t
*
)
&
glob_ptr_q
=
reinterpret_cast
<
uint64_t
>
(
gQ
.
data
().
get
());
glob_ptr_q
.
latter
|=
((
params
.
q_row_stride
*
2
)
<<
16
);
glob_ptr_q
.
latter
|=
0x40000000
;
uint32x4_t
global_addr_q
=
{
0
};
global_addr_q
[
0
]
=
(
glob_ptr_q
.
former
);
global_addr_q
[
1
]
=
(
glob_ptr_q
.
latter
);
global_addr_q
[
2
]
=
params
.
q_seq_per_hk
-
m_block
*
kBlockM
;
global_addr_q
[
3
]
=
0x00020000
;
int
virtual_row_
=
lane_idx
/
8
;
//0
int
virtual_col_
=
lane_idx
%
8
;
//0
int
swizzle_col_
=
virtual_row_
^
virtual_col_
;
int
row_
=
lane_idx
/
4
;
//0
// 8->9 9->8
// row_ = (row_ >= 8 ) ^ row_;
int
col_
=
swizzle_col_
%
4
;
auto
calc_row_and_col_k
=
[
&
]()
->
std
::
tuple
<
int
,
int
>
{
constexpr
int
elements_per_thread
=
8
;
#if defined(__gfx938__)
int
row_offset
=
row_
+
warp_idx
*
16
;
int
col_offset
=
col_
*
8
;
#else
int
row_offset
=
row_
*
4
+
warp_idx
;
int
col_offset
=
col_
*
8
;
#endif
return
{
row_offset
,
col_offset
};
};
auto
buffer_load_lds_k
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
block_idx
,
index_t
offset_k
)
{
constexpr
int
element_size
=
2
;
PtrWrapper
glob_ptr_k
;
*
(
uint64_t
*
)
&
glob_ptr_k
=
reinterpret_cast
<
uint64_t
>
(
params
.
k_ptr
)
+
offset_k
*
2
;
glob_ptr_k
.
latter
|=
((
params
.
k_row_stride
*
2
)
<<
16
);
glob_ptr_k
.
latter
|=
0x40000000
;
uint32x4_t
global_addr_k
=
{
0
};
global_addr_k
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr_k
.
former
);
global_addr_k
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr_k
.
latter
);
global_addr_k
[
2
]
=
seqlen_k
-
block_idx
*
kBlockN
;
global_addr_k
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
int
offset_v
=
col_offset
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
%
5
)
*
64
*
32
*
2
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
;
index_offset
[
1
]
=
offset_v
;
const
int
offset_s
=
k_idx
*
32
*
2
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_k
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
};
auto
k_lds_read_offset
=
[
&
]
()
->
int
{
int
row
=
lane_idx
%
16
;
int
col
=
lane_idx
/
16
;
col
=
(
row
/
2
)
^
col
;
col
=
col
%
4
;
// row = (row >= 8) ^ row;
const
auto
lds_offset
=
row
*
32
+
col
*
8
;
// #endif
return
lds_offset
;
};
auto
calc_row_and_col_v
=
[
&
](
int
i
)
->
int
{
int
row
=
lane_idx
/
4
;
// int col = lane_idx % 4;
int
row_offset
=
row
+
i
*
16
;
// int col_offset = col * 8 + warp_idx * 32;
return
row_offset
;
};
const
int
v_lds_read_ptr
=
reinterpret_cast
<
size_t
>
(
v_lds
+
lane_idx
*
8
);
Element
*
k_lds_read_ptr
=
(
k_lds
+
k_lds_read_offset
());
int
col_offset_v
=
(
lane_idx
%
4
)
*
8
+
warp_idx
*
32
;
auto
buffer_load_lds_v
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
n_idx
,
int
block_idx
,
index_t
offset_k
)
{
constexpr
int
element_size
=
2
;
PtrWrapper
glob_ptr_k
;
*
(
uint64_t
*
)
&
glob_ptr_k
=
reinterpret_cast
<
uint64_t
>
(
params
.
k_ptr
)
+
offset_k
*
2
;
glob_ptr_k
.
latter
|=
((
params
.
k_row_stride
*
2
)
<<
16
);
glob_ptr_k
.
latter
|=
0x40000000
;
uint32x4_t
global_addr_k
=
{
0
};
global_addr_k
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr_k
.
former
);
global_addr_k
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr_k
.
latter
);
global_addr_k
[
2
]
=
seqlen_k
-
block_idx
*
kBlockN
;
global_addr_k
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
// int v_idx = row_offset;
int
offset_v
=
col_offset
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
)
*
128
*
16
*
2
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
;
index_offset
[
1
]
=
offset_v
;
const
int
offset_s
=
n_idx
*
128
*
2
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_k
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
};
Bf16_storage
q_reg
[
18
];
for
(
int
i
=
0
;
i
<
18
;
i
++
)
{
constexpr
int
elements_per_thread
=
8
;
int
row
=
lane_idx
%
16
;
int
col
=
lane_idx
/
16
;
int
row_offset
=
row
+
warp_idx
*
16
;
int
col_offset
=
col
*
8
;
int
offset_v
=
col_offset
*
2
+
i
*
32
*
2
;
q_reg
[
i
].
data_128
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr_q
,
row_offset
,
offset_v
,
false
,
false
);
}
__syncthreads
();
v4f
acco_f32
[
32
];
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
acco_f32
[
i
].
x
=
0.0
f
;
acco_f32
[
i
].
y
=
0.0
f
;
acco_f32
[
i
].
z
=
0.0
f
;
acco_f32
[
i
].
w
=
0.0
f
;
}
const
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
auto
float2bf16
=
[]
(
float
s
)
->
uint16_t
{
uint32_t
x32
=
reinterpret_cast
<
uint32_t
const
&>
(
s
);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32
+=
0x8000u
;
#endif
return
uint16_t
(
x32
>>
16
);
};
// int block_idx = 0;
// int cur_block_table = block_table[block_idx];
// index_t offset_k = (index_t)(cur_block_table) * params.k_batch_stride;
// auto [row_offset, col] = calc_row_and_col_k(block_idx);
// buffer_load_lds_k(row_offset, col, 0, offset_k);
// __syncthreads();
{
// if (thread0())
// {
// int k = 0;
// for (int i = 0; i < 64; i++)
// {
// for (int j = 0; j < 32; j++)
// {
// printf(" %.3f ", float(k_lds[k]));
// k++;
// }
// printf("\n");
// }
// }
// if (block0() && threadIdx.x < 64)
// {
// cutlass::bfloat16_t q[8];
// for (int i = 0; i < 8; i++)
// {
// q[i].storage = v_reg[0].data_array[i];
// // q[i].storage = q_reg[0].data_array[i];
// }
// printf("tidx %d %.3f %.3f %.3f %.3f %.3f %.3f %.3f %.3f %d\n ", threadIdx.x,
// float(q[0]),
// float(q[1]),
// float(q[2]),
// float(q[3]),
// float(q[4]),
// float(q[5]),
// float(q[6]),
// float(q[7]),
// v_lds_read_ptr
// );
// }
}
struct
IsMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsNoMaskBlock
{};
flash
::
Softmax
<
1
>
softmax
;
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
static
constexpr
bool
IS_NO_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
int
cur_block_table
=
block_table
[
block_idx
];
v4f
accs_f32
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
accs_f32
[
i
].
x
=
0.0
f
;
accs_f32
[
i
].
y
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
index_t
offset_k
=
(
index_t
)(
cur_block_table
)
*
params
.
k_batch_stride
;
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
();
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 4, block_idx, offset_k); \
flash::qk_gemm<Element, k_val, 5>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier
(
0
);
\
}
{
constexpr
int
k_val
=
(
17
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
,
block_idx
,
offset_k
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
1
,
block_idx
,
offset_k
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
2
,
block_idx
,
offset_k
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
3
,
block_idx
,
offset_k
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
4
,
block_idx
,
offset_k
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
,
5
>
(
q_reg
[
k_val
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
LOAD_K_AND_QK_GEMM
(
16
);
LOAD_K_AND_QK_GEMM
(
15
);
LOAD_K_AND_QK_GEMM
(
14
);
LOAD_K_AND_QK_GEMM
(
13
);
LOAD_K_AND_QK_GEMM
(
12
);
LOAD_K_AND_QK_GEMM
(
11
);
LOAD_K_AND_QK_GEMM
(
10
);
LOAD_K_AND_QK_GEMM
(
9
);
LOAD_K_AND_QK_GEMM
(
8
);
LOAD_K_AND_QK_GEMM
(
7
);
LOAD_K_AND_QK_GEMM
(
6
);
LOAD_K_AND_QK_GEMM
(
5
);
LOAD_K_AND_QK_GEMM
(
4
);
// LOAD_K_AND_QK_GEMM(3);
flash
::
qk_gemm
<
Element
,
k_val
-
14
,
5
>
(
q_reg
[
k_val
-
14
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
15
,
5
>
(
q_reg
[
k_val
-
15
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
16
,
5
>
(
q_reg
[
k_val
-
16
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
17
,
5
>
(
q_reg
[
k_val
-
17
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", accs_f32[0][0], accs_f32[0][1]);
// }
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
for
(
int
i
=
0
;
i
<
16
;
++
i
)
{
int
idx
=
i
;
if
constexpr
(
!
T
::
Is_causal
)
{
if
((
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
>=
int
(
seqlen_k
-
block_idx
*
kBlockN
))
{
#if defined(__gfx938__)
accs_f32
[
i
/
4
][
i
%
4
]
=
-
INFINITY
;
#else
accs_f32
[
i
%
4
][
i
/
4
]
=
-
INFINITY
;
#endif
}
}
else
{
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
(
lane_idx
%
16
)
+
warp_idx
*
16
;;
int
col_limit_right
=
seqlen_k
-
1
-
block_idx
*
kBlockN
-
(
params
.
q_seq_per_hk
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
q_head_per_hk
;
if
((
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
>
col_limit_right
)
{
#if defined(__gfx938__)
accs_f32
[
i
/
4
][
i
%
4
]
=
-
INFINITY
;
#else
accs_f32
[
i
%
4
][
i
/
4
]
=
-
INFINITY
;
#endif
}
}
}
}
Tensor
scores
=
make_tensor
<
float
>
(
Shape
<
_1
,
_16
>
{});
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
#if defined(__gfx938__)
scores
(
0
,
i
)
=
accs_f32
[
i
/
4
][
i
%
4
];
#else
scores
(
0
,
i
)
=
accs_f32
[
i
%
4
][
i
/
4
];
#endif
}
softmax
.
template
softmax_rescale_o_prefill_4x1
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*//*Is_local=*/
T
::
Is_causal
>(
scores
,
acco_f32
,
params
.
scale_softmax_log2
);
Bf16_storage_x4
p
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#if defined(__gfx938__)
p
[
i
].
data_32
[
0
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
scores
(
0
,
i
*
4
),
0
,
scores
(
0
,
i
*
4
+
1
),
0
);
p
[
i
].
data_32
[
1
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
scores
(
0
,
i
*
4
+
2
),
0
,
scores
(
0
,
i
*
4
+
3
),
0
);
#else
p
[
i
].
data
[
0
]
=
float2bf16
(
scores
(
0
,
i
*
4
));
p
[
i
].
data
[
1
]
=
float2bf16
(
scores
(
0
,
i
*
4
+
1
));
p
[
i
].
data
[
2
]
=
float2bf16
(
scores
(
0
,
i
*
4
+
2
));
p
[
i
].
data
[
3
]
=
float2bf16
(
scores
(
0
,
i
*
4
+
3
));
#endif
}
int
row_offset_v
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
row_offset_v
[
i
]
=
calc_row_and_col_v
(
i
);
}
// __syncthreads();
// #if 1
{
constexpr
int
k_val
=
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
0
,
block_idx
,
offset_k
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
0
,
block_idx
,
offset_k
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
2
],
col_offset_v
,
k_val
+
2
,
0
,
block_idx
,
offset_k
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
3
],
col_offset_v
,
k_val
+
3
,
0
,
block_idx
,
offset_k
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
2
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
1
,
block_idx
,
offset_k
);
flash
::
pv_gemm
<
k_val
+
1
,
0
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
1
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
2
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
3
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
1
,
block_idx
,
offset_k
);
flash
::
pv_gemm
<
k_val
+
2
,
0
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
1
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
2
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
3
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
2
],
col_offset_v
,
k_val
+
2
,
1
,
block_idx
,
offset_k
);
flash
::
pv_gemm
<
k_val
+
3
,
0
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
1
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
2
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
3
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
3
],
col_offset_v
,
k_val
+
3
,
1
,
block_idx
,
offset_k
);
}
#define LOAD_V_AND_PV_GEMM(n) \
{ \
constexpr int k_val = (0); \
constexpr int n_val = (n); \
flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, n_val + 1, block_idx, offset_k); \
flash::pv_gemm<k_val + 1, n_val * 4>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 1>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 2>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 1, n_val * 4 + 3>(p[k_val + 1].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, n_val + 1, block_idx, offset_k); \
flash::pv_gemm<k_val + 2, n_val * 4>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 1>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 2>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 2, n_val * 4 + 3>(p[k_val + 2].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 2], col_offset_v, k_val + 2, n_val + 1, block_idx, offset_k); \
flash::pv_gemm<k_val + 3, n_val * 4>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 1>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 2>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val + 3, n_val * 4 + 3>(p[k_val + 3].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 3], col_offset_v, k_val + 3, n_val + 1, block_idx, offset_k); \
}
LOAD_V_AND_PV_GEMM
(
1
);
LOAD_V_AND_PV_GEMM
(
2
);
{
constexpr
int
n_val
=
(
3
);
flash
::
pv_gemm
<
0
,
12
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
13
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
14
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
15
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
1
,
12
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
13
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
14
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
15
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
2
,
12
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
13
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
14
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
15
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
3
,
12
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
13
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
14
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
15
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
};
constexpr
int
n_masking_steps
=
!
T
::
Is_causal
?
1
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
;
int
n_block
=
n_block_max
-
1
;
if
constexpr
(
n_masking_steps
==
1
)
{
if
(
n_block
>=
n_block_min
)
{
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
}
n_block
--
;
}
else
{
int
masking_step
=
1
;
if
(
n_block
>=
n_block_min
)
{
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
}
n_block
--
;
for
(;
n_block
>=
n_block_min
&&
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
process_one_block
(
n_block
,
IsMaskBlock
{});
}
}
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
process_one_block
(
n_block
,
IsNoMaskBlock
{});
}
using
ElementAccum
=
float
;
// if constexpr (!use_split_kv)
if
(
!
is_split
)
{
using
ElementO
=
Element
;
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_prefill_4x1
<
/*Is_dropout=*/
false
,
Split
>(
acco_f32
,
params
.
scale_softmax
);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lse
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
row
,
col
;
// int warpid = tidx / 64;
for
(
int
mi
=
0
;
mi
<
1
;
mi
++
)
{
row
=
mi
*
kBlockM
+
lane_idx
%
16
+
warp_idx
*
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
ni
=
0
;
ni
<
16
;
++
ni
)
{
#if defined(__gfx938__)
Bf16_storage
res
;
col
=
(
lane_idx
/
16
)
*
8
+
ni
*
32
;
res
.
data_32
[
0
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
0
],
0
,
acco_f32
[
ni
*
2
+
1
][
0
],
0
);
res
.
data_32
[
1
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
1
],
0
,
acco_f32
[
ni
*
2
+
1
][
1
],
0
);
res
.
data_32
[
2
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
2
],
0
,
acco_f32
[
ni
*
2
+
1
][
2
],
0
);
res
.
data_32
[
3
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
3
],
0
,
acco_f32
[
ni
*
2
+
1
][
3
],
0
);
*
(
__fp16x8_t
*
)(
&
gOaccum
(
row
,
col
))
=
res
.
data_128
;
#else
col
=
(
lane_idx
/
16
)
*
2
+
ni
*
32
;
using
result_type
=
cutlass
::
Array
<
Element
,
2
>
;
for
(
int
ei
=
0
;
ei
<
4
;
ei
++
)
{
result_type
res
;
Element
e0
,
e1
;
e0
.
storage
=
float2bf16
(
acco_f32
[
ni
*
2
][
ei
]);
e1
.
storage
=
float2bf16
(
acco_f32
[
ni
*
2
+
1
][
ei
]);
res
[
0
]
=
e0
;
res
[
1
]
=
e1
;
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*
(
result_type
*
)(
&
gOaccum
(
row
,
col
))
=
res
;
col
+=
8
;
}
#endif
}
// for (int n = 0; n < 1; n++) {
// col = (tidx % 64 / 16) + warpid * 32 + n * 128;
// for (int ei = 0; ei < 8; ei ++) {
// gOaccum(row, col) = rO(ei, m, n);
// col += 4;
// }
// }
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
}
else
{
using
ElementO
=
float
;
int
num_splits
=
params
.
total_num_splits
/
params
.
b
;
int
split_idx
=
(
blockIdx
.
z
/
params
.
b
)
+
bidb
*
num_splits
;
constexpr
bool
Split
=
true
;
const
index_t
row_offset_oaccum
=
((
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
)
*
T
::
HEAD_DIM_V
;
// (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const
index_t
row_offset_lseaccum
=
(
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_oaccum
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lseaccum
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_prefill_4x1
<
/*Is_dropout=*/
false
,
Split
>(
acco_f32
,
params
.
scale_softmax
);
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
row
,
col
;
// int warpid = tidx / 64;
for
(
int
mi
=
0
;
mi
<
1
;
mi
++
)
{
row
=
mi
*
kBlockM
+
lane_idx
%
16
+
warp_idx
*
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
ni
=
0
;
ni
<
16
;
++
ni
)
{
#if defined(__gfx938__)
col
=
(
lane_idx
/
16
)
*
8
+
ni
*
32
;
for
(
int
ei
=
0
;
ei
<
4
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
acco_f32
[
ni
*
2
][
ei
];
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
ni
*
2
+
1
][
ei
];
col
+=
2
;
}
#else
col
=
(
lane_idx
/
16
)
*
2
+
ni
*
32
;
for
(
int
ei
=
0
;
ei
<
4
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
acco_f32
[
ni
*
2
][
ei
];
gOaccum
(
row
,
col
+
1
)
=
acco_f32
[
ni
*
2
+
1
][
ei
];
col
+=
8
;
}
#endif
}
// for (int n = 0; n < 1; n++) {
// col = (tidx % 64 / 16) + warpid * 32 + n * 128;
// for (int ei = 0; ei < 8; ei ++) {
// gOaccum(row, col) = rO(ei, m, n);
// col += 4;
// }
// }
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
}
}
template
<
typename
InputT
>
void
run_flash_splitkv_mla_kernel
(
DenseAttnDecodeParams
&
params
)
{
FLASH_ASSERT
(
params
.
d
==
Config
::
HEAD_DIM_K
);
FLASH_ASSERT
(
params
.
d_v
==
Config
::
HEAD_DIM_V
);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
params
.
h_q
>=
64
&&
params
.
h_k
==
1
)
{
using
T
=
Traits_Block_M_64
<
InputT
,
Is_causal
>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
if
(
params
.
use_split_kv
)
{
auto
mla_kernel
=
&
flash_fwd_splitkv_mla_block_m_64_kernel
<
T
,
true
>
;
const
int
num_m_block
=
cute
::
ceil_div
(
params
.
q_seq_per_hk
,
T
::
BLOCK_SIZE_M
);
mla_kernel
<<<
dim3
(
num_m_block
,
params
.
h_k
,
params
.
total_num_splits
),
T
::
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
}
else
{
auto
mla_kernel
=
&
flash_fwd_splitkv_mla_block_m_64_kernel
<
T
>
;
const
int
num_m_block
=
cute
::
ceil_div
(
params
.
q_seq_per_hk
,
T
::
BLOCK_SIZE_M
);
mla_kernel
<<<
dim3
(
num_m_block
,
params
.
h_k
,
params
.
b
),
T
::
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
}
}
else
{
constexpr
size_t
smem_size
=
65536
;
using
T
=
Traits
<
InputT
,
Is_causal
>
;
const
int
num_m_block
=
cute
::
ceil_div
(
params
.
q_seq_per_hk
,
T
::
BLOCK_SIZE_M
);
auto
mla_kernel
=
&
flash_fwd_splitkv_mla_kernel
<
T
>
;
mla_kernel
<<<
dim3
(
num_m_block
,
params
.
h_k
,
params
.
num_sm_parts
),
T
::
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
}
});
// cudaLaunchConfig_t mla_kernel_config = {
// dim3(num_m_block, params.h_k, params.num_sm_parts),
// dim3(T::NUM_THREADS, 1, 1),
// smem_size,
// params.stream,
// mla_kernel_attributes,
// 1
// };
// cudaLaunchKernelEx(&mla_kernel_config, mla_kernel, params, tma_params);
CHECK_CUDA_KERNEL_LAUNCH
();
}
...
...
csrc/gfx93/decode/dense/traits.h
View file @
aec17474
...
...
@@ -127,3 +127,26 @@ struct Traits {
template
<
typename
InputT_
,
bool
Is_causal_
>
struct
Traits_Block_M_64
{
using
InputT
=
InputT_
;
static
constexpr
bool
Is_causal
=
Is_causal_
;
static
constexpr
int
BLOCK_SIZE_M
=
64
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
64
;
static
constexpr
int
HEAD_DIM_K
=
576
;
static
constexpr
int
HEAD_DIM_V
=
512
;
static
constexpr
int
NUM_THREADS
=
256
;
static_assert
(
std
::
is_same_v
<
InputT
,
cutlass
::
bfloat16_t
>
||
std
::
is_same_v
<
InputT
,
cutlass
::
half_t
>
);
static
constexpr
int
kBlockM
=
BLOCK_SIZE_M
;
static
constexpr
int
kBlockN
=
PAGE_BLOCK_SIZE
;
static
constexpr
int
kHeadDim
=
HEAD_DIM_K
;
static
constexpr
int
kHeadDimV
=
HEAD_DIM_V
;
static
constexpr
int
kNWarps
=
4
;
using
Element
=
InputT
;
using
elem_type
=
Element
;
using
ElementAccum
=
float
;
};
\ No newline at end of file
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
aec17474
...
...
@@ -236,7 +236,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int v_idx = row_offset;
int
offset_v
=
col_offset
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
%
1
)
*
512
*
16
*
2
+
n_idx
*
128
*
16
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
)
*
128
*
16
*
2
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
;
...
...
@@ -474,7 +474,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif
}
softmax
.
template
softmax_rescale_o_prefill_4x1
<
/*Is_first=*/
IS_FIRST_BLOCK
,
/*Check_inf=*//*Is_local=*/
fals
e
>(
scores
,
acco_f32
,
params
.
sm_scale_div_log2
);
softmax
.
template
softmax_rescale_o_prefill_4x1
<
/*Is_first=*/
IS_FIRST_BLOCK
,
/*Check_inf=*//*Is_local=*/
tru
e
>(
scores
,
acco_f32
,
params
.
sm_scale_div_log2
);
Bf16_storage_x4
p
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
...
...
@@ -500,9 +500,9 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
constexpr
int
k_val
=
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
1
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
2
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
3
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
2
],
col_offset_v
,
k_val
+
2
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
3
],
col_offset_v
,
k_val
+
3
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -513,109 +513,111 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
1
);
flash
::
pv_gemm
<
k_val
,
4
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
5
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
6
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
7
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
0
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
1
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
2
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
3
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
1
);
flash
::
pv_gemm
<
k_val
,
8
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
9
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
10
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
11
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
0
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
1
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
2
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
3
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
2
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
2
],
col_offset_v
,
k_val
+
2
,
1
);
flash
::
pv_gemm
<
k_val
,
12
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
14
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
15
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
0
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
1
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
2
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
3
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
3
);
}
buffer_load_lds_v
(
row_offset_v
[
k_val
+
3
],
col_offset_v
,
k_val
+
3
,
1
);
#define LOAD_V_AND_PV_GEMM(k) \
}
#define LOAD_V_AND_PV_GEMM(n) \
{ \
constexpr int k_val = (k); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
constexpr int k_val = (0); \
constexpr int n_val = (n); \
flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val
+ 1
], col_offset_v, k_val + 1
, 0
); \
flash::pv_gemm<k_val
,
4>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 5
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 6
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 7
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val
, n_val
+ 1); \
flash::pv_gemm<k_val
+ 1, n_val *
4>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 1, n_val * 4 + 1
>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 1, n_val * 4 + 2
>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 1, n_val * 4 + 3
>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); \
flash::pv_gemm<k_val
, 8
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 9
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 10
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 11
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1,
n_val +
1); \
flash::pv_gemm<k_val
+ 2, n_val * 4
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 2, n_val * 4 + 1
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 2, n_val * 4 + 2
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 2, n_val * 4 + 3
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val +
1
], col_offset_v, k_val +
1
,
2
); \
flash::pv_gemm<k_val
, 12
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 13
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 14
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 15
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val +
2
], col_offset_v, k_val +
2
,
n_val + 1
); \
flash::pv_gemm<k_val
+ 3, n_val * 4
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 3, n_val * 4 + 1
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 3, n_val * 4 + 2
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 3, n_val * 4 + 3
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val +
1
], col_offset_v, k_val +
1
,
3
); \
buffer_load_lds_v(row_offset_v[k_val +
3
], col_offset_v, k_val +
3
,
n_val + 1
); \
}
LOAD_V_AND_PV_GEMM
(
1
);
LOAD_V_AND_PV_GEMM
(
2
);
{
constexpr
int
k
_val
=
(
3
);
flash
::
pv_gemm
<
k_val
,
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
2
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
constexpr
int
n
_val
=
(
3
);
flash
::
pv_gemm
<
0
,
12
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
1
3
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
14
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
15
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
4
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
5
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
6
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
7
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
12
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
13
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
14
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
15
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
8
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
9
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
12
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
13
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
1
4
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
1
5
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
12
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
13
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
14
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
15
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
12
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
13
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
14
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
15
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
#else
#define LOAD_V_AND_PV_GEMM(k) \
{ \
...
...
csrc/params.h
View file @
aec17474
...
...
@@ -58,6 +58,10 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
float
*
__restrict__
oaccum_ptr
;
cudaStream_t
stream
;
bool
use_split_kv
;
int
partition_block_nums
;
};
struct
DenseAttnDecodeParams_fp8
:
public
DenseAttnDecodeParams
{
...
...
@@ -127,6 +131,12 @@ struct CombineParams {
float
*
attn_sink
;
// [h_q], may be nullptr
cudaStream_t
stream
;
bool
use_split_kv
;
int
num_splits
;
int
*
__restrict__
seqlens_k_ptr
;
int
partition_block_nums
;
};
struct
GetDecodeSchedMetaParams
{
...
...
csrc/softmax.h
View file @
aec17474
...
...
@@ -621,7 +621,7 @@ struct Softmax {
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
row_max
);
++
mi
)
{
float
scores_max_cur
=
!
true
float
scores_max_cur
=
!
Check_inf
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
...
...
csrc/utils.h
View file @
aec17474
...
...
@@ -1553,7 +1553,7 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
#endif
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
template
<
typename
Element
,
int
k_idx
>
template
<
typename
Element
,
int
k_idx
,
int
k_mod
=
4
>
__forceinline__
__device__
void
qk_gemm
(
const
__fp16x8_t
&
q_data
,
Element
*
k_lds_read_ptr
,
v4f
*
accs_f32
)
{
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
...
...
@@ -1563,7 +1563,7 @@ __forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_idx_even
=
k_idx
%
4
;
constexpr
int
k_idx_even
=
k_idx
%
k_mod
;
constexpr
int
n_offset
=
16
*
32
;
constexpr
int
k_offset
=
k_idx_even
*
64
*
32
;
Bf16_storage
q_reg
;
...
...
@@ -1616,7 +1616,7 @@ typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
template
<
int
k_idx
,
int
n_idx_val
>
__forceinline__
__device__
void
pv_gemm
(
const
__fp16x4_t
&
p
,
int
v_lds_read_ptr
,
v4f
*
acco_f32
)
{
constexpr
int
k_idx_even
=
k_idx
%
1
;
constexpr
int
k_idx_even
=
k_idx
;
constexpr
int
n_offset
=
16
*
32
*
2
;
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
union
Bf16_storage
{
...
...
@@ -1624,11 +1624,11 @@ __forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr,
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_offset
=
k_idx_even
*
16
*
5
12
*
2
;
constexpr
int
k_offset
=
k_idx_even
*
16
*
12
8
*
2
;
// #if 1
Bf16_storage
v_reg
;
v_reg
.
data_128
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
v_lds_read_ptr
),
k_offset
+
n_idx_val
*
n_offset
);
v_reg
.
data_128
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
v_lds_read_ptr
),
k_offset
+
(
n_idx_val
%
4
)
*
n_offset
);
#if defined(__gfx938__)
acco_f32
[
n_idx_val
*
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
0
],
acco_f32
[
n_idx_val
*
2
],
true
,
false
);
acco_f32
[
n_idx_val
*
2
+
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
1
],
acco_f32
[
n_idx_val
*
2
+
1
],
true
,
false
);
...
...
tests/test_flash_mla_dense_decoding.py
View file @
aec17474
...
...
@@ -172,7 +172,7 @@ def test_flash_mla(t: TestParam):
assert
is_correct
if
t
.
test_performance
:
time_usage
=
kk
.
bench_kineto
(
run_flash_mla
,
10
).
get_kernel_time
(
"flash_fwd_splitkv_mla
_kernel
"
)
time_usage
=
kk
.
bench_kineto
(
run_flash_mla
,
10
).
get_kernel_time
(
"flash_fwd_splitkv_mla"
)
mean_attended_seqlens
=
cache_seqlens
.
float
().
mean
().
item
()
compute_volume_flop
=
t
.
b
*
t
.
h_q
*
t
.
s_q
*
sum
([
...
...
@@ -226,7 +226,7 @@ def main(torch_dtype):
TestParam
(
128
,
s_q
,
s_k
,
is_varlen
=
True
,
is_causal
=
is_causal
,
h_q
=
h_q
,
test_performance
=
True
)
for
is_causal
in
[
False
,
True
]
for
s_q
in
[
1
,
2
]
for
h_q
in
[
16
,
128
]
for
h_q
in
[
16
,
64
,
128
]
for
s_k
in
[
4096
,
8192
,
16384
,
32768
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment