Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
8b581a54
Commit
8b581a54
authored
Apr 29, 2026
by
zhanghj2
Browse files
支持kme dense bf16
parent
c85c787e
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
455 additions
and
18 deletions
+455
-18
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+7
-7
csrc/gfx93/decode/dense/splitkv_mla.cuh
csrc/gfx93/decode/dense/splitkv_mla.cuh
+367
-2
csrc/gfx93/decode/dense/traits.h
csrc/gfx93/decode/dense/traits.h
+8
-0
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
+2
-0
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+2
-2
csrc/params.h
csrc/params.h
+1
-1
csrc/utils.h
csrc/utils.h
+10
-0
setup.py
setup.py
+1
-1
tests/test_flash_mla_dense_decoding.py
tests/test_flash_mla_dense_decoding.py
+13
-4
tests/test_flash_mla_fp8.py
tests/test_flash_mla_fp8.py
+7
-0
tests/test_flash_mla_qkvfp8.py
tests/test_flash_mla_qkvfp8.py
+6
-1
tests/test_flash_mla_qkvfp8_with_cat.py
tests/test_flash_mla_qkvfp8_with_cat.py
+6
-0
tests/test_flash_mla_sparse_decoding.py
tests/test_flash_mla_sparse_decoding.py
+6
-0
tests/test_flash_mla_sparse_prefill.py
tests/test_flash_mla_sparse_prefill.py
+6
-0
tests/test_flash_mla_with_q_concat.py
tests/test_flash_mla_with_q_concat.py
+6
-0
tests/test_flash_mla_with_q_concat_fp8.py
tests/test_flash_mla_with_q_concat_fp8.py
+7
-0
No files found.
csrc/api/dense_decode.h
View file @
8b581a54
...
...
@@ -24,9 +24,9 @@ dense_attn_decode_interface(
)
{
// Check arch
Arch
arch
=
Arch
();
if
(
!
arch
.
is_gfx93x
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx936 or gfx938 architecture"
);
}
//
if (!arch.is_gfx93x()) {
//
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
//
}
// Check data types
auto
q_dtype
=
q
.
dtype
();
...
...
@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS
(
out
);
KU_CHECK_CONTIGUOUS
(
lse
);
if
(
!
tile_scheduler_metadata
.
has_value
()
&&
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
))
{
if
(
!
tile_scheduler_metadata
.
has_value
()
&&
(
arch
.
is_gfx928
()
||
(
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
))
{
tile_scheduler_metadata
=
torch
::
empty
({
num_sm_parts
,
sizeof
(
DecodingSchedMeta
)
/
4
},
opts
.
dtype
(
torch
::
kInt32
));
num_splits
=
torch
::
empty
({
batch_size
+
1
},
opts
.
dtype
(
torch
::
kInt32
));
KU_CHECK_CONTIGUOUS
(
tile_scheduler_metadata
);
...
...
@@ -159,8 +159,8 @@ dense_attn_decode_interface(
params
.
block_table
=
block_table
.
data_ptr
<
int
>
();
params
.
block_table_batch_stride
=
block_table
.
stride
(
0
);
params
.
page_block_size
=
page_block_size
;
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
)
{
params
.
is_gfx928
=
arch
.
is_gfx928
();
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
arch
.
is_gfx928
()
)
{
params
.
tile_scheduler_metadata_ptr
=
(
DecodingSchedMeta
*
)
tile_scheduler_metadata
->
data_ptr
();
params
.
num_sm_parts
=
num_sm_parts
;
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
...
...
@@ -271,7 +271,7 @@ dense_attn_decode_interface(
params
.
partition_block_nums
};
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
params
.
use_split_kv
)
{
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
params
.
use_split_kv
||
arch
.
is_gfx928
()
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
...
...
csrc/gfx93/decode/dense/splitkv_mla.cuh
View file @
8b581a54
...
...
@@ -1178,6 +1178,364 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
}
}
template
<
typename
T
>
__device__
void
compute_attn_1rowblock_splitkv_mla_gfx928
(
const
DenseAttnDecodeParams
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
)
{
extern
__shared__
char
shared_memory
[];
using
SharedMemoryPlan
=
typename
T
::
SharedMemoryPlan
;
SharedMemoryPlan
&
plan
=
*
reinterpret_cast
<
SharedMemoryPlan
*>
(
shared_memory
);
const
int
tidx
=
threadIdx
.
x
;
constexpr
int
kBlockM
=
T
::
BLOCK_SIZE_M
;
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
constexpr
int
kHeadDim
=
T
::
HEAD_DIM_K
;
constexpr
int
kHeadDimV
=
T
::
HEAD_DIM_V
;
using
Element
=
T
::
InputT
;
using
index_t
=
int64_t
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
const
index_t
row_offset_k
=
(
bidh
)
*
params
.
k_head_stride
;
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_q
.
data
()),
typename
T
::
SmemLayoutQ
{});
Tensor
sV
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
typename
T
::
SmemLayoutV
{});
Tensor
sK
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
typename
T
::
SmemLayoutK
{});
Tensor
sP
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_p
.
data
()),
typename
T
::
SmemLayoutP
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
T
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
T
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sRow_max_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_max
.
data
()),
typename
T
::
SmemLayoutRow
{});
Tensor
sRow_sum_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_sum
.
data
()),
typename
T
::
SmemLayoutRow
{});
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
Element
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NT
>
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
using
TiledMma
=
TiledMMA
<
MMA_Atom_Arch
,
Layout
<
Shape
<
_1
,
Int
<
4
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
typename
T
::
TiledMma_O
tiled_mma_o
;
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
#if 1
typename
T
::
GmemTiledCopyQ
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gQ
),
size
<
1
>
(
gQ
)));
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQgQ
)));
if
(
tidx
<
128
)
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
true
,
false
>
(
gmem_tiled_copy_Q
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
params
.
q_seq_per_hk
-
m_block
*
kBlockM
);
__syncthreads
();
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
#else
#endif
typename
T
::
GmemTiledCopyK
gmem_tiled_copy_K
;
auto
gmem_thr_copy_K
=
gmem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tKgK
=
gmem_thr_copy_K
.
partition_S
(
gK
);
Tensor
tKsK
=
gmem_thr_copy_K
.
partition_D
(
sK
);
Tensor
cK
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gK
),
size
<
1
>
(
gK
)));
Tensor
tKcK
=
gmem_thr_copy_K
.
partition_S
(
cK
);
Tensor
tKpK
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKgK
)));
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSgK
=
smem_thr_copy_K
.
partition_S
(
gK
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
Tensor
tKcK_smem
=
smem_thr_copy_K
.
partition_S
(
cK
);
Tensor
tKpK_smem
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tSgK
)));
Tensor
tSrK_smem
=
thr_mma
.
partition_fragment_B
(
gK
);
typename
T
::
GmemTiledCopyV
gmem_tiled_copy_V
;
auto
gmem_thr_copy_V
=
gmem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tVgV
=
gmem_thr_copy_V
.
partition_S
(
gV
);
Tensor
tVsV
=
gmem_thr_copy_V
.
partition_D
(
sV
);
Tensor
cV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gV
),
size
<
1
>
(
gV
)));
Tensor
tVcV
=
gmem_thr_copy_V
.
partition_S
(
cV
);
Tensor
tVpV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tVgV
)));
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
Copy_Atom
<
GFX928_DS_READ_DS_M32x16_B16
,
Element
>
{},
tiled_mma_o
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
Tensor
tOrVt
=
thr_mma_o
.
partition_fragment_B
(
sVtNoSwizzle
);
constexpr
int
n_masking_steps
=
!
T
::
Is_causal
?
1
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
;
const
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
int
n_block
=
n_block_max
-
1
;
// constexpr static int k0_lds_loops = 0;
constexpr
static
int
k0_lds_loops
=
16
;
constexpr
static
int
k0_loops
=
size
<
2
>
(
tSrK_smem
);
constexpr
static
int
k1_loops
=
size
<
2
>
(
tOrVt
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
clear
(
acc_o
);
flash
::
Softmax
<
size
<
1
>
(
acc_o
)
>
softmax
;
int
cur_block_table
;
index_t
offset_k
;
constexpr
static
int
BUFFER_SIZE
=
4
;
uint128_t
buffer
[
BUFFER_SIZE
];
if
(
n_block
>=
n_block_min
)
{
cur_block_table
=
block_table
[
n_block
];
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[
0
],
0
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[
1
],
1
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[
2
],
2
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
}
#if 1
#pragma unroll
for
(
int
masking_step
=
n_masking_steps
;
n_block
>=
n_block_min
;
--
masking_step
,
--
n_block
)
{
asm
volatile
(
"s_barrier
\n\t
"
);
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
clear
(
acc_s
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// 计算0~11
#if 1
#pragma unroll
for
(
int
i
=
0
;
i
<
k0_lds_loops
-
BUFFER_SIZE
+
1
;
i
++
)
{
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
flash
::
asm_ds_write
(
buffer
[
i
%
BUFFER_SIZE
],
tKsK
,
i
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
i
),
tSrK_copy_view
(
_
,
_
,
i
));
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[(
i
+
BUFFER_SIZE
-
1
)
%
BUFFER_SIZE
],
i
+
BUFFER_SIZE
-
1
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
// asm volatile("s_barrier\n\t");
}
// asm volatile("s_barrier\n\t");
#endif
#if 0
#else
// 计算 13-15
const
int
k_idx
=
k0_lds_loops
-
BUFFER_SIZE
+
1
;
flash
::
asm_ds_write
(
buffer
[
k_idx
%
BUFFER_SIZE
],
tKsK
,
k_idx
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
1
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
1
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
1
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
1
),
tSrK
(
_
,
_
,
k_idx
+
1
),
acc_s
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
2
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
2
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
2
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
2
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
2
),
tSrK
(
_
,
_
,
k_idx
+
2
),
acc_s
);
// asm volatile("s_barrier\n\t");
// 读取16-17
flash
::
buffer_load_copy
<
false
,
true
,
true
>
(
gK
,
buffer
[
1
],
16
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
true
>
(
gK
,
buffer
[
2
],
17
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
16
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
asm
volatile
(
"s_barrier
\n\t
"
);
#endif
const
bool
is_masking_step
=
masking_step
>
0
;
const
bool
is_first_masking_step
=
masking_step
==
n_masking_steps
;
if
(
is_masking_step
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
T
::
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
n_block
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
n_block
*
kBlockN
-
(
params
.
q_seq_per_hk
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
q_head_per_hk
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
is_first_masking_step
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
)
:
is_masking_step
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
)
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*//*Is_local=*/
false
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
tOrP
=
flash
::
convert_layout_acc_Aregs
(
tiled_mma
,
tiled_mma_o
,
rP
,
sP
);
__syncthreads
();
#if 1
// 第15块已经读取到了buffer[3]中
flash
::
asm_ds_write
(
buffer
[
3
],
tVsV
,
15
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
#endif
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
if
(
n_block
>
n_block_min
)
{
cur_block_table
=
block_table
[
n_block
-
1
];
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
flash
::
buffer_load_copy
<
true
,
true
,
false
>
(
gK
,
buffer
[
0
],
0
,
params
.
k_row_stride
,
offset_k
);
flash
::
buffer_load_copy
<
true
,
true
,
false
>
(
gK
,
buffer
[
1
],
1
,
params
.
k_row_stride
,
offset_k
);
flash
::
buffer_load_copy
<
true
,
true
,
false
>
(
gK
,
buffer
[
2
],
2
,
params
.
k_row_stride
,
offset_k
);
}
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
#pragma unroll
for
(
int
i
=
0
;
i
<
k1_loops
;
i
++
)
{
cute
::
copy
(
smem_tiled_copy_V
,
tOsVt
(
_
,
_
,
i
),
tOrVt_copy_view
(
_
,
_
,
i
));
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
i
),
tOrVt
(
_
,
_
,
i
),
acc_o
);
}
asm
volatile
(
" s_barrier
\n\t
"
);
}
#endif
using
ElementAccum
=
float
;
if
(
NoSplit
)
{
using
ElementO
=
Element
;
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
scale_softmax
);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lse
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
Tensor
rO
=
flash
::
convert_type
<
ElementO
>
(
acc_o
);
if
(
get
<
1
>
(
taccOcO
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
mi
,
0
));
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
acc_o
);
n
++
)
{
col
=
(
tidx
%
64
/
16
)
+
warpid
*
32
+
n
*
128
;
for
(
int
ei
=
0
;
ei
<
8
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
rO
(
ei
,
m
,
n
);
col
+=
4
;
}
}
}
}
}
}
else
{
using
ElementO
=
float
;
int
split_idx
=
params
.
num_splits_ptr
[
bidb
]
+
n_split_idx
;
constexpr
bool
Split
=
true
;
const
index_t
row_offset_oaccum
=
((
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
)
*
T
::
HEAD_DIM_V
;
// (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const
index_t
row_offset_lseaccum
=
(
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_oaccum
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lseaccum
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
scale_softmax
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
if
(
get
<
1
>
(
taccOcO
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
mi
,
0
));
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
acc_o
);
n
++
)
{
col
=
(
tidx
%
64
/
16
)
+
warpid
*
32
+
n
*
128
;
for
(
int
ei
=
0
;
ei
<
8
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
acc_o
(
ei
,
m
,
n
);
col
+=
4
;
}
}
}
}
}
}
}
template
<
typename
T
>
__global__
void
__launch_bounds__
(
T
::
NUM_THREADS
,
1
)
...
...
@@ -1199,9 +1557,15 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
if
(
batch_idx
>
sched_meta
.
begin_req_idx
)
{
__syncthreads
();
}
#if defined(__gfx928__)
compute_attn_1rowblock_splitkv_mla_gfx928
<
T
>
(
params
,
batch_idx
,
bidh
,
m_block
,
n_split_idx
,
seqlen_k
,
start_block_idx
,
end_block_idx
,
is_no_split
);
#else
compute_attn_1rowblock_splitkv_mla_gfx936
<
T
>
(
params
,
batch_idx
,
bidh
,
m_block
,
n_split_idx
,
seqlen_k
,
start_block_idx
,
end_block_idx
,
is_no_split
);
#endif
}
}
...
...
@@ -1209,6 +1573,7 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
template
<
typename
T
,
bool
use_split_kv
=
false
>
__global__
void
__launch_bounds__
(
T
::
NUM_THREADS
,
1
)
flash_fwd_splitkv_mla_block_m_64_kernel
(
const
DenseAttnDecodeParams
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
const
int
m_block
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
...
...
@@ -1908,7 +2273,7 @@ flash_fwd_splitkv_mla_block_m_64_kernel(const DenseAttnDecodeParams params) {
}
}
}
#endif
}
...
...
@@ -1918,7 +2283,7 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) {
FLASH_ASSERT
(
params
.
d_v
==
Config
::
HEAD_DIM_V
);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
params
.
h_q
>=
64
&&
params
.
h_k
==
1
)
{
if
(
params
.
h_q
>=
64
&&
params
.
h_k
==
1
&&
!
params
.
is_gfx928
)
{
using
T
=
Traits_Block_M_64
<
InputT
,
Is_causal
>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
if
(
params
.
use_split_kv
)
...
...
csrc/gfx93/decode/dense/traits.h
View file @
8b581a54
...
...
@@ -102,6 +102,14 @@ struct Traits {
GmemLayoutAtomQ
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
using
GmemLayoutAtomK
=
Layout
<
Shape
<
_64
,
_4
>
,
Stride
<
_4
,
_1
>>
;
using
GmemTiledCopyK
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomK
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
using
GmemTiledCopyV
=
GmemTiledCopyK
;
struct
SharedMemoryPlan
{
...
...
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
View file @
8b581a54
...
...
@@ -969,7 +969,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnD
template
<
typename
Kernel
>
__global__
void
__launch_bounds__
(
Kernel
::
NUM_THREADS
,
1
)
flash_fwd_splitkv_mla_fp8_sparse_kernel
(
const
SparseAttnDecodeParams
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
Kernel
::
devfunc
(
params
);
#endif
}
template
<
ModelType
MODEL_TYPE
,
int
NUM_HEADS
>
...
...
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
8b581a54
...
...
@@ -1287,9 +1287,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
template
<
typename
Kernel
>
__global__
void
__launch_bounds__
(
Kernel
::
NUM_THREADS
,
1
)
sparse_attn_fwd_kernel
(
const
SparseAttnFwdParams
params
)
{
//
#if defined(__gfx936__)
#if defined(__gfx936__)
|| defined(__gfx938__)
Kernel
::
devfunc
(
params
);
//
#endif
#endif
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
...
...
csrc/params.h
View file @
8b581a54
...
...
@@ -61,7 +61,7 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
bool
use_split_kv
;
int
partition_block_nums
;
bool
is_gfx928
;
};
struct
DenseAttnDecodeParams_fp8
:
public
DenseAttnDecodeParams
{
...
...
csrc/utils.h
View file @
8b581a54
...
...
@@ -621,7 +621,17 @@ lds_direct_copy_for_prefill_sparse_mla(
:
);
}
template
<
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
asm_ds_write
(
const
uint128_t
&
src
,
Tensor
<
SrcEngine
,
SrcLayout
>
&
dst
,
int
k_idx
)
{
uint128_t
*
d
=
reinterpret_cast
<
uint128_t
*>
(
&
dst
(
0
,
0
,
k_idx
));
d
[
0
]
=
src
;
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
...
...
setup.py
View file @
8b581a54
...
...
@@ -31,7 +31,7 @@ def get_features_args():
def
get_arch_flags
():
arch_flags
=
[]
arch_flags
.
append
(
"--offload-arch=gfx938;gfx936"
)
arch_flags
.
append
(
"--offload-arch=gfx938;gfx936
;gfx928
"
)
return
arch_flags
# def get_nvcc_thread_args():
...
...
tests/test_flash_mla_dense_decoding.py
View file @
8b581a54
...
...
@@ -140,6 +140,9 @@ def reference_torch(
out_ref
=
out_ref
.
to
(
q
.
dtype
)
return
out_ref
,
lse_ref
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
@
torch
.
inference_mode
()
def
test_flash_mla
(
t
:
TestParam
):
...
...
@@ -166,6 +169,12 @@ def test_flash_mla(t: TestParam):
out_ans
,
lse_ans
=
run_flash_mla
()
out_ref
,
lse_ref
=
reference_torch
(
cache_seqlens
,
block_table
,
q
,
blocked_k
,
t
.
dv
,
t
.
is_causal
)
if
get_gcn_arch_name
()
==
"gfx928"
:
lse_abs_diff
=
(
lse_ans
-
lse_ref
).
max
().
abs
().
item
()
out_abs_diff
=
(
out_ref
-
out_ans
).
max
().
abs
().
item
()
print
(
"lse_abs_diff "
,
lse_abs_diff
,
out_abs_diff
)
assert
out_abs_diff
<=
4e-3
else
:
is_correct
=
True
is_correct
&=
kk
.
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
is_correct
&=
kk
.
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
...
...
tests/test_flash_mla_fp8.py
View file @
8b581a54
...
...
@@ -214,7 +214,14 @@ def main(torch_dtype, is_prof=False):
for
varlen
in
[
False
]:
test_flash_mla_fp8_e5m2
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
if
__name__
==
"__main__"
:
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
...
...
tests/test_flash_mla_qkvfp8.py
View file @
8b581a54
...
...
@@ -175,9 +175,14 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
(
torch_dtype
,
is_prof
=
False
):
if
get_gcn_arch_name
()
!=
"gfx938"
:
print
(
"[WARNING] The architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
init_dtype
=
torch
.
bfloat16
if
torch_dtype
==
torch
.
float8_e4m3fn
else
torch_dtype
torch
.
set_default_dtype
(
init_dtype
)
...
...
tests/test_flash_mla_qkvfp8_with_cat.py
View file @
8b581a54
...
...
@@ -252,8 +252,14 @@ def main(torch_dtype, is_prof=False):
# '''
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
if
__name__
==
"__main__"
:
if
get_gcn_arch_name
()
!=
"gfx938"
:
print
(
"[WARNING] The architecture is not supported."
)
exit
(
0
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
...
...
tests/test_flash_mla_sparse_decoding.py
View file @
8b581a54
...
...
@@ -232,8 +232,14 @@ def test_flash_mla(p: TestParam) -> Result:
performance_result
.
is_correct
=
is_correct
return
performance_result
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
():
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
...
...
tests/test_flash_mla_sparse_prefill.py
View file @
8b581a54
...
...
@@ -51,8 +51,14 @@ def run_test(p: TestParam) -> bool:
else
:
return
True
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
if
__name__
==
'__main__'
:
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_device
(
device
)
...
...
tests/test_flash_mla_with_q_concat.py
View file @
8b581a54
...
...
@@ -141,8 +141,14 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
(
torch_dtype
,
is_prof
=
False
):
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch_dtype
)
torch
.
set_default_device
(
device
)
...
...
tests/test_flash_mla_with_q_concat_fp8.py
View file @
8b581a54
...
...
@@ -168,7 +168,14 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
(
torch_dtype
,
is_prof
=
False
):
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch_dtype
)
torch
.
set_default_device
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment