Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
4e0bdf6e
Commit
4e0bdf6e
authored
Jun 03, 2026
by
shenzhe
Committed by
zhanghj2
Jun 06, 2026
Browse files
Support no-split sparse decode for large batch
parent
97ab7511
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
115 additions
and
32 deletions
+115
-32
csrc/api/sparse_decode.h
csrc/api/sparse_decode.h
+58
-31
csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
...decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
+56
-1
csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h
.../decode/get_decoding_sched_meta/get_decoding_sched_meta.h
+1
-0
No files found.
csrc/api/sparse_decode.h
100644 → 100755
View file @
4e0bdf6e
...
@@ -289,6 +289,17 @@ sparse_attn_decode_interface(
...
@@ -289,6 +289,17 @@ sparse_attn_decode_interface(
}
}
DecodeImplMeta
impl_meta
=
impl
->
get_meta
(
h_q
,
s_q
);
DecodeImplMeta
impl_meta
=
impl
->
get_meta
(
h_q
,
s_q
);
bool
force_no_split_kv
=
false
;
if
(
const
char
*
val
=
std
::
getenv
(
"FLASH_MLA_SPARSE_DECODE_DISABLE_SPLITKV"
))
{
force_no_split_kv
=
(
std
::
string
(
val
)
==
"1"
);
}
constexpr
int
max_sched_meta_smem_size
=
64
*
1024
;
bool
sched_meta_smem_overflow
=
sizeof
(
int
)
*
(
static_cast
<
int64_t
>
(
b
)
*
5
+
1
)
>
max_sched_meta_smem_size
;
bool
use_no_split_kv
=
force_no_split_kv
||
sched_meta_smem_overflow
;
if
(
use_no_split_kv
)
{
constexpr
int
max_grid_z
=
65535
;
impl_meta
.
num_sm_parts
=
std
::
min
(
b
,
max_grid_z
);
}
SparseAttnDecodeParams
params
=
{
SparseAttnDecodeParams
params
=
{
b
,
s_q
,
h_q
,
h_kv
,
d_qk
,
d_v
,
b
,
s_q
,
h_q
,
h_kv
,
d_qk
,
d_v
,
...
@@ -344,7 +355,11 @@ sparse_attn_decode_interface(
...
@@ -344,7 +355,11 @@ sparse_attn_decode_interface(
impl_meta
.
num_sm_parts
,
impl_meta
.
num_sm_parts
,
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
gfx9
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
if
(
use_no_split_kv
)
{
gfx9
::
decode
::
run_get_decoding_sched_meta_no_split_kernel
(
get_sched_meta_params
);
}
else
{
gfx9
::
decode
::
run_get_decoding_sched_meta_kernel
(
get_sched_meta_params
);
}
}
}
// Stick the metadata pointers to `params`
// Stick the metadata pointers to `params`
KU_CHECK_DEVICE
(
tile_scheduler_metadata
);
KU_CHECK_DEVICE
(
tile_scheduler_metadata
);
...
@@ -359,43 +374,55 @@ sparse_attn_decode_interface(
...
@@ -359,43 +374,55 @@ sparse_attn_decode_interface(
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
params
.
num_sm_parts
=
impl_meta
.
num_sm_parts
;
params
.
num_sm_parts
=
impl_meta
.
num_sm_parts
;
// Allocate intermediate buffers for split-KV
if
(
!
use_no_split_kv
)
{
const
int
total_num_splits
=
b
+
impl_meta
.
num_sm_parts
;
// Allocate intermediate buffers for split-KV
lse_accum
=
torch
::
empty
({
total_num_splits
,
s_q
,
h_q
},
opts
.
dtype
(
at
::
kFloat
));
const
int
total_num_splits
=
b
+
impl_meta
.
num_sm_parts
;
o_accum
=
torch
::
empty
({
total_num_splits
,
s_q
,
h_q
,
d_v
},
opts
.
dtype
(
at
::
kFloat
));
lse_accum
=
torch
::
empty
({
total_num_splits
,
s_q
,
h_q
},
opts
.
dtype
(
at
::
kFloat
));
KU_CHECK_CONTIGUOUS
(
lse_accum
);
o_accum
=
torch
::
empty
({
total_num_splits
,
s_q
,
h_q
,
d_v
},
opts
.
dtype
(
at
::
kFloat
));
KU_CHECK_CONTIGUOUS
(
o_accum
);
KU_CHECK_CONTIGUOUS
(
lse_accum
);
params
.
lse_accum
=
lse_accum
.
data_ptr
<
float
>
();
KU_CHECK_CONTIGUOUS
(
o_accum
);
params
.
o_accum
=
o_accum
.
data_ptr
<
float
>
();
params
.
lse_accum
=
lse_accum
.
data_ptr
<
float
>
();
params
.
stride_lse_accum_split
=
int64_stride_to_int
(
lse_accum
.
stride
(
0
));
params
.
o_accum
=
o_accum
.
data_ptr
<
float
>
();
params
.
stride_lse_accum_s_q
=
int64_stride_to_int
(
lse_accum
.
stride
(
1
));
params
.
stride_lse_accum_split
=
int64_stride_to_int
(
lse_accum
.
stride
(
0
));
params
.
stride_o_accum_split
=
int64_stride_to_int
(
o_accum
.
stride
(
0
));
params
.
stride_lse_accum_s_q
=
int64_stride_to_int
(
lse_accum
.
stride
(
1
));
params
.
stride_o_accum_s_q
=
int64_stride_to_int
(
o_accum
.
stride
(
1
));
params
.
stride_o_accum_split
=
int64_stride_to_int
(
o_accum
.
stride
(
0
));
params
.
stride_o_accum_h_q
=
int64_stride_to_int
(
o_accum
.
stride
(
2
));
params
.
stride_o_accum_s_q
=
int64_stride_to_int
(
o_accum
.
stride
(
1
));
params
.
stride_o_accum_h_q
=
int64_stride_to_int
(
o_accum
.
stride
(
2
));
}
else
{
params
.
lse_accum
=
nullptr
;
params
.
o_accum
=
nullptr
;
params
.
stride_lse_accum_split
=
0
;
params
.
stride_lse_accum_s_q
=
0
;
params
.
stride_o_accum_split
=
0
;
params
.
stride_o_accum_s_q
=
0
;
params
.
stride_o_accum_h_q
=
0
;
}
impl
->
run
(
params
,
features
);
impl
->
run
(
params
,
features
);
CombineParams
combine_params
=
{
if
(
!
use_no_split_kv
)
{
b
,
s_q
,
h_q
,
d_v
,
CombineParams
combine_params
=
{
b
,
s_q
,
h_q
,
d_v
,
params
.
lse
,
params
.
lse
,
params
.
out
,
params
.
out
,
params
.
stride_lse_b
,
params
.
stride_lse_s_q
,
params
.
stride_lse_b
,
params
.
stride_lse_s_q
,
params
.
stride_o_b
,
params
.
stride_o_s_q
,
params
.
stride_o_h_q
,
params
.
stride_o_b
,
params
.
stride_o_s_q
,
params
.
stride_o_h_q
,
params
.
lse_accum
,
params
.
lse_accum
,
params
.
o_accum
,
params
.
o_accum
,
params
.
stride_lse_accum_split
,
params
.
stride_lse_accum_s_q
,
params
.
stride_lse_accum_split
,
params
.
stride_lse_accum_s_q
,
params
.
stride_o_accum_split
,
params
.
stride_o_accum_s_q
,
params
.
stride_o_accum_h_q
,
params
.
stride_o_accum_split
,
params
.
stride_o_accum_s_q
,
params
.
stride_o_accum_h_q
,
params
.
tile_scheduler_metadata_ptr
,
params
.
tile_scheduler_metadata_ptr
,
params
.
num_splits_ptr
,
params
.
num_splits_ptr
,
params
.
num_sm_parts
,
params
.
num_sm_parts
,
ku
::
get_optional_tensor_ptr
<
float
>
(
attn_sink
),
ku
::
get_optional_tensor_ptr
<
float
>
(
attn_sink
),
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
at
::
cuda
::
getCurrentCUDAStream
().
stream
()
};
};
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
bf16
>
(
combine_params
);
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
bf16
>
(
combine_params
);
}
delete
impl
;
delete
impl
;
...
...
csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
100644 → 100755
View file @
4e0bdf6e
#include "get_decoding_sched_meta.h"
#include "get_decoding_sched_meta.h"
#include <algorithm>
#include <cuda_runtime_api.h>
#include <cuda_runtime_api.h>
#include <cutlass/fast_math.h>
#include <cutlass/fast_math.h>
#include <kerutils/kerutils.cuh>
#include <kerutils/kerutils.cuh>
...
@@ -105,10 +106,64 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
...
@@ -105,10 +106,64 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
}
}
}
}
__global__
void
__launch_bounds__
(
256
,
1
)
get_mla_metadata_no_split_kernel
(
const
GetDecodeSchedMetaParams
params
)
{
DecodingSchedMeta
*
tile_scheduler_metadata_ptr
=
params
.
tile_scheduler_metadata_ptr
;
int
*
num_splits_ptr
=
params
.
num_splits_ptr
;
int
batch_size
=
params
.
b
;
int
block_size_n
=
params
.
block_size_n
;
int
num_sm_parts
=
params
.
num_sm_parts
;
for
(
int
part_idx
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
part_idx
<
num_sm_parts
;
part_idx
+=
blockDim
.
x
*
gridDim
.
x
)
{
int
begin_req_idx
=
(
static_cast
<
int64_t
>
(
part_idx
)
*
batch_size
)
/
num_sm_parts
;
int
end_req_idx_exclusive
=
(
static_cast
<
int64_t
>
(
part_idx
+
1
)
*
batch_size
)
/
num_sm_parts
;
DecodingSchedMeta
cur_meta
;
cur_meta
.
begin_req_idx
=
begin_req_idx
;
cur_meta
.
end_req_idx
=
end_req_idx_exclusive
-
1
;
cur_meta
.
begin_block_idx
=
0
;
cur_meta
.
begin_split_idx
=
0
;
cur_meta
.
is_first_req_splitted
=
0
;
cur_meta
.
is_last_req_splitted
=
0
;
cur_meta
.
_pad
[
0
]
=
0
;
int
cur_s_k
=
0
;
if
(
begin_req_idx
<
end_req_idx_exclusive
)
{
if
(
params
.
topk
==
-
1
)
{
cur_s_k
=
__ldg
(
params
.
seqlens_k_ptr
+
cur_meta
.
end_req_idx
);
}
else
{
cur_s_k
=
params
.
topk_length
?
__ldg
(
params
.
topk_length
+
cur_meta
.
end_req_idx
)
:
params
.
topk
;
if
(
cur_s_k
==
0
)
cur_s_k
=
1
;
if
(
params
.
extra_topk
)
{
cur_s_k
=
ku
::
ceil
(
cur_s_k
,
block_size_n
);
cur_s_k
+=
params
.
extra_topk_length
?
__ldg
(
params
.
extra_topk_length
+
cur_meta
.
end_req_idx
)
:
params
.
extra_topk
;
}
}
}
cur_meta
.
end_block_idx
=
cutlass
::
ceil_div
(
cur_s_k
,
block_size_n
);
tile_scheduler_metadata_ptr
[
part_idx
]
=
cur_meta
;
}
for
(
int
i
=
blockIdx
.
x
*
blockDim
.
x
+
threadIdx
.
x
;
i
<=
batch_size
;
i
+=
blockDim
.
x
*
gridDim
.
x
)
{
num_splits_ptr
[
i
]
=
i
;
}
}
void
run_get_decoding_sched_meta_kernel
(
GetDecodeSchedMetaParams
&
params
)
{
void
run_get_decoding_sched_meta_kernel
(
GetDecodeSchedMetaParams
&
params
)
{
int
smem_size
=
sizeof
(
int
)
*
(
params
.
b
*
5
+
1
);
int
smem_size
=
sizeof
(
int
)
*
(
static_cast
<
int64_t
>
(
params
.
b
)
*
5
+
1
);
get_mla_metadata_kernel
<<<
1
,
64
,
smem_size
,
params
.
stream
>>>
(
params
);
get_mla_metadata_kernel
<<<
1
,
64
,
smem_size
,
params
.
stream
>>>
(
params
);
CHECK_CUDA_KERNEL_LAUNCH
();
CHECK_CUDA_KERNEL_LAUNCH
();
}
}
void
run_get_decoding_sched_meta_no_split_kernel
(
GetDecodeSchedMetaParams
&
params
)
{
int
grid
=
cutlass
::
ceil_div
(
std
::
max
(
params
.
num_sm_parts
,
params
.
b
+
1
),
256
);
grid
=
std
::
min
(
grid
,
1024
);
get_mla_metadata_no_split_kernel
<<<
grid
,
256
,
0
,
params
.
stream
>>>
(
params
);
CHECK_CUDA_KERNEL_LAUNCH
();
}
}
}
csrc/gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h
100644 → 100755
View file @
4e0bdf6e
...
@@ -5,5 +5,6 @@
...
@@ -5,5 +5,6 @@
namespace
gfx9
::
decode
{
namespace
gfx9
::
decode
{
void
run_get_decoding_sched_meta_kernel
(
GetDecodeSchedMetaParams
&
params
);
void
run_get_decoding_sched_meta_kernel
(
GetDecodeSchedMetaParams
&
params
);
void
run_get_decoding_sched_meta_no_split_kernel
(
GetDecodeSchedMetaParams
&
params
);
}
}
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment