Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
92a05388
Commit
92a05388
authored
Apr 29, 2026
by
zhanghj2
Browse files
Merge branch 'feature/kme-bf16' into 'master'
支持kme dense bf16 See merge request dcutoolkit/deeplearing/flashmla!11
parents
c85c787e
8b581a54
Changes
16
Show whitespace changes
Inline
Side-by-side
Showing
16 changed files
with
455 additions
and
18 deletions
+455
-18
csrc/api/dense_decode.h
csrc/api/dense_decode.h
+7
-7
csrc/gfx93/decode/dense/splitkv_mla.cuh
csrc/gfx93/decode/dense/splitkv_mla.cuh
+367
-2
csrc/gfx93/decode/dense/traits.h
csrc/gfx93/decode/dense/traits.h
+8
-0
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
+2
-0
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+2
-2
csrc/params.h
csrc/params.h
+1
-1
csrc/utils.h
csrc/utils.h
+10
-0
setup.py
setup.py
+1
-1
tests/test_flash_mla_dense_decoding.py
tests/test_flash_mla_dense_decoding.py
+13
-4
tests/test_flash_mla_fp8.py
tests/test_flash_mla_fp8.py
+7
-0
tests/test_flash_mla_qkvfp8.py
tests/test_flash_mla_qkvfp8.py
+6
-1
tests/test_flash_mla_qkvfp8_with_cat.py
tests/test_flash_mla_qkvfp8_with_cat.py
+6
-0
tests/test_flash_mla_sparse_decoding.py
tests/test_flash_mla_sparse_decoding.py
+6
-0
tests/test_flash_mla_sparse_prefill.py
tests/test_flash_mla_sparse_prefill.py
+6
-0
tests/test_flash_mla_with_q_concat.py
tests/test_flash_mla_with_q_concat.py
+6
-0
tests/test_flash_mla_with_q_concat_fp8.py
tests/test_flash_mla_with_q_concat_fp8.py
+7
-0
No files found.
csrc/api/dense_decode.h
View file @
92a05388
...
...
@@ -24,9 +24,9 @@ dense_attn_decode_interface(
)
{
// Check arch
Arch
arch
=
Arch
();
if
(
!
arch
.
is_gfx93x
())
{
TORCH_CHECK
(
false
,
"Dense decode MLA is only supported on gfx936 or gfx938 architecture"
);
}
//
if (!arch.is_gfx93x()) {
//
TORCH_CHECK(false, "Dense decode MLA is only supported on gfx936 or gfx938 architecture");
//
}
// Check data types
auto
q_dtype
=
q
.
dtype
();
...
...
@@ -92,7 +92,7 @@ dense_attn_decode_interface(
KU_CHECK_CONTIGUOUS
(
out
);
KU_CHECK_CONTIGUOUS
(
lse
);
if
(
!
tile_scheduler_metadata
.
has_value
()
&&
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
))
{
if
(
!
tile_scheduler_metadata
.
has_value
()
&&
(
arch
.
is_gfx928
()
||
(
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
))
{
tile_scheduler_metadata
=
torch
::
empty
({
num_sm_parts
,
sizeof
(
DecodingSchedMeta
)
/
4
},
opts
.
dtype
(
torch
::
kInt32
));
num_splits
=
torch
::
empty
({
batch_size
+
1
},
opts
.
dtype
(
torch
::
kInt32
));
KU_CHECK_CONTIGUOUS
(
tile_scheduler_metadata
);
...
...
@@ -159,8 +159,8 @@ dense_attn_decode_interface(
params
.
block_table
=
block_table
.
data_ptr
<
int
>
();
params
.
block_table_batch_stride
=
block_table
.
stride
(
0
);
params
.
page_block_size
=
page_block_size
;
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
)
{
params
.
is_gfx928
=
arch
.
is_gfx928
();
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
arch
.
is_gfx928
()
)
{
params
.
tile_scheduler_metadata_ptr
=
(
DecodingSchedMeta
*
)
tile_scheduler_metadata
->
data_ptr
();
params
.
num_sm_parts
=
num_sm_parts
;
params
.
num_splits_ptr
=
num_splits
->
data_ptr
<
int
>
();
...
...
@@ -271,7 +271,7 @@ dense_attn_decode_interface(
params
.
partition_block_nums
};
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
params
.
use_split_kv
)
{
if
((
num_heads_q
<
64
&&
num_heads_k
==
1
)
||
num_heads_k
>
1
||
params
.
use_split_kv
||
arch
.
is_gfx928
()
)
{
if
(
q_dtype
==
torch
::
kBFloat16
)
{
gfx9
::
decode
::
run_flash_mla_combine_kernel
<
cutlass
::
bfloat16_t
>
(
combine_params
);
}
else
if
(
q_dtype
==
torch
::
kHalf
)
{
...
...
csrc/gfx93/decode/dense/splitkv_mla.cuh
View file @
92a05388
...
...
@@ -1178,6 +1178,364 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
}
}
template
<
typename
T
>
__device__
void
compute_attn_1rowblock_splitkv_mla_gfx928
(
const
DenseAttnDecodeParams
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
)
{
extern
__shared__
char
shared_memory
[];
using
SharedMemoryPlan
=
typename
T
::
SharedMemoryPlan
;
SharedMemoryPlan
&
plan
=
*
reinterpret_cast
<
SharedMemoryPlan
*>
(
shared_memory
);
const
int
tidx
=
threadIdx
.
x
;
constexpr
int
kBlockM
=
T
::
BLOCK_SIZE_M
;
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
constexpr
int
kHeadDim
=
T
::
HEAD_DIM_K
;
constexpr
int
kHeadDimV
=
T
::
HEAD_DIM_V
;
using
Element
=
T
::
InputT
;
using
index_t
=
int64_t
;
const
index_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
+
bidh
*
params
.
q_head_stride
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
q_row_stride
,
_1
{}));
const
index_t
row_offset_k
=
(
bidh
)
*
params
.
k_head_stride
;
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
gV
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
params
.
k_row_stride
,
_1
{}));
Tensor
sQ
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_q
.
data
()),
typename
T
::
SmemLayoutQ
{});
Tensor
sV
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
typename
T
::
SmemLayoutV
{});
Tensor
sK
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_v
.
data
()),
typename
T
::
SmemLayoutK
{});
Tensor
sP
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_p
.
data
()),
typename
T
::
SmemLayoutP
{});
Tensor
sVt
=
make_tensor
(
sV
.
data
(),
typename
T
::
SmemLayoutVtransposed
{});
Tensor
sVtNoSwizzle
=
make_tensor
(
sV
.
data
(),
typename
T
::
SmemLayoutVtransposedNoSwizzle
{});
Tensor
sRow_max_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_max
.
data
()),
typename
T
::
SmemLayoutRow
{});
Tensor
sRow_sum_reduce_buffer
=
make_tensor
(
make_smem_ptr
(
plan
.
smem_row_sum
.
data
()),
typename
T
::
SmemLayoutRow
{});
using
MMA_Atom_Arch
=
std
::
conditional_t
<
std
::
is_same_v
<
Element
,
cutlass
::
half_t
>
,
MMA_Atom
<
GFX928_16x16x32_F32F16F16F32_NT
>
,
MMA_Atom
<
GFX928_16x16x32_F32BF16BF16F32_NT
>
>
;
using
ValLayoutMNK
=
Layout
<
Shape
<
_1
,
_1
,
_1
>>
;
using
TiledMma
=
TiledMMA
<
MMA_Atom_Arch
,
Layout
<
Shape
<
_1
,
Int
<
4
>
,
_1
>>
,
// 1x4x1 or 1x8x1 thread group
ValLayoutMNK
>
;
TiledMma
tiled_mma
;
auto
thr_mma
=
tiled_mma
.
get_thread_slice
(
tidx
);
typename
T
::
TiledMma_O
tiled_mma_o
;
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
#if 1
typename
T
::
GmemTiledCopyQ
gmem_tiled_copy_Q
;
auto
gmem_thr_copy_Q
=
gmem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tQgQ
=
gmem_thr_copy_Q
.
partition_S
(
gQ
);
Tensor
tQsQ
=
gmem_thr_copy_Q
.
partition_D
(
sQ
);
Tensor
cQ
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gQ
),
size
<
1
>
(
gQ
)));
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQgQ
)));
if
(
tidx
<
128
)
flash
::
copy
<
/*Is_even_MN=*/
false
,
/*Is_even_K=*/
true
,
false
>
(
gmem_tiled_copy_Q
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
params
.
q_seq_per_hk
-
m_block
*
kBlockM
);
__syncthreads
();
auto
smem_tiled_copy_Q
=
make_tiled_copy_A
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_Q
=
smem_tiled_copy_Q
.
get_thread_slice
(
tidx
);
Tensor
tSsQ
=
smem_thr_copy_Q
.
partition_S
(
sQ
);
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
sQ
);
Tensor
tSrQ_copy_view
=
smem_thr_copy_Q
.
retile_D
(
tSrQ
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
,
tSrQ_copy_view
);
__syncthreads
();
#else
#endif
typename
T
::
GmemTiledCopyK
gmem_tiled_copy_K
;
auto
gmem_thr_copy_K
=
gmem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tKgK
=
gmem_thr_copy_K
.
partition_S
(
gK
);
Tensor
tKsK
=
gmem_thr_copy_K
.
partition_D
(
sK
);
Tensor
cK
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gK
),
size
<
1
>
(
gK
)));
Tensor
tKcK
=
gmem_thr_copy_K
.
partition_S
(
cK
);
Tensor
tKpK
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tKgK
)));
auto
smem_tiled_copy_K
=
make_tiled_copy_B
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
tiled_mma
);
auto
smem_thr_copy_K
=
smem_tiled_copy_K
.
get_thread_slice
(
tidx
);
Tensor
tSgK
=
smem_thr_copy_K
.
partition_S
(
gK
);
Tensor
tSsK
=
smem_thr_copy_K
.
partition_S
(
sK
);
Tensor
tSrK
=
thr_mma
.
partition_fragment_B
(
sK
);
Tensor
tKcK_smem
=
smem_thr_copy_K
.
partition_S
(
cK
);
Tensor
tKpK_smem
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tSgK
)));
Tensor
tSrK_smem
=
thr_mma
.
partition_fragment_B
(
gK
);
typename
T
::
GmemTiledCopyV
gmem_tiled_copy_V
;
auto
gmem_thr_copy_V
=
gmem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tVgV
=
gmem_thr_copy_V
.
partition_S
(
gV
);
Tensor
tVsV
=
gmem_thr_copy_V
.
partition_D
(
sV
);
Tensor
cV
=
make_identity_tensor
(
make_shape
(
size
<
0
>
(
gV
),
size
<
1
>
(
gV
)));
Tensor
tVcV
=
gmem_thr_copy_V
.
partition_S
(
cV
);
Tensor
tVpV
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tVgV
)));
auto
smem_tiled_copy_V
=
make_tiled_copy_B
(
Copy_Atom
<
GFX928_DS_READ_DS_M32x16_B16
,
Element
>
{},
tiled_mma_o
);
auto
smem_thr_copy_V
=
smem_tiled_copy_V
.
get_thread_slice
(
tidx
);
Tensor
tOsVt
=
smem_thr_copy_V
.
partition_S
(
sVt
);
Tensor
tOrVt
=
thr_mma_o
.
partition_fragment_B
(
sVtNoSwizzle
);
constexpr
int
n_masking_steps
=
!
T
::
Is_causal
?
1
:
cute
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
;
const
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
int
n_block
=
n_block_max
-
1
;
// constexpr static int k0_lds_loops = 0;
constexpr
static
int
k0_lds_loops
=
16
;
constexpr
static
int
k0_loops
=
size
<
2
>
(
tSrK_smem
);
constexpr
static
int
k1_loops
=
size
<
2
>
(
tOrVt
);
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
clear
(
acc_o
);
flash
::
Softmax
<
size
<
1
>
(
acc_o
)
>
softmax
;
int
cur_block_table
;
index_t
offset_k
;
constexpr
static
int
BUFFER_SIZE
=
4
;
uint128_t
buffer
[
BUFFER_SIZE
];
if
(
n_block
>=
n_block_min
)
{
cur_block_table
=
block_table
[
n_block
];
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[
0
],
0
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[
1
],
1
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[
2
],
2
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
}
#if 1
#pragma unroll
for
(
int
masking_step
=
n_masking_steps
;
n_block
>=
n_block_min
;
--
masking_step
,
--
n_block
)
{
asm
volatile
(
"s_barrier
\n\t
"
);
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
clear
(
acc_s
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
// 计算0~11
#if 1
#pragma unroll
for
(
int
i
=
0
;
i
<
k0_lds_loops
-
BUFFER_SIZE
+
1
;
i
++
)
{
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
flash
::
asm_ds_write
(
buffer
[
i
%
BUFFER_SIZE
],
tKsK
,
i
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
i
),
tSrK_copy_view
(
_
,
_
,
i
));
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[(
i
+
BUFFER_SIZE
-
1
)
%
BUFFER_SIZE
],
i
+
BUFFER_SIZE
-
1
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
// asm volatile("s_barrier\n\t");
}
// asm volatile("s_barrier\n\t");
#endif
#if 0
#else
// 计算 13-15
const
int
k_idx
=
k0_lds_loops
-
BUFFER_SIZE
+
1
;
flash
::
asm_ds_write
(
buffer
[
k_idx
%
BUFFER_SIZE
],
tKsK
,
k_idx
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
1
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
1
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
1
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
1
),
tSrK
(
_
,
_
,
k_idx
+
1
),
acc_s
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
2
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
2
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
2
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
2
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
2
),
tSrK
(
_
,
_
,
k_idx
+
2
),
acc_s
);
// asm volatile("s_barrier\n\t");
// 读取16-17
flash
::
buffer_load_copy
<
false
,
true
,
true
>
(
gK
,
buffer
[
1
],
16
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
true
>
(
gK
,
buffer
[
2
],
17
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
16
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
asm
volatile
(
"s_barrier
\n\t
"
);
#endif
const
bool
is_masking_step
=
masking_step
>
0
;
const
bool
is_first_masking_step
=
masking_step
==
n_masking_steps
;
if
(
is_masking_step
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
T
::
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
n_block
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
n_block
*
kBlockN
-
(
params
.
q_seq_per_hk
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
q_head_per_hk
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
is_first_masking_step
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
true
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
)
:
is_masking_step
?
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*/
T
::
Is_causal
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
)
:
softmax
.
template
softmax_rescale_o
<
/*Is_first=*/
false
,
/*Check_inf=*//*Is_local=*/
false
>(
acc_s
,
acc_o
,
sRow_max_reduce_buffer
,
params
.
scale_softmax_log2
);
Tensor
rP
=
flash
::
convert_type
<
Element
>
(
acc_s
);
Tensor
tOrP
=
flash
::
convert_layout_acc_Aregs
(
tiled_mma
,
tiled_mma_o
,
rP
,
sP
);
__syncthreads
();
#if 1
// 第15块已经读取到了buffer[3]中
flash
::
asm_ds_write
(
buffer
[
3
],
tVsV
,
15
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
#endif
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
if
(
n_block
>
n_block_min
)
{
cur_block_table
=
block_table
[
n_block
-
1
];
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
flash
::
buffer_load_copy
<
true
,
true
,
false
>
(
gK
,
buffer
[
0
],
0
,
params
.
k_row_stride
,
offset_k
);
flash
::
buffer_load_copy
<
true
,
true
,
false
>
(
gK
,
buffer
[
1
],
1
,
params
.
k_row_stride
,
offset_k
);
flash
::
buffer_load_copy
<
true
,
true
,
false
>
(
gK
,
buffer
[
2
],
2
,
params
.
k_row_stride
,
offset_k
);
}
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
#pragma unroll
for
(
int
i
=
0
;
i
<
k1_loops
;
i
++
)
{
cute
::
copy
(
smem_tiled_copy_V
,
tOsVt
(
_
,
_
,
i
),
tOrVt_copy_view
(
_
,
_
,
i
));
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
i
),
tOrVt
(
_
,
_
,
i
),
acc_o
);
}
asm
volatile
(
" s_barrier
\n\t
"
);
}
#endif
using
ElementAccum
=
float
;
if
(
NoSplit
)
{
using
ElementO
=
Element
;
const
index_t
row_offset_o
=
bidb
*
params
.
o_batch_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
+
bidh
*
params
.
o_head_stride
;
const
index_t
row_offset_lse
=
(
bidb
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
constexpr
bool
Split
=
false
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_o
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
scale_softmax
);
// if (block0() && tidx < 64)
// {
// printf(" %.3f %.3f \n", float(acc_o(0)), float(acc_o(1)));
// }
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lse
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
Tensor
rO
=
flash
::
convert_type
<
ElementO
>
(
acc_o
);
if
(
get
<
1
>
(
taccOcO
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
mi
,
0
));
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
acc_o
);
n
++
)
{
col
=
(
tidx
%
64
/
16
)
+
warpid
*
32
+
n
*
128
;
for
(
int
ei
=
0
;
ei
<
8
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
rO
(
ei
,
m
,
n
);
col
+=
4
;
}
}
}
}
}
}
else
{
using
ElementO
=
float
;
int
split_idx
=
params
.
num_splits_ptr
[
bidb
]
+
n_split_idx
;
constexpr
bool
Split
=
true
;
const
index_t
row_offset_oaccum
=
((
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
)
*
T
::
HEAD_DIM_V
;
// (BLOCK_SIZE_M, HEAD_DIM_V) : (HEAD_DIM_V, 1)
const
index_t
row_offset_lseaccum
=
(
split_idx
*
params
.
h_k
+
bidh
)
*
params
.
q_seq_per_hk
+
m_block
*
kBlockM
;
Tensor
gOaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementO
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
(
row_offset_oaccum
)),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
Split
?
kHeadDimV
:
params
.
o_row_stride
,
_1
{}));
Tensor
gLSEaccum
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
ElementAccum
*>
(
Split
?
params
.
softmax_lseaccum_ptr
:
params
.
softmax_lse_ptr
)
+
(
row_offset_lseaccum
)),
Shape
<
Int
<
kBlockM
>>
{},
Stride
<
_1
>
{});
Tensor
lse
=
softmax
.
template
normalize_softmax_lse
<
/*Is_dropout=*/
false
,
Split
>(
acc_o
,
sRow_sum_reduce_buffer
,
params
.
scale_softmax
);
Tensor
caccO
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
// (BLK_M,BLK_K) -> (blk_m,blk_k)
Tensor
taccOcO
=
thr_mma_o
.
partition_C
(
caccO
);
if
(
get
<
1
>
(
taccOcO
(
0
))
==
0
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
lse
);
++
mi
)
{
const
int
row
=
get
<
0
>
(
taccOcO
(
0
,
mi
,
0
));
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
gLSEaccum
(
row
)
=
lse
(
mi
);
}
}
}
{
// using result_type = cutlass::Array<bfloat16_t, 2>;
// int tidx = threadIdx.x;
int
col
=
0
;
int
warpid
=
tidx
/
64
;
for
(
int
m
=
0
;
m
<
1
;
m
++
)
{
const
int
row
=
tidx
%
16
;
if
(
row
<
params
.
q_seq_per_hk
-
m_block
*
kBlockM
)
{
for
(
int
n
=
0
;
n
<
size
<
2
>
(
acc_o
);
n
++
)
{
col
=
(
tidx
%
64
/
16
)
+
warpid
*
32
+
n
*
128
;
for
(
int
ei
=
0
;
ei
<
8
;
ei
++
)
{
gOaccum
(
row
,
col
)
=
acc_o
(
ei
,
m
,
n
);
col
+=
4
;
}
}
}
}
}
}
}
template
<
typename
T
>
__global__
void
__launch_bounds__
(
T
::
NUM_THREADS
,
1
)
...
...
@@ -1199,9 +1557,15 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
if
(
batch_idx
>
sched_meta
.
begin_req_idx
)
{
__syncthreads
();
}
#if defined(__gfx928__)
compute_attn_1rowblock_splitkv_mla_gfx928
<
T
>
(
params
,
batch_idx
,
bidh
,
m_block
,
n_split_idx
,
seqlen_k
,
start_block_idx
,
end_block_idx
,
is_no_split
);
#else
compute_attn_1rowblock_splitkv_mla_gfx936
<
T
>
(
params
,
batch_idx
,
bidh
,
m_block
,
n_split_idx
,
seqlen_k
,
start_block_idx
,
end_block_idx
,
is_no_split
);
#endif
}
}
...
...
@@ -1209,6 +1573,7 @@ flash_fwd_splitkv_mla_kernel(const DenseAttnDecodeParams params) {
template
<
typename
T
,
bool
use_split_kv
=
false
>
__global__
void
__launch_bounds__
(
T
::
NUM_THREADS
,
1
)
flash_fwd_splitkv_mla_block_m_64_kernel
(
const
DenseAttnDecodeParams
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
constexpr
int
kBlockN
=
T
::
PAGE_BLOCK_SIZE
;
const
int
m_block
=
blockIdx
.
x
;
const
int
bidh
=
blockIdx
.
y
;
...
...
@@ -1908,7 +2273,7 @@ flash_fwd_splitkv_mla_block_m_64_kernel(const DenseAttnDecodeParams params) {
}
}
}
#endif
}
...
...
@@ -1918,7 +2283,7 @@ void run_flash_splitkv_mla_kernel(DenseAttnDecodeParams ¶ms) {
FLASH_ASSERT
(
params
.
d_v
==
Config
::
HEAD_DIM_V
);
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
params
.
h_q
>=
64
&&
params
.
h_k
==
1
)
{
if
(
params
.
h_q
>=
64
&&
params
.
h_k
==
1
&&
!
params
.
is_gfx928
)
{
using
T
=
Traits_Block_M_64
<
InputT
,
Is_causal
>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
if
(
params
.
use_split_kv
)
...
...
csrc/gfx93/decode/dense/traits.h
View file @
92a05388
...
...
@@ -102,6 +102,14 @@ struct Traits {
GmemLayoutAtomQ
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
using
GmemLayoutAtomK
=
Layout
<
Shape
<
_64
,
_4
>
,
Stride
<
_4
,
_1
>>
;
using
GmemTiledCopyK
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomK
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
using
GmemTiledCopyV
=
GmemTiledCopyK
;
struct
SharedMemoryPlan
{
...
...
csrc/gfx93/decode/sparse_fp8/splitkv_mla.cuh
View file @
92a05388
...
...
@@ -969,7 +969,9 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::devfunc(const SparseAttnD
template
<
typename
Kernel
>
__global__
void
__launch_bounds__
(
Kernel
::
NUM_THREADS
,
1
)
flash_fwd_splitkv_mla_fp8_sparse_kernel
(
const
SparseAttnDecodeParams
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
Kernel
::
devfunc
(
params
);
#endif
}
template
<
ModelType
MODEL_TYPE
,
int
NUM_HEADS
>
...
...
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
92a05388
...
...
@@ -1287,9 +1287,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
template
<
typename
Kernel
>
__global__
void
__launch_bounds__
(
Kernel
::
NUM_THREADS
,
1
)
sparse_attn_fwd_kernel
(
const
SparseAttnFwdParams
params
)
{
//
#if defined(__gfx936__)
#if defined(__gfx936__)
|| defined(__gfx938__)
Kernel
::
devfunc
(
params
);
//
#endif
#endif
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
...
...
csrc/params.h
View file @
92a05388
...
...
@@ -61,7 +61,7 @@ struct DenseAttnDecodeParams { // TODO Change name to DenseAttnDecodeParams
bool
use_split_kv
;
int
partition_block_nums
;
bool
is_gfx928
;
};
struct
DenseAttnDecodeParams_fp8
:
public
DenseAttnDecodeParams
{
...
...
csrc/utils.h
View file @
92a05388
...
...
@@ -621,7 +621,17 @@ lds_direct_copy_for_prefill_sparse_mla(
:
);
}
template
<
class
SrcEngine
,
class
SrcLayout
>
CUTE_HOST_DEVICE
void
asm_ds_write
(
const
uint128_t
&
src
,
Tensor
<
SrcEngine
,
SrcLayout
>
&
dst
,
int
k_idx
)
{
uint128_t
*
d
=
reinterpret_cast
<
uint128_t
*>
(
&
dst
(
0
,
0
,
k_idx
));
d
[
0
]
=
src
;
}
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
...
...
setup.py
View file @
92a05388
...
...
@@ -31,7 +31,7 @@ def get_features_args():
def
get_arch_flags
():
arch_flags
=
[]
arch_flags
.
append
(
"--offload-arch=gfx938;gfx936"
)
arch_flags
.
append
(
"--offload-arch=gfx938;gfx936
;gfx928
"
)
return
arch_flags
# def get_nvcc_thread_args():
...
...
tests/test_flash_mla_dense_decoding.py
View file @
92a05388
...
...
@@ -140,6 +140,9 @@ def reference_torch(
out_ref
=
out_ref
.
to
(
q
.
dtype
)
return
out_ref
,
lse_ref
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
@
torch
.
inference_mode
()
def
test_flash_mla
(
t
:
TestParam
):
...
...
@@ -166,6 +169,12 @@ def test_flash_mla(t: TestParam):
out_ans
,
lse_ans
=
run_flash_mla
()
out_ref
,
lse_ref
=
reference_torch
(
cache_seqlens
,
block_table
,
q
,
blocked_k
,
t
.
dv
,
t
.
is_causal
)
if
get_gcn_arch_name
()
==
"gfx928"
:
lse_abs_diff
=
(
lse_ans
-
lse_ref
).
max
().
abs
().
item
()
out_abs_diff
=
(
out_ref
-
out_ans
).
max
().
abs
().
item
()
print
(
"lse_abs_diff "
,
lse_abs_diff
,
out_abs_diff
)
assert
out_abs_diff
<=
4e-3
else
:
is_correct
=
True
is_correct
&=
kk
.
check_is_allclose
(
"out"
,
out_ans
,
out_ref
,
abs_tol
=
8e-4
,
rel_tol
=
2.01
/
128
,
cos_diff_tol
=
5e-6
)
is_correct
&=
kk
.
check_is_allclose
(
"lse"
,
lse_ans
,
lse_ref
,
abs_tol
=
1e-6
,
rel_tol
=
8.01
/
65536
)
...
...
tests/test_flash_mla_fp8.py
View file @
92a05388
...
...
@@ -214,7 +214,14 @@ def main(torch_dtype, is_prof=False):
for
varlen
in
[
False
]:
test_flash_mla_fp8_e5m2
(
b
,
s_q
,
s
,
h_q
,
h_kv
,
d
,
dv
,
causal
,
varlen
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
if
__name__
==
"__main__"
:
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
...
...
tests/test_flash_mla_qkvfp8.py
View file @
92a05388
...
...
@@ -175,9 +175,14 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
(
torch_dtype
,
is_prof
=
False
):
if
get_gcn_arch_name
()
!=
"gfx938"
:
print
(
"[WARNING] The architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
init_dtype
=
torch
.
bfloat16
if
torch_dtype
==
torch
.
float8_e4m3fn
else
torch_dtype
torch
.
set_default_dtype
(
init_dtype
)
...
...
tests/test_flash_mla_qkvfp8_with_cat.py
View file @
92a05388
...
...
@@ -252,8 +252,14 @@ def main(torch_dtype, is_prof=False):
# '''
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
if
__name__
==
"__main__"
:
if
get_gcn_arch_name
()
!=
"gfx938"
:
print
(
"[WARNING] The architecture is not supported."
)
exit
(
0
)
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
"--dtype"
,
...
...
tests/test_flash_mla_sparse_decoding.py
View file @
92a05388
...
...
@@ -232,8 +232,14 @@ def test_flash_mla(p: TestParam) -> Result:
performance_result
.
is_correct
=
is_correct
return
performance_result
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
():
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
dtype
=
torch
.
bfloat16
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
dtype
)
...
...
tests/test_flash_mla_sparse_prefill.py
View file @
92a05388
...
...
@@ -51,8 +51,14 @@ def run_test(p: TestParam) -> bool:
else
:
return
True
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
if
__name__
==
'__main__'
:
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch
.
bfloat16
)
torch
.
set_default_device
(
device
)
...
...
tests/test_flash_mla_with_q_concat.py
View file @
92a05388
...
...
@@ -141,8 +141,14 @@ def test_flash_mla(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, is_prof=Fa
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
(
torch_dtype
,
is_prof
=
False
):
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch_dtype
)
torch
.
set_default_device
(
device
)
...
...
tests/test_flash_mla_with_q_concat_fp8.py
View file @
92a05388
...
...
@@ -168,7 +168,14 @@ def test_flash_mla_fp8_e5m2(b, s_q, mean_sk, h_q, h_kv, d, dv, causal, varlen, i
f
"
{
t
:.
3
f
}
ms,
{
FLOPS
/
10
**
9
/
t
:.
0
f
}
TFLOPS,
{
bytes
/
10
**
6
/
t
:.
0
f
}
GB/s"
)
def
get_gcn_arch_name
()
->
str
:
GPU_ARCH
=
torch
.
cuda
.
get_device_properties
(
"cuda"
).
gcnArchName
return
GPU_ARCH
.
split
(
':'
)[
0
]
def
main
(
torch_dtype
,
is_prof
=
False
):
if
get_gcn_arch_name
()
==
"gfx928"
:
print
(
"[WARNING] gfx928 architecture is not supported."
)
exit
(
0
)
device
=
torch
.
device
(
"cuda:0"
)
torch
.
set_default_dtype
(
torch_dtype
)
torch
.
set_default_device
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment