Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
142846b5
Commit
142846b5
authored
May 27, 2026
by
zhanghj2
Browse files
fix精度问题
parent
a9e4de8d
Changes
5
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
499 additions
and
7 deletions
+499
-7
csrc/extension/flash_fwd_mla_kernel.h
csrc/extension/flash_fwd_mla_kernel.h
+214
-0
csrc/extension/flash_fwd_mla_kernel_fp8.h
csrc/extension/flash_fwd_mla_kernel_fp8.h
+66
-0
csrc/extension/utils.h
csrc/extension/utils.h
+40
-0
csrc/gfx93/decode/dense/splitkv_mla.cuh
csrc/gfx93/decode/dense/splitkv_mla.cuh
+112
-7
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+67
-0
No files found.
csrc/extension/flash_fwd_mla_kernel.h
View file @
142846b5
This diff is collapsed.
Click to expand it.
csrc/extension/flash_fwd_mla_kernel_fp8.h
View file @
142846b5
...
@@ -642,17 +642,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -642,17 +642,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ
,
sQ
,
0
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ
,
sQ
,
0
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ
,
sQ
,
1
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ
,
sQ
,
1
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
__syncthreads
();
__syncthreads
();
}
}
...
@@ -708,12 +714,14 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -708,12 +714,14 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
const
int
offset_s
=
0
;
const
int
offset_s
=
0
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
s_q
)
+
warp_idx
*
bytes_per_warp
+
k_idx
*
bytes_per_block
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
s_q
)
+
warp_idx
*
bytes_per_warp
+
k_idx
*
bytes_per_block
+
offset_k
*
3
*
bytes_per_block
;
+
offset_k
*
3
*
bytes_per_block
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
};
};
lds_direct_copy_q
(
0
,
0
);
lds_direct_copy_q
(
0
,
0
);
...
@@ -723,7 +731,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -723,7 +731,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
lds_direct_copy_q
(
2
,
0
);
lds_direct_copy_q
(
2
,
0
);
ElementQ
*
s_q_read_ptr
=
s_q
+
lane_idx
*
8
;
ElementQ
*
s_q_read_ptr
=
s_q
+
lane_idx
*
8
;
Fp8_storage
bf16_data
;
Fp8_storage
bf16_data
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
float
fp32
[
8
];
float
fp32
[
8
];
union
Fp8_temp
{
union
Fp8_temp
{
int32_t
data
;
int32_t
data
;
...
@@ -747,7 +757,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -747,7 +757,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
}
}
s_q_read_ptr
+=
16
*
32
;
s_q_read_ptr
+=
16
*
32
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
@@ -766,7 +778,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -766,7 +778,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
}
}
s_q_read_ptr
+=
16
*
32
;
s_q_read_ptr
+=
16
*
32
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
s_q_read_ptr
=
s_q
+
lane_idx
*
8
+
3
*
4
*
16
*
4
*
8
;
s_q_read_ptr
=
s_q
+
lane_idx
*
8
+
3
*
4
*
16
*
4
*
8
;
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
...
@@ -786,7 +800,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -786,7 +800,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
}
}
s_q_read_ptr
+=
16
*
32
;
s_q_read_ptr
+=
16
*
32
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
...
@@ -805,7 +821,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -805,7 +821,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
}
}
s_q_read_ptr
+=
16
*
32
;
s_q_read_ptr
+=
16
*
32
;
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
s_q_read_ptr
=
s_q
+
lane_idx
*
8
+
2
*
4
*
16
*
4
*
8
;
s_q_read_ptr
=
s_q
+
lane_idx
*
8
+
2
*
4
*
16
*
4
*
8
;
for
(
int
k
=
8
;
k
<
9
;
k
++
)
{
for
(
int
k
=
8
;
k
<
9
;
k
++
)
{
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
bf16_data
.
data
=
*
reinterpret_cast
<
intx4_t
*>
(
s_q_read_ptr
);
...
@@ -848,17 +866,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -848,17 +866,23 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ_nope
,
sQ
,
0
,
params
.
q_nope_head_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ_nope
,
sQ
,
0
,
params
.
q_nope_head_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ_nope
,
sQ
,
1
,
params
.
q_nope_head_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8
<
false
,
true
,
true
>
(
gQ_nope
,
sQ
,
1
,
params
.
q_nope_head_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_pe
<
false
,
false
,
true
>
(
gQ_pe
,
sQ
,
2
,
params
.
q_pe_head_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_pe
<
false
,
false
,
true
>
(
gQ_pe
,
sQ
,
2
,
params
.
q_pe_head_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
__syncthreads
();
__syncthreads
();
}
}
...
@@ -968,7 +992,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
...
@@ -968,7 +992,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938(co
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// if (thread0()) {
// if (thread0()) {
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
// printf(" %.2f %.2f %.2f %.2f \n", acc_s(0), acc_s(1), acc_s(2), acc_s(3));
// }
// }
...
@@ -1252,26 +1278,36 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1252,26 +1278,36 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
}
}
auto
q_lds_read_ptr
=
sQ
.
data
().
get
()
+
(
warp_id
%
4
)
*
16
*
64
;
auto
q_lds_read_ptr
=
sQ
.
data
().
get
()
+
(
warp_id
%
4
)
*
16
*
64
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
q_r
[
0
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
0
,
3
,
1
,
0
);
q_r
[
0
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
0
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
1
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
64
*
64
,
3
,
1
,
0
);
q_r
[
1
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
64
*
64
,
3
,
1
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
2
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
2
*
64
*
64
,
3
,
1
,
0
);
q_r
[
2
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
2
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
3
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
3
*
64
*
64
,
3
,
1
,
0
);
q_r
[
3
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
3
*
64
*
64
,
3
,
1
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
4
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
4
*
64
*
64
,
3
,
1
,
0
);
q_r
[
4
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
4
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
5
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
5
*
64
*
64
,
3
,
1
,
0
);
q_r
[
5
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
5
*
64
*
64
,
3
,
1
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
6
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
6
*
64
*
64
,
3
,
1
,
0
);
q_r
[
6
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
6
*
64
*
64
,
3
,
1
,
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
7
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
7
*
64
*
64
,
3
,
1
,
0
);
q_r
[
7
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
7
*
64
*
64
,
3
,
1
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// q_lds_read_ptr += 64 * 64;
// q_lds_read_ptr += 64 * 64;
q_r
[
8
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
8
*
64
*
64
,
3
,
1
,
0
);
q_r
[
8
].
data
=
__builtin_hcu_ds_read_matrix_trans_format_u8
((
__attribute__
((
address_space
(
3
)))
int
*
)(
q_lds_read_ptr
),
8
*
64
*
64
,
3
,
1
,
0
);
__syncthreads
();
__syncthreads
();
...
@@ -1939,7 +1975,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1939,7 +1975,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
lds_direct_copy_qkvfp8_q_tp4
<
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp4
<
false
,
true
>
(
gQ
,
sQ
,
2
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp4
<
false
,
true
>
(
gQ
,
sQ
,
3
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp4
<
false
,
true
>
(
gQ
,
sQ
,
3
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp4
<
false
,
false
>
(
gQ
,
sQ
,
4
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
lds_direct_copy_qkvfp8_q_tp4
<
false
,
false
>
(
gQ
,
sQ
,
4
,
params
.
q_row_stride
,
params
.
seqlen_q
-
m_block
*
kBlockM
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
uint8_t
*
q_lds_read_ptr
=
reinterpret_cast
<
uint8_t
*>
(
sQ
.
data
().
get
())
+
(
tidx
%
64
)
*
16
+
(
warp_id
%
2
)
*
(
16
*
64
);
uint8_t
*
q_lds_read_ptr
=
reinterpret_cast
<
uint8_t
*>
(
sQ
.
data
().
get
())
+
(
tidx
%
64
)
*
16
+
(
warp_id
%
2
)
*
(
16
*
64
);
{
{
int
k
=
0
;
int
k
=
0
;
...
@@ -1961,7 +1999,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1961,7 +1999,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
{
{
q_lds_read_ptr
+=
32
*
64
;
q_lds_read_ptr
+=
32
*
64
;
int
k
=
2
;
int
k
=
2
;
...
@@ -1984,7 +2024,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -1984,7 +2024,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
{
{
q_lds_read_ptr
+=
32
*
64
;
q_lds_read_ptr
+=
32
*
64
;
int
k
=
4
;
int
k
=
4
;
...
@@ -2007,7 +2049,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -2007,7 +2049,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
{
{
q_lds_read_ptr
+=
32
*
64
;
q_lds_read_ptr
+=
32
*
64
;
int
k
=
6
;
int
k
=
6
;
...
@@ -2030,7 +2074,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -2030,7 +2074,9 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// tSrQ_ptr = reinterpret_cast<intx4_t*>(&(tSrQ(0, 0, k + 1)));
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
// *tSrQ_ptr = *reinterpret_cast<intx4_t*>(q_lds_read_ptr);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
{
{
q_lds_read_ptr
+=
32
*
64
;
q_lds_read_ptr
+=
32
*
64
;
int
k
=
8
;
int
k
=
8
;
...
@@ -2092,10 +2138,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -2092,10 +2138,12 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
int
cur_block_table
;
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
// cur_block_table = block_table[n_block - 1];
// cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
"=s"
(
cur_block_table
));
__builtin_amdgcn_sched_barrier
(
0
);
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
0
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
...
@@ -2107,31 +2155,49 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
...
@@ -2107,31 +2155,49 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_gfx938_TP
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
6
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
7
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
8
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
lds_direct_copy_qkvfp8
<
false
,
true
>
(
gK
,
sK
,
8
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(8)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(8)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
),
tSrK
(
_
,
_
,
0
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
),
tSrK
(
_
,
_
,
0
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(7)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(7)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
),
tSrK
(
_
,
_
,
1
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
),
tSrK
(
_
,
_
,
1
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(6)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(6)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
2
),
tSrK_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
2
),
tSrK_copy_view
(
_
,
_
,
2
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
2
),
tSrK
(
_
,
_
,
2
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
2
),
tSrK
(
_
,
_
,
2
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(5)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(5)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
3
),
tSrK_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
3
),
tSrK_copy_view
(
_
,
_
,
3
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
3
),
tSrK
(
_
,
_
,
3
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
3
),
tSrK
(
_
,
_
,
3
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
4
),
tSrK_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
4
),
tSrK_copy_view
(
_
,
_
,
4
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
4
),
tSrK
(
_
,
_
,
4
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
4
),
tSrK
(
_
,
_
,
4
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
5
),
tSrK_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
5
),
tSrK_copy_view
(
_
,
_
,
5
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
5
),
tSrK
(
_
,
_
,
5
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
5
),
tSrK
(
_
,
_
,
5
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
6
),
tSrK_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
6
),
tSrK_copy_view
(
_
,
_
,
6
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
6
),
tSrK
(
_
,
_
,
6
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
6
),
tSrK
(
_
,
_
,
6
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
7
),
tSrK_copy_view
(
_
,
_
,
7
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
7
),
tSrK
(
_
,
_
,
7
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
8
),
tSrK_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
8
),
tSrK_copy_view
(
_
,
_
,
8
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
8
),
tSrK
(
_
,
_
,
8
),
acc_s
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
...
...
csrc/extension/utils.h
View file @
142846b5
...
@@ -300,9 +300,11 @@ __forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor<Engine0,
...
@@ -300,9 +300,11 @@ __forceinline__ __device__ void copy_k_idx(TiledCopy tiled_copy, Tensor<Engine0,
template
<
int
N
>
template
<
int
N
>
CUTE_HOST_DEVICE
CUTE_HOST_DEVICE
void
wait_vmcnt
()
{
void
wait_vmcnt
()
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(%0) ;
\n\t
"
asm
volatile
(
"s_waitcnt vmcnt(%0) ;
\n\t
"
"s_barrier;
\n\t
"
"s_barrier;
\n\t
"
::
"n"
(
N
));
::
"n"
(
N
));
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
template
<
...
@@ -377,11 +379,13 @@ buffer_load_copy(
...
@@ -377,11 +379,13 @@ buffer_load_copy(
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -408,11 +412,13 @@ buffer_load_copy(
...
@@ -408,11 +412,13 @@ buffer_load_copy(
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -473,11 +479,13 @@ buffer_load_copy_fp8(
...
@@ -473,11 +479,13 @@ buffer_load_copy_fp8(
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -542,11 +550,13 @@ buffer_load_copy_fp8x2(
...
@@ -542,11 +550,13 @@ buffer_load_copy_fp8x2(
if
constexpr
(
use_asm
)
{
if
constexpr
(
use_asm
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
offset_v
),
"+s"
(
global_addr
)
"+v"
(
offset_v
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
auto
res
=
__builtin_amdgcn_buffer_load_dwordx2
(
global_addr
,
0
,
offset_v
,
false
,
false
);
auto
res
=
__builtin_amdgcn_buffer_load_dwordx2
(
global_addr
,
0
,
offset_v
,
false
,
false
);
...
@@ -711,12 +721,14 @@ lds_direct_copy_qkvfp8_pe(
...
@@ -711,12 +721,14 @@ lds_direct_copy_qkvfp8_pe(
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -778,12 +790,14 @@ lds_direct_copy_qkvfp8(
...
@@ -778,12 +790,14 @@ lds_direct_copy_qkvfp8(
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -845,12 +859,14 @@ lds_direct_copy_qkvfp8(
...
@@ -845,12 +859,14 @@ lds_direct_copy_qkvfp8(
#if defined(__gfx938__)
#if defined(__gfx938__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
}
}
...
@@ -933,12 +949,14 @@ lds_direct_copy_fp8(
...
@@ -933,12 +949,14 @@ lds_direct_copy_fp8(
// }
// }
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
}
}
...
@@ -1005,12 +1023,14 @@ lds_direct_copy_tp1(
...
@@ -1005,12 +1023,14 @@ lds_direct_copy_tp1(
// {
// {
// printf(" %x \n", ldsAddrPerWave);
// printf(" %x \n", ldsAddrPerWave);
// }
// }
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
...
@@ -1186,12 +1206,14 @@ lds_direct_copy_sparse_k(
...
@@ -1186,12 +1206,14 @@ lds_direct_copy_sparse_k(
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
0
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
0
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -1255,12 +1277,14 @@ lds_direct_copy(
...
@@ -1255,12 +1277,14 @@ lds_direct_copy(
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
if
(
!
Is_even_K
&&
col_offset
>=
576
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -1319,12 +1343,14 @@ lds_direct_copy(
...
@@ -1319,12 +1343,14 @@ lds_direct_copy(
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
}
...
@@ -1409,12 +1435,14 @@ lds_direct_copy_for_prefill_sparse_mla(
...
@@ -1409,12 +1435,14 @@ lds_direct_copy_for_prefill_sparse_mla(
uint32x2_t
index_offset
=
{
0
};
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
==
-
1
?
max_MN
:
row_offset
;
index_offset
[
0
]
=
row_offset
==
-
1
?
max_MN
:
row_offset
;
index_offset
[
1
]
=
offset_v
;
index_offset
[
1
]
=
offset_v
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
...
@@ -1474,11 +1502,13 @@ buffer_load_copy_sparse_fp8(
...
@@ -1474,11 +1502,13 @@ buffer_load_copy_sparse_fp8(
uint32x2_t
index_offset
=
{
0
};
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
(
row_offset
+
64
)
%
64
;
index_offset
[
0
]
=
(
row_offset
+
64
)
%
64
;
index_offset
[
1
]
=
offset_v
;
index_offset
[
1
]
=
offset_v
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx2 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
index_offset
),
"+s"
(
global_addr
)
"+v"
(
index_offset
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
// auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 , offset_v, false, false);
// auto res = __builtin_amdgcn_buffer_load_dwordx2(global_addr, (row_offset + 64 ) % 64 , offset_v, false, false);
...
@@ -1555,11 +1585,13 @@ buffer_load_copy_sparse_decoding(
...
@@ -1555,11 +1585,13 @@ buffer_load_copy_sparse_decoding(
uint32x2_t
index_offset
=
{
0
};
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
(
row_offset
+
64
)
%
64
;
index_offset
[
0
]
=
(
row_offset
+
64
)
%
64
;
index_offset
[
1
]
=
offset_v
;
index_offset
[
1
]
=
offset_v
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"buffer_load_dwordx4 %0, %1, %2 ,0 offen offset:0
\n
"
"
\n\t
"
:
"=v"
(
dst
),
"
\n\t
"
:
"=v"
(
dst
),
"+v"
(
index_offset
),
"+s"
(
global_addr
)
"+v"
(
index_offset
),
"+s"
(
global_addr
)
);
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
else
{
// auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false);
// auto res = __builtin_amdgcn_buffer_load_dwordx4(global_addr, (row_offset + 64 ) % 64 + (batch_stride / row_stride) * block_idx , offset_v, false, false);
...
@@ -2303,12 +2335,14 @@ lds_direct_copy_qkvfp8_q_tp1(
...
@@ -2303,12 +2335,14 @@ lds_direct_copy_qkvfp8_q_tp1(
#if defined(__gfx938__)
#if defined(__gfx938__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
...
@@ -2362,12 +2396,14 @@ lds_direct_copy_qkvfp8_q_tp4(
...
@@ -2362,12 +2396,14 @@ lds_direct_copy_qkvfp8_q_tp4(
#if defined(__gfx938__)
#if defined(__gfx938__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
...
@@ -2444,12 +2480,14 @@ lds_direct_copy_qkvfp8_tp1(
...
@@ -2444,12 +2480,14 @@ lds_direct_copy_qkvfp8_tp1(
#if defined(__gfx938__)
#if defined(__gfx938__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
...
@@ -2792,12 +2830,14 @@ lds_direct_copy_qkvfp8_zero_lds(
...
@@ -2792,12 +2830,14 @@ lds_direct_copy_qkvfp8_zero_lds(
#if defined(__gfx938__)
#if defined(__gfx938__)
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
}
}
...
...
csrc/gfx93/decode/dense/splitkv_mla.cuh
View file @
142846b5
...
@@ -744,10 +744,13 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -744,10 +744,13 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int
cur_block_table
;
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
// cur_block_table = block_table[n_block - 1];
// cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
"=s"
(
cur_block_table
));
__builtin_amdgcn_sched_barrier
(
0
);
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
...
@@ -768,81 +771,119 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -768,81 +771,119 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int
k_idx
=
0
;
int
k_idx
=
0
;
// k_idx++;
// k_idx++;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(14 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(14 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(13 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(13 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(12 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(12 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(11 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(11 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(10 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(10 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(9+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(9+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(8+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(8+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(7+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(7+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(6+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(6+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(5 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(5 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
...
@@ -850,14 +891,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -850,14 +891,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
buffer_to_tensor
(
buffer
[
0
],
tSrK_smem
,
15
);
flash
::
buffer_to_tensor
(
buffer
[
0
],
tSrK_smem
,
15
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
15
),
tSrK_smem
(
_
,
_
,
15
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
15
),
tSrK_smem
(
_
,
_
,
15
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
16
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
16
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
...
@@ -903,7 +950,9 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -903,7 +950,9 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
flash
::
lds_direct_copy
<
false
,
true
>
(
gK
,
sK
,
15
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
lds_direct_copy
<
false
,
true
>
(
gK
,
sK
,
15
,
params
.
k_row_stride
,
seqlen_k
-
n_block
*
kBlockN
);
// asm_ds_write(buffer[0], tVsV, 15);
// asm_ds_write(buffer[0], tVsV, 15);
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
// asm volatile("s_waitcnt lgkmcnt(0) \n\t s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
...
@@ -925,10 +974,12 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -925,10 +974,12 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int
cur_block_table
;
int
cur_block_table
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
const
int
*
cur_block_table_ptr
=
block_table
+
n_block
;
// cur_block_table = block_table[n_block - 1];
// cur_block_table = block_table[n_block - 1];
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
asm
volatile
(
"s_load_dword %1, %0, 0x0
\n\t
"
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"s_waitcnt lgkmcnt(0)
\n\t
"
:
"+s"
(
cur_block_table_ptr
),
"+s"
(
cur_block_table_ptr
),
"=s"
(
cur_block_table
));
"=s"
(
cur_block_table
));
__builtin_amdgcn_sched_barrier
(
0
);
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
index_t
offset_k
=
cur_block_table
*
params
.
k_batch_stride
;
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
offset_k
);
...
@@ -948,85 +999,117 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -948,85 +999,117 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
int
k_idx
=
0
;
int
k_idx
=
0
;
// k_idx++;
// k_idx++;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(14 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(14 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(13 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(13 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(12 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(12 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(11 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(11 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(10 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(10 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(9+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(9+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(8+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(8+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(7+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(7+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(6+ 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(6+ 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(5 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(5 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0 + 2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
k_idx
++
;
k_idx
++
;
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
...
@@ -1039,14 +1122,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
...
@@ -1039,14 +1122,20 @@ compute_attn_1rowblock_splitkv_mla_gfx936(const DenseAttnDecodeParams& params,
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
buffer_to_tensor
(
buffer
[
0
],
tSrK_smem
,
16
);
flash
::
buffer_to_tensor
(
buffer
[
0
],
tSrK_smem
,
16
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
16
),
tSrK_smem
(
_
,
_
,
16
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
17
);
flash
::
buffer_to_tensor
(
buffer
[
1
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
// asm volatile("s_barrier\n\t");
// asm volatile("s_barrier\n\t");
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
// We have key_padding_mask so we'll need to Check_inf
// We have key_padding_mask so we'll need to Check_inf
...
@@ -1325,7 +1414,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
...
@@ -1325,7 +1414,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
#if 1
#if 1
#pragma unroll
#pragma unroll
for
(
int
masking_step
=
n_masking_steps
;
n_block
>=
n_block_min
;
--
masking_step
,
--
n_block
)
{
for
(
int
masking_step
=
n_masking_steps
;
n_block
>=
n_block_min
;
--
masking_step
,
--
n_block
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
kBlockM
>
,
Int
<
kBlockN
>>
{});
clear
(
acc_s
);
clear
(
acc_s
);
...
@@ -1337,7 +1428,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
...
@@ -1337,7 +1428,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
for
(
int
i
=
0
;
i
<
k0_lds_loops
-
BUFFER_SIZE
+
1
;
i
++
)
{
for
(
int
i
=
0
;
i
<
k0_lds_loops
-
BUFFER_SIZE
+
1
;
i
++
)
{
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
// asm volatile("s_waitcnt vmcnt(3) \n\t \n\t");
flash
::
asm_ds_write
(
buffer
[
i
%
BUFFER_SIZE
],
tKsK
,
i
);
flash
::
asm_ds_write
(
buffer
[
i
%
BUFFER_SIZE
],
tKsK
,
i
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
i
),
tSrK_copy_view
(
_
,
_
,
i
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
i
),
tSrK_copy_view
(
_
,
_
,
i
));
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[(
i
+
BUFFER_SIZE
-
1
)
%
BUFFER_SIZE
],
i
+
BUFFER_SIZE
-
1
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
flash
::
buffer_load_copy
<
false
,
true
,
false
>
(
gK
,
buffer
[(
i
+
BUFFER_SIZE
-
1
)
%
BUFFER_SIZE
],
i
+
BUFFER_SIZE
-
1
,
params
.
k_row_stride
,
offset_k
,
seqlen_k
-
n_block
*
kBlockN
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
i
),
tSrK
(
_
,
_
,
i
),
acc_s
);
...
@@ -1353,17 +1446,23 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
...
@@ -1353,17 +1446,23 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
// 计算 13-15
// 计算 13-15
const
int
k_idx
=
k0_lds_loops
-
BUFFER_SIZE
+
1
;
const
int
k_idx
=
k0_lds_loops
-
BUFFER_SIZE
+
1
;
flash
::
asm_ds_write
(
buffer
[
k_idx
%
BUFFER_SIZE
],
tKsK
,
k_idx
);
flash
::
asm_ds_write
(
buffer
[
k_idx
%
BUFFER_SIZE
],
tKsK
,
k_idx
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
1
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
1
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
1
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
1
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
1
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
1
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
1
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
1
),
tSrK
(
_
,
_
,
k_idx
+
1
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
1
),
tSrK
(
_
,
_
,
k_idx
+
1
),
acc_s
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
2
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
2
);
flash
::
asm_ds_write
(
buffer
[(
k_idx
+
2
)
%
BUFFER_SIZE
],
tKsK
,
k_idx
+
2
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
2
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
2
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
+
2
),
tSrK_copy_view
(
_
,
_
,
k_idx
+
2
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
2
),
tSrK
(
_
,
_
,
k_idx
+
2
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
+
2
),
tSrK
(
_
,
_
,
k_idx
+
2
),
acc_s
);
...
@@ -1380,7 +1479,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
...
@@ -1380,7 +1479,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
flash
::
buffer_to_tensor
(
buffer
[
2
],
tSrK_smem
,
17
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
17
),
tSrK_smem
(
_
,
_
,
17
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
...
@@ -1415,7 +1516,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
...
@@ -1415,7 +1516,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
#if 1
#if 1
// 第15块已经读取到了buffer[3]中
// 第15块已经读取到了buffer[3]中
flash
::
asm_ds_write
(
buffer
[
3
],
tVsV
,
15
);
flash
::
asm_ds_write
(
buffer
[
3
],
tVsV
,
15
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
#endif
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
gK
.
data
()
=
gK
.
data
()
+
(
-
offset_k
);
...
@@ -1434,7 +1537,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
...
@@ -1434,7 +1537,9 @@ compute_attn_1rowblock_splitkv_mla_gfx928(const DenseAttnDecodeParams& params,
cute
::
copy
(
smem_tiled_copy_V
,
tOsVt
(
_
,
_
,
i
),
tOrVt_copy_view
(
_
,
_
,
i
));
cute
::
copy
(
smem_tiled_copy_V
,
tOsVt
(
_
,
_
,
i
),
tOrVt_copy_view
(
_
,
_
,
i
));
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
i
),
tOrVt
(
_
,
_
,
i
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
i
),
tOrVt
(
_
,
_
,
i
),
acc_o
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
" s_barrier
\n\t
"
);
asm
volatile
(
" s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
#endif
#endif
...
...
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
142846b5
...
@@ -800,48 +800,66 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -800,48 +800,66 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
// asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t");
if
constexpr
(
D_QK
==
576
)
if
constexpr
(
D_QK
==
576
)
{
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(4)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
15
),
tSrQ_copy_view
(
_
,
_
,
15
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
15
),
tSrQ_copy_view
(
_
,
_
,
15
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
16
),
tSrQ_copy_view
(
_
,
_
,
16
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
16
),
tSrQ_copy_view
(
_
,
_
,
16
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
17
),
tSrQ_copy_view
(
_
,
_
,
17
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
17
),
tSrQ_copy_view
(
_
,
_
,
17
));
}
}
else
else
{
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
0
),
tSrQ_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
1
),
tSrQ_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
2
),
tSrQ_copy_view
(
_
,
_
,
2
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
3
),
tSrQ_copy_view
(
_
,
_
,
3
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
4
),
tSrQ_copy_view
(
_
,
_
,
4
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
5
),
tSrQ_copy_view
(
_
,
_
,
5
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
6
),
tSrQ_copy_view
(
_
,
_
,
6
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
7
),
tSrQ_copy_view
(
_
,
_
,
7
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
8
),
tSrQ_copy_view
(
_
,
_
,
8
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
9
),
tSrQ_copy_view
(
_
,
_
,
9
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
10
),
tSrQ_copy_view
(
_
,
_
,
10
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
11
),
tSrQ_copy_view
(
_
,
_
,
11
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
12
),
tSrQ_copy_view
(
_
,
_
,
12
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
13
),
tSrQ_copy_view
(
_
,
_
,
13
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
cute
::
copy
(
smem_tiled_copy_Q
,
tSsQ
(
_
,
_
,
14
),
tSrQ_copy_view
(
_
,
_
,
14
));
...
@@ -898,10 +916,14 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -898,10 +916,14 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
i
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
i
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
0
),
tSrK_copy_view
(
_
,
_
,
0
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
+
16
),
tSrK
(
_
,
_
,
0
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
0
+
16
),
tSrK
(
_
,
_
,
0
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
0
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
0
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
1
),
tSrK_copy_view
(
_
,
_
,
1
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
+
16
),
tSrK
(
_
,
_
,
1
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
1
+
16
),
tSrK
(
_
,
_
,
1
),
acc_s
);
...
@@ -917,58 +939,79 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -917,58 +939,79 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
}
}
int
k_idx
=
0
;
int
k_idx
=
0
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
0
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
1
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
2
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
3
,
0
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
4
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
4
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
5
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
5
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
6
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
6
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
7
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
7
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
0
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
0
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
1
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
2
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
2
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
3
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
3
,
1
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
8
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
8
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
9
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
9
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
...
@@ -976,29 +1019,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -976,29 +1019,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
11
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
11
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
0
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
1
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
2
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
flash
::
__ds_read_m32x16_row_col_rrow_alt
<
0
,
3
,
2
>
(
tOsVt
,
tOrVt_copy_view
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
12
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
12
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
13
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
13
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
...
@@ -1006,25 +1059,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -1006,25 +1059,39 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
15
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
flash
::
lds_direct_copy_for_prefill_sparse_mla
<
true
,
false
,
false
>
(
gK
,
sK
,
row_offset
,
col
,
15
,
params
.
stride_kv_s_kv
,
params
.
s_kv
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
k_idx
++
;
k_idx
++
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
copy
(
smem_tiled_copy_K
,
tSsK
(
_
,
_
,
k_idx
%
4
),
tSrK_copy_view
(
_
,
_
,
k_idx
));
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
cute
::
gemm
(
tiled_mma
,
tSrQ
(
_
,
_
,
k_idx
),
tSrK
(
_
,
_
,
k_idx
),
acc_s
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// if (block0())
// if (block0())
// {
// {
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment