Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2ff340aa
Commit
2ff340aa
authored
Feb 28, 2026
by
zhanghj2
Browse files
优化tp8 nmz 代码
parent
4d897ed1
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
29 deletions
+56
-29
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+56
-29
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
2ff340aa
...
@@ -547,6 +547,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -547,6 +547,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
,
SharedStorage
&
shared_storage
,
const
float
descale_k
,
const
float
scale_softmax
,
const
float
scale_softmax_log2
)
{
SharedStorage
&
shared_storage
,
const
float
descale_k
,
const
float
scale_softmax
,
const
float
scale_softmax_log2
)
{
if
(
n_block_max
<=
n_block_min
)
{
return
;
}
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
...
@@ -872,16 +875,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -872,16 +875,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c3_0
.
x
=
0.0
f
;
c3_0
.
y
=
0.0
f
;
c3_0
.
z
=
0.0
f
;
c3_0
.
w
=
0.0
f
;
c3_0
.
x
=
0.0
f
;
c3_0
.
y
=
0.0
f
;
c3_0
.
z
=
0.0
f
;
c3_0
.
w
=
0.0
f
;
c3_1
.
x
=
0.0
f
;
c3_1
.
y
=
0.0
f
;
c3_1
.
z
=
0.0
f
;
c3_1
.
w
=
0.0
f
;
c3_1
.
x
=
0.0
f
;
c3_1
.
y
=
0.0
f
;
c3_1
.
z
=
0.0
f
;
c3_1
.
w
=
0.0
f
;
// #pragma unroll
extern
__shared__
char
shared_memory
[];
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
struct
IsMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsNoMaskBlock
{};
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
static
constexpr
bool
IS_NO_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
clear
(
acc_s
);
clear
(
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
int
cur_block_table
;
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_
block
;
const
int
*
cur_block_table_ptr
=
block_table
+
block
_idx
;
// cur_block_table = block_table[
n_
block - 1];
// cur_block_table = block_table[block
_idx
- 1];
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
...
@@ -889,17 +899,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -889,17 +899,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
#if 1
#if 1
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
5
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
5
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
constexpr
static
int
BUFFER_SIZE
=
1
;
constexpr
static
int
BUFFER_SIZE
=
1
;
uint128_t
buffer
[
BUFFER_SIZE
];
uint128_t
buffer
[
BUFFER_SIZE
];
buffer_load_copy_qkvfp8
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
8
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_
block
*
kBlockN
);
buffer_load_copy_qkvfp8
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
8
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
block
_idx
*
kBlockN
);
asm
volatile
(
"s_waitcnt vmcnt(8)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(8)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
...
@@ -937,21 +947,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -937,21 +947,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
#else
#else
#endif
#endif
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
Is_causal
)
{
if
constexpr
(
!
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
n_block
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
block_idx
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
}
else
{
// Ensure seqlen_k - 1 - (n_block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// Ensure seqlen_k - 1 - (block_idx * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - n_block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block_idx * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
n_block
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
int
col_limit_right
=
seqlen_k
-
1
-
block_idx
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
}
}
}
// We have key_padding_mask so we'll need to Check_inf
// We have key_padding_mask so we'll need to Check_inf
// if constexpr (n_masking_steps == 1)
// if constexpr (n_masking_steps == 1)
// {
// {
...
@@ -959,10 +971,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -959,10 +971,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
// }
// }
// else
// else
{
{
const
bool
is_first_masking_step
=
masking_step
==
0
;
is_first_masking_step
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
);
?
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
)
:
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
);
}
}
...
@@ -1025,7 +1035,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -1025,7 +1035,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
};
if
constexpr
(
n_masking_steps
==
1
)
{
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
n_block
--
;
}
else
{
int
masking_step
=
1
;
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
n_block
--
;
for
(;
n_block
>=
n_block_min
&&
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
process_one_block
(
n_block
,
IsMaskBlock
{});
}
}
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
process_one_block
(
n_block
,
IsNoMaskBlock
{});
}
}
#endif
#endif
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
acc_o
(
0
,
0
,
0
)
=
c0_0
.
x
;
acc_o
(
1
,
0
,
0
)
=
c0_0
.
y
;
acc_o
(
2
,
0
,
0
)
=
c0_0
.
z
;
acc_o
(
3
,
0
,
0
)
=
c0_0
.
w
;
acc_o
(
0
,
0
,
0
)
=
c0_0
.
x
;
acc_o
(
1
,
0
,
0
)
=
c0_0
.
y
;
acc_o
(
2
,
0
,
0
)
=
c0_0
.
z
;
acc_o
(
3
,
0
,
0
)
=
c0_0
.
w
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment