Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
40f4bf39
Commit
40f4bf39
authored
Mar 12, 2026
by
zhanghj2
Browse files
优化nmz fp8 tp1
parent
5e577dee
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
11 additions
and
8 deletions
+11
-8
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+4
-2
csrc/extension/utils.h
csrc/extension/utils.h
+7
-6
No files found.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
40f4bf39
...
...
@@ -1555,7 +1555,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
buffer_load_copy_fp8_tp1
<
true
,
true
,
1
>
(
gK
,
kv_data
[
1
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
true
,
2
>
(
gK
,
kv_data
[
2
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
true
,
3
>
(
gK
,
kv_data
[
3
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
buffer_load_copy_fp8_tp1
<
true
,
false
,
4
>
(
gK
,
kv_data
[
4
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
if
(
warp_id
<
4
)
buffer_load_copy_fp8_tp1
<
true
,
false
,
4
>
(
gK
,
kv_data
[
4
].
data
,
params
.
k_row_stride
,
seqlen_k
-
block_idx
*
kBlockN
);
}
for
(
int
n
=
0
;
n
<
4
;
n
++
)
...
...
@@ -1598,7 +1599,8 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
kv_lds_write_ptr
+=
64
*
128
;
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
3
].
data
;
kv_lds_write_ptr
+=
64
*
128
;
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
4
].
data
;
if
(
warp_id
<
4
)
*
(
reinterpret_cast
<
intx4_t
*>
(
kv_lds_write_ptr
))
=
kv_data
[
4
].
data
;
}
// asm volatile("s_barrier \n\t");
};
...
...
csrc/extension/utils.h
View file @
40f4bf39
...
...
@@ -2809,11 +2809,12 @@ buffer_load_copy_fp8_tp1(
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr
.
latter
|=
((
row_stride
)
<<
16
);
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
2
]
=
!
Is_even_MN
?
max_MN
:
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
...
...
@@ -2821,12 +2822,12 @@ buffer_load_copy_fp8_tp1(
int
col
=
lane
%
4
;
int
row_offset
=
row
+
((
warp_id
%
4
)
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
128
+
(
warp_id
/
4
)
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
//
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
//
if (!Is_even_K && col_offset >=576) offset_v = -1;
//
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
{
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset
_v
,
false
,
false
);
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
row_offset
,
col_
offset
,
false
,
false
);
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment