Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
2ff340aa
"wrappers/python/setup.py" did not exist on "abb19052327c3fbf8eb9bd41be1d027d30314d07"
Commit
2ff340aa
authored
Feb 28, 2026
by
zhanghj2
Browse files
优化tp8 nmz 代码
parent
4d897ed1
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
29 deletions
+56
-29
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+56
-29
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
2ff340aa
...
...
@@ -547,6 +547,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const
int
n_split_idx
,
const
int
seqlen_k
,
const
int
n_block_min
,
const
int
n_block_max
,
const
bool
NoSplit
,
SharedStorage
&
shared_storage
,
const
float
descale_k
,
const
float
scale_softmax
,
const
float
scale_softmax_log2
)
{
if
(
n_block_max
<=
n_block_min
)
{
return
;
}
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
...
...
@@ -872,16 +875,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
c3_0
.
x
=
0.0
f
;
c3_0
.
y
=
0.0
f
;
c3_0
.
z
=
0.0
f
;
c3_0
.
w
=
0.0
f
;
c3_1
.
x
=
0.0
f
;
c3_1
.
y
=
0.0
f
;
c3_1
.
z
=
0.0
f
;
c3_1
.
w
=
0.0
f
;
// #pragma unroll
for
(
int
masking_step
=
0
;
n_block
>=
n_block_min
;
++
masking_step
,
--
n_block
)
{
extern
__shared__
char
shared_memory
[];
struct
IsMaskBlock
{};
struct
IsFirstMaskBlock
{};
struct
IsNoMaskBlock
{};
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_mask_block_t
)
{
static
constexpr
bool
IS_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
static
constexpr
bool
IS_FIRST_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsFirstMaskBlock
>
;
static
constexpr
bool
IS_NO_MASK_BLOCK
=
std
::
is_same_v
<
decltype
(
is_mask_block_t
),
IsNoMaskBlock
>
;
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tOrVt_copy_view
=
smem_thr_copy_V
.
retile_D
(
tOrVt
);
clear
(
acc_s
);
// asm volatile("s_barrier\n\t");
Tensor
tSrK_copy_view
=
smem_thr_copy_K
.
retile_D
(
tSrK
);
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_
block
;
// cur_block_table = block_table[
n_
block - 1];
const
int
*
cur_block_table_ptr
=
block_table
+
block
_idx
;
// cur_block_table = block_table[block
_idx
- 1];
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
...
...
@@ -889,17 +899,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
#if 1
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
5
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
n_
block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
1
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
2
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
3
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
4
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
5
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
block
_idx
*
kBlockN
);
constexpr
static
int
BUFFER_SIZE
=
1
;
uint128_t
buffer
[
BUFFER_SIZE
];
buffer_load_copy_qkvfp8
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
8
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_
block
*
kBlockN
);
buffer_load_copy_qkvfp8
<
false
,
true
,
true
,
true
>
(
gK
,
buffer
[
0
],
8
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
block
_idx
*
kBlockN
);
asm
volatile
(
"s_waitcnt vmcnt(8)
\n\t
s_barrier
\n\t
"
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
...
...
@@ -937,20 +947,22 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
#else
#endif
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
if
constexpr
(
!
IS_NO_MASK_BLOCK
)
{
Tensor
cS
=
make_identity_tensor
(
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
tScS
=
thr_mma
.
partition_C
(
cS
);
for
(
int
i
=
0
;
i
<
size
(
acc_s
);
++
i
)
{
if
constexpr
(
!
Is_causal
)
{
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
n_
block
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>=
int
(
seqlen_k
-
block
_idx
*
kBlockN
))
acc_s
(
i
)
=
-
INFINITY
;
}
else
{
// Ensure seqlen_k - 1 - (
n_
block * kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 -
n_
block * kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// Ensure seqlen_k - 1 - (block
_idx
* kBlockN + col) >= (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
// col <= seqlen_k - 1 - block
_idx
* kBlockN - (seqlen_q - 1 - (m_block * kBlockM + row)) / ngroups
int
row
=
int
(
get
<
0
>
(
tScS
(
i
)));
int
col_limit_right
=
seqlen_k
-
1
-
n_
block
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
int
col_limit_right
=
seqlen_k
-
1
-
block
_idx
*
kBlockN
-
(
params
.
seqlen_q
-
1
-
(
m_block
*
kBlockM
+
row
))
/
params
.
ngroups
;
if
(
int
(
get
<
1
>
(
tScS
(
i
)))
>
col_limit_right
)
acc_s
(
i
)
=
-
INFINITY
;
}
}
}
// We have key_padding_mask so we'll need to Check_inf
// if constexpr (n_masking_steps == 1)
...
...
@@ -959,10 +971,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
// }
// else
{
const
bool
is_first_masking_step
=
masking_step
==
0
;
is_first_masking_step
?
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
true
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
)
:
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
false
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
);
softmax
.
template
softmax_rescale_o_fp8
<
/*Is_first=*/
IS_FIRST_MASK_BLOCK
,
/*Check_inf=*/
Is_causal
>(
acc_s
,
sRow_max_reduce_buffer
,
scale_softmax_log2
,
c0_0
,
c0_1
,
c1_0
,
c1_1
,
c2_0
,
c2_1
,
c3_0
,
c3_1
);
}
...
...
@@ -1025,7 +1035,24 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
__builtin_amdgcn_sched_barrier
(
0
);
}
};
if
constexpr
(
n_masking_steps
==
1
)
{
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
n_block
--
;
}
else
{
int
masking_step
=
1
;
process_one_block
(
n_block
,
IsFirstMaskBlock
{});
n_block
--
;
for
(;
n_block
>=
n_block_min
&&
masking_step
<
n_masking_steps
;
++
masking_step
,
--
n_block
)
{
process_one_block
(
n_block
,
IsMaskBlock
{});
}
}
for
(;
n_block
>=
n_block_min
;
--
n_block
)
{
process_one_block
(
n_block
,
IsNoMaskBlock
{});
}
#endif
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{});
acc_o
(
0
,
0
,
0
)
=
c0_0
.
x
;
acc_o
(
1
,
0
,
0
)
=
c0_0
.
y
;
acc_o
(
2
,
0
,
0
)
=
c0_0
.
z
;
acc_o
(
3
,
0
,
0
)
=
c0_0
.
w
;
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment