Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
a206ecac
Commit
a206ecac
authored
Jun 08, 2026
by
shenzhe
Browse files
Tune BF16 decode split defaults
parent
3b811287
Changes
2
Show whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
27 additions
and
21 deletions
+27
-21
csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu
csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu
+15
-12
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_reduce.h
...fx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_reduce.h
+12
-9
No files found.
csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu
View file @
a206ecac
...
@@ -48,20 +48,23 @@ static int default_num_splits(int b, int s_q, int topk, int extra_topk) {
...
@@ -48,20 +48,23 @@ static int default_num_splits(int b, int s_q, int topk, int extra_topk) {
return
2
;
return
2
;
}
}
int
split
=
1
;
const
int64_t
decode_tasks
=
static_cast
<
int64_t
>
(
b
)
*
s_q
;
if
(
topk
>
1024
)
{
if
(
topk
==
512
)
{
split
=
32
;
return
decode_tasks
<=
8
?
8
:
1
;
}
else
if
(
topk
==
1024
)
{
split
=
16
;
}
else
if
(
topk
==
512
)
{
split
=
8
;
}
}
if
(
topk
==
1024
)
{
constexpr
int64_t
kMaxDecodeTasksBeforeReducingSplit
=
2048
;
if
(
decode_tasks
<=
4
)
return
16
;
while
(
split
>
1
&&
static_cast
<
int64_t
>
(
b
)
*
s_q
*
split
>
kMaxDecodeTasksBeforeReducingSplit
)
{
if
(
decode_tasks
<=
8
)
return
8
;
split
/=
2
;
return
1
;
}
if
(
topk
>
1024
)
{
if
(
decode_tasks
<=
2
)
return
32
;
if
(
decode_tasks
<=
4
)
return
16
;
if
(
decode_tasks
<=
8
)
return
8
;
if
(
decode_tasks
<=
64
)
return
4
;
return
2
;
}
}
return
split
;
return
1
;
}
}
static
void
check_optional_extra
(
static
void
check_optional_extra
(
...
...
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_reduce.h
View file @
a206ecac
...
@@ -63,13 +63,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
...
@@ -63,13 +63,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
constexpr
int
tx_float_count
=
kHeadDim
>>
6
;
constexpr
int
tx_float_count
=
kHeadDim
>>
6
;
float
tx_accum
[
tx_float_count
]
=
{
0.
f
};
float
tx_accum
[
tx_float_count
]
=
{
0.
f
};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int
oaccum_stride
=
s_m_split_stride
*
kHeadDim
;
int
64_t
oaccum_stride
=
static_cast
<
int64_t
>
(
s_m_split_stride
)
*
kHeadDim
;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int
in_batch_offset
=
block_x
-
bidb
*
params
.
h
*
params
.
seqlen_q
;
int
in_batch_offset
=
block_x
-
bidb
*
params
.
h
*
params
.
seqlen_q
;
int
bidh
=
in_batch_offset
/
params
.
seqlen_q
;
int
bidh
=
in_batch_offset
/
params
.
seqlen_q
;
int
bids
=
in_batch_offset
-
bidh
*
params
.
seqlen_q
;
int
bids
=
in_batch_offset
-
bidh
*
params
.
seqlen_q
;
int
real_block_x
=
params
.
layout
==
0
?
block_x
/*bhsd layout*/
:
bidb
*
params
.
seqlen_q
*
params
.
h
+
bids
*
params
.
h
+
bidh
/*bshd layout*/
;
int64_t
real_block_x
=
params
.
layout
==
0
?
static_cast
<
int64_t
>
(
block_x
)
/*bhsd layout*/
:
int
tx_offset
=
real_block_x
*
kHeadDim
+
(
tx
&
63
)
*
tx_float_count
;
static_cast
<
int64_t
>
(
bidb
)
*
params
.
seqlen_q
*
params
.
h
+
static_cast
<
int64_t
>
(
bids
)
*
params
.
h
+
bidh
/*bshd layout*/
;
int64_t
tx_offset
=
real_block_x
*
kHeadDim
+
(
tx
&
63
)
*
tx_float_count
;
reduceType
*
output_ptr
=
reinterpret_cast
<
reduceType
*>
(
params
.
o_ptr
)
+
tx_offset
;
reduceType
*
output_ptr
=
reinterpret_cast
<
reduceType
*>
(
params
.
o_ptr
)
+
tx_offset
;
accumType
*
oaccum_ptr
=
reinterpret_cast
<
accumType
*>
(
params
.
oaccum_ptr
);
accumType
*
oaccum_ptr
=
reinterpret_cast
<
accumType
*>
(
params
.
oaccum_ptr
);
// num_splits may not be 64, and thus need boundary judgement
// num_splits may not be 64, and thus need boundary judgement
...
@@ -207,14 +208,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
...
@@ -207,14 +208,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float
tx_accum
[
tx_float_count
]
=
{
0.
f
};
float
tx_accum
[
tx_float_count
]
=
{
0.
f
};
static_assert
(
tx_float_count
*
128
<
LDS_SIZE
&&
"for each thread, it's not allowed to processing more than 8 half data"
);
static_assert
(
tx_float_count
*
128
<
LDS_SIZE
&&
"for each thread, it's not allowed to processing more than 8 half data"
);
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int
oaccum_stride
=
s_m_split_stride
*
kHeadDim
;
int
64_t
oaccum_stride
=
static_cast
<
int64_t
>
(
s_m_split_stride
)
*
kHeadDim
;
// each wave read data from 0 in 128 halfs, and thus (tx % 64)
// each wave read data from 0 in 128 halfs, and thus (tx % 64)
// int tx_offset = block_x * kHeadDim + (tx & 63) * tx_float_count;
// int tx_offset = block_x * kHeadDim + (tx & 63) * tx_float_count;
int
in_batch_offset
=
block_x
-
bidb
*
params
.
h
*
params
.
seqlen_q
;
int
in_batch_offset
=
block_x
-
bidb
*
params
.
h
*
params
.
seqlen_q
;
int
bidh
=
in_batch_offset
/
params
.
seqlen_q
;
int
bidh
=
in_batch_offset
/
params
.
seqlen_q
;
int
bids
=
in_batch_offset
-
bidh
*
params
.
seqlen_q
;
int
bids
=
in_batch_offset
-
bidh
*
params
.
seqlen_q
;
int
real_block_x
=
params
.
layout
==
0
?
block_x
/*bhsd layout*/
:
bidb
*
params
.
seqlen_q
*
params
.
h
+
bids
*
params
.
h
+
bidh
/*bshd layout*/
;
int64_t
real_block_x
=
params
.
layout
==
0
?
static_cast
<
int64_t
>
(
block_x
)
/*bhsd layout*/
:
int
tx_offset
=
real_block_x
*
kHeadDim
+
(
tx
&
63
)
*
tx_float_count
;
static_cast
<
int64_t
>
(
bidb
)
*
params
.
seqlen_q
*
params
.
h
+
static_cast
<
int64_t
>
(
bids
)
*
params
.
h
+
bidh
/*bshd layout*/
;
int64_t
tx_offset
=
real_block_x
*
kHeadDim
+
(
tx
&
63
)
*
tx_float_count
;
int
begin
=
wave_id
<<
6
;
int
begin
=
wave_id
<<
6
;
reduceType
*
output_ptr
=
reinterpret_cast
<
reduceType
*>
(
params
.
o_ptr
)
+
tx_offset
;
reduceType
*
output_ptr
=
reinterpret_cast
<
reduceType
*>
(
params
.
o_ptr
)
+
tx_offset
;
// for wave 0, splits [0, 63]; for wave 1, splits [64, 127]; for wave 2, splits [128, 191] ......
// for wave 0, splits [0, 63]; for wave 1, splits [64, 127]; for wave 2, splits [128, 191] ......
...
@@ -518,13 +520,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
...
@@ -518,13 +520,14 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
constexpr
int
tx_float_count
=
(
kHeadDim
>>
2
)
>>
6
;
constexpr
int
tx_float_count
=
(
kHeadDim
>>
2
)
>>
6
;
float
tx_accum
[
tx_float_count
]
=
{
0.
f
};
float
tx_accum
[
tx_float_count
]
=
{
0.
f
};
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
// offset from the next split for output from previous kernel, split * (batch, head,seq) * headdim
int
oaccum_stride
=
s_m_split_stride
*
kHeadDim
;
int
64_t
oaccum_stride
=
static_cast
<
int64_t
>
(
s_m_split_stride
)
*
kHeadDim
;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
// int tx_offset= block_x * kHeadDim + tx * tx_float_count;
int
in_batch_offset
=
block_x
-
bidb
*
h
*
seqlen_q
;
int
in_batch_offset
=
block_x
-
bidb
*
h
*
seqlen_q
;
int
bidh
=
in_batch_offset
/
seqlen_q
;
int
bidh
=
in_batch_offset
/
seqlen_q
;
int
bids
=
in_batch_offset
-
bidh
*
seqlen_q
;
int
bids
=
in_batch_offset
-
bidh
*
seqlen_q
;
int
real_block_x
=
layout
==
0
?
block_x
/*bhsd layout*/
:
bidb
*
seqlen_q
*
h
+
bids
*
h
+
bidh
/*bshd layout*/
;
int64_t
real_block_x
=
layout
==
0
?
static_cast
<
int64_t
>
(
block_x
)
/*bhsd layout*/
:
int
tx_offset
=
real_block_x
*
kHeadDim
+
tx
*
tx_float_count
+
blockIdx
.
y
*
(
kHeadDim
>>
2
)
+
min
(
wave_id
,
num_splits
-
1
)
*
oaccum_stride
;
static_cast
<
int64_t
>
(
bidb
)
*
seqlen_q
*
h
+
static_cast
<
int64_t
>
(
bids
)
*
h
+
bidh
/*bshd layout*/
;
int64_t
tx_offset
=
real_block_x
*
kHeadDim
+
tx
*
tx_float_count
+
blockIdx
.
y
*
(
kHeadDim
>>
2
)
+
min
(
wave_id
,
num_splits
-
1
)
*
oaccum_stride
;
reduceType
*
output_ptr
=
reinterpret_cast
<
reduceType
*>
(
o_ptr
)
+
tx_offset
;
reduceType
*
output_ptr
=
reinterpret_cast
<
reduceType
*>
(
o_ptr
)
+
tx_offset
;
// fetch all data into vgprs
// fetch all data into vgprs
constexpr
int
SPLITS_PER_WAVE
=
std
::
max
<
int32_t
>
(
1
,
num_splits
>>
2
);
constexpr
int
SPLITS_PER_WAVE
=
std
::
max
<
int32_t
>
(
1
,
num_splits
>>
2
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment