Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
aec17474
Commit
aec17474
authored
Apr 20, 2026
by
zhanghj2
Browse files
Feature/kimi nhead64 dense
parent
a45f646b
Changes
9
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
9 changed files
with
1590 additions
and
159 deletions
+1590
-159
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+88
-38
csrc/gfx9/decode/combine/combine.cu
csrc/gfx9/decode/combine/combine.cu
+61
-30
csrc/gfx93/decode/dense/splitkv_mla.cuh
csrc/gfx93/decode/dense/splitkv_mla.cuh
+1335
-20
csrc/gfx93/decode/dense/traits.h
csrc/gfx93/decode/dense/traits.h
+23
-0
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+65
-63
csrc/params.h
csrc/params.h
+10
-0
csrc/softmax.h
csrc/softmax.h
+1
-1
csrc/utils.h
csrc/utils.h
+5
-5
tests/test_flash_mla_dense_decoding.py
tests/test_flash_mla_dense_decoding.py
+2
-2
No files found.
csrc/api/dense_decode.h
View file @
aec17474
...
...
@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS
(
out
);
KU_CHECK_CONTIGUOUS
(
lse
);
if
(
!
tile_scheduler_metadata
.
has_value
())
{
if
(
!
tile_scheduler_metadata
.
has_value
()
&&
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
)
)
{
tile_scheduler_metadata
=
torch
::
empty
({
num_sm_parts
,
sizeof
(
DecodingSchedMeta
)
/
4
},
opts
.
dtype
(
torch
::
kInt32
));
num_splits
=
torch
::
empty
({
batch_size
+
1
},
opts
.
dtype
(
torch
::
kInt32
));
KU_CHECK_CONTIGUOUS
(
tile_scheduler_metadata
);
...
...
@@ -125,20 +125,6 @@ dense_attn_decode_interface(
if
(
const
char
*
val
=
std
::
getenv
(
"FLASH_MLA_PRINT_PARAM"
))
{
print_param
=
(
std
::
string
(
val
)
==
"1"
);
}
if
(
print_param
)
{
fprintf
(
stderr
,
"[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d
\n
"
,
arch
.
archName
.
c_str
(),
batch_size
,
seqlen_q_ori
,
num_heads_q
,
head_size_k
,
max_num_blocks_per_seq
,
num_blocks
,
page_block_size
,
num_heads_k
);
}
// Set the sizes
DenseAttnDecodeParams
params
;
params
.
b
=
batch_size
;
...
...
@@ -173,20 +159,78 @@ dense_attn_decode_interface(
params
.
block_table
=
block_table
.
data_ptr
<
int
>
();
params
.
block_table_batch_stride
=
block_table
.
stride
(
0
);
params
.
page_block_size
=
page_block_size
;
params
.
tile_scheduler_metadata_ptr
=
(
DecodingSchedMeta
*
)
tile_scheduler_metadata
->
data_ptr
();
params
.
num_sm_parts
=
num_sm_parts
;
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
const
int
total_num_splits
=
batch_size
+
params
.
num_sm_parts
;
at
::
Tensor
lse_accum
=
torch
::
empty
({
total_num_splits
,
num_heads
,
q_seq_per_hk
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
total_num_splits
,
num_heads
,
q_seq_per_hk
,
head_size_v
},
opts
.
dtype
(
at
::
kFloat
));
KU_CHECK_CONTIGUOUS
(
lse_accum
);
KU_CHECK_CONTIGUOUS
(
out_accum
);
params
.
total_num_splits
=
total_num_splits
;
params
.
softmax_lseaccum_ptr
=
lse_accum
.
data_ptr
<
float
>
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
<
float
>
();
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
)
{
params
.
tile_scheduler_metadata_ptr
=
(
DecodingSchedMeta
*
)
tile_scheduler_metadata
->
data_ptr
();
params
.
num_sm_parts
=
num_sm_parts
;
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
const
int
total_num_splits
=
batch_size
+
params
.
num_sm_parts
;
at
::
Tensor
lse_accum
=
torch
::
empty
({
total_num_splits
,
num_heads
,
q_seq_per_hk
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
total_num_splits
,
num_heads
,
q_seq_per_hk
,
head_size_v
},
opts
.
dtype
(
at
::
kFloat
));
KU_CHECK_CONTIGUOUS
(
lse_accum
);
KU_CHECK_CONTIGUOUS
(
out_accum
);
params
.
total_num_splits
=
total_num_splits
;
params
.
softmax_lseaccum_ptr
=
lse_accum
.
data_ptr
<
float
>
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
<
float
>
();
params
.
use_split_kv
=
false
;
}
else
{
bool
use_split_kv
=
true
;
int
num_m_blocks
=
(
params
.
q_seq_per_hk
+
64
-
1
)
/
64
;
int
num_sms
=
arch
.
num_sms
;
int
num_splits
=
num_sms
*
3
/
(
num_m_blocks
*
params
.
b
);
if
(
max_num_blocks_per_seq
>=
32768
/
64
)
{
num_splits
=
32
;
}
else
if
(
max_num_blocks_per_seq
>=
16384
/
64
)
{
num_splits
=
32
;
}
else
if
(
max_num_blocks_per_seq
>=
8192
/
64
)
{
num_splits
=
16
;
}
else
if
(
max_num_blocks_per_seq
>=
4096
/
64
)
{
num_splits
=
8
;
}
else
if
(
max_num_blocks_per_seq
>=
2048
/
64
)
{
num_splits
=
4
;
}
else
{
num_splits
=
1
;
}
if
(
params
.
b
>=
128
)
{
num_splits
=
1
;
}
if
(
num_splits
<=
1
)
{
use_split_kv
=
false
;
}
else
{
num_splits
=
std
::
min
(
num_splits
,
240
);
params
.
partition_block_nums
=
max_num_blocks_per_seq
/
num_splits
;
}
if
(
params
.
partition_block_nums
<=
4
)
{
use_split_kv
=
false
;
}
params
.
use_split_kv
=
use_split_kv
;
params
.
total_num_splits
=
params
.
b
*
num_splits
;
at
::
Tensor
lse_accum
=
torch
::
empty
({
params
.
total_num_splits
,
num_heads
,
q_seq_per_hk
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
out_accum
=
torch
::
empty
({
params
.
total_num_splits
,
num_heads
,
q_seq_per_hk
,
head_size_v
},
opts
.
dtype
(
at
::
kFloat
));
KU_CHECK_CONTIGUOUS
(
lse_accum
);
KU_CHECK_CONTIGUOUS
(
out_accum
);
params
.
softmax_lseaccum_ptr
=
lse_accum
.
data_ptr
<
float
>
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
<
float
>
();
}
if
(
print_param
)
{
fprintf
(
stderr
,
"[FlashMLA] [dense_attn_decode_interface] [%s] batch_size = %d seqlen_q_ori = %d "
"num_heads_q = %d head_size_k = %d max_num_blocks_per_seq = %d num_blocks %d page_block_size = %d num_heads_k = %d use_split_kv = %d num_splits %d params.partition_block_nums = %d
\n
"
,
arch
.
archName
.
c_str
(),
batch_size
,
seqlen_q_ori
,
num_heads_q
,
head_size_k
,
max_num_blocks_per_seq
,
num_blocks
,
page_block_size
,
num_heads_k
,
params
.
use_split_kv
,
params
.
total_num_splits
/
params
.
b
,
params
.
partition_block_nums
);
}
params
.
stream
=
at
::
cuda
::
getCurrentCUDAStream
().
stream
();
if
(
q_dtype
==
torch
::
kBFloat16
)
{
...
...
@@ -220,17 +264,23 @@ dense_attn_decode_interface(
params
.
num_sm_parts
,
nullptr
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
(),
params
.
use_split_kv
,
params
.
total_num_splits
/
params
.
b
,
params
.
seqlens_k_ptr
,
params
.
partition_block_nums
};
if
(
q_dtype
==
torch
::
kBFloat16
)
{
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
#ifndef FLASH_MLA_DISABLE_FP16
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
half_t
>
(
combine_params
);
#endif
}
else
{
TORCH_CHECK
(
false
,
"Unsupported tensor dtype for query"
);
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
params
.
use_split_kv
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
#ifndef FLASH_MLA_DISABLE_FP16
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
half_t
>
(
combine_params
);
#endif
}
else
{
TORCH_CHECK
(
false
,
"Unsupported tensor dtype for query"
);
}
}
out
=
out
.
view
({
batch_size
,
num_heads_k
,
seqlen_q_ori
,
num_q_heads_per_hk
,
head_size_v
}).
transpose
(
1
,
2
)
...
...
csrc/gfx9/decode/combine/combine.cu
View file @
aec17474
...
...
@@ -16,7 +16,7 @@ using namespace cute;
namespace
gfx9
::
decode
{
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
>
template
<
typename
ElementT
,
int
HEAD_DIM_V
,
int
BLOCK_SIZE_M
,
int
MAX_SPLITS
,
int
NUM_THREADS
,
bool
USE_SPLIT_KV
=
false
>
__global__
void
__launch_bounds__
(
NUM_THREADS
,
1
)
flash_fwd_mla_combine_kernel
(
const
CombineParams
params
)
{
// grid_shape: [batch_size, s_q, h_q/BLOCK_SIZE_M]
...
...
@@ -32,14 +32,38 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
if
(
warp_idx
>=
num_valid_heads
)
{
return
;
}
const
int
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
const
int
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
const
int
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
int
start_split_idx
;
int
end_split_idx
;
int
my_num_splits
;
if
constexpr
(
USE_SPLIT_KV
)
{
start_split_idx
=
batch_idx
*
params
.
num_splits
;
end_split_idx
=
(
batch_idx
+
1
)
*
params
.
num_splits
;
int
seqlen_k
=
__ldg
(
params
.
seqlens_k_ptr
+
batch_idx
);
end_split_idx
=
std
::
min
(
cute
::
ceil_div
(
cute
::
ceil_div
(
seqlen_k
,
64
),
params
.
partition_block_nums
),
params
.
num_splits
)
+
start_split_idx
;
// if (lane_idx == 0 && batch_idx == 61)
// {
// printf(" batch_idx = %d start_split_idx = %d end_split_idx = %d seqlen_k = %d \n",batch_idx, start_split_idx, end_split_idx, seqlen_k);
// }
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
}
}
else
{
start_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
);
end_split_idx
=
__ldg
(
params
.
num_splits_ptr
+
batch_idx
+
1
);
my_num_splits
=
end_split_idx
-
start_split_idx
;
if
(
my_num_splits
==
1
)
{
return
;
}
}
// FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
Tensor
gLseAccum
=
make_tensor
(
...
...
@@ -245,6 +269,9 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
} else if (NUM_SPLITS <= 160) { \
constexpr static int NAME = 160; \
return __VA_ARGS__(); \
} else if (NUM_SPLITS <= 240) { \
constexpr static int NAME = 240; \
return __VA_ARGS__(); \
} else { \
FLASH_ASSERT(false); \
} \
...
...
@@ -255,29 +282,33 @@ template<typename ElementT>
void
run_flash_mla_combine_kernel
(
CombineParams
&
params
)
{
static
constexpr
int
HEAD_DIM_V
=
512
;
// Since only this head dimension is supported by Flash MLA
FLASH_ASSERT
(
params
.
d_v
==
HEAD_DIM_V
);
MLA_NUM_SPLITS_SWITCH
(
params
.
num_sm_parts
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
>
;
// CHECK_CUDA(cudaFuncSetAttribute(combine_kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size));
// // Use cudaLaunchKernelEx to enable PDL (Programmatic Dependent Launch)
// cudaLaunchAttribute attribute[1];
// attribute[0].id = cudaLaunchAttributeProgrammaticStreamSerialization;
// attribute[0].val.programmaticStreamSerializationAllowed = 1;
// cudaLaunchConfig_t combine_kernel_config = {
// dim3(params.b, params.s_q, ku::ceil_div(params.h_q, BLOCK_SIZE_M)),
// dim3(NUM_THREADS, 1, 1),
// 0,
// params.stream,
// attribute,
// 1
// };
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
if
(
params
.
use_split_kv
)
{
MLA_NUM_SPLITS_SWITCH
(
params
.
num_splits
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
,
true
>
;
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
}
else
{
MLA_NUM_SPLITS_SWITCH
(
params
.
num_sm_parts
,
NUM_SPLITS
,
[
&
]
{
constexpr
int
BLOCK_SIZE_M
=
4
;
constexpr
int
NUM_THREADS
=
BLOCK_SIZE_M
*
64
;
constexpr
size_t
smem_size
=
BLOCK_SIZE_M
*
(
NUM_SPLITS
+
1
)
*
sizeof
(
float
);
auto
combine_kernel
=
&
flash_fwd_mla_combine_kernel
<
ElementT
,
HEAD_DIM_V
,
BLOCK_SIZE_M
,
NUM_SPLITS
,
NUM_THREADS
>
;
combine_kernel
<<<
dim3
(
params
.
b
,
params
.
s_q
,
ku
::
ceil_div
(
params
.
h_q
,
BLOCK_SIZE_M
)),
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
});
}
CHECK_CUDA_KERNEL_LAUNCH
();
}
...
...
csrc/gfx93/decode/dense/splitkv_mla.cuh
View file @
aec17474
This diff is collapsed.
Click to expand it.
csrc/gfx93/decode/dense/traits.h
View file @
aec17474
...
...
@@ -127,3 +127,26 @@ struct Traits {
template
<
typename
InputT_
,
bool
Is_causal_
>
struct
Traits_Block_M_64
{
using
InputT
=
InputT_
;
static
constexpr
bool
Is_causal
=
Is_causal_
;
static
constexpr
int
BLOCK_SIZE_M
=
64
;
static
constexpr
int
PAGE_BLOCK_SIZE
=
64
;
static
constexpr
int
HEAD_DIM_K
=
576
;
static
constexpr
int
HEAD_DIM_V
=
512
;
static
constexpr
int
NUM_THREADS
=
256
;
static_assert
(
std
::
is_same_v
<
InputT
,
cutlass
::
bfloat16_t
>
||
std
::
is_same_v
<
InputT
,
cutlass
::
half_t
>
);
static
constexpr
int
kBlockM
=
BLOCK_SIZE_M
;
static
constexpr
int
kBlockN
=
PAGE_BLOCK_SIZE
;
static
constexpr
int
kHeadDim
=
HEAD_DIM_K
;
static
constexpr
int
kHeadDimV
=
HEAD_DIM_V
;
static
constexpr
int
kNWarps
=
4
;
using
Element
=
InputT
;
using
elem_type
=
Element
;
using
ElementAccum
=
float
;
};
\ No newline at end of file
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
aec17474
...
...
@@ -236,7 +236,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
// int v_idx = row_offset;
int
offset_v
=
col_offset
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
%
1
)
*
512
*
16
*
2
+
n_idx
*
128
*
16
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
)
*
128
*
16
*
2
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
;
...
...
@@ -474,7 +474,7 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
#endif
}
softmax
.
template
softmax_rescale_o_prefill_4x1
<
/*Is_first=*/
IS_FIRST_BLOCK
,
/*Check_inf=*//*Is_local=*/
fals
e
>(
scores
,
acco_f32
,
params
.
sm_scale_div_log2
);
softmax
.
template
softmax_rescale_o_prefill_4x1
<
/*Is_first=*/
IS_FIRST_BLOCK
,
/*Check_inf=*//*Is_local=*/
tru
e
>(
scores
,
acco_f32
,
params
.
sm_scale_div_log2
);
Bf16_storage_x4
p
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
...
...
@@ -500,121 +500,123 @@ __device__ void KernelTemplate_B_H_64<D_QK, HAVE_TOPK_LENGTH, IS_TOPK_2048>::dev
{
constexpr
int
k_val
=
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
1
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
2
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
3
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
2
],
col_offset_v
,
k_val
+
2
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
3
],
col_offset_v
,
k_val
+
3
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
2
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
0
);
flash
::
pv_gemm
<
k_val
,
4
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
5
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
6
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
7
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
1
);
flash
::
pv_gemm
<
k_val
+
1
,
0
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
1
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
2
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
1
,
3
>
(
p
[
k_val
+
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
1
);
flash
::
pv_gemm
<
k_val
,
8
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
9
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
10
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
11
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
0
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
1
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
2
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
2
,
3
>
(
p
[
k_val
+
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
2
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
2
],
col_offset_v
,
k_val
+
2
,
1
);
flash
::
pv_gemm
<
k_val
,
12
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
14
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
15
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
0
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
1
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
2
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
+
3
,
3
>
(
p
[
k_val
+
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
3
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
3
],
col_offset_v
,
k_val
+
3
,
1
);
}
#define LOAD_V_AND_PV_GEMM(k) \
#define LOAD_V_AND_PV_GEMM(n) \
{ \
constexpr int k_val = (k); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
constexpr int k_val = (0); \
constexpr int n_val = (n); \
flash::pv_gemm<k_val, n_val * 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, n_val * 4 + 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val
+ 1
], col_offset_v, k_val + 1
, 0
); \
flash::pv_gemm<k_val
,
4>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 5
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 6
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 7
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val
, n_val
+ 1); \
flash::pv_gemm<k_val
+ 1, n_val *
4>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 1, n_val * 4 + 1
>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 1, n_val * 4 + 2
>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 1, n_val * 4 + 3
>(p[k_val +
1
].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); \
flash::pv_gemm<k_val
, 8
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 9
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 10
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 11
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1,
n_val +
1); \
flash::pv_gemm<k_val
+ 2, n_val * 4
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 2, n_val * 4 + 1
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 2, n_val * 4 + 2
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 2, n_val * 4 + 3
>(p[k_val +
2
].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val +
1
], col_offset_v, k_val +
1
,
2
); \
flash::pv_gemm<k_val
, 12
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 13
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 14
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
, 15
>(p[k_val +
0
].data_64, v_lds_read_ptr, acco_f32); \
buffer_load_lds_v(row_offset_v[k_val +
2
], col_offset_v, k_val +
2
,
n_val + 1
); \
flash::pv_gemm<k_val
+ 3, n_val * 4
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 3, n_val * 4 + 1
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 3, n_val * 4 + 2
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val
+ 3, n_val * 4 + 3
>(p[k_val +
3
].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val +
1
], col_offset_v, k_val +
1
,
3
); \
buffer_load_lds_v(row_offset_v[k_val +
3
], col_offset_v, k_val +
3
,
n_val + 1
); \
}
LOAD_V_AND_PV_GEMM
(
1
);
LOAD_V_AND_PV_GEMM
(
2
);
{
constexpr
int
k
_val
=
(
3
);
flash
::
pv_gemm
<
k_val
,
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
2
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
constexpr
int
n
_val
=
(
3
);
flash
::
pv_gemm
<
0
,
12
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
1
3
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
14
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
0
,
15
>
(
p
[
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
4
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
5
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
6
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
7
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
12
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
13
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
14
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
1
,
15
>
(
p
[
1
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
8
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
9
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
12
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
13
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
1
4
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
2
,
1
5
>
(
p
[
2
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
12
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
13
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
14
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
15
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
12
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
13
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
14
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
3
,
15
>
(
p
[
3
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
#else
#define LOAD_V_AND_PV_GEMM(k) \
...
...
csrc/params.h
View file @
aec17474
...
...
@@ -58,6 +58,10 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
float
*
__restrict__
oaccum_ptr
;
cudaStream_t
stream
;
bool
use_split_kv
;
int
partition_block_nums
;
};
struct
DenseAttnDecodeParams_fp8
:
public
DenseAttnDecodeParams
{
...
...
@@ -127,6 +131,12 @@ struct CombineParams {
float
*
attn_sink
;
// [h_q], may be nullptr
cudaStream_t
stream
;
bool
use_split_kv
;
int
num_splits
;
int
*
__restrict__
seqlens_k_ptr
;
int
partition_block_nums
;
};
struct
GetDecodeSchedMetaParams
{
...
...
csrc/softmax.h
View file @
aec17474
...
...
@@ -621,7 +621,7 @@ struct Softmax {
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
row_max
);
++
mi
)
{
float
scores_max_cur
=
!
true
float
scores_max_cur
=
!
Check_inf
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
...
...
csrc/utils.h
View file @
aec17474
...
...
@@ -1553,7 +1553,7 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
#endif
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
template
<
typename
Element
,
int
k_idx
>
template
<
typename
Element
,
int
k_idx
,
int
k_mod
=
4
>
__forceinline__
__device__
void
qk_gemm
(
const
__fp16x8_t
&
q_data
,
Element
*
k_lds_read_ptr
,
v4f
*
accs_f32
)
{
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
...
...
@@ -1563,7 +1563,7 @@ __forceinline__ __device__ void qk_gemm(const __fp16x8_t& q_data, Element* k_lds
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_idx_even
=
k_idx
%
4
;
constexpr
int
k_idx_even
=
k_idx
%
k_mod
;
constexpr
int
n_offset
=
16
*
32
;
constexpr
int
k_offset
=
k_idx_even
*
64
*
32
;
Bf16_storage
q_reg
;
...
...
@@ -1616,7 +1616,7 @@ typedef __bf16 __fp16x4_t __attribute__((ext_vector_type(4)));
template
<
int
k_idx
,
int
n_idx_val
>
__forceinline__
__device__
void
pv_gemm
(
const
__fp16x4_t
&
p
,
int
v_lds_read_ptr
,
v4f
*
acco_f32
)
{
constexpr
int
k_idx_even
=
k_idx
%
1
;
constexpr
int
k_idx_even
=
k_idx
;
constexpr
int
n_offset
=
16
*
32
*
2
;
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
union
Bf16_storage
{
...
...
@@ -1624,11 +1624,11 @@ __forceinline__ __device__ void pv_gemm(const __fp16x4_t& p, int v_lds_read_ptr,
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_offset
=
k_idx_even
*
16
*
5
12
*
2
;
constexpr
int
k_offset
=
k_idx_even
*
16
*
12
8
*
2
;
// #if 1
Bf16_storage
v_reg
;
v_reg
.
data_128
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
v_lds_read_ptr
),
k_offset
+
n_idx_val
*
n_offset
);
v_reg
.
data_128
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
v_lds_read_ptr
),
k_offset
+
(
n_idx_val
%
4
)
*
n_offset
);
#if defined(__gfx938__)
acco_f32
[
n_idx_val
*
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
0
],
acco_f32
[
n_idx_val
*
2
],
true
,
false
);
acco_f32
[
n_idx_val
*
2
+
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
1
],
acco_f32
[
n_idx_val
*
2
+
1
],
true
,
false
);
...
...
tests/test_flash_mla_dense_decoding.py
View file @
aec17474
...
...
@@ -172,7 +172,7 @@ def test_flash_mla(t: TestParam):
assert
is_correct
if
t
.
test_performance
:
time_usage
=
kk
.
bench_kineto
(
run_flash_mla
,
10
).
get_kernel_time
(
"flash_fwd_splitkv_mla
_kernel
"
)
time_usage
=
kk
.
bench_kineto
(
run_flash_mla
,
10
).
get_kernel_time
(
"flash_fwd_splitkv_mla"
)
mean_attended_seqlens
=
cache_seqlens
.
float
().
mean
().
item
()
compute_volume_flop
=
t
.
b
*
t
.
h_q
*
t
.
s_q
*
sum
([
...
...
@@ -226,7 +226,7 @@ def main(torch_dtype):
TestParam
(
128
,
s_q
,
s_k
,
is_varlen
=
True
,
is_causal
=
is_causal
,
h_q
=
h_q
,
test_performance
=
True
)
for
is_causal
in
[
False
,
True
]
for
s_q
in
[
1
,
2
]
for
h_q
in
[
16
,
128
]
for
h_q
in
[
16
,
64
,
128
]
for
s_k
in
[
4096
,
8192
,
16384
,
32768
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment