Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
79c06a56
Commit
79c06a56
authored
Mar 07, 2026
by
zhanghj2
Browse files
优化nmz buffer load提升fp8性能
parent
276a1fb7
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
56 additions
and
53 deletions
+56
-53
csrc/extension/utils.h
csrc/extension/utils.h
+56
-53
No files found.
csrc/extension/utils.h
View file @
79c06a56
...
@@ -591,11 +591,11 @@ buffer_load_copy_qkvfp8(
...
@@ -591,11 +591,11 @@ buffer_load_copy_qkvfp8(
PtrWrapper
glob_ptr
;
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr
.
latter
|=
((
row_stride
)
<<
16
);
uint32x4_t
global_addr
=
{
0
};
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
2
]
=
!
Is_even_MN
?
max_MN
:
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
mma_k
=
32
*
64
;
...
@@ -603,19 +603,20 @@ buffer_load_copy_qkvfp8(
...
@@ -603,19 +603,20 @@ buffer_load_copy_qkvfp8(
int
col
=
lane
/
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
row_offset
=
row
+
(
warp_id
*
16
)
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
// int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
// if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
// uint32x2_t index_offset = {0};
// index_offset[0] = row_offset;
// index_offset[1] = col_offset;
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
asm
volatile
(
//
asm volatile(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
//
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0 \n"
"
\n\t
"
:
"=v"
(
dst
),
//
" \n\t" :"=v"(dst),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
//
"+v"(offset_v), "+s"(global_addr)
);
//
);
}
}
else
{
else
{
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset
_v
,
false
,
false
);
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
row_offset
,
col_
offset
,
false
,
false
);
}
}
}
else
{
}
else
{
...
@@ -634,11 +635,13 @@ buffer_load_copy_qkvfp8(
...
@@ -634,11 +635,13 @@ buffer_load_copy_qkvfp8(
PtrWrapper
glob_ptr
;
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
glob_ptr
.
latter
|=
((
row_stride
)
<<
16
);
constexpr
int
elements_per_thread
=
16
;
constexpr
int
elements_per_thread
=
16
;
uint32x4_t
global_addr
=
{
0
};
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
2
]
=
!
Is_even_MN
?
max_MN
:
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
global_addr
[
3
]
=
0x00020000
;
int
mma_k
=
32
*
64
;
int
mma_k
=
32
*
64
;
...
@@ -646,10 +649,10 @@ buffer_load_copy_qkvfp8(
...
@@ -646,10 +649,10 @@ buffer_load_copy_qkvfp8(
int
col
=
lane
%
4
;
int
col
=
lane
%
4
;
int
row_offset
=
row
+
(
warp_id
*
16
);
int
row_offset
=
row
+
(
warp_id
*
16
);
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
col_offset
=
col
*
elements_per_thread
+
k_idx
*
64
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
//
int offset_v = (row_offset * row_stride + col_offset) * element_size; // bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
//
if (!Is_even_MN && row_offset >= max_MN) offset_v = -1;
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset
_v
,
false
,
false
);
dst
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
row_offset
,
col_
offset
,
false
,
false
);
}
}
}
}
...
@@ -737,48 +740,48 @@ lds_direct_copy_qkvfp8(
...
@@ -737,48 +740,48 @@ lds_direct_copy_qkvfp8(
if
constexpr
(
Is_load_Q
)
{
if
constexpr
(
Is_load_Q
)
{
constexpr
int
warp_size
=
64
;
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
1
;
constexpr
int
element_size
=
1
;
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
int
k_idx
=
__builtin_amdgcn_readfirstlane
(
k_idx_
);
const
int
offset_s
=
0
;
const
int
offset_s
=
0
;
struct
PtrWrapper
{
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
former
;
uint32_t
latter
;
uint32_t
latter
;
};
};
PtrWrapper
glob_ptr
;
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
// glob_ptr.latter |= 0x40000000; // 62 bit: cache swizzle; 48~61: Stride
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
16
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
int
mma_k
=
16
*
256
;
uint32x4_t
global_addr
=
{
0
};
int
row
=
lane
%
16
;
global_addr
[
0
]
=
(
glob_ptr
.
former
);
int
col
=
lane
/
16
;
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
int
row_offset
=
row
;
global_addr
[
3
]
=
0x00020000
;
int
col_offset
=
(
col
+
warp_id
*
4
)
*
elements_per_thread
+
k_idx
*
256
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
constexpr
int
elements_per_thread
=
16
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
int
mma_k
=
16
*
256
;
int
row
=
lane
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
;
int
col_offset
=
(
col
+
warp_id
*
4
)
*
elements_per_thread
+
k_idx
*
256
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
asm
volatile
(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
asm
volatile
(
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s_mov_b32 m0, %1
\n\t
"
:
);
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment