Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
9687a0a3
Commit
9687a0a3
authored
May 29, 2026
by
zhanghj2
Browse files
add __builtin_amdgcn_sched_barrier(0);
parent
142846b5
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
26 additions
and
8 deletions
+26
-8
csrc/utils.h
csrc/utils.h
+26
-8
No files found.
csrc/utils.h
View file @
9687a0a3
...
@@ -33,9 +33,11 @@
...
@@ -33,9 +33,11 @@
#define FLASH_DEVICE_ASSERT(cond) \
#define FLASH_DEVICE_ASSERT(cond) \
do { \
do { \
if (not (cond)) { \
if (not (cond)) {
__builtin_amdgcn_sched_barrier
(
0
);
\
printf
(
"Assertion failed (%s:%d): %s
\n
"
,
__FILE__
,
__LINE__
,
#
cond
);
\
printf
(
"Assertion failed (%s:%d): %s
\n
"
,
__FILE__
,
__LINE__
,
#
cond
);
\
asm volatile("s_trap 0 \n\t"); \
asm
volatile
(
"s_trap 0
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
\
}
\
}
\
}
while
(
0
)
}
while
(
0
)
...
@@ -477,14 +479,14 @@ lds_direct_copy(
...
@@ -477,14 +479,14 @@ lds_direct_copy(
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
...
@@ -541,13 +543,14 @@ lds_direct_copy(
...
@@ -541,13 +543,14 @@ lds_direct_copy(
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
}
...
@@ -613,12 +616,14 @@ lds_direct_copy_for_prefill_sparse_mla(
...
@@ -613,12 +616,14 @@ lds_direct_copy_for_prefill_sparse_mla(
uint32x2_t
index_offset
=
{
0
};
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
==
-
1
?
max_MN
:
row_offset
;
index_offset
[
0
]
=
row_offset
==
-
1
?
max_MN
:
row_offset
;
index_offset
[
1
]
=
offset_v
;
index_offset
[
1
]
=
offset_v
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
template
<
...
@@ -681,11 +686,13 @@ buffer_load_copy(
...
@@ -681,11 +686,13 @@ buffer_load_copy(
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -712,11 +719,13 @@ buffer_load_copy(
...
@@ -712,11 +719,13 @@ buffer_load_copy(
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -850,13 +859,14 @@ lds_direct_copy_qkvfp8(
...
@@ -850,13 +859,14 @@ lds_direct_copy_qkvfp8(
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -916,7 +926,7 @@ lds_direct_copy_qkvfp8(
...
@@ -916,7 +926,7 @@ lds_direct_copy_qkvfp8(
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
//int ldsAddrPerWave = reinterpret_cast<size_t>(dst.data().get()) + warp_id * bytes_per_warp + (k_idx % 2) * mma_k * element_size;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
(
k_idx
)
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
(
k_idx
)
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
#if defined(__gfx938__)
#if defined(__gfx938__)
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
...
@@ -925,6 +935,7 @@ lds_direct_copy_qkvfp8(
...
@@ -925,6 +935,7 @@ lds_direct_copy_qkvfp8(
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
#endif
#endif
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
}
...
@@ -978,11 +989,13 @@ buffer_load_copy_qkvfp8(
...
@@ -978,11 +989,13 @@ buffer_load_copy_qkvfp8(
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -1074,7 +1087,7 @@ lds_direct_copy_fp8(
...
@@ -1074,7 +1087,7 @@ lds_direct_copy_fp8(
// {
// {
// printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size);
// printf("offset_v = %d %d \n", offset_v, warp_id * bytes_per_warp + k_idx * mma_k * element_size);
// }
// }
__builtin_amdgcn_sched_barrier
(
0
);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
...
@@ -1082,6 +1095,7 @@ lds_direct_copy_fp8(
...
@@ -1082,6 +1095,7 @@ lds_direct_copy_fp8(
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
}
}
...
@@ -1137,9 +1151,11 @@ __forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) {
...
@@ -1137,9 +1151,11 @@ __forceinline__ __device__ cutlass::half_t fp8e5m2_to_fp16(const fp8& input) {
template
<
int
N
>
template
<
int
N
>
CUTE_HOST_DEVICE
CUTE_HOST_DEVICE
void
wait_vmcnt
()
{
void
wait_vmcnt
()
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(%0) ;
\n\t
"
asm
volatile
(
"s_waitcnt vmcnt(%0) ;
\n\t
"
"s_barrier;
\n\t
"
"s_barrier;
\n\t
"
::
"n"
(
N
));
::
"n"
(
N
));
__builtin_amdgcn_sched_barrier
(
0
);
}
}
#if 0
#if 0
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
template<typename Element, bool is_scale_equal_one, Fp8KVCacheDataType KV_DTYPE, typename Tensor0, typename Tensor1, typename Tensor2, typename Tensor3, typename Tensor4,
...
@@ -1412,11 +1428,13 @@ buffer_load_copy_fp8(
...
@@ -1412,11 +1428,13 @@ buffer_load_copy_fp8(
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment