Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2c35de66
Commit
2c35de66
authored
Jun 04, 2026
by
shenzhe
Committed by
zhanghj2
Jun 06, 2026
Browse files
Add DSA BF16 sparse decode support
parent
a1eef562
Changes
13
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
13 changed files
with
640 additions
and
158 deletions
+640
-158
csrc/api/sparse_decode.h
csrc/api/sparse_decode.h
+9
-0
csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu
csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu
+275
-0
csrc/gfx93/decode/sparse_bf16_dsa/fwd.h
csrc/gfx93/decode/sparse_bf16_dsa/fwd.h
+24
-0
csrc/gfx93/prefill/sparse/dsa_mls/dispatch.h
csrc/gfx93/prefill/sparse/dsa_mls/dispatch.h
+102
-14
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/flash.h
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/flash.h
+1
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_b16_mla.h
...x93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_b16_mla.h
+73
-68
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_launch_template_mla.h
...sparse/dsa_mls/legacy/src/flash_fwd_launch_template_mla.h
+8
-2
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_reduce.h
...fx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_reduce.h
+60
-59
flash_mla/flash_mla_interface.py
flash_mla/flash_mla_interface.py
+5
-2
setup.py
setup.py
+1
-0
tests/lib.py
tests/lib.py
+26
-5
tests/test_flash_mla_sparse_decoding.py
tests/test_flash_mla_sparse_decoding.py
+27
-4
tests/test_flash_mla_sparse_prefill.py
tests/test_flash_mla_sparse_prefill.py
+29
-4
No files found.
csrc/api/sparse_decode.h
View file @
2c35de66
...
...
@@ -6,6 +6,7 @@
#include "params.h"
#include "gfx93/decode/sparse_fp8/splitkv_mla.h"
#include "gfx93/decode/sparse_bf16_dsa/fwd.h"
#include "gfx9/decode/get_decoding_sched_meta/get_decoding_sched_meta.h"
#include "gfx9/decode/combine/combine.h"
...
...
@@ -123,6 +124,14 @@ sparse_attn_decode_interface(
bool
have_extra_topk_length
=
extra_topk_length
.
has_value
();
bool
have_attn_sink
=
attn_sink
.
has_value
();
if
(
kv
.
dtype
()
==
torch
::
kBFloat16
)
{
return
gfx93
::
decode
::
sparse_bf16_dsa
::
run
(
q
,
kv
,
indices
,
topk_length
,
attn_sink
,
tile_scheduler_metadata
,
num_splits
,
extra_kv
,
extra_indices
,
extra_topk_length
,
d_v
,
sm_scale
);
}
int
extra_num_blocks
=
0
,
extra_page_block_size
=
0
,
extra_topk
=
0
;
if
(
have_extra_kcache
)
{
extra_num_blocks
=
extra_kv
->
size
(
0
);
...
...
csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu
0 → 100644
View file @
2c35de66
#include "fwd.h"
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <torch/extension.h>
#include <algorithm>
#include <cstring>
#include <limits>
#include <optional>
#include <tuple>
#include "kerutils/supplemental/torch_tensors.h"
#include "gfx93/prefill/sparse/dsa_mls/dispatch.h"
namespace
gfx93
::
decode
::
sparse_bf16_dsa
{
static
constexpr
float
LOG_2_E
=
1.44269504
f
;
struct
LocalArch
{
int
num_sms
;
std
::
string
arch_name
;
LocalArch
()
{
auto
*
props
=
at
::
cuda
::
getCurrentDeviceProperties
();
num_sms
=
props
->
multiProcessorCount
;
arch_name
=
props
->
gcnArchName
;
}
bool
is_gfx93x
()
const
{
const
auto
base
=
arch_name
.
substr
(
0
,
arch_name
.
find
(
':'
));
return
base
==
"gfx936"
||
base
==
"gfx938"
;
}
};
static
int
int64_stride_to_int
(
int64_t
stride
)
{
TORCH_CHECK
(
stride
<=
std
::
numeric_limits
<
int
>::
max
(),
"DSA BF16 sparse decode stride exceeds int32 limit: "
,
stride
);
return
static_cast
<
int
>
(
stride
);
}
static
int
default_num_splits
(
int
topk
,
int
extra_topk
)
{
if
(
extra_topk
>
0
)
{
return
2
;
}
if
(
topk
==
1024
)
return
16
;
if
(
topk
==
512
)
return
8
;
return
1
;
}
static
void
check_optional_extra
(
const
std
::
optional
<
at
::
Tensor
>&
extra_kv
,
const
std
::
optional
<
at
::
Tensor
>&
extra_indices
,
const
std
::
optional
<
at
::
Tensor
>&
extra_topk_length
)
{
if
(
extra_kv
.
has_value
())
{
TORCH_CHECK
(
extra_indices
.
has_value
(),
"extra_indices_in_kvcache must be provided when extra_k_cache is provided"
);
}
else
{
TORCH_CHECK
(
!
extra_indices
.
has_value
(),
"extra_indices_in_kvcache must not be provided when extra_k_cache is not provided"
);
TORCH_CHECK
(
!
extra_topk_length
.
has_value
(),
"extra_topk_length must not be provided when extra_k_cache is not provided"
);
}
}
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
run
(
const
at
::
Tensor
&
q
,
const
at
::
Tensor
&
kv
,
const
at
::
Tensor
&
indices
,
const
std
::
optional
<
at
::
Tensor
>&
topk_length
,
const
std
::
optional
<
at
::
Tensor
>&
attn_sink
,
std
::
optional
<
at
::
Tensor
>&
tile_scheduler_metadata
,
std
::
optional
<
at
::
Tensor
>&
num_splits
,
const
std
::
optional
<
at
::
Tensor
>&
extra_kv
,
const
std
::
optional
<
at
::
Tensor
>&
extra_indices
,
const
std
::
optional
<
at
::
Tensor
>&
extra_topk_length
,
int
d_v
,
float
sm_scale
)
{
LocalArch
arch
;
TORCH_CHECK
(
arch
.
is_gfx93x
(),
"DSA BF16 sparse decode is only supported on gfx936/gfx938"
);
KU_CHECK_NDIM
(
q
,
4
);
KU_CHECK_NDIM
(
kv
,
4
);
KU_CHECK_NDIM
(
indices
,
3
);
if
(
extra_kv
.
has_value
())
KU_CHECK_NDIM
(
extra_kv
,
4
);
if
(
extra_indices
.
has_value
())
KU_CHECK_NDIM
(
extra_indices
,
3
);
const
int
b
=
q
.
size
(
0
);
const
int
s_q
=
q
.
size
(
1
);
const
int
h_q
=
q
.
size
(
2
);
const
int
d_qk
=
q
.
size
(
3
);
const
int
page_block_size
=
kv
.
size
(
1
);
const
int
h_kv
=
kv
.
size
(
2
);
const
int
topk
=
indices
.
size
(
2
);
const
bool
has_extra
=
extra_kv
.
has_value
()
&&
extra_indices
.
has_value
()
&&
extra_kv
->
numel
()
>
0
&&
extra_indices
->
numel
()
>
0
&&
extra_indices
->
size
(
2
)
>
0
;
const
int
extra_topk
=
has_extra
?
extra_indices
->
size
(
2
)
:
0
;
TORCH_CHECK
(
b
>
0
&&
s_q
>
0
&&
h_q
>
0
,
"Invalid q shape for DSA BF16 sparse decode"
);
TORCH_CHECK
(
h_kv
==
1
,
"DSA BF16 sparse decode only supports h_kv == 1"
);
TORCH_CHECK
(
h_q
==
64
||
h_q
==
128
,
"DSA BF16 sparse decode only supports h_q == 64 or 128"
);
TORCH_CHECK
(
d_qk
==
512
||
d_qk
==
576
,
"DSA BF16 sparse decode only supports d_qk == 512 or 576"
);
TORCH_CHECK
(
d_v
==
512
,
"DSA BF16 sparse decode only supports d_v == 512"
);
TORCH_CHECK
(
topk
>
0
,
"topk must be positive"
);
if
(
has_extra
)
{
TORCH_CHECK
(
topk
<=
256
,
"DSA BF16 sparse decode with extra_kv supports topk <= 256"
);
TORCH_CHECK
(
extra_topk
<=
1024
,
"DSA BF16 sparse decode supports extra_topk <= 1024"
);
TORCH_CHECK
(
extra_kv
->
size
(
1
)
>
0
,
"extra page_block_size must be positive"
);
TORCH_CHECK
(
extra_kv
->
size
(
2
)
==
h_kv
,
"extra_kv h_kv must match kv h_kv"
);
TORCH_CHECK
(
extra_kv
->
size
(
3
)
==
d_qk
,
"extra_kv d_qk must match q d_qk"
);
}
else
{
TORCH_CHECK
(
topk
<=
1024
,
"DSA BF16 sparse decode supports topk <= 1024"
);
}
check_optional_extra
(
extra_kv
,
extra_indices
,
extra_topk_length
);
KU_CHECK_DEVICE
(
q
);
KU_CHECK_DEVICE
(
kv
);
KU_CHECK_DEVICE
(
indices
);
KU_CHECK_DEVICE
(
topk_length
);
KU_CHECK_DEVICE
(
attn_sink
);
KU_CHECK_DEVICE
(
tile_scheduler_metadata
);
KU_CHECK_DEVICE
(
num_splits
);
KU_CHECK_DEVICE
(
extra_kv
);
KU_CHECK_DEVICE
(
extra_indices
);
KU_CHECK_DEVICE
(
extra_topk_length
);
KU_CHECK_DTYPE
(
q
,
torch
::
kBFloat16
);
KU_CHECK_DTYPE
(
kv
,
torch
::
kBFloat16
);
KU_CHECK_DTYPE
(
indices
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
topk_length
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
attn_sink
,
torch
::
kFloat32
);
KU_CHECK_DTYPE
(
tile_scheduler_metadata
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
extra_kv
,
torch
::
kBFloat16
);
KU_CHECK_DTYPE
(
extra_indices
,
torch
::
kInt32
);
KU_CHECK_DTYPE
(
extra_topk_length
,
torch
::
kInt32
);
KU_CHECK_LAST_DIM_CONTIGUOUS
(
q
);
KU_CHECK_LAST_DIM_CONTIGUOUS
(
kv
);
KU_CHECK_LAST_DIM_CONTIGUOUS
(
indices
);
KU_CHECK_CONTIGUOUS
(
topk_length
);
KU_CHECK_CONTIGUOUS
(
attn_sink
);
KU_CHECK_LAST_DIM_CONTIGUOUS
(
extra_kv
);
KU_CHECK_LAST_DIM_CONTIGUOUS
(
extra_indices
);
KU_CHECK_CONTIGUOUS
(
extra_topk_length
);
KU_CHECK_SHAPE
(
q
,
b
,
s_q
,
h_q
,
d_qk
);
KU_CHECK_SHAPE
(
kv
,
kv
.
size
(
0
),
page_block_size
,
h_kv
,
d_qk
);
KU_CHECK_SHAPE
(
indices
,
b
,
s_q
,
topk
);
KU_CHECK_SHAPE
(
topk_length
,
b
);
KU_CHECK_SHAPE
(
attn_sink
,
h_q
);
if
(
has_extra
)
{
KU_CHECK_SHAPE
(
extra_indices
,
b
,
s_q
,
extra_topk
);
KU_CHECK_SHAPE
(
extra_topk_length
,
b
);
}
at
::
Tensor
indices_for_dsa
=
indices
.
unsqueeze
(
2
);
at
::
Tensor
extra_indices_for_dsa
;
if
(
has_extra
)
{
extra_indices_for_dsa
=
extra_indices
->
unsqueeze
(
2
);
}
c10
::
cuda
::
CUDAGuard
device_guard
{
q
.
device
()};
auto
opts
=
q
.
options
();
at
::
Tensor
out
=
torch
::
empty
({
b
,
s_q
,
h_q
,
d_v
},
opts
);
at
::
Tensor
lse
=
torch
::
empty
({
b
,
h_q
,
s_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
scores_memory
=
torch
::
empty
({
2
,
b
,
h_kv
,
s_q
*
h_q
},
opts
.
dtype
(
at
::
kFloat
));
at
::
Tensor
scores_max
=
scores_memory
.
select
(
0
,
0
);
at
::
Tensor
scores_sum
=
scores_memory
.
select
(
0
,
1
);
if
(
!
num_splits
.
has_value
())
{
const
int
split
=
default_num_splits
(
topk
,
extra_topk
);
num_splits
=
torch
::
empty
({
1
},
opts
.
dtype
(
torch
::
kInt32
));
num_splits
->
fill_
(
split
);
}
KU_CHECK_DTYPE
(
num_splits
,
torch
::
kInt32
);
KU_CHECK_DEVICE
(
num_splits
);
KU_CHECK_CONTIGUOUS
(
num_splits
);
TORCH_CHECK
(
num_splits
->
numel
()
==
1
,
"DSA BF16 sparse decode expects num_splits to be a scalar tensor"
);
const
int
requested_num_splits
=
num_splits
->
item
<
int
>
();
TORCH_CHECK
(
requested_num_splits
>=
1
&&
requested_num_splits
<=
64
,
"DSA BF16 sparse decode requires 1 <= num_splits <= 64"
);
Flash_fwd_mla_params_dsa
params
;
std
::
memset
(
&
params
,
0
,
sizeof
(
params
));
params
.
layout
=
1
;
params
.
b
=
b
;
params
.
h
=
h_kv
;
params
.
h_k
=
h_kv
;
params
.
h_h_k_ratio
=
1
;
params
.
mtp
=
1
;
params
.
ngroups
=
h_q
/
h_kv
;
params
.
topk
=
topk
;
params
.
extra_topk
=
has_extra
?
extra_topk
:
0
;
params
.
d
=
d_qk
;
params
.
d_v
=
d_v
;
params
.
scale_softmax
=
sm_scale
;
params
.
scale_softmax_log2
=
sm_scale
*
LOG_2_E
;
params
.
topk_length
=
ku
::
get_optional_tensor_ptr
<
int
>
(
topk_length
);
params
.
extra_topk_length
=
ku
::
get_optional_tensor_ptr
<
int
>
(
extra_topk_length
);
params
.
attn_sink
=
ku
::
get_optional_tensor_ptr
<
float
>
(
attn_sink
);
params
.
q_ptr
=
q
.
data_ptr
();
params
.
k_ptr
=
kv
.
data_ptr
();
params
.
v_ptr
=
kv
.
data_ptr
();
params
.
extra_k_ptr
=
has_extra
?
extra_kv
->
data_ptr
()
:
nullptr
;
params
.
extra_v_ptr
=
has_extra
?
extra_kv
->
data_ptr
()
:
nullptr
;
params
.
o_ptr
=
out
.
data_ptr
();
params
.
sparse_indices
=
reinterpret_cast
<
int
*>
(
indices_for_dsa
.
data_ptr
());
params
.
extra_sparse_indices
=
has_extra
?
reinterpret_cast
<
int
*>
(
extra_indices_for_dsa
.
data_ptr
())
:
nullptr
;
params
.
softmax_lse_ptr
=
lse
.
data_ptr
<
float
>
();
params
.
scores_max_ptr
=
scores_max
.
data_ptr
<
float
>
();
params
.
scores_sum_ptr
=
scores_sum
.
data_ptr
<
float
>
();
params
.
page_block_size
=
page_block_size
;
params
.
extra_page_block_size
=
has_extra
?
extra_kv
->
size
(
1
)
:
0
;
params
.
is_causal
=
false
;
params
.
q_batch_stride
=
int64_stride_to_int
(
q
.
stride
(
0
));
params
.
q_token_stride
=
int64_stride_to_int
(
q
.
stride
(
1
));
params
.
q_row_stride
=
int64_stride_to_int
(
q
.
stride
(
2
));
params
.
q_head_stride
=
int64_stride_to_int
(
q
.
stride
(
2
));
params
.
k_batch_stride
=
int64_stride_to_int
(
kv
.
stride
(
0
));
params
.
k_row_stride
=
int64_stride_to_int
(
kv
.
stride
(
1
));
params
.
k_head_stride
=
int64_stride_to_int
(
kv
.
stride
(
2
));
params
.
v_batch_stride
=
params
.
k_batch_stride
;
params
.
v_row_stride
=
params
.
k_row_stride
;
params
.
v_head_stride
=
params
.
k_head_stride
;
params
.
extra_k_batch_stride
=
has_extra
?
int64_stride_to_int
(
extra_kv
->
stride
(
0
))
:
0
;
params
.
extra_k_row_stride
=
has_extra
?
int64_stride_to_int
(
extra_kv
->
stride
(
1
))
:
0
;
params
.
extra_v_batch_stride
=
params
.
extra_k_batch_stride
;
params
.
extra_v_row_stride
=
params
.
extra_k_row_stride
;
params
.
sparse_indices_batch_stride
=
int64_stride_to_int
(
indices_for_dsa
.
stride
(
0
));
params
.
sparse_indices_row_stride
=
int64_stride_to_int
(
indices_for_dsa
.
stride
(
1
));
params
.
sparse_indices_head_stride
=
int64_stride_to_int
(
indices_for_dsa
.
stride
(
2
));
params
.
sparse_indices_topk_stride
=
int64_stride_to_int
(
indices_for_dsa
.
stride
(
3
));
params
.
extra_sparse_indices_batch_stride
=
has_extra
?
int64_stride_to_int
(
extra_indices_for_dsa
.
stride
(
0
))
:
0
;
params
.
extra_sparse_indices_row_stride
=
has_extra
?
int64_stride_to_int
(
extra_indices_for_dsa
.
stride
(
1
))
:
0
;
params
.
extra_sparse_indices_head_stride
=
has_extra
?
int64_stride_to_int
(
extra_indices_for_dsa
.
stride
(
2
))
:
0
;
params
.
extra_sparse_indices_topk_stride
=
has_extra
?
int64_stride_to_int
(
extra_indices_for_dsa
.
stride
(
3
))
:
0
;
params
.
o_batch_stride
=
int64_stride_to_int
(
out
.
stride
(
0
));
params
.
o_row_stride
=
int64_stride_to_int
(
out
.
stride
(
1
));
params
.
o_head_stride
=
int64_stride_to_int
(
out
.
stride
(
2
));
params
.
seqlen_q
=
s_q
*
params
.
ngroups
;
params
.
seqlen_k
=
kv
.
size
(
0
)
*
kv
.
size
(
1
);
params
.
max_seqlen
=
s_q
;
params
.
is_bf16
=
true
;
params
.
is_e4m3
=
false
;
params
.
is_int8
=
false
;
params
.
cu_count
=
arch
.
num_sms
;
params
.
seqlenq_ngroups_swapped
=
true
;
params
.
is_seqlens_k_cumulative
=
false
;
params
.
splitkv_use_fp32_as_accum
=
false
;
params
.
num_splits
=
requested_num_splits
;
params
.
partition_size
=
topk
+
params
.
extra_topk
;
if
(
params
.
num_splits
>
1
)
{
params
.
partition_size
=
std
::
max
(
64
,
(
params
.
partition_size
+
params
.
num_splits
-
1
)
/
params
.
num_splits
);
params
.
partition_size
=
((
params
.
partition_size
+
63
)
/
64
)
*
64
;
}
at
::
Tensor
out_accum
;
at
::
Tensor
lse_accum
;
if
(
params
.
num_splits
>
1
)
{
lse_accum
=
torch
::
empty
({
params
.
num_splits
,
b
,
h_kv
,
params
.
seqlen_q
},
opts
.
dtype
(
at
::
kFloat
));
out_accum
=
torch
::
empty
({
params
.
num_splits
,
b
,
s_q
,
h_q
,
d_v
},
opts
);
params
.
softmax_lse_ptr
=
lse_accum
.
data_ptr
<
float
>
();
params
.
oaccum_ptr
=
out_accum
.
data_ptr
();
}
hipStream_t
stream
=
reinterpret_cast
<
hipStream_t
>
(
at
::
cuda
::
getCurrentCUDAStream
().
stream
());
if
(
d_qk
==
512
)
{
gfx93
::
fwd
::
dsa_mls
::
run_dsa_prefill_nopage_64_dispatch
<
BFloat16
,
512
,
512
>
(
params
,
stream
);
}
else
{
gfx93
::
fwd
::
dsa_mls
::
run_dsa_prefill_nopage_64_dispatch
<
BFloat16
,
576
,
512
>
(
params
,
stream
);
}
return
{
out
,
lse
,
tile_scheduler_metadata
,
num_splits
};
}
}
// namespace gfx93::decode::sparse_bf16_dsa
csrc/gfx93/decode/sparse_bf16_dsa/fwd.h
0 → 100644
View file @
2c35de66
#pragma once
#include <ATen/core/Tensor.h>
#include <optional>
#include <tuple>
namespace
gfx93
::
decode
::
sparse_bf16_dsa
{
std
::
tuple
<
at
::
Tensor
,
at
::
Tensor
,
std
::
optional
<
at
::
Tensor
>
,
std
::
optional
<
at
::
Tensor
>>
run
(
const
at
::
Tensor
&
q
,
const
at
::
Tensor
&
kv
,
const
at
::
Tensor
&
indices
,
const
std
::
optional
<
at
::
Tensor
>&
topk_length
,
const
std
::
optional
<
at
::
Tensor
>&
attn_sink
,
std
::
optional
<
at
::
Tensor
>&
tile_scheduler_metadata
,
std
::
optional
<
at
::
Tensor
>&
num_splits
,
const
std
::
optional
<
at
::
Tensor
>&
extra_kv
,
const
std
::
optional
<
at
::
Tensor
>&
extra_indices
,
const
std
::
optional
<
at
::
Tensor
>&
extra_topk_length
,
int
d_v
,
float
sm_scale
);
}
// namespace gfx93::decode::sparse_bf16_dsa
csrc/gfx93/prefill/sparse/dsa_mls/dispatch.h
View file @
2c35de66
...
...
@@ -9,8 +9,68 @@
#include "legacy/include/static_switch.h"
#include "legacy/src/flash_fwd_b16_mla.h"
#include "legacy/src/flash_fwd_reduce.h"
namespace
gfx93
::
fwd
::
dsa_mls
{
template
<
typename
Kernel_traits
,
const
bool
Tail
,
typename
Params
>
void
run_dsa_mla_splitkv_reduce
(
Params
&
params
,
hipStream_t
stream
)
{
static_assert
(
Kernel_traits
::
kHeadDimV
==
512
,
"run_dsa_mla_splitkv_reduce only supports hdimv == 512"
);
using
Element
=
typename
Kernel_traits
::
Element
;
using
SplitkvAccumType
=
typename
Kernel_traits
::
SplitkvAccumType
;
Flash_fwd_mla_reduce_params
reduce_params
;
reduce_params
.
softmax_lse_ptr
=
params
.
softmax_lse_ptr
;
reduce_params
.
oaccum_ptr
=
params
.
oaccum_ptr
;
reduce_params
.
o_ptr
=
params
.
o_ptr
;
reduce_params
.
cu_seqlens_k
=
params
.
cu_seqlens_k
;
reduce_params
.
num_splits
=
params
.
num_splits
;
reduce_params
.
partition_size
=
params
.
partition_size
;
reduce_params
.
h
=
params
.
h
;
reduce_params
.
ngroups
=
params
.
ngroups
;
reduce_params
.
seqlen_q
=
params
.
seqlen_q
;
reduce_params
.
layout
=
params
.
layout
;
reduce_params
.
topk_length
=
params
.
topk_length
;
reduce_params
.
attn_sink
=
params
.
attn_sink
;
reduce_params
.
extra_topk_length
=
params
.
extra_topk_length
;
reduce_params
.
topk
=
params
.
topk
;
reduce_params
.
extra_topk
=
params
.
extra_topk
;
if
(
params
.
num_splits
>
1
)
{
dim3
block
(
256
);
dim3
grid
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
,
4
);
constexpr
int
MAX_NUM_SPLITS
=
64
;
if
(
params
.
num_splits
>
MAX_NUM_SPLITS
)
{
printf
(
"
\x1b
[31mnum_splits %d is larger than limit %d, and thus won't execute the kernel
\033
[0m
\n
"
,
params
.
num_splits
,
MAX_NUM_SPLITS
);
return
;
}
if
(
params
.
num_splits
==
2
)
{
::
flash_mla_splitkv_reduce_kernel
<
SplitkvAccumType
,
Element
,
2
,
true
,
Tail
,
Kernel_traits
::
kHeadDimV
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reduce_params
);
}
else
if
(
params
.
num_splits
==
4
)
{
::
flash_mla_splitkv_reduce_kernel
<
SplitkvAccumType
,
Element
,
4
,
true
,
Tail
,
Kernel_traits
::
kHeadDimV
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reduce_params
);
}
else
if
(
params
.
num_splits
==
8
)
{
::
flash_mla_splitkv_reduce_kernel
<
SplitkvAccumType
,
Element
,
8
,
true
,
Tail
,
Kernel_traits
::
kHeadDimV
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reduce_params
);
}
else
if
(
params
.
num_splits
==
16
)
{
::
flash_mla_splitkv_reduce_kernel
<
SplitkvAccumType
,
Element
,
16
,
true
,
Tail
,
Kernel_traits
::
kHeadDimV
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reduce_params
);
}
else
if
(
params
.
num_splits
==
32
)
{
::
flash_mla_splitkv_reduce_kernel
<
SplitkvAccumType
,
Element
,
32
,
true
,
Tail
,
Kernel_traits
::
kHeadDimV
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reduce_params
);
}
else
if
(
params
.
num_splits
==
64
)
{
::
flash_mla_splitkv_reduce_kernel
<
SplitkvAccumType
,
Element
,
64
,
true
,
Tail
,
Kernel_traits
::
kHeadDimV
>
<<<
grid
,
block
,
0
,
stream
>>>
(
reduce_params
);
}
else
{
printf
(
"
\x1b
[31mnum_splits %d is not supported yet, and thus won't execute the kernel
\033
[0m
\n
"
,
params
.
num_splits
);
}
}
}
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_dsa_prefill_nopage_64_dispatch
(
Flash_fwd_mla_params_dsa
&
params
,
hipStream_t
stream
)
{
constexpr
int
kBlockM
=
64
;
...
...
@@ -34,21 +94,49 @@ void run_dsa_prefill_nopage_64_dispatch(Flash_fwd_mla_params_dsa& params, hipStr
constexpr
bool
Is_dropout
=
false
;
constexpr
bool
IsEvenMNConst
=
false
;
BOOL_SWITCH
(
params
.
mtp
>
1
,
Is_MTP
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
params
.
topk
==
2048
)
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64
<
Kernel_traits
,
true
,
Is_dropout
,
false
,
Is_causal
,
IsEvenMNConst
,
true
,
false
,
Is_MTP
,
0
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
}
else
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024
<
Kernel_traits
,
true
,
Is_dropout
,
false
,
Is_causal
,
IsEvenMNConst
,
true
,
false
,
Is_MTP
,
0
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
}
constexpr
int
REUSE_KV
=
1
;
const
bool
has_extra
=
params
.
extra_sparse_indices
!=
nullptr
&&
params
.
extra_topk
>
0
;
if
(
params
.
num_splits
==
1
)
{
BOOL_SWITCH
(
params
.
mtp
>
1
,
Is_MTP
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
has_extra
,
Has_extra
,
[
&
]
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64
<
Kernel_traits
,
true
,
Is_dropout
,
false
,
Is_causal
,
IsEvenMNConst
,
true
,
false
,
Is_MTP
,
0
,
Has_extra
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
});
});
});
}
else
if
(
params
.
num_splits
!=
0
)
{
dimGrid
.
y
=
params
.
num_splits
;
BOOL_SWITCH
(
params
.
mtp
>
1
,
Is_MTP
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
BOOL_SWITCH
(
has_extra
,
Has_extra
,
[
&
]
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv
<
Kernel_traits
,
true
,
Is_dropout
,
false
,
Is_causal
,
IsEvenMNConst
,
true
,
false
,
Is_MTP
,
0
,
Has_extra
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
});
});
});
run_dsa_mla_splitkv_reduce
<
Kernel_traits
,
false
>
(
params
,
stream
);
}
else
{
BOOL_SWITCH
(
params
.
mtp
>
1
,
Is_MTP
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
params
.
topk
==
2048
)
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64
<
Kernel_traits
,
true
,
Is_dropout
,
false
,
Is_causal
,
IsEvenMNConst
,
true
,
false
,
Is_MTP
,
0
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
}
else
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64_topk1024
<
Kernel_traits
,
true
,
Is_dropout
,
false
,
Is_causal
,
IsEvenMNConst
,
true
,
false
,
Is_MTP
,
0
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
}
});
});
}
);
}
}
}
// namespace gfx93::fwd::dsa_mls
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/flash.h
View file @
2c35de66
...
...
@@ -450,6 +450,7 @@ struct Flash_fwd_mla_reduce_params {
int
num_splits
;
int
partition_size
;
int
h
;
int
ngroups
;
int
seqlen_q
;
int
layout
;
float
*
attn_sink
;
...
...
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_b16_mla.h
View file @
2c35de66
This diff is collapsed.
Click to expand it.
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_launch_template_mla.h
View file @
2c35de66
...
...
@@ -37,6 +37,7 @@ void run_mla_splitkv_reduce(Params ¶ms, hipStream_t stream) {
reduce_params
.
num_splits
=
params
.
num_splits
;
reduce_params
.
partition_size
=
params
.
partition_size
;
reduce_params
.
h
=
params
.
h
;
reduce_params
.
ngroups
=
params
.
ngroups
;
reduce_params
.
seqlen_q
=
params
.
seqlen_q
;
reduce_params
.
layout
=
params
.
layout
;
reduce_params
.
topk_length
=
params
.
topk_length
;
...
...
@@ -556,11 +557,14 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa ¶ms
dimGrid
.
z
=
params
.
b
;
constexpr
bool
IsEvenMNConst
=
false
;
const
bool
has_extra
=
params
.
extra_sparse_indices
!=
nullptr
&&
params
.
extra_topk
>
0
;
if
(
params
.
num_splits
==
1
){
BOOL_SWITCH
(
params
.
mtp
>
1
/* is_mtp */
,
Is_MTP
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64
<
Kernel_traits
,
true
/*Is_training*/
,
Is_dropout
,
false
/* Is_prefix | flashmla */
,
Is_causal
,
IsEvenMNConst
,
/*Is_even_K*/
true
,
/*Return_softmax*/
false
,
Is_MTP
,
0
,
Flash_fwd_mla_params_dsa
>
BOOL_SWITCH
(
has_extra
,
Has_extra
,
[
&
]
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64
<
Kernel_traits
,
true
/*Is_training*/
,
Is_dropout
,
false
/* Is_prefix | flashmla */
,
Is_causal
,
IsEvenMNConst
,
/*Is_even_K*/
true
,
/*Return_softmax*/
false
,
Is_MTP
,
0
,
Has_extra
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
});
});
});
}
...
...
@@ -570,8 +574,10 @@ void run_mla_fwd_dispatch_dsa_prefill_nopage_64(Flash_fwd_mla_params_dsa ¶ms
dimGrid
.
z
=
params
.
b
;
BOOL_SWITCH
(
params
.
mtp
>
1
/* is_mtp */
,
Is_MTP
,
[
&
]
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv
<
Kernel_traits
,
true
/*Is_training*/
,
Is_dropout
,
false
/* Is_prefix | flashmla */
,
Is_causal
,
IsEvenMNConst
,
/*Is_even_K*/
true
,
/*Return_softmax*/
false
,
Is_MTP
,
0
,
Flash_fwd_mla_params_dsa
>
BOOL_SWITCH
(
has_extra
,
Has_extra
,
[
&
]
{
flash
::
flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv
<
Kernel_traits
,
true
/*Is_training*/
,
Is_dropout
,
false
/* Is_prefix | flashmla */
,
Is_causal
,
IsEvenMNConst
,
/*Is_even_K*/
true
,
/*Return_softmax*/
false
,
Is_MTP
,
0
,
Has_extra
,
Flash_fwd_mla_params_dsa
>
<<<
dimGrid
,
dimBlock
,
21
*
1024
,
stream
>>>
(
params
);
});
});
});
run_mla_splitkv_reduce
<
Kernel_traits
,
false
/*Tail*/
>
(
params
,
stream
);
...
...
csrc/gfx93/prefill/sparse/dsa_mls/legacy/src/flash_fwd_reduce.h
View file @
2c35de66
#pragma once
#include "numeric_types.h"
#include "splitkv.h"
#include "intrinsic.h"
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -29,7 +30,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
// compute partition_size when fix num_splits
int
partition_size
=
params
.
partition_size
>
MLA_MAX_SPLITS
?
splitkv_get_partitionsize_of_fix_numsplits
(
actual_seqlen_k
,
params
.
num_splits
)
:
params
.
partition_size
;
const
int
true_num_splits
=
Tail
?
max
(
1
,
floor_div
(
actual_seqlen_k
,
partition_size
))
:
ceil_div
(
actual_seqlen_k
,
partition_size
);
const
int
true_num_splits
=
Tail
?
max
(
1
,
flash
::
floor_div
(
actual_seqlen_k
,
partition_size
))
:
flash
::
ceil_div
(
actual_seqlen_k
,
partition_size
);
// const int true_num_splits = num_splits;
bool
exceed_split
=
(
tx
>=
true_num_splits
);
// process boundary
...
...
@@ -40,7 +41,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
float
s_max_tmp
=
s_max_load_ori
;
#pragma unroll
for
(
int
step
=
SPLIT_COUNT
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
s_max_tmp
=
max
(
s_max_tmp
,
__shfl_xor_tmp
(
s_max_tmp
,
step
));
s_max_tmp
=
max
(
s_max_tmp
,
flash
::
__shfl_xor_tmp
(
s_max_tmp
,
step
));
}
// compute rescale coefficient for max (numerator)
float
s_max_ratio
=
__expf
(
s_max_load_ori
-
s_max_tmp
);
...
...
@@ -50,7 +51,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
float
s_sum_tmp
=
s_sum_load_ori
*
s_max_ratio
;
#pragma unroll
for
(
int
step
=
SPLIT_COUNT
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
s_sum_tmp
=
s_sum_tmp
+
__shfl_xor_tmp
(
s_sum_tmp
,
step
);
s_sum_tmp
=
s_sum_tmp
+
flash
::
__shfl_xor_tmp
(
s_sum_tmp
,
step
);
}
// max-rescale coefficient x sum-rescale coefficient
...
...
@@ -81,18 +82,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
// read ultimate scale value for current split
vec2_Element
<
accumType
>
load
=
*
(
vec2_Element
<
accumType
>*
)(
oaccum_ptr
+
tx_offset
+
t
);
// half -> float32, reduce precision loss
float
a_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
1
])
:
0.
f
;
float
a_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
1
])
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
flash
::
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
}
else
if
constexpr
(
kHeadDim
==
64
)
{
// read ultimate scale value for current split
accumType
load
=
*
(
accumType
*
)(
oaccum_ptr
+
tx_offset
+
t
);
// half -> float32, reduce precision loss
float
load_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
)
:
0.
f
;
float
load_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
)
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
load_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
load_f32
,
s_scale
,
tx_accum
[
t
]);
}
}
// switch to next split
...
...
@@ -103,14 +104,14 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
if
constexpr
(
kHeadDim
%
128
==
0
)
{
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
accum_result
=
flash
::
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
// here, v_cvt_pkrtz can be used
accum_result
[
0
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
// here, v_cvt_pkrtz can be used
#endif
*
(
vec2_Element
<
reduceType
>*
)(
output_ptr
+
t
)
=
accum_result
;
}
else
if
constexpr
(
kHeadDim
==
64
)
{
reduceType
accum_result
=
DownCast
<
float
,
reduceType
,
false
>
(
tx_accum
[
t
]);
reduceType
accum_result
=
flash
::
DownCast
<
float
,
reduceType
,
false
>
(
tx_accum
[
t
]);
output_ptr
[
t
]
=
accum_result
;
}
}
...
...
@@ -146,7 +147,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
}
// compute partition_size when fix num_splits
int
partition_size
=
params
.
partition_size
>
MLA_MAX_SPLITS
?
splitkv_get_partitionsize_of_fix_numsplits
(
actual_seqlen_k
,
params
.
num_splits
)
:
params
.
partition_size
;
const
int
true_num_splits
=
Tail
?
max
(
1
,
floor_div
(
actual_seqlen_k
,
partition_size
))
:
ceil_div
(
actual_seqlen_k
,
partition_size
);
const
int
true_num_splits
=
Tail
?
max
(
1
,
flash
::
floor_div
(
actual_seqlen_k
,
partition_size
))
:
flash
::
ceil_div
(
actual_seqlen_k
,
partition_size
);
// const int true_num_splits = num_splits;
bool
exceed_split
=
(
tx
>=
true_num_splits
);
// process boundary
...
...
@@ -157,7 +158,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float
s_max_tmp
=
s_max_load_ori
;
#pragma unroll
for
(
int
step
=
64
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
s_max_tmp
=
max
(
s_max_tmp
,
__shfl_xor_tmp
(
s_max_tmp
,
step
));
s_max_tmp
=
max
(
s_max_tmp
,
flash
::
__shfl_xor_tmp
(
s_max_tmp
,
step
));
}
// for multiple waves, store the reduced max value to lds individually, and recompute max across multiple waves
int
wave_id
=
(
tx
>>
6
);
...
...
@@ -181,7 +182,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
float
s_sum_tmp
=
s_sum_load_ori
*
s_max_ratio
;
#pragma unroll
for
(
int
step
=
64
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
s_sum_tmp
=
s_sum_tmp
+
__shfl_xor_tmp
(
s_sum_tmp
,
step
);
s_sum_tmp
=
s_sum_tmp
+
flash
::
__shfl_xor_tmp
(
s_sum_tmp
,
step
);
}
// for multiple wave, store the reduced sum value to lds individually, and recompute sum across multiple waves
lds
[
LDS_ACCUM
+
wave_id
]
=
s_sum_tmp
;
...
...
@@ -230,18 +231,18 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
// read 2 halfs from current split of this threads
vec2_Element
<
accumType
>
load
=
*
(
vec2_Element
<
accumType
>*
)(
oaccum_ptr
+
tx_offset
+
t
);
// half -> float32, reduce precision loss
float
a_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
1
])
:
0.
f
;
float
a_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
1
])
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
flash
::
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
}
else
if
constexpr
(
kHeadDim
==
64
)
{
// read 1 half from current split of this threads
accumType
load
=
*
(
accumType
*
)(
oaccum_ptr
+
tx_offset
+
t
);
// half -> float32, reduce precision loss
float
load_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
)
:
0.
f
;
float
load_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
)
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
load_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
load_f32
,
s_scale
,
tx_accum
[
t
]);
}
}
// switch to next split
...
...
@@ -290,15 +291,15 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
tx_accum
[
t
+
1
]
=
lds
[
tx
*
tx_float_count
+
t
+
1
];
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
accum_result
=
flash
::
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
// here, v_cvt_pkrtz can be used
accum_result
[
0
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
// here, v_cvt_pkrtz can be used
#endif
*
(
vec2_Element
<
reduceType
>*
)(
output_ptr
+
t
)
=
accum_result
;
}
else
if
constexpr
(
kHeadDim
==
64
)
{
tx_accum
[
t
]
=
lds
[
tx
*
tx_float_count
+
t
];
reduceType
accum_result
=
DownCast
<
float
,
reduceType
,
false
>
(
tx_accum
[
t
]);
reduceType
accum_result
=
flash
::
DownCast
<
float
,
reduceType
,
false
>
(
tx_accum
[
t
]);
*
(
reduceType
*
)(
output_ptr
+
t
)
=
accum_result
;
}
}
...
...
@@ -341,7 +342,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
// compute partition_size when fix num_splits
int
partition_size
=
splitkv_get_partitionsize_of_fix_numsplits
(
actual_seqlen_k
,
params
.
num_splits
);
const
int
true_num_splits
=
Tail
?
max
(
1
,
floor_div
(
actual_seqlen_k
,
partition_size
))
:
ceil_div
(
actual_seqlen_k
,
partition_size
);
const
int
true_num_splits
=
Tail
?
max
(
1
,
flash
::
floor_div
(
actual_seqlen_k
,
partition_size
))
:
flash
::
ceil_div
(
actual_seqlen_k
,
partition_size
);
// const int true_num_splits = num_splits;
bool
exceed_split
=
(
tx
>=
true_num_splits
);
// process boundary
...
...
@@ -399,13 +400,13 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
float
lse_max_local
=
lse_local
;
#pragma unroll
for
(
int
step
=
SPLIT_COUNT
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
lse_max_local
=
max
(
lse_max_local
,
__shfl_xor_tmp
(
lse_max_local
,
step
));
lse_max_local
=
max
(
lse_max_local
,
flash
::
__shfl_xor_tmp
(
lse_max_local
,
step
));
}
// reduce sum lse
float
lse_local_logsum
=
__expf
(
lse_local
-
lse_max_local
);
#pragma unroll
for
(
int
step
=
SPLIT_COUNT
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
lse_local_logsum
=
lse_local_logsum
+
__shfl_xor_tmp
(
lse_local_logsum
,
step
);
lse_local_logsum
=
lse_local_logsum
+
flash
::
__shfl_xor_tmp
(
lse_local_logsum
,
step
);
}
lse_local_logsum
=
__logf
(
lse_local_logsum
)
+
lse_max_local
;
...
...
@@ -427,16 +428,16 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for
(
int
t
=
0
;
t
<
tx_float_count
;
t
+=
2
)
{
if
constexpr
(
kHeadDim
%
128
==
0
)
{
// half -> float32, reduce precision loss
float
a_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load_vec
[
i
][
t
>>
1
][
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load_vec
[
i
][
t
>>
1
][
1
])
:
0.
f
;
float
a_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load_vec
[
i
][
t
>>
1
][
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load_vec
[
i
][
t
>>
1
][
1
])
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
flash
::
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
}
else
if
constexpr
(
kHeadDim
==
64
)
{
// half -> float32, reduce precision loss
float
load_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
i
][
t
>>
1
])
:
0.
f
;
float
load_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
i
][
t
>>
1
])
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
load_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
load_f32
,
s_scale
,
tx_accum
[
t
]);
}
}
}
...
...
@@ -446,14 +447,14 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
if
constexpr
(
kHeadDim
%
128
==
0
)
{
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
accum_result
=
flash
::
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
// here, v_cvt_pkrtz can be used
accum_result
[
0
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
// here, v_cvt_pkrtz can be used
#endif
*
(
vec2_Element
<
reduceType
>*
)(
output_ptr
+
t
)
=
accum_result
;
}
else
if
constexpr
(
kHeadDim
==
64
)
{
reduceType
accum_result
=
DownCast
<
float
,
reduceType
,
false
>
(
tx_accum
[
t
]);
reduceType
accum_result
=
flash
::
DownCast
<
float
,
reduceType
,
false
>
(
tx_accum
[
t
]);
output_ptr
[
t
]
=
accum_result
;
}
}
...
...
@@ -501,15 +502,15 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// int main_len = params.topk_length ? params.topk_length[row] : params.topk;
// int extra_len = params.extra_topk_length ? params.extra_topk_length[row] : params.extra_topk;
// int actual_seqlen_k = ceil_div(main_len, 64) * 64 + ceil_div(extra_len, 64) * 64;
// int actual_seqlen_k =
flash::
ceil_div(main_len, 64) * 64 +
flash::
ceil_div(extra_len, 64) * 64;
int
row
=
block_x
/
64
;
int
main_len
=
params
.
topk_length
?
params
.
topk_length
[
row
]
:
params
.
topk
;
int
extra_len
=
params
.
extra_topk_length
?
params
.
extra_topk_length
[
row
]
:
params
.
extra_topk
;
int
topk_length_row
=
h
==
1
?
bidb
:
block_x
/
64
;
int
main_len
=
params
.
topk_length
?
params
.
topk_length
[
topk_length_
row
]
:
params
.
topk
;
int
extra_len
=
params
.
extra_topk_length
?
params
.
extra_topk_length
[
topk_length_
row
]
:
params
.
extra_topk
;
int
total_blocks
=
ceil_div
(
main_len
,
64
)
+
ceil_div
(
extra_len
,
64
);
int
blocks_per_split
=
ceil_div
(
params
.
partition_size
,
64
);
int
true_num_splits
=
ceil_div
(
total_blocks
,
blocks_per_split
);
int
total_blocks
=
flash
::
ceil_div
(
main_len
,
64
)
+
flash
::
ceil_div
(
extra_len
,
64
);
int
blocks_per_split
=
flash
::
ceil_div
(
params
.
partition_size
,
64
);
int
true_num_splits
=
flash
::
ceil_div
(
total_blocks
,
blocks_per_split
);
// for flashmla, 512 elements are engaged to 4 blocks
// within each block, num_splits / WARM_NUM load transactions are engaged to each wave
...
...
@@ -539,7 +540,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// compute partition_size when fix num_splits
// int partition_size = params_partition_size > MLA_MAX_SPLITS ? splitkv_get_partitionsize_of_fix_numsplits(actual_seqlen_k, params.num_splits): params_partition_size;
// const int true_num_splits = Tail ? max(1, floor_div(actual_seqlen_k, partition_size)): ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = Tail ? max(1,
flash::
floor_div(actual_seqlen_k, partition_size)):
flash::
ceil_div(actual_seqlen_k, partition_size);
// const int true_num_splits = num_splits;
bool
exceed_split
=
(
tx
>=
true_num_splits
);
// process boundary
...
...
@@ -552,20 +553,20 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
float
lse_max_local
=
lse_local
;
#pragma unroll
for
(
int
step
=
SPLIT_COUNT
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
lse_max_local
=
max
(
lse_max_local
,
__shfl_xor_tmp
(
lse_max_local
,
step
));
lse_max_local
=
max
(
lse_max_local
,
flash
::
__shfl_xor_tmp
(
lse_max_local
,
step
));
}
// reduce sum lse
float
lse_local_logsum
=
__expf
(
lse_local
-
lse_max_local
);
#pragma unroll
for
(
int
step
=
SPLIT_COUNT
>>
1
;
step
>
0
;
step
=
(
step
>>
1
))
{
lse_local_logsum
=
lse_local_logsum
+
__shfl_xor_tmp
(
lse_local_logsum
,
step
);
lse_local_logsum
=
lse_local_logsum
+
flash
::
__shfl_xor_tmp
(
lse_local_logsum
,
step
);
}
lse_local_logsum
=
__logf
(
lse_local_logsum
)
+
lse_max_local
;
float
attn_sink_o_scale
=
1.0
f
;
if
(
params
.
attn_sink
!=
nullptr
)
{
// 当前 reduce kernel 的 block_x 是按 b,h,s 展开的,所以 bidh 就是 head id。
float
rAttn_sink
=
params
.
attn_sink
[
block_x
%
64
];
int
attn_sink_idx
=
h
==
1
?
in_batch_offset
%
params
.
ngroups
:
block_x
%
64
;
float
rAttn_sink
=
params
.
attn_sink
[
attn_sink_idx
];
if
(
rAttn_sink
==
INFINITY
)
{
attn_sink_o_scale
=
0.0
f
;
...
...
@@ -588,11 +589,11 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
#pragma unroll
for
(
int
t
=
0
;
t
<
tx_float_count
;
t
+=
2
)
{
// half -> float32, reduce precision loss
float
a_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
i
>>
2
][
t
>>
1
][
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
splitkv_upcast_to_f32
<
accumType
>
(
load
[
i
>>
2
][
t
>>
1
][
1
])
:
0.
f
;
float
a_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
i
>>
2
][
t
>>
1
][
0
])
:
0.
f
;
float
b_f32
=
within_splits
?
flash
::
splitkv_upcast_to_f32
<
accumType
>
(
load
[
i
>>
2
][
t
>>
1
][
1
])
:
0.
f
;
// do rescale and sum
tx_accum
[
t
]
=
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
tx_accum
[
t
]
=
flash
::
__llvm_fma_f32
(
a_f32
,
s_scale
,
tx_accum
[
t
]);
tx_accum
[
t
+
1
]
=
flash
::
__llvm_fma_f32
(
b_f32
,
s_scale
,
tx_accum
[
t
+
1
]);
}
}
// reduce across 4 waves
...
...
@@ -616,13 +617,13 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
// cvt
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
accum_result
=
flash
::
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
accum_result
[
0
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
accum_result
[
1
]
=
flash
::
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
+
1
]);
#endif
// storation
*
(
vec2_Element
<
reduceType
>*
)(
output_ptr
+
t
)
=
accum_result
;
}
}
}
\ No newline at end of file
}
flash_mla/flash_mla_interface.py
View file @
2c35de66
...
...
@@ -151,7 +151,10 @@ def flash_mla_with_kvcache(
if
topk
is
not
None
:
# Sparse attention
assert
not
causal
,
"causal must be False when sparse attention is enabled"
assert
is_fp8_kvcache
,
"is_fp8_kvcache must be True when sparse attention is enabled"
if
not
is_fp8_kvcache
:
assert
k_cache
.
dtype
==
torch
.
bfloat16
,
"BF16 sparse attention requires k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
if
extra_k_cache
is
not
None
:
assert
extra_k_cache
.
dtype
==
torch
.
bfloat16
,
"BF16 sparse attention requires extra_k_cache dtype to be torch.bfloat16 when is_fp8_kvcache is False"
out
,
lse
,
new_tile_scheduler_metadata
,
new_num_splits
=
flash_mla_cuda
.
sparse_decode_fwd
(
q
,
k_cache
,
indices_in_kvcache
,
topk_length
,
attn_sink
,
sched_meta
.
tile_scheduler_metadata
,
sched_meta
.
num_splits
,
...
...
@@ -640,4 +643,4 @@ def flash_mla_with_kvcache_quantization_q_nope_pe(
# )
# sched_meta.tile_scheduler_metadata = new_tile_scheduler_metadata
# sched_meta.num_splits = new_num_splits
# return (out, lse)
\ No newline at end of file
# return (out, lse)
setup.py
View file @
2c35de66
...
...
@@ -93,6 +93,7 @@ ext_modules.append(
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h16.cu"
,
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h64.cu"
,
"csrc/gfx93/decode/sparse_fp8/instantiations/v32_persistent_h128.cu"
,
"csrc/gfx93/decode/sparse_bf16_dsa/fwd.cu"
,
# # gfx93 sparse prefill
"csrc/gfx93/prefill/sparse/fwd.cu"
,
...
...
tests/lib.py
View file @
2c35de66
...
...
@@ -14,6 +14,9 @@ class TestTarget(enum.Enum):
FWD
=
0
DECODE
=
1
def
is_decode_bf16_kvcache
()
->
bool
:
return
os
.
environ
.
get
(
"FLASH_MLA_DECODE_BF16"
,
""
).
lower
()
in
[
"1"
,
"true"
,
"yes"
,
"y"
,
"bf16"
]
@
dataclasses
.
dataclass
class
ExtraTestParamForDecode
:
b
:
int
...
...
@@ -42,6 +45,21 @@ class TestParam:
have_topk_length
:
bool
=
False
decode
:
Optional
[
ExtraTestParamForDecode
]
=
None
def
is_bf16_decode_supported_param
(
t
:
TestParam
)
->
bool
:
if
t
.
decode
is
None
:
return
False
if
t
.
is_all_indices_invalid
or
t
.
decode
.
have_zero_seqlen_k
:
return
False
if
t
.
h_kv
!=
1
or
t
.
d_v
!=
512
:
return
False
if
t
.
h_q
not
in
[
64
,
128
]:
return
False
if
t
.
d_qk
not
in
[
512
,
576
]:
return
False
if
t
.
decode
.
extra_topk
is
None
:
return
t
.
topk
<=
1024
return
t
.
topk
<=
256
and
t
.
decode
.
extra_topk
<=
1024
@
dataclasses
.
dataclass
class
RawTestParamForDecode
:
"""
...
...
@@ -289,14 +307,16 @@ def generate_testcase_for_decode(t: TestParam) -> TestcaseForDecode:
return
KVScope
(
t
,
cache_seqlens
,
block_table
,
blocked_k
,
abs_indices
,
indices_in_kvcache
,
topk_length
)
kv_scope0
=
generate_one_k_scope
(
t
.
s_kv
,
t
.
decode
.
block_size
,
t
.
topk
,
t
.
decode
.
is_varlen
,
t
.
decode
.
have_zero_seqlen_k
,
t
.
is_all_indices_invalid
,
t
.
have_topk_length
)
kv_scope0
.
quant_and_dequant_
()
if
not
is_decode_bf16_kvcache
():
kv_scope0
.
quant_and_dequant_
()
if
t
.
decode
.
extra_topk
is
not
None
:
if
t
.
decode
.
extra_s_k
is
None
:
t
.
decode
.
extra_s_k
=
t
.
decode
.
extra_topk
*
2
if
t
.
decode
.
extra_block_size
is
None
:
t
.
decode
.
extra_block_size
=
t
.
decode
.
block_size
kv_scope1
=
generate_one_k_scope
(
t
.
decode
.
extra_s_k
,
t
.
decode
.
extra_block_size
,
t
.
decode
.
extra_topk
,
t
.
decode
.
is_varlen
,
t
.
decode
.
have_zero_seqlen_k
,
t
.
is_all_indices_invalid
,
t
.
decode
.
have_extra_topk_length
)
kv_scope1
.
quant_and_dequant_
()
if
not
is_decode_bf16_kvcache
():
kv_scope1
.
quant_and_dequant_
()
else
:
assert
t
.
decode
.
extra_block_size
is
None
and
t
.
decode
.
extra_s_k
is
None
and
not
t
.
decode
.
have_extra_topk_length
kv_scope1
=
None
...
...
@@ -318,16 +338,17 @@ def run_flash_mla_sparse_fwd(p: TestParam, t: Testcase, return_p_sum: bool):
def
run_flash_mla_decode
(
p
:
TestParam
,
t
:
TestcaseForDecode
,
tile_scheduler_metadata
,
num_splits
):
assert
p
.
decode
is
not
None
is_fp8_kvcache
=
not
is_decode_bf16_kvcache
()
return
flash_mla
.
flash_mla_with_kvcache
(
t
.
q
,
t
.
kv_scope
.
get_kvcache_for_flash_mla
(),
t
.
kv_scope
.
get_kvcache_for_flash_mla
()
if
is_fp8_kvcache
else
t
.
kv_scope
.
blocked_k
,
None
,
None
,
p
.
d_v
,
tile_scheduler_metadata
,
num_splits
,
t
.
sm_scale
,
False
,
Tru
e
,
t
.
sm_scale
,
False
,
is_fp8_kvcach
e
,
t
.
kv_scope
.
indices_in_kvcache
,
t
.
attn_sink
,
t
.
extra_kv_scope
.
get_kvcache_for_flash_mla
()
if
t
.
extra_kv_scope
is
not
None
else
None
,
(
t
.
extra_kv_scope
.
get_kvcache_for_flash_mla
()
if
is_fp8_kvcache
else
t
.
extra_kv_scope
.
blocked_k
)
if
t
.
extra_kv_scope
is
not
None
else
None
,
t
.
extra_kv_scope
.
indices_in_kvcache
if
t
.
extra_kv_scope
is
not
None
else
None
,
t
.
kv_scope
.
topk_length
,
t
.
extra_kv_scope
.
topk_length
if
t
.
extra_kv_scope
is
not
None
and
t
.
extra_kv_scope
.
topk_length
is
not
None
else
None
...
...
tests/test_flash_mla_sparse_decoding.py
View file @
2c35de66
...
...
@@ -172,8 +172,12 @@ def test_flash_mla(p: TestParam) -> Result:
else
:
result
=
kk
.
bench_kineto
(
run_decode
,
p
.
num_runs
)
splitkv_kernel_name
=
"flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name
=
"flash_fwd_mla_combine_kernel"
if
lib
.
is_decode_bf16_kvcache
():
splitkv_kernel_name
=
"flash_fwd_mla_decode_kernel_gfx938_dsa_nopage_64_splitkv"
combine_kernel_name
=
"flash_mla_splitkv_reduce_kernel"
else
:
splitkv_kernel_name
=
"flash_fwd_splitkv_mla_fp8_sparse_kernel"
combine_kernel_name
=
"flash_fwd_mla_combine_kernel"
# Get individual kernel time usages
kernel_time_usages_us
:
Dict
[
str
,
Optional
[
float
]]
=
{}
...
...
@@ -226,8 +230,11 @@ def test_flash_mla(p: TestParam) -> Result:
out_ref
,
lse_ref
=
ref
.
ref_sparse_attn_decode
(
p
,
t
)
is_out_correct
=
kk
.
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
1e-3
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
is_lse_correct
=
kk
.
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
is_correct
&=
is_out_correct
and
is_lse_correct
if
lib
.
is_decode_bf16_kvcache
():
is_correct
&=
is_out_correct
else
:
is_lse_correct
=
kk
.
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
is_correct
&=
is_out_correct
and
is_lse_correct
performance_result
.
is_correct
=
is_correct
return
performance_result
...
...
@@ -250,6 +257,22 @@ def main():
raw_testcases
=
gen_testcase
()
testcases
=
[
t
.
to_test_param
()
for
t
in
raw_testcases
]
if
lib
.
is_decode_bf16_kvcache
():
bf16_testcases
=
[]
seen_bf16_cases
=
set
()
for
t
in
testcases
:
if
not
lib
.
is_bf16_decode_supported_param
(
t
):
continue
if
t
.
num_runs
>
0
and
t
.
decode
.
b
>
16
:
t
=
dataclasses
.
replace
(
t
,
decode
=
dataclasses
.
replace
(
t
.
decode
,
b
=
16
))
key
=
dataclasses
.
asdict
(
t
)
key
[
"decode"
]
=
tuple
(
key
[
"decode"
].
items
())
if
key
[
"decode"
]
is
not
None
else
None
key
=
tuple
(
key
.
items
())
if
key
in
seen_bf16_cases
:
continue
seen_bf16_cases
.
add
(
key
)
bf16_testcases
.
append
(
t
)
testcases
=
bf16_testcases
print
(
f
"
{
kk
.
colors
[
'CYAN_BG'
]
}{
len
(
testcases
)
}
testcases to run
{
kk
.
colors
[
'CLEAR'
]
}
"
)
...
...
tests/test_flash_mla_sparse_prefill.py
View file @
2c35de66
...
...
@@ -10,6 +10,26 @@ import ref
_counter
=
kk
.
Counter
()
def
is_dsa_mls_prefill_case
(
p
:
TestParam
)
->
bool
:
if
p
.
d_v
!=
512
:
return
False
if
p
.
d_qk
not
in
[
512
,
576
]:
return
False
if
p
.
h_kv
!=
1
:
return
False
if
p
.
h_q
not
in
[
64
,
128
]:
return
False
if
not
(
p
.
topk
<=
1024
or
p
.
topk
==
2048
):
return
False
if
p
.
topk
==
2048
and
(
p
.
have_attn_sink
or
p
.
have_topk_length
):
return
False
if
p
.
d_qk
==
512
and
((
p
.
h_q
==
64
and
p
.
topk
==
512
)
or
(
p
.
h_q
==
128
and
p
.
topk
==
1024
)):
return
True
if
p
.
d_qk
==
576
and
p
.
h_q
==
64
and
p
.
topk
==
2048
and
p
.
s_kv
>=
32768
:
return
True
return
False
@
torch
.
inference_mode
()
def
run_test
(
p
:
TestParam
)
->
bool
:
if
p
.
seed
==
-
1
:
...
...
@@ -31,7 +51,12 @@ def run_test(p: TestParam) -> bool:
if
p
.
num_runs
>
0
:
flops_and_mem_vol
=
lib
.
count_flop_and_mem_vol
(
p
,
t
)
prefill_ans_time
=
kk
.
bench_kineto
(
run_prefill
,
num_tests
=
p
.
num_runs
).
get_kernel_time
(
"sparse_attn_fwd"
)
bench_result
=
kk
.
bench_kineto
(
run_prefill
,
num_tests
=
p
.
num_runs
)
kernel_names
=
bench_result
.
get_kernel_names
()
prefill_kernel_name
=
"sparse_attn_fwd"
if
not
any
(
prefill_kernel_name
in
name
for
name
in
kernel_names
):
prefill_kernel_name
=
"flash_fwd_mla_decode_kernel_gfx938_dsa_prefill_nopage_64"
prefill_ans_time
=
bench_result
.
get_kernel_time
(
prefill_kernel_name
)
prefill_flops
=
flops_and_mem_vol
.
fwd_flop
/
prefill_ans_time
/
1e12
prefill_mem_bw
=
flops_and_mem_vol
.
fwd_mem_vol
/
prefill_ans_time
/
1e12
print
(
f
"Prefill:
{
prefill_ans_time
*
1e6
:
4.0
f
}
us,
{
prefill_flops
:
6.1
f
}
TFlops,
{
prefill_mem_bw
:
4.2
f
}
TBps"
)
...
...
@@ -44,8 +69,9 @@ def run_test(p: TestParam) -> bool:
is_correct
=
True
is_correct
&=
kk
.
check_is_allclose
(
"out"
,
prefill_ans_out
.
float
(),
ref_out_fp32
,
abs_tol
=
8e-4
,
rel_tol
=
3.01
/
128
,
cos_diff_tol
=
7e-6
)
is_correct
&=
kk
.
check_is_allclose
(
"max_logits"
,
prefill_ans_max_logits
,
ref_max_logits
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
kk
.
check_is_allclose
(
"lse"
,
prefill_ans_lse
,
ref_lse
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
if
not
is_dsa_mls_prefill_case
(
p
):
is_correct
&=
kk
.
check_is_allclose
(
"max_logits"
,
prefill_ans_max_logits
,
ref_max_logits
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
is_correct
&=
kk
.
check_is_allclose
(
"lse"
,
prefill_ans_lse
,
ref_lse
,
abs_tol
=
1e-6
,
rel_tol
=
2.01
/
65536
)
return
is_correct
else
:
...
...
@@ -187,4 +213,3 @@ if __name__ == '__main__':
sys
.
exit
(
1
)
else
:
print
(
f
"
\033
[32m
\033
[1mAll
{
len
(
testcases
)
}
cases passed!
\033
[0m"
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment