Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
79c06a56
Commit
79c06a56
authored
Mar 07, 2026
by
zhanghj2
Browse files
优化nmz buffer load提升fp8性能
parent
276a1fb7
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
53 deletions
+56
-53
csrc/extension/utils.h
csrc/extension/utils.h
+56
-53
No files found.
csrc/extension/utils.h
View file @
79c06a56
...
...
@@ -591,11 +591,11 @@ buffer_load_copy_qkvfp8(
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr
.
latter
|=
((
row_stride
)
<<
16
);
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
2
]
=
!
Is_even_MN
?
max_MN
:
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
...
...
@@ -603,19 +603,20 @@ buffer_load_copy_qkvfp8(
int
col
=
lane
/
16
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = col_offset;
if
constexpr
(
use_asm
)
{
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
//
asm volatile(
//
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
//
" \n\t" :"=v"(dst),
//
"+v"(offset_v), "+s"(global_addr)
//
);
}
else
{
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset
_v
,
false
,
false
);
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
row_offset
,
col_
offset
,
false
,
false
);
}
}
else
{
...
...
@@ -634,11 +635,13 @@ buffer_load_copy_qkvfp8(
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr
.
latter
|=
((
row_stride
)
<<
16
);
constexpr
int
elements_per_thread
=
16
;
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
2
]
=
!
Is_even_MN
?
max_MN
:
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
...
...
@@ -646,10 +649,10 @@ buffer_load_copy_qkvfp8(
int
col
=
lane
%
4
;
int
row_offset
=
row
+
(
warp_id
*
16
);
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
//
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset
_v
,
false
,
false
);
//
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
row_offset
,
col_
offset
,
false
,
false
);
}
}
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment