Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
a1eef562
"tests/vscode:/vscode.git/clone" did not exist on "a45b979d9fe4e700a81256ad314a9d5fd65a2829"
Commit
a1eef562
authored
Jun 04, 2026
by
shenzhe
Committed by
zhanghj2
Jun 06, 2026
Browse files
Add DSA MLS sparse prefill dispatch
parent
4e0bdf6e
Changes
121
Hide whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
10289 additions
and
0 deletions
+10289
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_epilogue_gfx938.h
...a_mls/legacy/include/mla/gfx938/fp8_mla_epilogue_gfx938.h
+42
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_softmax_gfx938.h
...sa_mls/legacy/include/mla/gfx938/fp8_mla_softmax_gfx938.h
+83
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h
...nclude/mla/gfx938/fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h
+180
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_pv_gemm_utils_gfx938.h
...acy/include/mla/gfx938/fp8_mla_tp8_pv_gemm_utils_gfx938.h
+54
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_qk_gemm_gfx938.h
...ls/legacy/include/mla/gfx938/fp8_mla_tp8_qk_gemm_gfx938.h
+135
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_qk_gemm_utils_gfx938.h
...acy/include/mla/gfx938/fp8_mla_tp8_qk_gemm_utils_gfx938.h
+116
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_epilogue_tile16x32_lit.h
...ls/legacy/include/mla/gfx938/mla_epilogue_tile16x32_lit.h
+387
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_pv_gemm_prefetch_k_mls_ds.h
...legacy/include/mla/gfx938/mla_pv_gemm_prefetch_k_mls_ds.h
+3364
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_pv_gemm_utils_mls_ds.h
..._mls/legacy/include/mla/gfx938/mla_pv_gemm_utils_mls_ds.h
+171
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_qk_gemm_prefetch_v_mls_ds.h
...legacy/include/mla/gfx938/mla_qk_gemm_prefetch_v_mls_ds.h
+3044
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_qk_gemm_utils_mls_ds.h
..._mls/legacy/include/mla/gfx938/mla_qk_gemm_utils_mls_ds.h
+134
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_softmax_gfx938.h
...se/dsa_mls/legacy/include/mla/gfx938/mla_softmax_gfx938.h
+705
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_tp8_epilogue_gfx938.h
...a_mls/legacy/include/mla/gfx938/mla_tp8_epilogue_gfx938.h
+63
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h
.../legacy/include/mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h
+86
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_acco_reduce.h
...efill/sparse/dsa_mls/legacy/include/mla/mla_acco_reduce.h
+82
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_acco_reduce_tile16x32.h
...se/dsa_mls/legacy/include/mla/mla_acco_reduce_tile16x32.h
+68
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_epilogue.h
.../prefill/sparse/dsa_mls/legacy/include/mla/mla_epilogue.h
+134
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_epilogue_tile16x32.h
...parse/dsa_mls/legacy/include/mla/mla_epilogue_tile16x32.h
+89
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_prefix_prefill.h
...ll/sparse/dsa_mls/legacy/include/mla/mla_prefix_prefill.h
+979
-0
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_pv_gemm_prefetch_k.h
...parse/dsa_mls/legacy/include/mla/mla_pv_gemm_prefetch_k.h
+373
-0
No files found.
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_epilogue_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "numeric_types.h"
template
<
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_mla_epilugue_rescale_acco_gfx938
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
],
ElementAccum
v_descale
)
{
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
K_LOOP_COUNT
;
++
pv_n_loop
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
K_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
ElementAccum
sum
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
ElementAccum
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
v_descale
:
v_descale
/
sum
;
__float2
scale_pair
=
{
inv_sum
,
inv_sum
};
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
}
#else
for
(
int
vec_id
=
0
;
vec_id
<
4
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
f32
[
vec_id
]
*=
inv_sum
;
}
#endif
}
}
}
}
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_softmax_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "fwd/utils.h"
using
namespace
flash
;
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
fp8_mla_apply_mask_gfx938
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
// lane id, 0-63
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
8
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
4
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
if
(
col_idx
>=
max_seqlen_k
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
}
}
}
}
}
}
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
fp8_mla_apply_mask_causal_gfx938_mtp
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
mtp
,
const
int
layout
)
{
const
int
MTP_REGROUP_COUNT
=
max_seqlen_q
/
mtp
;
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
8
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
row_in_mtp
=
layout
==
0
?
(
row_idx
%
mtp
)
:
(
row_idx
/
MTP_REGROUP_COUNT
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_in_mtp
+
max_seqlen_k
-
mtp
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
4
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
fp8_mla_apply_descale_gfx938
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
__float2
qk_descale
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
M_WARP_COUNT
*
N_WARP_COUNT
;
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
],
qk_descale
);
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
],
qk_descale
);
}
}
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "fp8_mla_tp8_pv_gemm_utils_gfx938.h"
#include "fp8_mla_tp8_qk_gemm_utils_gfx938.h"
template
<
bool
PrefetchK
,
int
K_LOOP_COUNT
,
int
kBlockN
,
int
kBlockK
,
int
M_WARP_COUNT
,
int
PV_K_WARP_COUNT
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
V_Element
,
typename
P_Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_mla_tp8_pv_gemm_prefetch_k_gfx938
(
vec4_uint
v_addr
,
vec4_uint
&
k_addr
,
V_Element
*
v_lds
,
V_Element
*
k_lds
,
union_vec2_f16x2
<
P_Element
>
p_reg
[
M_WARP_COUNT
*
PV_K_WARP_COUNT
][
4
],
vec4_Accum
<
ElementAccum
>
pv_reg
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
k_row_stride
,
int
v_row_stride
,
int
max_seq_v_offset
,
int64_t
k_addr_offset
)
{
static_assert
(
K_LOOP_COUNT
%
2
==
0
);
constexpr
int
K_LOOP_COUNT_
=
K_LOOP_COUNT
/
(
64
/
kBlockN
);
constexpr
int
PREFETCH
=
2
;
// 防止与多 wave reduce max 需要的 lds 冲突
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
// 准备 MLS 的 resource 寄存器
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
// pingpong
int
stage_id
=
1
;
#pragma unroll
for
(
int
k_loop
=
K_LOOP_COUNT_
-
1
-
PREFETCH
;
k_loop
>=
1
;
k_loop
-=
PREFETCH
)
{
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
// lds 的写入地址
int
warp_lds_write_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
// global 随着 warp 的地址偏移
int
warp_global_bytes
;
// = warp_id * 32 * v_row_stride * sizeof(V_Element);
// global 随着 k_loop 的地址偏移
int
v_loop_global_bytes
=
(
k_loop
-
load_id
)
*
64
*
sizeof
(
V_Element
);
// 计算边界
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_v_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes
=
real_mls_warp_id
*
32
*
v_row_stride
*
sizeof
(
V_Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_v_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc
[
3
]
=
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
warp_global_bytes
+
v_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
v_lds
,
v_srsrc
,
warp_lds_write_bytes
,
0
);
}
// 等待 4 个 warp 数据写入 lds 完毕, 各 warp 之间数据不共享, 可以尝试不 sync
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
PREFETCH
);
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
// 分配 v 计算 mmac 需要的寄存器资源
union_vec16_fp8
v_regs
[
2
];
// 从 lds 读取数据到寄存器
int
lds_load_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
,
v_regs
[
0
].
i32x4
,
false
/*transpose*/
)
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
+
32
,
v_regs
[
1
].
i32x4
,
false
/*transpose*/
)
// mmac
// P, fp16, 半精度
// V, fp8
int
k_loop_inner
=
k_loop
-
load_id
+
PREFETCH
;
#pragma unroll
for
(
int
tile32x32_id
=
0
;
tile32x32_id
<
2
;
++
tile32x32_id
)
{
// wait data written to registers
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
tile32x32_id
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// 16 fp8 for ds32x32_b8
#pragma unroll
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
// fp8 -> f32
vec2_fp32
v_f32x2
[
4
];
// 8 fp8 -> 8 f32, for 1 mmac
v_f32x2
[
0
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
false
/*word_sel*/
);
v_f32x2
[
1
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
true
/*word_sel*/
);
v_f32x2
[
2
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
false
/*word_sel*/
);
v_f32x2
[
3
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
true
/*word_sel*/
);
// f32 -> fp16
union_vec4_f16x2
<
P_Element
>
v_f16x8
;
v_f16x8
.
f16x2
[
0
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
1
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
2
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
3
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
// mmac_16x16x16, 4 fp16
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
P_Element
,
ElementAccum
>
(
p_reg
[
0
][
mmac_id
*
2
+
min_tile_m
].
f16x4
,
v_f16x8
.
f16x4
[
mmac_id
],
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
}
// 处理 K
*
(
int64_t
*
)
&
k_addr
+=
k_addr_offset
;
if
constexpr
(
PrefetchK
)
{
fp8_mla_tp8_prefetch_k_gfx938
<
WARP_NUM
,
V_Element
>
(
k_addr
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_v_offset
-
kBlockK
);
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
1
);
}
else
{
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
}
{
constexpr
int
k_loop
=
1
-
PREFETCH
;
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
// 分配 v 计算 mmac 需要的寄存器资源
union_vec16_fp8
v_regs
[
2
];
// 从 lds 读取数据到寄存器
int
lds_load_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
,
v_regs
[
0
].
i32x4
,
false
/*transpose*/
)
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
+
32
,
v_regs
[
1
].
i32x4
,
false
/*transpose*/
)
// mmac
// P, fp16, 半精度
// V, fp8
int
k_loop_inner
=
k_loop
-
load_id
+
PREFETCH
;
#pragma unroll
for
(
int
tile32x32_id
=
0
;
tile32x32_id
<
2
;
++
tile32x32_id
)
{
// wait data written to registers
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
tile32x32_id
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// 16 fp8 for ds32x32_b8
#pragma unroll
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
// fp8 -> f32
vec2_fp32
v_f32x2
[
4
];
// 8 fp8 -> 8 f32, for 1 mmac
v_f32x2
[
0
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
false
/*word_sel*/
);
v_f32x2
[
1
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
true
/*word_sel*/
);
v_f32x2
[
2
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
false
/*word_sel*/
);
v_f32x2
[
3
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
true
/*word_sel*/
);
// f32 -> fp16
union_vec4_f16x2
<
P_Element
>
v_f16x8
;
v_f16x8
.
f16x2
[
0
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
1
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
2
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
3
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
// mmac_16x16x16, 4 fp16
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
P_Element
,
ElementAccum
>
(
p_reg
[
0
][
mmac_id
*
2
+
min_tile_m
].
f16x4
,
v_f16x8
.
f16x4
[
mmac_id
],
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
}
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
// here, K/V use more lds, and thus reuse togather, need sync
}
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_pv_gemm_utils_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h"
template
<
int
K_LOOP_COUNT
,
int
kBlockN
,
int
WARP_NUM
,
typename
V_Element
>
__forceinline__
__device__
void
fp8_mla_tp8_prefetch_v_gfx938
(
vec4_uint
v_addr
,
V_Element
*
v_lds
,
int
warp_id
,
int
v_row_stride
,
int
max_seq_v_offset
=
0
)
{
static_assert
(
K_LOOP_COUNT
%
2
==
0
);
constexpr
int
K_LOOP_COUNT_
=
K_LOOP_COUNT
/
(
64
/
kBlockN
);
constexpr
int
PREFETCH
=
2
;
// 防止与多 wave reduce max 需要的 lds 冲突
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
// 准备 MLS 的 resource 寄存器
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
// pingpong
int
stage_id
=
0
;
{
int
k_loop
=
K_LOOP_COUNT_
-
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
// 准备读取 V 32x64 个 fp8
// lds 的写入地址
int
warp_lds_write_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
// global 随着 warp 的地址偏移
int
warp_global_bytes
;
// = warp_id * 32 * v_row_stride * sizeof(V_Element);
// global 随着 k_loop 的地址偏移
int
v_loop_global_bytes
=
(
k_loop
-
load_id
)
*
64
*
sizeof
(
V_Element
);
// 计算边界
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_v_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes
=
real_mls_warp_id
*
32
*
v_row_stride
*
sizeof
(
V_Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_v_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc
[
3
]
=
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
warp_global_bytes
+
v_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
v_lds
,
v_srsrc
,
warp_lds_write_bytes
,
0
);
}
}
}
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_qk_gemm_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "fp8_mla_tp8_qk_gemm_utils_gfx938.h"
template
<
int
kHeadDim
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_mla_tp8_qk_gemm_gfx938
(
vec4_uint
k_addr
,
Element
*
k_lds
,
union_vec16_fp8
q_reg
[
M_MMAC_COUNT
][
kHeadDim
/
64
],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
int
warp_id
,
int
k_row_stride
,
int
max_seq_k_offset
=
0
)
{
int
stage_id
=
0
;
// 准备 MLS resource 寄存器
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
k_row_stride
;
// 初始化 s
#pragma unroll
for
(
int
i
=
0
;
i
<
(
WARP_N
/
WARP_N
)
*
(
WARP_M
/
32
);
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
asm
volatile
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
:
"=v"
(
s_reg
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]),
"=v"
(
s_reg
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
])
:
);
}
}
}
// round
stage_id
^=
1
;
#pragma unroll
for
(
int
k_loop
=
1
;
k_loop
<
kHeadDim
/
64
;
++
k_loop
)
{
// lds 的写入地址
int
warp_lds_write_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
// global 随着 warp 的地址偏移
int
warp_global_bytes
;
// = warp_id * 32 * k_row_stride * sizeof(Element);
// global 随着 k_loop 的地址偏移
int
k_loop_global_bytes
=
k_loop
*
64
*
sizeof
(
Element
);
// 计算边界
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_k_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes
=
real_mls_warp_id
*
32
*
k_row_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_k_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc
[
3
]
=
nm_filter
<<
8
;
k_srsrc
[
3
]
+=
0x40000
;
}
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
warp_global_bytes
+
k_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
k_lds
,
k_srsrc
,
warp_lds_write_bytes
,
0
);
// 等待 4 个 warp 数据写入 lds 完毕, 各 warp 之间数据不共享, 可以尝试不 sync
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
1
);
// round
stage_id
^=
1
;
// 分配 k 计算 mmac 需要的寄存器资源
union_vec16_fp8
k_regs
[
WARP_N
/
16
];
// 从 lds 读取数据到寄存器
int
lds_load_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
,
k_regs
[
0
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
+
1024
,
k_regs
[
1
].
i32x4
,
true
/*transpose*/
)
// mmac
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
// 等待数据写到寄存器
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
min_tile_n
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_reg
[
min_tile_m
][
k_loop
-
1
].
i8x8
[
min_tile_k
],
k_regs
[
min_tile_n
].
i8x8
[
min_tile_k
],
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
{
constexpr
int
k_loop
=
kHeadDim
/
64
;
// 等待 4 个 warp 数据写入 lds 完毕, 各 warp 之间数据不共享, 可以尝试不 sync
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
stage_id
^=
1
;
// 分配 k 计算 mmac 需要的寄存器资源
union_vec16_fp8
k_regs
[
WARP_N
/
16
];
// 从 lds 读取数据到寄存器
int
lds_load_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
,
k_regs
[
0
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
+
1024
,
k_regs
[
1
].
i32x4
,
true
/*transpose*/
)
// mmac
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
// 等待数据写到寄存器
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
min_tile_n
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_reg
[
min_tile_m
][
k_loop
-
1
].
i8x8
[
min_tile_k
],
k_regs
[
min_tile_n
].
i8x8
[
min_tile_k
],
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
// qk_gemm
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/fp8_mla_tp8_qk_gemm_utils_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds_b8.h"
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
int
WARP_NUM
,
typename
Element
,
typename
ElementAccum
,
int
STAGES
,
int
M_MMAC_COUNT
>
__forceinline__
__device__
void
fp8_mla_tp8_prefetch_q_to_vgpr_gfx938_with_initialization
(
vec4_uint
q_addr
,
Element
*
q_lds
,
union_vec16_fp8
q_reg
[
M_MMAC_COUNT
][
kHeadDim
/
64
],
int
warp_id
,
int
q_row_stride
,
int
max_seq_q_offset
,
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
],
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDimV
/
kBlockK
][
4
])
{
// 准备 MLS 寄存器
vec4_uint
q_srsrc
;
q_srsrc
[
0
]
=
q_addr
[
0
];
q_srsrc
[
1
]
=
q_addr
[
1
];
q_srsrc
[
2
]
=
q_row_stride
;
q_srsrc
[
3
]
=
0
;
// 计算 lds 写入地址
int
q_lds_write_bytes
=
warp_id
*
16
*
128
*
sizeof
(
Element
);
// 计算 global 读取地址
int
q_mls_warp_global_offset
=
warp_id
*
128
*
sizeof
(
Element
);
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
q_mls_warp_global_offset
);
// mls 读取 16x128 bytes
if
constexpr
(
true
)
{
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
16
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
nm_filter
<<
8
;
}
inline_matrix_load_128x16_b8_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
q_lds_write_bytes
,
0
);
// add alu between def-use
attention_initialize
<
kHeadDimV
/
kBlockK
,
WARP_M
/
32
,
1
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
// 等待 4 个 warp 数据写入 lds 完毕
flash
::
wait_buffer_data_arrived
<
true
/*sync*/
>
(
0
);
// 从 lds 读取数据
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_NUM
;
++
i
)
{
int
q_lds_load_offset
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
(
i
*
16
*
128
)
*
sizeof
(
Element
);
DS_READ_MATRIX_64x16_B8
(
q_lds_load_offset
,
q_reg
[
0
][
i
*
2
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
q_lds_load_offset
+
1024
,
q_reg
[
0
][
i
*
2
+
1
].
i32x4
,
true
/*transpose*/
)
}
__builtin_amdgcn_sched_barrier
(
0
);
// 接着读取剩下的 16x64
// =====================================================================================================
if
(
warp_id
==
0
)
{
// [RTL bug] MLS 128B 请求指令使用 m_filter 需要限制起始地址和 stride 都是 128B 对齐, 否则在访问矩阵最后一行末尾时, 若地址跨越 64B, 一定概率跨越了页表, 导致 invalid address
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
((
WARP_NUM
-
1
)
*
128
+
64
)
*
sizeof
(
Element
));
inline_matrix_load_128x16_b8_lds_trans
<
0
,
1
>
(
q_lds
+
16384
,
q_srsrc
,
q_lds_write_bytes
,
0
);
// 等待数据写到 lds
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
}
// sync
flash
::
wait_all_warp_arrived
();
// 每个 warp 读取 16x64 的内容
int
q_lds_load_offset
=
reinterpret_cast
<
size_t
>
(
q_lds
+
16384
)
*
sizeof
(
Element
);
DS_READ_MATRIX_64x16_B8
(
q_lds_load_offset
+
1024
,
q_reg
[
0
][
8
].
i32x4
,
true
/*transpose*/
)
// 同步, 等待数据写到寄存器, 同时防止 lds 被新的 mls 指令写入
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
template
<
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
fp8_mla_tp8_prefetch_k_gfx938
(
vec4_uint
k_addr
,
Element
*
k_lds
,
int
warp_id
,
int
k_row_stride
,
int
max_seq_k_offset
=
0
)
{
int
stage_id
=
0
;
// 准备 MLS resource 寄存器
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
k_row_stride
;
{
constexpr
int
k_loop
=
0
;
// lds 的写入地址
int
warp_lds_write_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
// global 随着 warp 的地址偏移
int
warp_global_bytes
;
// = warp_id * 32 * k_row_stride * sizeof(Element);
// global 随着 k_loop 的地址偏移
int
k_loop_global_bytes
=
k_loop
*
64
*
sizeof
(
Element
);
// 计算边界
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_k_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
warp_global_bytes
=
real_mls_warp_id
*
32
*
k_row_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_k_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc
[
3
]
=
nm_filter
<<
8
;
k_srsrc
[
3
]
+=
0x40000
;
}
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
warp_global_bytes
+
k_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
k_lds
,
k_srsrc
,
warp_lds_write_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_epilogue_tile16x32_lit.h
0 → 100644
View file @
a1eef562
#include "numeric_types.h"
#include "intrinsic.h"
#define CUDART_L2E_F 1.442695041F
// DataType: {vec2_Accum<ElementAccum>, vec_Accum<ElementAccum>}
template
<
int
WARP_M
,
int
kBlockK
,
int
kHeadDimV
,
bool
Is_dropout
,
typename
ElementAccum
,
typename
DataType
=
union_vec2_fp32
/* vec2_Accum<ElementAccum> */
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
prefill_mla_epilugue_rescale_acco
(
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
],
DataType
lse
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
DataType
scores_max
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
DataType
scores_sum
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
const
ElementAccum
scale_softmax
,
const
ElementAccum
rp_dropout
)
{
// Epilogue
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
ElementAccum
sum
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
ElementAccum
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
[
mi
].
f32
[
min_tile_m
]
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
-
INFINITY
:
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
+
__logf
(
sum
);
ElementAccum
scale
=
Is_dropout
?
inv_sum
*
rp_dropout
:
inv_sum
;
__float2
scale_pair
=
{
scale
,
scale
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockK
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
}
else
{
mmac_id
=
min_tile_n
;
}
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
kHeadDimV
/
kBlockK
);
++
pv_n_loop
)
{
const
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
+
mi
;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
}
#else
for
(
int
vec_id
=
0
;
vec_id
<
4
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
f32
[
vec_id
]
*=
scale
;
}
#endif
}
}
}
}
}
}
template
<
int
WARP_M
,
int
kBlockK
,
int
kHeadDimV
,
bool
Is_dropout
,
typename
ElementAccum
,
typename
DataType
=
union_vec2_fp32
/* vec2_Accum<ElementAccum> */
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
decode_dsa_epilugue_rescale_acco
(
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
],
DataType
lse
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
DataType
scores_max
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
DataType
scores_sum
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
const
ElementAccum
scale_softmax
,
const
ElementAccum
rp_dropout
,
float
*
attn_sink
)
{
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
// Epilogue
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
ElementAccum
sum
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
ElementAccum
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
[
mi
].
f32
[
min_tile_m
]
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
-
INFINITY
:
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
+
__logf
(
sum
);
float
attn_sink_o_scale
=
1.0
f
;
if
(
attn_sink
!=
nullptr
)
{
float
rAttn_sink
=
attn_sink
[
warp_id
*
16
+
tid
%
16
];
if
(
rAttn_sink
==
INFINITY
)
{
attn_sink_o_scale
=
0.0
f
;
}
else
if
((
lse
[
mi
].
f32
[
min_tile_m
]
!=
-
INFINITY
)
&&
(
lse
[
mi
].
f32
[
min_tile_m
]
!=
INFINITY
))
{
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
lse
[
mi
].
f32
[
min_tile_m
]
*
CUDART_L2E_F
);
float
rAttn_sink_exp2
=
__builtin_amdgcn_exp2f
(
rAttn_sink
*
CUDART_L2E_F
);
attn_sink_o_scale
=
lse_exp2
/
(
lse_exp2
+
rAttn_sink_exp2
);
}
}
ElementAccum
scale
=
inv_sum
*
attn_sink_o_scale
;
__float2
scale_pair
=
{
scale
,
scale
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockK
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
}
else
{
mmac_id
=
min_tile_n
;
}
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
kHeadDimV
/
kBlockK
);
++
pv_n_loop
)
{
const
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
+
mi
;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
}
#else
for
(
int
vec_id
=
0
;
vec_id
<
4
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
f32
[
vec_id
]
*=
scale
;
}
#endif
}
}
}
}
}
}
template
<
int
WARP_M
,
int
kBlockK
,
int
kHeadDimV
,
bool
Is_dropout
,
typename
ElementAccum
,
typename
DataType
=
union_vec2_fp32
/* vec2_Accum<ElementAccum> */
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
prefill_dsa_epilugue_rescale_acco
(
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
],
DataType
lse
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
DataType
scores_max
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
DataType
scores_sum
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
const
ElementAccum
scale_softmax
,
const
ElementAccum
rp_dropout
,
float
*
attn_sink
)
{
int
tid
=
threadIdx
.
x
%
64
;
int
warp_id
=
threadIdx
.
x
/
64
;
// Epilogue
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
ElementAccum
sum
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
ElementAccum
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
[
mi
].
f32
[
min_tile_m
]
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
-
INFINITY
:
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
+
__logf
(
sum
);
float
attn_sink_o_scale
=
1.0
f
;
if
(
attn_sink
!=
nullptr
)
{
float
rAttn_sink
=
attn_sink
[
warp_id
*
16
+
tid
%
16
];
if
(
rAttn_sink
==
INFINITY
)
{
attn_sink_o_scale
=
0.0
f
;
}
else
if
((
lse
[
mi
].
f32
[
min_tile_m
]
!=
-
INFINITY
)
&&
(
lse
[
mi
].
f32
[
min_tile_m
]
!=
INFINITY
))
{
float
lse_exp2
=
__builtin_amdgcn_exp2f
(
lse
[
mi
].
f32
[
min_tile_m
]
*
CUDART_L2E_F
);
float
rAttn_sink_exp2
=
__builtin_amdgcn_exp2f
(
rAttn_sink
*
CUDART_L2E_F
);
attn_sink_o_scale
=
lse_exp2
/
(
lse_exp2
+
rAttn_sink_exp2
);
}
}
ElementAccum
scale
=
inv_sum
*
attn_sink_o_scale
;
__float2
scale_pair
=
{
scale
,
scale
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockK
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
}
else
{
mmac_id
=
min_tile_n
;
}
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
kHeadDimV
/
kBlockK
);
++
pv_n_loop
)
{
const
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
+
mi
;
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
}
#else
for
(
int
vec_id
=
0
;
vec_id
<
4
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
f32
[
vec_id
]
*=
scale
;
}
#endif
}
}
}
}
}
}
template
<
int
WARP_M
,
bool
Is_even_MN
,
bool
SplitD
,
bool
Is_Interleaved
,
typename
ElementAccum
,
typename
DataType
=
union_vec2_fp32
/* vec2_Accum<ElementAccum> */
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
prefill_mla_epilogue_store_lse
(
DataType
lse
[
WARP_M
/
(
16
*
M_MMAC_COUNT
)],
void
*
softmax_lse_ptr
,
int
row_offset_lse
,
int
warp_id
,
int
lane_id
,
int
headdim_split_id
,
int
seqlen_q_limit
)
{
ElementAccum
*
gLSE
=
reinterpret_cast
<
ElementAccum
*>
(
softmax_lse_ptr
)
+
row_offset_lse
;
#if (DEBUG_LEVEL >= 1)
ElementAccum
*
scores_sum_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
scores_sum_ptr
)
+
row_offset_lse
;
ElementAccum
*
scores_max_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
scores_max_ptr
)
+
row_offset_lse
;
#endif
const
bool
write_lse
=
SplitD
>
1
?
(
lane_id
>>
4
)
==
0
and
headdim_split_id
==
0
:
(
lane_id
>>
4
)
==
0
;
if
(
write_lse
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row
=
Is_Interleaved
?
warp_id
*
WARP_M
+
mi
*
(
16
*
M_MMAC_COUNT
)
+
(
lane_id
&
15
)
+
min_tile_m
*
16
:
warp_id
*
WARP_M
+
mi
*
(
16
*
M_MMAC_COUNT
)
+
(
lane_id
&
15
)
*
2
+
min_tile_m
;
if
constexpr
(
Is_even_MN
)
{
gLSE
[
row
]
=
lse
[
mi
].
f32
[
min_tile_m
];
#if (DEBUG_LEVEL >= 1)
scores_sum_ptr
[
row
]
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
scores_max_ptr
[
row
]
=
scores_max
[
mi
].
f32
[
min_tile_m
];
#endif
}
else
{
if
(
row
<
seqlen_q_limit
)
{
gLSE
[
row
]
=
lse
[
mi
].
f32
[
min_tile_m
];
#if (DEBUG_LEVEL >= 1)
scores_sum_ptr
[
row
]
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
scores_max_ptr
[
row
]
=
scores_max
[
mi
].
f32
[
min_tile_m
];
#endif
}
}
}
}
}
}
template
<
int
kHeadDimV
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
bool
Is_even_MN
,
bool
Is_Interleaved
,
bool
TcpSwizzle
,
typename
Element
,
typename
ElementAccum
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
prefill_mla_epilogue_store_output
(
Element
*
o_ptr
,
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
],
int
m_block
,
int
warp_id
,
int
lane_id
,
int
seqlen_o_stride
,
int
seqlen_q_limit
)
{
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
if
constexpr
(
Is_Interleaved
)
{
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
union_vec2_f16x2
<
Element
>
acc_o_fp16
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
];
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll 2
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
else
mmac_id
=
min_tile_n
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
// convert float -> bf16/fp16
acc_o_fp16
[
k_loop
][
mmac_id
].
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
k_loop
][
mmac_id
].
f32x2
[
vec_index
]);
}
ds_mpermute_kdim_for_mmac
(
acc_o_fp16
[
k_loop
][
mmac_id
].
f32
);
}
}
}
#endif
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
flash
::
wait_lds_data_arrived
<
false
>
((
kHeadDimV
/
kBlockK
-
k_loop
-
1
)
*
2
*
2
);
#endif
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK
/
32
);
++
k_tile_idx
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll 2
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
pv_tile_id
=
k_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
;
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
}
else
{
mmac_id
=
min_tile_n
;
}
int
seqlen_q_offset
=
warp_id
*
WARP_M
+
warp_m_idx
*
(
16
*
M_MMAC_COUNT
)
+
min_tile_m
*
16
+
pv_lane_seq_idx
;
// prepare for store
int
s_offset
=
k_tile_idx
*
32
+
min_tile_n
*
16
;
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
k_loop
*
kBlockK
+
pv_lane_head_dim_idx
*
4
;
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
if
constexpr
(
not
Is_even_MN
)
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
acc_o_fp16
[
k_loop
][
mmac_id
];
}
}
else
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
acc_o_fp16
[
k_loop
][
mmac_id
];
}
#else
union_vec2_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
// convert float -> bf16/fp16
v_data
.
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
pv_tile_id
][
mmac_id
].
f32x2
[
vec_index
]);
}
if
constexpr
(
not
Is_even_MN
)
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
}
else
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
#endif
}
}
}
}
}
// brace, to control vgpr usage
}
else
{
// 仅支持LIT的部分
auto
gO
=
prepare_for_buffer_load
<
kHeadDimV
,
Element
>
(
o_ptr
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK
/
32
);
++
k_tile_idx
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
++
vec_index
)
{
if
constexpr
(
not
Is_even_MN
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
seqlen_q_offset
=
warp_id
*
WARP_M
+
warp_m_idx
*
(
16
*
M_MMAC_COUNT
)
+
pv_lane_seq_idx
+
min_tile_m
*
16
;
/*算的是 1 个 kBlockM 内在 seqlen_q 方向上的位置*/
int
pv_global_addr
=
seqlen_q_offset
*
seqlen_o_stride
+
/*headdim 方向上的偏移*/
k_loop
*
kBlockK
+
k_tile_idx
*
32
+
vec_index
*
8
+
pv_lane_head_dim_idx
*
2
+
min_tile_n
;
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
)
o_ptr
[
pv_global_addr
]
=
DownCast
<
ElementAccum
,
Element
>
(
acc_o
[
k_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
][
min_tile_m
+
min_tile_n
*
2
].
f32
[
vec_index
]);
else
o_ptr
[
pv_global_addr
]
=
DownCast
<
ElementAccum
,
Element
>
(
acc_o
[
k_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
][
min_tile_n
].
f32
[
vec_index
]);
}
}
}
else
{
int
tile32x32_id
=
k_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
;
int
s_offset
=
k_loop
*
kBlockK
;
int
s_offset_constexpr
=
k_tile_idx
*
32
+
vec_index
*
8
;
/*overflow for s_offset_constexpr*/
int
v_offset
=
(
warp_id
*
WARP_M
+
warp_m_idx
*
(
16
*
M_MMAC_COUNT
)
+
pv_lane_seq_idx
+
min_tile_m
*
16
)
*
seqlen_o_stride
+
pv_lane_head_dim_idx
*
2
;
vec2_Element
<
Element
>
v_data
;
// convert float -> bf16/fp16
if
constexpr
(
std
::
is_same
<
Element
,
bhalf_t
>::
value
)
{
#if 1
v_data
[
0
]
=
DownCast
<
ElementAccum
,
Element
,
true
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
]);
v_data
[
1
]
=
DownCast
<
ElementAccum
,
Element
,
true
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]);
#else
v_data
[
0
]
=
inlineasm_float2bfloat16_ushort_nonan
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
]);
v_data
[
1
]
=
inlineasm_float2bfloat16_ushort_nonan
(
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]);
#endif
}
else
if
constexpr
(
std
::
is_same
<
Element
,
half_t
>::
value
)
{
#ifdef USE_CVT_PKRTZ_FP16_FP32
*
(
vec2_Element
<
Element
>*
)
&
v_data
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
],
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]
);
#else
v_data
[
0
]
=
DownCast
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
]);
v_data
[
1
]
=
DownCast
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]);
#endif
}
// write to global memory
inline_buffer_store_dword
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gO
,
s_offset
,
/* immediate integer */
s_offset_constexpr
);
}
}
}
}
}
}
// brace, to control vgpr usage
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_pv_gemm_prefetch_k_mls_ds.h
0 → 100644
View file @
a1eef562
#include "mla_qk_gemm_utils_mls_ds.h"
#include "static_switch.h"
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512
(
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
*
block_table
,
int
batch_stride
,
int
page_block_size
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
READ_ONCE_COUNT
=
32
*
32
;
constexpr
int
kHeadDimV_OPT
=
256
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
32
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
lds_stage_id
=
1
;
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// prefetch same warpk, next 32x256 G2S
{
int
n_load
=
1
;
int
n_loop_
=
n_loop
-
1
;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_1
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_1
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_2
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_2
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
// DS
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
7
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
int
lds_load_offset
=
v_lds_base
+
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
((
k_loop
&
7
)
==
0x0
)
{
int
n_loop_
=
n_loop
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_1
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_1
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_2
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_2
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
{
constexpr
int
n_loop
=
kBlockK
/
WARP_K
;
// MLS for special headdimV
{
constexpr
int
n_loop_
=
n_loop
-
1
;
int
n_load
=
1
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_1
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_1
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_2
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_2
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
// [TODO]更早的预取
// DS
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
7
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
// DS
int
lds_load_offset
=
v_lds_base
+
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// 预取Q K
if
constexpr
(
PREFETCH_K
)
{
prefetch_q_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
warp_id
,
seqlen_q_stride
,
max_seq_q_offset
);
prefetch_k_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockK
,
kBlockN
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
max_seq_kv_offset
-
kBlockK
);
}
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage
(
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
page_block_size
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
READ_ONCE_COUNT
=
32
*
32
;
constexpr
int
kHeadDimV_OPT
=
256
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
32
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
lds_stage_id
=
1
;
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// prefetch same warpk, next 32x256 G2S
{
int
n_load
=
1
;
int
n_loop_
=
n_loop
-
1
;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
// DS
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
7
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
int
lds_load_offset
=
v_lds_base
+
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
((
k_loop
&
7
)
==
0x0
)
{
int
n_loop_
=
n_loop
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
{
constexpr
int
n_loop
=
kBlockK
/
WARP_K
;
// MLS for special headdimV
{
constexpr
int
n_loop_
=
n_loop
-
1
;
int
n_load
=
1
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
// [TODO]更早的预取
// DS
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
7
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
// DS
int
lds_load_offset
=
v_lds_base
+
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// 预取Q K
if
constexpr
(
PREFETCH_K
)
{
prefetch_q_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
warp_id
,
seqlen_q_stride
,
max_seq_q_offset
);
prefetch_k_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockK
,
kBlockN
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
max_seq_kv_offset
-
kBlockK
);
}
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage_64
(
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
page_block_size
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
READ_ONCE_COUNT
=
32
*
32
;
constexpr
int
kHeadDimV_OPT
=
256
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
32
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
lds_stage_id
=
1
;
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// prefetch same warpk, next 32x256 G2S
{
int
n_load
=
1
;
int
n_loop_
=
n_loop
-
1
;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
32
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
32
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
g_offset_v
+=
WARP_NUM
*
32
*
ELEMENT_BYTES
/
4
;
g_offset_v_add
+=
WARP_NUM
*
32
*
ELEMENT_BYTES
/
4
;
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
(
WARP_NUM
+
warp_id
)
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
(
WARP_NUM
+
warp_id
)
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
// DS
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
7
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
int
lds_load_offset
=
v_lds_base
+
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
((
k_loop
&
7
)
==
0x0
)
{
int
n_loop_
=
n_loop
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
g_offset_v
+=
WARP_NUM
*
32
*
ELEMENT_BYTES
/
4
;
g_offset_v_add
+=
WARP_NUM
*
32
*
ELEMENT_BYTES
/
4
;
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
(
WARP_NUM
+
warp_id
)
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
(
WARP_NUM
+
warp_id
)
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
{
constexpr
int
n_loop
=
kBlockK
/
WARP_K
;
// MLS for special headdimV
{
constexpr
int
n_loop_
=
n_loop
-
1
;
int
n_load
=
1
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
32
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
32
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
g_offset_v
+=
WARP_NUM
*
32
*
ELEMENT_BYTES
/
4
;
g_offset_v_add
+=
WARP_NUM
*
32
*
ELEMENT_BYTES
/
4
;
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
(
WARP_NUM
+
warp_id
)
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
(
WARP_NUM
+
warp_id
)
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
}
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
// [TODO]更早的预取
// DS
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
7
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
// DS
int
lds_load_offset
=
v_lds_base
+
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// 预取Q K
if
constexpr
(
PREFETCH_K
)
{
prefetch_q_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
warp_id
,
seqlen_q_stride
,
max_seq_q_offset
);
prefetch_k_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockK
,
kBlockN
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
max_seq_kv_offset
-
kBlockK
);
}
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new
(
Element
*
k_faker
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
page_block_size
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
READ_ONCE_COUNT
=
32
*
32
;
constexpr
int
kHeadDimV_OPT
=
128
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
32
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
index_topk_1
[
2
];
index_topk_1
[
0
]
=
index_ptr
[((
n_loop_real
*
64
))
+
(
tid
/
4
)];
index_topk_1
[
1
]
=
index_ptr
[((
n_loop_real
*
64
))
+
(
tid
/
4
)
+
32
];
int
index_topk_2
[
2
];
index_topk_2
[
0
]
=
index_ptr
[((
n_loop_real
*
64
))
+
(
tid
/
4
)
+
16
];
index_topk_2
[
1
]
=
index_ptr
[((
n_loop_real
*
64
))
+
(
tid
/
4
)
+
48
];
// int index_topk_2 = index_ptr[((n_loop_real * 64)) + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
tid
%
4
*
4
+
index_topk_1
[
0
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
warp_id
*
16
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
tid
%
4
*
4
+
index_topk_2
[
0
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s_add
=
warp_id
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
g_offset_s_add
,
g_offset_v_add
);
// __builtin_amdgcn_sched_barrier(0);
int
lds_stage_id
=
1
;
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// prefetch same warpk, next 32x256 G2S
{
int
n_load
=
1
;
int
n_loop_
=
n_loop
-
1
;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v
=
tid
%
4
*
4
+
index_topk_1
[
0
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
g_offset_s
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
;
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v_add
=
tid
%
4
*
4
+
index_topk_2
[
0
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
g_offset_s_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
g_offset_s_add
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
3
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
}
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
((
k_loop
&
3
)
==
0x0
)
{
int
n_loop_
=
n_loop
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
// index_topk_1 = index_ptr[k_loop / 12 * 32 + (n_loop_real * 64) + (tid / 4)];
// index_topk_2 = index_ptr[k_loop / 12 * 32 + (n_loop_real * 64) + (tid / 4) + 16];
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v
=
tid
%
4
*
4
+
index_topk_1
[
k_loop
/
12
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
g_offset_s
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
;
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v_add
=
tid
%
4
*
4
+
index_topk_2
[
k_loop
/
12
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
g_offset_s_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
g_offset_s_add
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
}
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
{
constexpr
int
n_loop
=
kBlockK
/
WARP_K
;
// MLS for special headdimV
{
constexpr
int
n_loop_
=
n_loop
-
1
;
int
n_load
=
1
;
// lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
// int index_topk_1 = index_ptr[n_loop_ * 32 + (n_loop_real * 64) + (tid / 4)];
// int index_topk_2 = index_ptr[n_loop_ * 32 + (n_loop_real * 64) + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
tid
%
4
*
4
+
index_topk_1
[
1
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
tid
%
4
*
4
+
index_topk_2
[
1
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
g_offset_s_add
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
// [TODO]更早的预取
// DS
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
3
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
}
// DS
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
// v_reg[stage_id * 2 + 0].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 0, 2, 1, 0);
// v_reg[stage_id * 2 + 1].f16x8 = __builtin_hcu_ds_read_matrix_format_f16(lds_load_offset, 1024, 2, 1, 0);
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
(
k_loop
==
4
||
k_loop
==
8
)
{
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
// int index_topk_1 = index_ptr[32 + (n_loop_real * 64) + (tid / 4)];
// int index_topk_2 = index_ptr[32 + (n_loop_real * 64) + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
tid
%
4
*
4
+
index_topk_1
[
1
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
tid
%
4
*
4
+
index_topk_2
[
1
]
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
g_offset_s_add
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
}
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
int
abc
[
1
];
int
index_topk
=
index_ptr
[(((
n_loop_real
+
1
)
%
16
)
*
64
)
+
warp_id
*
16
];
int
offset_m
=
index_topk
*
seqlen_k_stride
;
auto
g_abc
=
(
reinterpret_cast
<
uint64_t
>
(
k_faker
+
offset_m
));
inline_s_load_dword
(
abc
[
0
],
g_abc
,
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// 预取Q K
// if constexpr (PREFETCH_K) {
// prefetch_q_to_lds_mls_ds_576_512<kHeadDim, kBlockM, kBlockK, WARP_M, Element, Is_even_MN>(q_ptr, q_lds, warp_id, seqlen_q_stride, max_seq_q_offset);
// prefetch_k_to_lds_mls_ds_576_512<kHeadDim, kBlockK, kBlockN, WARP_NUM, WARP_N, Element, Is_even_MN>(k_ptr, k_lds, warp_id, seqlen_k_stride, max_seq_kv_offset - kBlockK);
// }
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_666
(
Element
*
k_faker
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
64
;
constexpr
int
READ_ONCE_COUNT
=
16
*
32
;
constexpr
int
kHeadDimV_OPT
=
64
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
64
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
index_topk
=
index_ptr
[(
n_loop_real
*
64
)
+
warp_id
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
+
32
*
64
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
tid
%
4
*
4
+
index_topk
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
0
*
WARP_N
*
ELEMENT_BYTES
/
4
+
0
;
int
g_offset_s_2
=
0
*
WARP_N
*
ELEMENT_BYTES
/
4
+
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
int
lds_stage_id
=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
{
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
WARP_N
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
WARP_N
+
warp_id
*
32
*
16
+
32
*
64
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
k_loop
*
WARP_N
*
ELEMENT_BYTES
/
4
;
g_offset_s_2
=
k_loop
*
WARP_N
*
ELEMENT_BYTES
/
4
+
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
}
// 不对称MLS指令
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
lds_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// K DS PRE
stage_id
^=
1
;
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
flash
::
raise_priority
();
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
int
k_loop
=
kHeadDimV
/
kBlockN
;
lds_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// K DS PRE
stage_id
^=
1
;
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
flash
::
raise_priority
();
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
(
k_loop
-
1
)
*
2
+
1
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_999
(
Element
*
k_faker
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
page_block_size
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
16
;
constexpr
int
READ_ONCE_COUNT
=
16
*
32
;
constexpr
int
kHeadDimV_OPT
=
256
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
16
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
256
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
WARP_N
/
32
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
index_topk
[
4
];
index_topk
[
0
]
=
index_ptr
[(
n_loop_real
*
64
)
+
0
*
16
+
(
tid
/
4
)];
index_topk
[
1
]
=
index_ptr
[(
n_loop_real
*
64
)
+
1
*
16
+
(
tid
/
4
)];
index_topk
[
2
]
=
index_ptr
[(
n_loop_real
*
64
)
+
2
*
16
+
(
tid
/
4
)];
index_topk
[
3
]
=
index_ptr
[(
n_loop_real
*
64
)
+
3
*
16
+
(
tid
/
4
)];
int
fallback_index
=
index_ptr
[
0
]
==
-
1
?
0
:
index_ptr
[
0
];
index_topk
[
0
]
=
(
index_topk
[
0
]
==
-
1
)
?
fallback_index
:
index_topk
[
0
];
index_topk
[
1
]
=
(
index_topk
[
1
]
==
-
1
)
?
fallback_index
:
index_topk
[
1
];
index_topk
[
2
]
=
(
index_topk
[
2
]
==
-
1
)
?
fallback_index
:
index_topk
[
2
];
index_topk
[
3
]
=
(
index_topk
[
3
]
==
-
1
)
?
fallback_index
:
index_topk
[
3
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
int
index_block
=
index_topk
[
0
]
/
page_block_size
;
int
index_offset
=
index_topk
[
0
]
-
index_block
*
page_block_size
;
int
g_offset_v
=
tid
%
4
*
4
+
(
index_block
*
batch_stride
+
index_offset
*
seqlen_v_stride
)
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
warp_id
*
32
*
ELEMENT_BYTES
/
4
;
int
g_offset_s_2
=
warp_id
*
32
*
ELEMENT_BYTES
/
4
+
128
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
int
lds_stage_id
=
1
;
for
(
int
total_loop
=
1
;
total_loop
<
(
kHeadDimV
/
kBlockN
)
*
4
;
++
total_loop
)
{
{
index_block
=
index_topk
[
total_loop
/
2
]
/
page_block_size
;
index_offset
=
index_topk
[
total_loop
/
2
]
-
index_block
*
page_block_size
;
g_offset_v
=
tid
%
4
*
4
+
(
index_block
*
batch_stride
+
index_offset
*
seqlen_v_stride
)
*
ELEMENT_BYTES
/
4
;
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
WARP_N
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
WARP_N
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
(
total_loop
%
2
)
*
WARP_N
*
ELEMENT_BYTES
/
4
+
warp_id
*
32
*
ELEMENT_BYTES
/
4
;
g_offset_s_2
=
(
total_loop
%
2
)
*
WARP_N
*
ELEMENT_BYTES
/
4
+
warp_id
*
32
*
ELEMENT_BYTES
/
4
+
128
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
}
// 不对称MLS指令
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
lds_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
16
*
64
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// K DS PRE
stage_id
^=
1
;
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
16
*
64
+
16
*
128
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
16
*
64
+
16
*
128
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
// flash::raise_priority();
// MMAC
stage_id
^=
1
;
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
// flash::raise_priority();
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
2
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
3
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
// flash::raise_priority();
stage_id
^=
1
;
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
4
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
5
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// flash::raise_priority();
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
6
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
7
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
total_loop
-
1
)
/
4
][((
total_loop
-
1
)
/
2
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
16
*
64
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// K DS PRE
stage_id
^=
1
;
{
int
v_lds_load_offset
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
0
*
16
*
64
+
16
*
128
)
*
ELEMENT_BYTES
;
int
v_lds_load_offset_2
=
v_lds_base
+
(
lds_stage_id
*
WARP_N
*
WARP_K
+
1
*
16
*
64
+
16
*
128
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset
,
v_reg
[
stage_id
*
2
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
);
DS_READ_MATRIX_32X32_B16
(
v_lds_load_offset_2
,
v_reg
[
4
+
stage_id
*
2
].
f16
,
v_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
false
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
// flash::raise_priority();
// MMAC
stage_id
^=
1
;
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
0
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
// flash::raise_priority();
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
2
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
3
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
// flash::raise_priority();
stage_id
^=
1
;
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
4
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
5
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// int abc[1];
// int index_topk_qk = index_ptr[(((n_loop_real+1) % 16) * 64) + warp_id * 16];
// int offset_m = index_topk_qk * seqlen_k_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(k_faker + offset_m));
// inline_s_load_dword(abc[0], g_abc, 0);
// flash::raise_priority();
{
int
min_tile_nk
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
6
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_nk
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
lds_stage_id
*
8
+
7
;
int
v_tile_id
=
4
+
stage_id
*
2
+
min_tile_nk
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
1
][
1
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_2
(
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
16
;
constexpr
int
READ_ONCE_COUNT
=
32
*
16
;
constexpr
int
kHeadDimV_OPT
=
128
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
// static_assert (WARP_K == 32 and "Error: To simplify, only WARP_K = 32 is supported!");
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
16
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
lds_stage_id
=
1
;
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// prefetch same warpk, next 32x256 G2S
{
int
n_load
=
1
;
int
n_loop_
=
n_loop
-
1
;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_
*
16
+
n_loop_real
*
64
+
(
tid
/
4
)];
// int index_topk_2 = index_ptr[n_loop_real * 64 + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
16
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
lds_load_offset
,
v_reg
[
stage_id
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
3
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
}
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
16
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
lds_load_offset
,
v_reg
[
stage_id
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
n_loop
-
1
)
/
2
][(
n_loop
-
1
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
((
k_loop
&
3
)
==
0x0
)
{
int
n_loop_
=
n_loop
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[(
n_loop
-
1
)
*
16
+
k_loop
/
12
*
16
+
n_loop_real
*
64
+
(
tid
/
4
)];
// int index_topk_2 = index_ptr[k_loop / 12 * 32 + n_loop_real * 64 + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
}
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// last mmac
flash
::
raise_priority
();
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
n_loop
-
1
)
/
2
][(
n_loop
-
1
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
{
constexpr
int
n_loop
=
kBlockK
/
WARP_K
;
// MLS for special headdimV
{
constexpr
int
n_loop_
=
n_loop
-
1
;
int
n_load
=
1
;
// lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_
*
16
+
n_loop_real
*
64
+
(
tid
/
4
)];
// int index_topk_2 = index_ptr[n_loop_ * 32 + n_loop_real * 64 + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = n_load * WARP_NUM * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
// [TODO]更早的预取
// DS
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
16
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
lds_load_offset
,
v_reg
[
stage_id
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
3
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
}
// DS
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
16
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
lds_load_offset
,
v_reg
[
stage_id
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
n_loop
-
1
)
/
2
][(
n_loop
-
1
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
(
k_loop
==
4
||
k_loop
==
8
)
{
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
48
+
n_loop_real
*
64
+
(
tid
/
4
)];
// int index_topk_2 = index_ptr[32 + n_loop_real * 64 + (tid / 4) + 16];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
// int lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + warp_id * 32 * 32) * ELEMENT_BYTES / 4);
// int g_offset_v_add = ((k_loop + 4) & 15) * 16 + warp_id * 16 + tid % 4 * 4 + index_topk_2 * seqlen_v_stride * ELEMENT_BYTES / 4;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
}
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// last mmac
flash
::
raise_priority
();
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[(
n_loop
-
1
)
/
2
][(
n_loop
-
1
)
%
2
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// 预取Q K
if
constexpr
(
PREFETCH_K
)
{
prefetch_q_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
warp_id
,
seqlen_q_stride
,
max_seq_q_offset
);
prefetch_k_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockK
,
kBlockN
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
max_seq_kv_offset
-
kBlockK
);
}
}
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_576_512_nopage_64_new_1
(
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
16
)
*
(
kBlockK
/
32
)][
2
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
n_loop_real
,
int
max_seq_q_offset
=
0
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
READ_ONCE_COUNT
=
32
*
32
;
constexpr
int
kHeadDimV_OPT
=
128
;
// lds 32x32x8x2B == 16KB
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV_OPT
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
32
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"Error: To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// 计算 V lds 起始偏移量
int
v_lds_base
=
reinterpret_cast
<
size_t
>
(
v_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// 准备 V 寄存器
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// MLS
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
lds_stage_id
=
1
;
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// prefetch same warpk, next 32x256 G2S
{
int
n_load
=
1
;
int
n_loop_
=
n_loop
-
1
;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(n_loop_ * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
// DS
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
int
k_loop
=
1
;
stage_id
^=
1
;
for
(;
k_loop
<
4
;
++
k_loop
)
{
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// Wait for special headdim
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
int
n_loop_
=
n_loop
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
k_loop
++
;
for
(;
k_loop
<
8
;
++
k_loop
)
{
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// Wait for special headdim
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
n_loop_
=
n_loop
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
index_topk_1
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
index_topk_2
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
k_loop
++
;
for
(;
k_loop
<
12
;
++
k_loop
)
{
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// Wait for special headdim
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
n_loop_
=
n_loop
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
index_topk_1
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
index_topk_2
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
k_loop
++
;
for
(;
k_loop
<
16
;
++
k_loop
)
{
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// Wait for special headdim
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
n_loop_
=
n_loop
;
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
index_topk_1
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
index_topk_2
=
index_ptr
[
k_loop
/
12
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
g_offset_v_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
k_loop
++
;
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
{
constexpr
int
n_loop
=
kBlockK
/
WARP_K
;
// MLS for special headdimV
{
constexpr
int
n_loop_
=
n_loop
-
1
;
int
n_load
=
1
;
// lds_stage_id ^= 1;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + n_load * WARP_NUM * 32) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop_
*
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
n_load
*
WARP_NUM
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
}
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
*
2
);
// [TODO]更早的预取
// DS
int
lds_load_offset
=
v_lds_base
+
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// Wait for special headdim
if
((
k_loop
&
3
)
==
0x0
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
lds_stage_id
^=
1
;
}
// DS
int
lds_load_offset
=
v_lds_base
+
((
k_loop
&
3
)
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
flash
::
raise_priority
();
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
// MLS for special headdimV
if
(
k_loop
==
4
||
k_loop
==
8
)
{
lds_stage_id
^=
1
;
// if constexpr (Is_even_MN) {
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_loop_ * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// } else {
// int nm_filter_max = n_loop_ * WARP_K + 32 - max_seq_kv_offset;
// int real_mls_loop = nm_filter_max >= 32 ? 0: n_loop_; // 如果全越界了, 则只访问 n_loop = 0 的那波数据
// int nm_filter = inline_min_max<0, 32>(real_mls_loop * WARP_K + 32 - max_seq_kv_offset); // 重新计算 nm_filter
// v_srsrc[3] = nm_filter << 8;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (real_mls_loop * WARP_K * seqlen_v_stride + warp_id * 32 + 0 * 32 * WARP_NUM) * ELEMENT_BYTES);
// }
// int lds_offset = (lds_stage_id * WARP_K * kHeadDimV_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
index_topk_1
=
index_ptr
[
32
+
n_loop_real
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
32
+
n_loop_real
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
((
k_loop
+
4
)
&
15
)
*
16
+
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((lds_stage_id * WARP_K * kHeadDimV_OPT + 32 * 16 + (WARP_NUM + warp_id) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
lds_stage_id
^=
1
;
}
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
flash
::
raise_priority
();
{
constexpr
int
min_tile_k
=
0
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
1
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
min_tile_k
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
].
f32
);
}
}
}
flash
::
lower_priority
();
}
// 预取Q K
if
constexpr
(
PREFETCH_K
)
{
prefetch_q_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
warp_id
,
seqlen_q_stride
,
max_seq_q_offset
);
prefetch_k_to_lds_mls_ds_576_512
<
kHeadDim
,
kBlockK
,
kBlockN
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
max_seq_kv_offset
-
kBlockK
);
}
}
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_pv_gemm_utils_mls_ds.h
0 → 100644
View file @
a1eef562
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic_mls_ds.h"
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_v_to_lds_mls_ds_576_512
(
vec4_uint
v_ptr
,
Element
*
v_lds
,
int
warp_id
,
int
seqlen_v_stride
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
kHeadDim_OPT
=
256
;
// 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
int
n_load
=
0
;
vec4_uint
v_srsrc
;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_ptr
+
(
n_load
*
WARP_NUM
*
32
+
warp_id
*
32
)
*
ELEMENT_BYTES
);
v_srsrc
[
2
]
=
seqlen_v_stride
;
if
constexpr
(
true
)
{
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
0
*
WARP_K
+
32
-
max_seq_kv_offset
);
v_srsrc
[
3
]
=
max_seq_kv_offset
%
kBlockK
==
0
?
0
:
nm_filter
<<
8
;
}
int
lds_stage_id
=
0
;
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDim_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_v_to_lds_mls_ds_576_512_buffer_load
(
vec4_uint
v_ptr
,
Element
*
v_lds
,
int
warp_id
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
*
block_table
,
int
batch_stride
,
int
n_loop
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
kHeadDim_OPT
=
256
;
// 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
// int n_load = 0;
// vec4_uint v_srsrc;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
// v_srsrc[2] = seqlen_v_stride;
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_stage_id = 0;
// int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
tid
=
threadIdx
.
x
%
64
;
int
index_topk_1
=
index_ptr
[
n_loop
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_1
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_1
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
warp_id
*
16
+
tid
%
4
*
4
+
block_table
[
index_topk_2
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk_2
%
128
)
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage
(
vec4_uint
v_ptr
,
Element
*
v_lds
,
int
warp_id
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
n_loop
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
kHeadDim_OPT
=
256
;
// 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
// int n_load = 0;
// vec4_uint v_srsrc;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
// v_srsrc[2] = seqlen_v_stride;
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_stage_id = 0;
// int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
tid
=
threadIdx
.
x
%
64
;
int
index_topk_1
=
index_ptr
[
n_loop
*
64
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[
n_loop
*
64
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage_64
(
vec4_uint
v_ptr
,
Element
*
v_lds
,
int
warp_id
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
n_loop
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
kHeadDim_OPT
=
256
;
// 32x32 x WARP_NUM x 2B x stage == 32K
// MLS
// int n_load = 0;
// vec4_uint v_srsrc;
// *(uint64_t*)&v_srsrc = VA_LIMIT_BITS(*(uint64_t*)&v_ptr + (n_load * WARP_NUM * 32 + warp_id * 32) * ELEMENT_BYTES);
// v_srsrc[2] = seqlen_v_stride;
// if constexpr (true) {
// int nm_filter = inline_min_max<0, 32>(0 * WARP_K + 32 - max_seq_kv_offset);
// v_srsrc[3] = max_seq_kv_offset % kBlockK == 0 ? 0: nm_filter << 8;
// }
// int lds_stage_id = 0;
// int lds_offset = (lds_stage_id * WARP_K * kHeadDim_OPT + warp_id * 32 * 32) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived(); // 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
// inline_matrix_load_32x32_b16_lds<0, 1>(v_lds, v_srsrc, lds_offset, 0);
int
tid
=
threadIdx
.
x
%
64
;
int
index_topk_1
=
index_ptr
[((
n_loop
*
64
))
+
(
tid
/
4
)];
int
index_topk_2
=
index_ptr
[((
n_loop
*
64
))
+
(
tid
/
4
)
+
16
];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_1
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
int
lds_offset_add
=
__builtin_amdgcn_readfirstlane
((
32
*
16
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v_add
=
warp_id
*
16
+
tid
%
4
*
4
+
index_topk_2
*
seqlen_v_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset
,
0
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
v_lds
,
v_ptr
,
lds_offset_add
,
0
,
g_offset_v_add
);
// g_offset_v += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// g_offset_v_add += WARP_NUM * 32 * ELEMENT_BYTES / 4;
// lds_offset = __builtin_amdgcn_readfirstlane(((warp_id + WARP_NUM) * 32 * 32) * ELEMENT_BYTES / 4);
// lds_offset_add = __builtin_amdgcn_readfirstlane((32 * 16 + (warp_id + WARP_NUM) * 32 * 32) * ELEMENT_BYTES / 4);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset, 0, g_offset_v);
// inline_buffer_load_dwordx4_lds(v_lds, v_ptr, lds_offset_add, 0, g_offset_v_add);
__builtin_amdgcn_sched_barrier
(
0
);
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_qk_gemm_prefetch_v_mls_ds.h
0 → 100644
View file @
a1eef562
#pragma once
#include "mla_pv_gemm_utils_mls_ds.h"
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
*
block_table
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
2
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)];
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
q_lds_base
=
reinterpret_cast
<
size_t
>
(
q_lds
);
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// MLS
vec4_uint
q_srsrc
;
vec4_uint
k_srsrc
;
q_srsrc
[
2
]
=
__seqlen_q_stride
;
if
constexpr
(
Is_FlashMLA
)
{
k_srsrc
[
2
]
=
seqlen_k_stride
;
}
else
{
k_srsrc
[
2
]
=
seqlen_v_stride
;
}
q_srsrc
[
3
]
=
0
;
k_srsrc
[
3
]
=
0
;
int
q_stage_id
=
0
;
int
k_stage_id
=
0
;
if
constexpr
(
STAGES
==
2
)
{
q_stage_id
^=
1
;
k_stage_id
^=
1
;
}
{
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
// k预取的标志位
int
k_even
=
((
k_loop
&
1
)
==
0x0
)
?
1
:
0
;
{
uint64_t
q_base_addr
;
int
seqlen_q_stride
;
int
kloop_true
;
if
constexpr
(
Is_FlashMLA
)
{
q_srsrc
[
2
]
=
__seqlen_q_stride
;
q_base_addr
=
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
__seqlen_q_stride
;
kloop_true
=
k_loop
;
}
else
{
q_srsrc
[
2
]
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
q_base_addr
=
(
k_loop
>=
2
)
?
*
(
uint64_t
*
)
&
qv_ptr
:
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
kloop_true
=
(
k_loop
>=
2
)
?
(
k_loop
-
2
)
:
(
k_loop
);
}
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
q_base_addr
+
(
kloop_true
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
16
*
warp_id
+
16
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
max_seq_q_offset
%
kBlockM
==
0
?
0
:
nm_filter
<<
8
;
int
lds_offset
=
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
if
(
k_even
)
{
k_stage_id
^=
1
;
int
index_topk
=
index_ptr
[
n_loop
*
64
+
warp_id_n
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
k_loop
)
/
2
)
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
+
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
block_table
[
index_topk
/
128
]
*
batch_stride_k
*
ELEMENT_BYTES
/
4
+
(
index_topk
%
128
)
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
0
,
g_offset_v
);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
// 不对称MLS指令
if
(
k_even
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
Q_LOAD_REQUESTS
+
K_LOAD_REQUESTS
);
}
else
{
flash
::
wait_buffer_data_arrived
<
true
>
(
Q_LOAD_REQUESTS
);
}
q_stage_id
^=
1
;
// Q DS
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[
0
].
f16
,
true
);
}
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
constexpr
int
k_loop
=
kHeadDim
/
kBlockK
;
constexpr
int
k_even
=
((
k_loop
&
1
)
==
0x0
)
?
1
:
0
;
if
constexpr
(
k_even
)
{
k_stage_id
^=
1
;
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// Q DS
q_stage_id
^=
1
;
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[
0
].
f16
,
true
);
}
// K DS
k_stage_id
^=
1
;
int
stage_id
=
0
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512_buffer_load
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
index_ptr
,
block_table
,
batch_stride_v
,
n_loop
,
max_seq_k_offset
);
#else
#endif
}
}
// qk_gemm
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
2
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)];
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
q_lds_base
=
reinterpret_cast
<
size_t
>
(
q_lds
);
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// MLS
vec4_uint
q_srsrc
;
vec4_uint
k_srsrc
;
q_srsrc
[
2
]
=
__seqlen_q_stride
;
if
constexpr
(
Is_FlashMLA
)
{
k_srsrc
[
2
]
=
seqlen_k_stride
;
}
else
{
k_srsrc
[
2
]
=
seqlen_v_stride
;
}
q_srsrc
[
3
]
=
0
;
k_srsrc
[
3
]
=
0
;
int
q_stage_id
=
0
;
int
k_stage_id
=
0
;
if
constexpr
(
STAGES
==
2
)
{
q_stage_id
^=
1
;
k_stage_id
^=
1
;
}
{
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
// k预取的标志位
int
k_even
=
((
k_loop
&
1
)
==
0x0
)
?
1
:
0
;
{
uint64_t
q_base_addr
;
int
seqlen_q_stride
;
int
kloop_true
;
if
constexpr
(
Is_FlashMLA
)
{
q_srsrc
[
2
]
=
__seqlen_q_stride
;
q_base_addr
=
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
__seqlen_q_stride
;
kloop_true
=
k_loop
;
}
else
{
q_srsrc
[
2
]
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
q_base_addr
=
(
k_loop
>=
2
)
?
*
(
uint64_t
*
)
&
qv_ptr
:
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
kloop_true
=
(
k_loop
>=
2
)
?
(
k_loop
-
2
)
:
(
k_loop
);
}
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
q_base_addr
+
(
kloop_true
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
16
*
warp_id
+
16
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
max_seq_q_offset
%
kBlockM
==
0
?
0
:
nm_filter
<<
8
;
int
lds_offset
=
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
if
(
k_even
)
{
k_stage_id
^=
1
;
int
index_topk
=
index_ptr
[
n_loop
*
64
+
warp_id_n
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
k_loop
)
/
2
)
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
+
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
index_topk
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
0
,
g_offset_v
);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
}
// 不对称MLS指令
if
(
k_even
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
Q_LOAD_REQUESTS
+
K_LOAD_REQUESTS
);
}
else
{
flash
::
wait_buffer_data_arrived
<
true
>
(
Q_LOAD_REQUESTS
);
}
q_stage_id
^=
1
;
// Q DS
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[
0
].
f16
,
true
);
}
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
constexpr
int
k_loop
=
kHeadDim
/
kBlockK
;
constexpr
int
k_even
=
((
k_loop
&
1
)
==
0x0
)
?
1
:
0
;
if
constexpr
(
k_even
)
{
k_stage_id
^=
1
;
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// Q DS
q_stage_id
^=
1
;
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[
0
].
f16
,
true
);
}
// K DS
k_stage_id
^=
1
;
int
stage_id
=
0
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
k_even
*
32
*
64
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
index_ptr
,
batch_stride_v
,
n_loop
,
max_seq_k_offset
);
#else
#endif
}
}
// qk_gemm
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
32
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
1
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)];
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
q_lds_base
=
reinterpret_cast
<
size_t
>
(
q_lds
);
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// MLS
vec4_uint
q_srsrc
;
vec4_uint
k_srsrc
;
q_srsrc
[
2
]
=
__seqlen_q_stride
;
if
constexpr
(
Is_FlashMLA
)
{
k_srsrc
[
2
]
=
seqlen_k_stride
;
}
else
{
k_srsrc
[
2
]
=
seqlen_v_stride
;
}
q_srsrc
[
3
]
=
0
;
k_srsrc
[
3
]
=
0
;
int
q_stage_id
=
0
;
int
k_stage_id
=
0
;
if
constexpr
(
STAGES
==
2
)
{
q_stage_id
^=
1
;
k_stage_id
^=
1
;
}
{
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
{
uint64_t
q_base_addr
;
int
seqlen_q_stride
;
int
kloop_true
;
if
constexpr
(
Is_FlashMLA
)
{
q_srsrc
[
2
]
=
__seqlen_q_stride
;
q_base_addr
=
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
__seqlen_q_stride
;
kloop_true
=
k_loop
;
}
else
{
q_srsrc
[
2
]
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
q_base_addr
=
(
k_loop
>=
2
)
?
*
(
uint64_t
*
)
&
qv_ptr
:
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
kloop_true
=
(
k_loop
>=
2
)
?
(
k_loop
-
2
)
:
(
k_loop
);
}
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
q_base_addr
+
(
kloop_true
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
16
*
warp_id
+
16
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
max_seq_q_offset
%
kBlockM
==
0
?
0
:
nm_filter
<<
8
;
int
lds_offset
=
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
int
index_topk
=
index_ptr
[
n_loop
*
64
+
warp_id_n
*
16
+
(
tid
/
4
)];
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
k_loop
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
+
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
index_topk
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
0
,
g_offset_v
);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// 不对称MLS指令
flash
::
wait_buffer_data_arrived
<
true
>
(
Q_LOAD_REQUESTS
+
K_LOAD_REQUESTS
);
q_stage_id
^=
1
;
// Q DS
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[
0
].
f16
,
true
);
}
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// Q DS
q_stage_id
^=
1
;
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[
0
].
f16
,
true
);
}
// K DS
k_stage_id
^=
1
;
int
stage_id
=
0
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage_64
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
index_ptr
,
batch_stride_v
,
n_loop
,
max_seq_k_offset
);
#else
#endif
}
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
Element
*
v_faker
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
64
)
*
(
kHeadDim
/
kBlockK
)],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
32
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
1
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// MLS
vec4_uint
k_srsrc
;
if
constexpr
(
Is_FlashMLA
)
{
k_srsrc
[
2
]
=
seqlen_k_stride
;
}
else
{
k_srsrc
[
2
]
=
seqlen_v_stride
;
}
k_srsrc
[
3
]
=
0
;
int
k_stage_id
=
0
;
int
index_topk
=
index_ptr
[(
n_loop
*
64
)
+
warp_id_n
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
index_topk
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
((
kHeadDim
/
kBlockK
)
-
1
)
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
if
constexpr
(
STAGES
==
2
)
{
k_stage_id
^=
1
;
}
{
// int index_topk = index_ptr[(n_loop * 64) + warp_id_n * 16 + (tid / 4)];
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
index_topk
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
for
(
int
k_loop
=
(
kHeadDim
/
kBlockK
)
-
2
;
k_loop
>=
0
;
--
k_loop
)
{
{
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
k_loop
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// 不对称MLS指令
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
k_loop
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
k_loop
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
k_loop
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
k_loop
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// K DS
k_stage_id
^=
1
;
int
stage_id
=
0
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// int abc[1];
// int index_topk = index_ptr[(((n_loop+1) % 16) * 64) + warp_id * 16];
// int offset_m = index_topk * seqlen_v_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(v_faker + offset_m));
// inline_s_load_dword(abc[0], g_abc, 0);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
flash
::
lower_priority
();
}
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__) || defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
// prefetch_v_to_lds_mls_ds_576_512_buffer_load_nopage_64<kHeadDimV, kBlockM, kBlockK, kBlockN, WARP_M, kBlockK, Element, Is_even_MN>(v_ptr, v_lds, warp_id, seqlen_v_stride, index_ptr, batch_stride_v, n_loop, max_seq_k_offset);
#else
#endif
}
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_666
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
Element
*
v_faker
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)
*
(
kHeadDim
/
kBlockK
)],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
64
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
1
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// MLS
vec4_uint
k_srsrc
;
if
constexpr
(
Is_FlashMLA
)
{
k_srsrc
[
2
]
=
seqlen_k_stride
;
}
else
{
k_srsrc
[
2
]
=
seqlen_v_stride
;
}
k_srsrc
[
3
]
=
0
;
int
k_stage_id
=
0
;
int
index_topk
=
index_ptr
[(
n_loop
*
64
)
+
warp_id_n
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
+
32
*
64
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
index_topk
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
((
kHeadDim
/
kBlockK
)
-
1
)
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
;
int
g_offset_s_2
=
((
kHeadDim
/
kBlockK
)
-
1
)
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
+
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
if
constexpr
(
STAGES
==
2
)
{
k_stage_id
^=
1
;
}
{
// int index_topk = index_ptr[(n_loop * 64) + warp_id_n * 16 + (tid / 4)];
// g_offset_v = ((4 - (tid / 8) % 4) * 4 + tid % 4 * 4) % 16 + index_topk * seqlen_k_stride * ELEMENT_BYTES / 4;
for
(
int
k_loop
=
(
kHeadDim
/
kBlockK
)
-
2
;
k_loop
>=
0
;
--
k_loop
)
{
{
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
+
32
*
64
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
k_loop
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
;
g_offset_s_2
=
k_loop
*
kHeadDim_OPT
*
ELEMENT_BYTES
/
4
+
warp_id_m
*
16
+
16
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
// k_stage_id ^= 1;
// int nm_filter = inline_min_max<0,16>(16 * warp_id_n + 16 - max_seq_k_offset);
// if constexpr (Is_FlashMLA) {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_k_stride + ((k_loop) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// } else {
// *(uint64_t*)&k_srsrc = VA_LIMIT_BITS(*(uint64_t*)&k_ptr + (warp_id_m * 32 + warp_id_n * 16 * seqlen_v_stride + ((k_loop - 2) / 2) * kHeadDim_OPT) * ELEMENT_BYTES);
// }
// k_srsrc[3] = (max_seq_k_offset % kBlockN == 0x0 ? 0: nm_filter) << 8;
// int lds_offset = (k_stage_id * WARP_N * kHeadDim_OPT + warp_id * 32 * 16) * ELEMENT_BYTES;
// flash::wait_all_warp_arrived();
// inline_matrix_load_32x16_b16_lds_trans<0, 0>(k_lds, k_srsrc, lds_offset, 0);
}
// 不对称MLS指令
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
// flash::raise_priority();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
// flash::raise_priority();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[(
k_loop
+
1
)
*
2
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// K DS
k_stage_id
^=
1
;
int
stage_id
=
0
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
0
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
1
*
32
*
32
+
32
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
// flash::raise_priority();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
// flash::raise_priority();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// int abc[2];
// int index_topk = index_ptr[(n_loop * 64)];
// int index_topk2 = index_ptr[(n_loop * 64) + 8];
// int offset_m = index_topk * seqlen_k_stride;
// int offset_m2 = index_topk2 * seqlen_k_stride;
// auto g_abc = (reinterpret_cast<uint64_t>(v_faker + offset_m));
// auto g_abc2 = (reinterpret_cast<uint64_t>(v_faker + offset_m2));
// inline_s_load_dword(abc[0], g_abc, 0);
// inline_s_load_dword(abc[1], g_abc2, 0);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
stage_id
][
min_tile_n
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
stage_id
][
min_tile_n
].
f32
);
}
}
// flash::lower_priority();
}
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_777
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
Element
*
v_faker
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)
*
(
kHeadDim
/
kBlockK
)],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
64
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
64
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
int
tid
=
threadIdx
.
x
%
64
;
int
index_topk
[
4
];
index_topk
[
0
]
=
index_ptr
[(
n_loop
*
64
)
+
0
*
16
+
(
tid
/
4
)];
index_topk
[
1
]
=
index_ptr
[(
n_loop
*
64
)
+
1
*
16
+
(
tid
/
4
)];
index_topk
[
2
]
=
index_ptr
[(
n_loop
*
64
)
+
2
*
16
+
(
tid
/
4
)];
index_topk
[
3
]
=
index_ptr
[(
n_loop
*
64
)
+
3
*
16
+
(
tid
/
4
)];
int
fallback_index
=
index_ptr
[
0
]
==
-
1
?
0
:
index_ptr
[
0
];
index_topk
[
0
]
=
(
index_topk
[
0
]
==
-
1
)
?
fallback_index
:
index_topk
[
0
];
index_topk
[
1
]
=
(
index_topk
[
1
]
==
-
1
)
?
fallback_index
:
index_topk
[
1
];
index_topk
[
2
]
=
(
index_topk
[
2
]
==
-
1
)
?
fallback_index
:
index_topk
[
2
];
index_topk
[
3
]
=
(
index_topk
[
3
]
==
-
1
)
?
fallback_index
:
index_topk
[
3
];
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
#pragma unroll 1
for
(
int
i
=
3
;
i
>=
0
;
i
--
)
{
int
k_stage_id
=
0
;
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
lds_offset_2
;
int
index_block
=
index_topk
[
i
]
/
page_block_size
;
int
index_offset
=
index_topk
[
i
]
-
index_block
*
page_block_size
;
int
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
(
index_block
*
batch_stride_k
+
index_offset
*
seqlen_k_stride
)
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
512
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
int
g_offset_s_2
;
flash
::
wait_all_warp_arrived
();
if
(
warp_id
<
2
){
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
}
if
constexpr
(
STAGES
==
2
)
{
k_stage_id
^=
1
;
}
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
g_offset_s_2
=
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
+
64
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
16
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
17
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
0
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
g_offset_s_2
=
0
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
+
64
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
128
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
192
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
// flash::raise_priority();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
8
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
9
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
10
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
11
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
// flash::raise_priority();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
12
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
13
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
14
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
15
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
k_stage_id
^=
1
;
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
128
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
192
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
// flash::raise_priority();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
2
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
3
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
// flash::raise_priority();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
4
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
5
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// flash::raise_priority();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
6
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
7
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
// flash::lower_priority();
}
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_888
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
Element
*
v_faker
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)
*
(
kHeadDim
/
kBlockK
)],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
64
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
512
)
?
64
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
int
tid
=
threadIdx
.
x
%
64
;
int
index_topk
[
4
];
index_topk
[
0
]
=
index_ptr
[(
n_loop
*
64
)
+
0
*
16
+
(
tid
/
4
)];
index_topk
[
1
]
=
index_ptr
[(
n_loop
*
64
)
+
1
*
16
+
(
tid
/
4
)];
index_topk
[
2
]
=
index_ptr
[(
n_loop
*
64
)
+
2
*
16
+
(
tid
/
4
)];
index_topk
[
3
]
=
index_ptr
[(
n_loop
*
64
)
+
3
*
16
+
(
tid
/
4
)];
index_topk
[
0
]
=
(
index_topk
[
0
]
==
-
1
)
?
0
:
index_topk
[
0
];
index_topk
[
1
]
=
(
index_topk
[
1
]
==
-
1
)
?
0
:
index_topk
[
1
];
index_topk
[
2
]
=
(
index_topk
[
2
]
==
-
1
)
?
0
:
index_topk
[
2
];
index_topk
[
3
]
=
(
index_topk
[
3
]
==
-
1
)
?
0
:
index_topk
[
3
];
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
k_stage_id
=
0
;
int
stage_id
;
int
index_block
=
index_topk
[
3
]
/
page_block_size
;
int
index_offset
=
index_topk
[
3
]
-
index_block
*
page_block_size
;
int
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
(
index_block
*
batch_stride_k
+
index_offset
*
seqlen_k_stride
)
*
ELEMENT_BYTES
/
4
;
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_s
=
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
int
g_offset_s_2
=
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
+
64
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
k_stage_id
^=
1
;
// #pragma unroll 1
for
(
int
i
=
6
;
i
>=
0
;
i
--
)
{
int
score_id
=
(
i
+
1
)
>>
1
;
index_block
=
index_topk
[
i
/
2
]
/
page_block_size
;
index_offset
=
index_topk
[
i
/
2
]
-
index_block
*
page_block_size
;
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
(
index_block
*
batch_stride_k
+
index_offset
*
seqlen_k_stride
)
*
ELEMENT_BYTES
/
4
;
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
(
i
%
2
)
*
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
g_offset_s_2
=
(
i
%
2
)
*
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
+
64
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
128
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
192
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
2
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
3
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
4
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
5
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
6
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[((
i
+
1
)
%
2
)
*
8
+
7
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
score_id
/
2
][
score_id
%
2
].
f32
);
}
}
flash
::
lower_priority
();
}
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
k_stage_id
^=
1
;
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
128
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
192
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
2
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
3
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
4
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
5
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
6
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
0
][
0
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
7
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
0
].
f32
);
}
}
flash
::
lower_priority
();
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_999
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
Element
*
v_faker
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
32
)
*
(
kHeadDim
/
kBlockK
)],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
page_block_size
,
int
n_loop
,
int
max_seq_q_offset
=
0
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
64
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
16
)
*
(
kBlockN
/
32
);
++
i
)
{
for
(
int
j
=
0
;
j
<
2
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0.0
f
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0.0
f
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
int
tid
=
threadIdx
.
x
%
64
;
int
index_topk
[
4
];
index_topk
[
0
]
=
index_ptr
[(
n_loop
*
64
)
+
0
*
16
+
(
tid
/
4
)];
index_topk
[
1
]
=
index_ptr
[(
n_loop
*
64
)
+
1
*
16
+
(
tid
/
4
)];
index_topk
[
2
]
=
index_ptr
[(
n_loop
*
64
)
+
2
*
16
+
(
tid
/
4
)];
index_topk
[
3
]
=
index_ptr
[(
n_loop
*
64
)
+
3
*
16
+
(
tid
/
4
)];
index_topk
[
0
]
=
(
index_topk
[
0
]
==
-
1
)
?
0
:
index_topk
[
0
];
index_topk
[
1
]
=
(
index_topk
[
1
]
==
-
1
)
?
0
:
index_topk
[
1
];
index_topk
[
2
]
=
(
index_topk
[
2
]
==
-
1
)
?
0
:
index_topk
[
2
];
index_topk
[
3
]
=
(
index_topk
[
3
]
==
-
1
)
?
0
:
index_topk
[
3
];
// 准备 q,k 寄存器
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
32
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 q_lds,k_lds 的起始偏移量
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
int
k_stage_id
=
0
;
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
lds_offset_2
;
int
index_block
=
index_topk
[
3
]
/
page_block_size
;
int
index_offset
=
index_topk
[
3
]
-
index_block
*
page_block_size
;
int
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
(
index_block
*
batch_stride_k
+
index_offset
*
seqlen_k_stride
)
*
ELEMENT_BYTES
/
4
;
int
g_offset_s
=
512
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
int
g_offset_s_2
;
flash
::
wait_all_warp_arrived
();
if
(
warp_id
<
2
)
{
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
}
if
constexpr
(
STAGES
==
2
)
{
k_stage_id
^=
1
;
}
{
#pragma unroll 1
for
(
int
i
=
3
;
i
>=
0
;
i
--
)
{
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
g_offset_s_2
=
256
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
+
64
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
int
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
16
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
17
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
lds_offset_2
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
+
128
*
16
)
*
ELEMENT_BYTES
/
4
);
g_offset_s
=
0
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
g_offset_s_2
=
0
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
+
64
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset_2
,
g_offset_s_2
,
g_offset_v
);
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
k_stage_id
^=
1
;
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
128
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
192
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
8
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
9
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
10
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
11
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
12
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
13
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
14
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
15
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
if
(
i
!=
0
){
lds_offset
=
__builtin_amdgcn_readfirstlane
((
k_stage_id
*
16
*
256
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
index_block
=
index_topk
[
i
-
1
]
/
page_block_size
;
index_offset
=
index_topk
[
i
-
1
]
-
index_block
*
page_block_size
;
g_offset_v
=
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
(
index_block
*
batch_stride_k
+
index_offset
*
seqlen_k_stride
)
*
ELEMENT_BYTES
/
4
;
g_offset_s
=
512
*
ELEMENT_BYTES
/
4
+
warp_id
*
16
;
flash
::
wait_all_warp_arrived
();
if
(
warp_id
<
2
){
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
g_offset_s
,
g_offset_v
);
flash
::
wait_buffer_data_arrived
<
true
>
(
1
);
}
else
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
}
else
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
k_stage_id
^=
1
;
stage_id
=
0
;
// K DS
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
64
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// K DS PRE
stage_id
^=
1
;
{
int
k_lds_load_offset
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
128
)
*
ELEMENT_BYTES
;
int
k_lds_load_offset_2
=
k_lds_base
+
(
k_stage_id
*
16
*
256
+
16
*
192
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
k_lds_load_offset_2
,
k_reg
[
4
+
stage_id
*
2
].
f16
,
k_reg
[
4
+
stage_id
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
6
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
0
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
1
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
4
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
2
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
3
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
flash
::
raise_priority
();
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
4
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
5
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
flash
::
raise_priority
();
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
6
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
int
k_tile_id
=
4
+
stage_id
*
2
+
min_tile_n
;
s_reg
[
i
/
2
][
i
%
2
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
7
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
i
/
2
][
i
%
2
].
f32
);
}
}
flash
::
lower_priority
();
}
}
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
,
bool
Is_FlashMLA
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_576_512_nopage_64_new_q
(
vec4_uint
qv_ptr
,
vec4_uint
q_ptr
,
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
WARP_M
*
kBlockK
)
/
(
16
*
64
)
*
(
kHeadDim
/
kBlockK
)],
int
warp_id
,
int
seqlen_qv_stride
,
int
__seqlen_q_stride
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
*
index_ptr
,
int
batch_stride_k
,
int
batch_stride_v
,
int
max_seq_q_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
64
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
16
and
"To simplify, only WARP_M = 16 is supported!"
);
static_assert
(
WARP_N
==
64
and
"To simplify, only WARP_N = 64 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
kHeadDim_OPT
=
32
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
K_LDS_LOAD_NUM
=
(
kHeadDim_OPT
*
WARP_N
)
/
(
32
*
16
);
constexpr
int
K_LOAD_REQUESTS
=
K_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
1
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
// 计算 q_lds,k_lds 的起始偏移量
int
q_lds_base
=
reinterpret_cast
<
size_t
>
(
q_lds
);
int
tid
=
threadIdx
.
x
%
64
;
// MLS
vec4_uint
q_srsrc
;
vec4_uint
q_srsrc2
;
q_srsrc
[
2
]
=
__seqlen_q_stride
;
q_srsrc2
[
2
]
=
__seqlen_q_stride
;
q_srsrc
[
3
]
=
0
;
q_srsrc2
[
3
]
=
0
;
int
q_stage_id
=
0
;
if
constexpr
(
STAGES
==
2
)
{
q_stage_id
^=
1
;
}
{
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
{
uint64_t
q_base_addr
;
int
seqlen_q_stride
;
int
kloop_true
;
if
constexpr
(
Is_FlashMLA
)
{
q_srsrc
[
2
]
=
__seqlen_q_stride
;
q_base_addr
=
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
__seqlen_q_stride
;
kloop_true
=
k_loop
;
}
else
{
q_srsrc
[
2
]
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
q_base_addr
=
(
k_loop
>=
2
)
?
*
(
uint64_t
*
)
&
qv_ptr
:
*
(
uint64_t
*
)
&
q_ptr
;
seqlen_q_stride
=
(
k_loop
>=
2
)
?
seqlen_qv_stride
:
__seqlen_q_stride
;
kloop_true
=
(
k_loop
>=
2
)
?
(
k_loop
-
2
)
:
(
k_loop
);
}
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
q_base_addr
+
(
kloop_true
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
*
(
uint64_t
*
)
&
q_srsrc2
=
VA_LIMIT_BITS
(
q_base_addr
+
(
kloop_true
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
+
32
)
*
ELEMENT_BYTES
);
// int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
// q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
int
lds_offset
=
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
)
*
ELEMENT_BYTES
;
int
lds_offset2
=
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
+
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc2
,
lds_offset2
,
0
);
}
// 不对称MLS指令
flash
::
wait_buffer_data_arrived
<
true
>
(
Q_LOAD_REQUESTS
);
q_stage_id
^=
1
;
// Q DS
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
)
*
ELEMENT_BYTES
;
int
q_lds_load_offset2
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
+
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[(
k_loop
-
1
)
*
2
+
0
].
f16
,
true
);
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset2
,
q_reg
[(
k_loop
-
1
)
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
}
// 等回最后的q_panel
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// Q DS
q_stage_id
^=
1
;
{
int
q_lds_load_offset
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
)
*
ELEMENT_BYTES
;
int
q_lds_load_offset2
=
q_lds_base
+
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
+
16
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset
,
q_reg
[(
kHeadDim
/
kBlockK
-
1
)
*
2
+
0
].
f16
,
true
);
DS_READ_MATRIX_32X16_B16
(
q_lds_load_offset2
,
q_reg
[(
kHeadDim
/
kBlockK
-
1
)
*
2
+
1
].
f16
,
true
);
}
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
}
}
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_qk_gemm_utils_mls_ds.h
0 → 100644
View file @
a1eef562
#pragma once
#include "intrinsic_mls_ds.h"
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_q_to_lds_mls_ds_576_512
(
vec4_uint
q_ptr
,
Element
*
q_lds
,
int
warp_id
,
int
seqlen_q_stride
,
int
max_seq_q_offset
=
0
)
{
// 编译期可知变量
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
Q_LDS_LOAD_NUM
=
(
kBlockM
*
kBlockK
)
/
(
16
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
// LDS 起始地址
int
q_lds_base
=
reinterpret_cast
<
size_t
>
(
q_lds
);
// MLS
vec4_uint
q_srsrc
;
vec4_uint
q_srsrc2
;
q_srsrc
[
2
]
=
seqlen_q_stride
;
q_srsrc
[
3
]
=
0
;
q_srsrc2
[
2
]
=
seqlen_q_stride
;
q_srsrc2
[
3
]
=
0
;
int
stage_id
=
0
;
{
int
k_loop
=
0
;
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_ptr
+
(
k_loop
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
*
(
uint64_t
*
)
&
q_srsrc2
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_ptr
+
(
k_loop
*
kBlockK
+
warp_id
*
16
*
seqlen_q_stride
+
32
)
*
ELEMENT_BYTES
);
// if constexpr (true) {
// int nm_filter = inline_min_max<0,16>(16 * warp_id + 16 - max_seq_q_offset);
// q_srsrc[3] = max_seq_q_offset % kBlockM == 0 ? 0: nm_filter << 8;
// }
int
lds_offset
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
)
*
ELEMENT_BYTES
;
int
lds_offset2
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
64
+
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// pvgemm 完成后会发射q,k的预取,避免有的warp还没完成,即规避读V写Q/K,造成数据覆盖
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc2
,
lds_offset2
,
0
);
}
}
template
<
int
kHeadDim
,
int
kBlockN
,
int
kBlockK
,
int
WARP_NUM
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_k_to_lds_mls_ds_576_512
(
vec4_uint
k_ptr
,
Element
*
k_lds
,
int
warp_id
,
int
seqlen_k_stride
,
int
max_seq_k_offset
=
0
)
{
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
2
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
int
stage_id
=
0
;
int
n_loop
=
0
;
int
k_loop
=
0
;
// MLS
vec4_uint
k_srsrc
;
k_srsrc
[
2
]
=
seqlen_k_stride
;
if
constexpr
(
true
)
{
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
n_loop
*
WARP_N
+
16
*
warp_id_n
+
16
-
max_seq_k_offset
);
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_ptr
+
(
n_loop
*
WARP_N
*
seqlen_k_stride
+
warp_id_m
*
32
+
warp_id_n
*
16
*
seqlen_k_stride
+
k_loop
*
32
*
WARP_NUM_M
)
*
ELEMENT_BYTES
);
k_srsrc
[
3
]
=
max_seq_k_offset
%
kBlockN
==
0x0
?
0
:
nm_filter
<<
8
;
}
int
lds_offset
=
(
stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
inline_matrix_load_32x16_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
template
<
int
kHeadDim
,
int
kBlockN
,
int
kBlockK
,
int
WARP_NUM
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_k_to_lds_mls_ds_576_512_buffer_load
(
vec4_uint
k_ptr
,
Element
*
k_lds
,
int
warp_id
,
int
seqlen_k_stride
,
int
*
index_ptr
,
int
*
block_table
,
int
batch_stride
,
int
n_loop
,
int
max_seq_k_offset
=
0
)
{
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
2
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
int
tid
=
threadIdx
.
x
%
64
;
int
stage_id
=
0
;
int
k_loop
=
0
;
int
index_topk
=
index_ptr
[
n_loop
*
64
+
warp_id_n
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id_m
*
16
+
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
block_table
[
index_topk
/
128
]
*
batch_stride
*
ELEMENT_BYTES
/
4
+
(
index_topk
%
128
)
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
0
,
g_offset_v
);
}
template
<
int
kHeadDim
,
int
kBlockN
,
int
kBlockK
,
int
WARP_NUM
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_k_to_lds_mls_ds_576_512_buffer_load_nopage
(
vec4_uint
k_ptr
,
Element
*
k_lds
,
int
warp_id
,
int
seqlen_k_stride
,
int
*
index_ptr
,
int
batch_stride
,
int
n_loop
,
int
max_seq_k_offset
=
0
)
{
constexpr
int
kHeadDim_OPT
=
(
kHeadDim
==
576
)
?
64
:
kHeadDim
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
constexpr
int
WARP_NUM_M
=
2
;
constexpr
int
WARP_NUM_N
=
4
;
int
warp_id_m
=
warp_id
/
WARP_NUM_N
;
int
warp_id_n
=
warp_id
%
WARP_NUM_N
;
int
tid
=
threadIdx
.
x
%
64
;
int
stage_id
=
0
;
int
k_loop
=
0
;
int
index_topk
=
index_ptr
[(
n_loop
*
64
)
&
1023
+
warp_id_n
*
16
+
(
tid
/
4
)];
int
lds_offset
=
__builtin_amdgcn_readfirstlane
((
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
/
4
);
int
g_offset_v
=
warp_id_m
*
16
+
((
4
-
(
tid
/
8
)
%
4
)
*
4
+
tid
%
4
*
4
)
%
16
+
index_topk
*
seqlen_k_stride
*
ELEMENT_BYTES
/
4
;
flash
::
wait_all_warp_arrived
();
inline_buffer_load_dwordx4_lds
(
k_lds
,
k_ptr
,
lds_offset
,
0
,
g_offset_v
);
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_softmax_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "philox.cuh"
#include "fwd/utils.h"
using
namespace
flash
;
template
<
int
THREADS
,
typename
DataType
=
union_vec2_fp32
>
struct
PrefillMlaAllreduce
{
static_assert
(
THREADS
==
64
);
template
<
typename
Operator
>
static
__device__
inline
DataType
run
(
DataType
x
,
Operator
&
op
)
{
DataType
res
;
if
constexpr
(
std
::
is_same
<
DataType
,
union_vec2_fp32
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Operator
,
SumOp
<
float
>
>::
value
)
{
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
res
.
f32
[
0
]
=
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
res
.
f32
[
1
]
=
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
);
x
.
u64
=
__builtin_hcu_pk_add_f32
(
x
.
u64
,
res
.
u64
);
res
.
f32
[
0
]
=
__shfl_xor_tmp
(
x
.
f32
[
0
],
16
);
res
.
f32
[
1
]
=
__shfl_xor_tmp
(
x
.
f32
[
1
],
16
);
res
.
u64
=
__builtin_hcu_pk_add_f32
(
res
.
u64
,
x
.
u64
);
#else
x
.
f32
[
0
]
=
x
.
f32
[
0
]
+
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
x
.
f32
[
1
]
=
x
.
f32
[
1
]
+
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
);
res
.
f32
[
0
]
=
x
.
f32
[
0
]
+
__shfl_xor_tmp
(
x
.
f32
[
0
],
16
);
res
.
f32
[
1
]
=
x
.
f32
[
1
]
+
__shfl_xor_tmp
(
x
.
f32
[
1
],
16
);
#endif
}
else
if
constexpr
(
std
::
is_same
<
Operator
,
MaxOp
<
float
>
>::
value
)
{
x
.
f32
[
0
]
=
op
(
x
.
f32
[
0
],
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
));
x
.
f32
[
1
]
=
op
(
x
.
f32
[
1
],
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
));
res
.
f32
[
0
]
=
op
(
x
.
f32
[
0
],
__shfl_xor_tmp
(
x
.
f32
[
0
],
16
));
res
.
f32
[
1
]
=
op
(
x
.
f32
[
1
],
__shfl_xor_tmp
(
x
.
f32
[
1
],
16
));
}
}
else
{
// union_vec_fp32 f32
if
constexpr
(
std
::
is_same
<
Operator
,
SumOp
<
float
>
>::
value
)
{
x
.
f32
[
0
]
=
x
.
f32
[
0
]
+
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
res
.
f32
[
0
]
=
x
.
f32
[
0
]
+
__shfl_xor_tmp
(
x
.
f32
[
0
],
16
);
}
else
if
constexpr
(
std
::
is_same
<
Operator
,
MaxOp
<
float
>
>::
value
)
{
x
.
f32
[
0
]
=
op
(
x
.
f32
[
0
],
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
));
res
.
f32
[
0
]
=
op
(
x
.
f32
[
0
],
__shfl_xor_tmp
(
x
.
f32
[
0
],
16
));
}
}
return
res
;
}
};
template
<
bool
zero_init
=
true
,
typename
Operator
,
int
OpType
,
typename
DataType0
,
typename
DataType1
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__device__
inline
void
prefill_mla_thread_reduce_max
(
const
DataType0
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
DataType1
*
summary
,
Operator
&
op
,
DataType1
*
summary_cur
=
nullptr
)
{
if
(
zero_init
==
true
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
summary
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
-
INFINITY
;
// OpType:0 is sum operator, 1 is max operator
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
if
constexpr
(
M_MMAC_COUNT
==
2
)
summary
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
else
summary
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]);
}
}
}
}
}
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
summary
[
m_idx
*
2
].
f32
[
min_tile_m
];
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
if
constexpr
(
M_MMAC_COUNT
==
2
)
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
else
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]);
}
}
}
}
}
}
}
template
<
bool
zero_init
=
true
,
typename
Operator
,
int
OpType
,
typename
DataType0
,
typename
DataType1
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__device__
inline
void
prefill_mla_thread_reduce_sum
(
const
DataType0
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
DataType1
*
summary
,
Operator
&
op
,
DataType1
*
summary_cur
=
nullptr
)
{
if
(
zero_init
==
true
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
summary
[
m_idx
*
2
].
u64
=
0x0
;
}
else
{
summary
[
m_idx
*
2
].
f32
[
0
]
=
0x0
;
}
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
){
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary
[
m_idx
*
2
].
u64
=
__builtin_hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
,
additem_pair
);
}
else
{
summary
[
m_idx
*
2
].
f32
[
0
]
=
summary
[
m_idx
*
2
].
f32
[
0
]
+
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
}
#else
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
summary
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
0
;
// OpType:0 is sum operator, 1 is max operator
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
summary
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
}
else
{
summary
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]);
}
}
}
}
}
#endif
}
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
summary_cur
[
m_idx
*
2
].
u64
=
summary
[
m_idx
*
2
].
u64
;
}
else
{
summary_cur
[
m_idx
*
2
].
f32
[
0
]
=
summary
[
m_idx
*
2
].
f32
[
0
];
}
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary_cur
[
m_idx
*
2
].
u64
=
__builtin_hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
,
additem_pair
);
}
else
{
summary_cur
[
m_idx
*
2
].
f32
[
0
]
=
summary_cur
[
m_idx
*
2
].
f32
[
0
]
+
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
}
#else
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
summary
[
m_idx
*
2
].
f32
[
min_tile_m
];
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
}
else
{
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
]
=
op
(
summary_cur
[
m_idx
*
2
].
f32
[
min_tile_m
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]);
}
}
}
}
}
#endif
}
}
}
template
<
typename
Operator
,
typename
DataType
,
int
WARP_M
,
int
M_MMAC_COUNT
=
2
>
__device__
inline
void
prefill_mla_quad_allreduce_
(
DataType
*
dst
,
DataType
*
src
,
Operator
&
op
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
mi
++
)
{
dst
[
mi
]
=
PrefillMlaAllreduce
<
64
,
DataType
>::
run
(
src
[
mi
],
op
);
}
}
template
<
bool
zero_init
=
true
,
typename
Operator
,
int
OpType
,
typename
DataType0
,
typename
DataType1
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__device__
inline
void
prefill_mla_reduce_
(
const
DataType0
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
DataType1
*
summary
,
Operator
&
op
,
DataType1
*
summary_cur
=
nullptr
)
{
if
constexpr
(
OpType
==
0
)
{
// sum
if
constexpr
(
zero_init
==
true
)
{
prefill_mla_thread_reduce_sum
<
true
,
Operator
,
0
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
summary
,
op
);
prefill_mla_quad_allreduce_
<
Operator
,
DataType1
,
WARP_M
,
M_MMAC_COUNT
>
(
summary
,
summary
,
op
);
}
else
{
prefill_mla_thread_reduce_sum
<
false
,
Operator
,
0
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
summary
,
op
,
summary_cur
);
prefill_mla_quad_allreduce_
<
Operator
,
DataType1
,
WARP_M
,
M_MMAC_COUNT
>
(
summary_cur
,
summary_cur
,
op
);
}
}
else
if
constexpr
(
OpType
==
1
)
{
// max
if
constexpr
(
zero_init
==
true
)
{
prefill_mla_thread_reduce_max
<
true
,
Operator
,
1
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
summary
,
op
);
prefill_mla_quad_allreduce_
<
Operator
,
DataType1
,
WARP_M
,
M_MMAC_COUNT
>
(
summary
,
summary
,
op
);
}
else
{
prefill_mla_thread_reduce_max
<
false
,
Operator
,
1
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
summary
,
op
,
summary_cur
);
prefill_mla_quad_allreduce_
<
Operator
,
DataType1
,
WARP_M
,
M_MMAC_COUNT
>
(
summary_cur
,
summary_cur
,
op
);
}
}
}
// zero_init==true, max is current max_score, max_cur=nullptr
// zero_init==false, max is prev max_score, max_cur!=nullptr
template
<
bool
zero_init
=
true
,
typename
DataType0
,
typename
DataType1
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__device__
inline
void
reduce_max
(
const
DataType0
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
DataType1
*
max
,
DataType1
*
max_cur
=
nullptr
)
{
MaxOp
<
float
>
max_op
;
if
constexpr
(
zero_init
==
true
)
{
prefill_mla_reduce_
<
true
,
MaxOp
<
float
>
,
1
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
max
,
max_op
);
}
else
{
prefill_mla_reduce_
<
false
,
MaxOp
<
float
>
,
1
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
max
,
max_op
,
max_cur
);
}
}
template
<
bool
zero_init
=
true
,
typename
DataType0
,
typename
DataType1
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__device__
inline
void
reduce_sum
(
DataType0
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
DataType1
*
sum
,
DataType1
*
sum_cur
=
nullptr
){
SumOp
<
float
>
sum_op
;
if
constexpr
(
zero_init
==
true
)
{
prefill_mla_reduce_
<
true
,
SumOp
<
float
>
,
0
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
sum
,
sum_op
);
}
else
{
prefill_mla_reduce_
<
false
,
SumOp
<
float
>
,
0
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
tensor
,
sum
,
sum_op
,
sum_cur
);
}
}
// Apply the exp to all the elements.
template
<
bool
Scale_max
=
true
,
typename
DataType0
,
typename
DataType1
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
inline
__device__
void
scale_apply_exp2
(
DataType0
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
const
DataType1
*
max
,
const
float
scale
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
float
max_scaled
=
(
max
[
mi
*
2
].
f32
[
min_tile_m
]
==
-
INFINITY
)
?
0.
f
:
(
max
[
mi
*
2
].
f32
[
min_tile_m
]
*
(
Scale_max
?
scale
:
float
(
M_LOG2E
)));
__float2
neg_max_scaled_pair
=
{
-
max_scaled
,
-
max_scaled
};
__float2
scale_pair
=
{
scale
,
scale
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
}
else
{
mmac_id
=
min_tile_n
;
}
int
qk_tile_id
=
mi
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
__builtin_hcu_pk_fma_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scale_pair
,
neg_max_scaled_pair
);
}
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]);
}
#else
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]
*
scale
-
max_scaled
);
}
#endif
}
}
}
}
}
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
DataType0
,
typename
DataType1
,
int
K
/*head_dim_v*/
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
inline
__device__
void
prefill_mla_softmax_rescale_o
(
DataType0
scores
[(
WARP_N
/
32
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))][
2
*
M_MMAC_COUNT
],
DataType1
*
scores_max
,
DataType1
*
scores_sum
,
DataType0
acc_o
[(
K
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
],
float
softmax_scale_log2
)
{
if
constexpr
(
Is_first
)
{
reduce_max
<
/*zero_init=*/
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_max
);
scale_apply_exp2
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_max
,
softmax_scale_log2
);
reduce_sum
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_sum
);
}
else
{
DataType1
scores_max_cur
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))];
reduce_max
<
/*zero_init=*/
false
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_max
,
scores_max_cur
);
// scores_max is prev scores max
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
// If max is -inf, then all elements must have been -inf (possibly due to masking).
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
float
scores_max_cur_reg
=
!
Check_inf
?
scores_max_cur
[
mi
*
2
].
f32
[
min_tile_m
]
:
(
scores_max_cur
[
mi
*
2
].
f32
[
min_tile_m
]
==
-
INFINITY
?
0.0
f
:
scores_max_cur
[
mi
*
2
].
f32
[
min_tile_m
]);
float
scores_scale
=
__llvm_exp2_f32
((
scores_max
[
mi
*
2
].
f32
[
min_tile_m
]
-
scores_max_cur_reg
)
*
softmax_scale_log2
);
scores_sum
[
mi
*
2
].
f32
[
min_tile_m
]
*=
scores_scale
;
__float2
scores_scale_pair
=
{
scores_scale
,
scores_scale
};
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
K
/
kBlockK
);
pv_n_loop
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockK
/
32
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
mi
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
}
else
{
mmac_id
=
min_tile_n
;
}
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scores_scale_pair
);
}
#else
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
f32
[
vec_idx
]
*=
scores_scale
;
}
#endif
}
}
}
}
}
scale_apply_exp2
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_max_cur
,
softmax_scale_log2
);
DataType1
scores_sum_cur
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))];
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
scores_sum_cur
[
mi
].
u64
=
0x0
;
}
else
{
scores_sum_cur
[
mi
].
f32
[
0
]
=
0x0
;
}
}
reduce_sum
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_sum_cur
);
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
scores_sum
[
mi
].
u64
=
__builtin_hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
}
else
{
scores_sum
[
mi
].
f32
[
0
]
=
scores_sum
[
mi
].
f32
[
0
]
+
scores_sum_cur
[
mi
].
f32
[
0
];
}
#else // for perf-model, add listed below will be optimized as v_fmac_f32, leading to incorrect results
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
scores_sum
[
mi
].
f32
[
0
]
+=
scores_sum_cur
[
mi
].
f32
[
0
];
scores_sum
[
mi
].
f32
[
1
]
+=
scores_sum_cur
[
mi
].
f32
[
1
];
}
else
{
scores_sum
[
mi
].
f32
[
0
]
+=
scores_sum_cur
[
mi
].
f32
[
0
];
}
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__))
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
inlineasm_fa_v_mov_b64
(
scores_max
[
mi
].
u64
,
scores_max_cur
[
mi
].
u64
);
}
else
{
scores_max
[
mi
].
f32
[
0
]
=
scores_max_cur
[
mi
].
f32
[
0
];
}
#else
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
scores_max
[
mi
].
f32
[
0
]
=
scores_max_cur
[
mi
].
f32
[
0
];
scores_max
[
mi
].
f32
[
1
]
=
scores_max_cur
[
mi
].
f32
[
1
];
}
else
{
scores_max
[
mi
].
f32
[
0
]
=
scores_max_cur
[
mi
].
f32
[
0
];
}
#endif
}
}
};
// #define USE_CVT_PKRTZ_FP16_FP32
template
<
int
WARP_M
,
int
WARP_N
,
typename
Element
,
typename
ElementAccum
,
int
M_MMAC_COUNT
=
2
>
inline
__device__
void
prefill_mla_convert_pk_type
(
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
union_vec4_fp32
s_reg
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
])
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
}
else
{
p_reg
[
n_idx
*
(
WARP_M
/
16
)
+
m_idx
][
0
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
16
)
+
m_idx
][
0
].
f32x2
[
min_tile_k
]);
p_reg
[
n_idx
*
(
WARP_M
/
16
)
+
m_idx
][
1
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
16
)
+
m_idx
][
1
].
f32x2
[
min_tile_k
]);
}
#else
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
0
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
0
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
1
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
1
]);
#endif
}
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
inline
__device__
void
prefill_mla_apply_mask_gfx938
(
DataType
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
// if (col_idx >= max_seqlen_k) {
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
else
{
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]
=
(
col_idx
>=
max_seqlen_k
)
?-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
// }
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
inline
__device__
void
decode_dsa_apply_mask_gfx938
(
DataType
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
int
*
index_ptr
,
const
int
col_idx_offset_
=
0
,
int
real_topk
=
512
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
// if (col_idx >= max_seqlen_k) {
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
else
{
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]
=
((
col_idx
>=
real_topk
)
||
(
index_ptr
[
col_idx
%
1024
]
==
-
1
))
?-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
// }
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
inline
__device__
void
prefill_dsa_apply_mask_gfx938
(
DataType
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
int
*
index_ptr
,
const
int
col_idx_offset_
=
0
,
int
real_topk
=
512
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
// if (col_idx >= max_seqlen_k) {
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
else
{
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]
=
((
col_idx
>=
real_topk
)
||
(
index_ptr
[
col_idx
%
1024
]
==
-
1
))
?-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
// }
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
prefill_mla_apply_mask_causal_gfx938
(
DataType
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
(
16
*
M_MMAC_COUNT
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
else
{
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
,
int
M_MMAC_COUNT
=
2
>
__forceinline__
__device__
void
prefill_mla_apply_mtp_mask_causal_gfx938
(
DataType
tensor
[(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
WARP_N
/
32
)][
2
*
M_MMAC_COUNT
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
)
+
max_seqlen_k
-
max_seqlen_q
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
(
16
*
M_MMAC_COUNT
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
row_idx
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
else
{
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]
=
(
col_idx
>
row_idx
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
}
}
}
}
template
<
typename
DataType
,
int
kBlockN
,
int
WARP_M
,
int
WARP_NUM
>
__forceinline__
__device__
void
flashmla_apply_mtp_mask_causal_gfx938
(
DataType
s_reg
[(
WARP_M
/
16
)
*
(
kBlockN
/
32
)][
2
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
ngroups
,
const
int
mtp
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
constexpr
int
mi
=
0
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
int
row_idx
=
row_idx_offset_
+
(
lane_id
&
15
);
int
row_in_mtp
=
row_idx
/
ngroups
;
int
col_idx_limit_right
=
min
(
max_seqlen_k
,
row_in_mtp
+
max_seqlen_k
-
mtp
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
kBlockN
/
32
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_offset_
+
ni
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
)
*
4
+
vec_idx
;
/*BMZ vec_idx * 4 + (lane_id >> 4) */
s_reg
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
s_reg
[
mi
+
ni
*
(
WARP_M
/
16
)][
min_tile_n
].
f32
[
vec_idx
];
}
}
}
}
}
template
<
bool
HasWSLeft
=
true
,
typename
DataType
,
int
WARP_M
,
int
WARP_N
>
inline
__device__
void
prefill_mla_apply_mask_local_gfx938
(
DataType
tensor
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
||
(
HasWSLeft
&&
col_idx
<
(
col_idx_limit_left
-
1
)))
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
>
inline
__device__
void
prefill_mla_apply_alibi_gfx938
(
DataType
tensor
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
float
g_alibi
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
+=
g_alibi
*
(
col_idx
-
row_idx
);
}
}
}
}
}
}
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_tp8_epilogue_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "numeric_types.h"
template
<
typename
Params
,
int
kHeadDimV
,
int
kHeadDimVSplit
,
bool
Interleave2
,
bool
Split
,
typename
SplitkvAccumType
,
typename
ElementAccum
,
int
kBlockM
,
int
kBlockK
,
int
WARP_NUM
,
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
>
__forceinline__
__device__
void
mla_tp8_epilogue_store_output_gfx938
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
Params
params
,
int
bidb
,
int
bidh
,
int
m_block
,
int
split_id
,
int
headdim_split_id
,
int
warp_id
,
int
lane_id
)
{
int
o_row_stride
=
params
.
o_row_stride
;
const
int64_t
row_offset_o
=
bidb
*
int64_t
(
params
.
o_batch_stride
)
+
bidh
*
params
.
o_head_stride
+
headdim_split_id
*
kHeadDimVSplit
;
SplitkvAccumType
*
o_ptr
=
Split
?
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
oaccum_ptr
)
+
row_offset_o
+
/*which split*/
split_id
*
params
.
b
*
params
.
o_batch_stride
:
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
o_ptr
)
+
row_offset_o
;
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
K_LOOP_COUNT
;
k_loop
+=
4
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
M_WARP_COUNT
;
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
K_WARP_COUNT
;
++
k_tile_idx
)
{
// which 32x32 tile
int
tile_32x32_id
=
k_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
warp_m_idx
*
K_WARP_COUNT
+
k_tile_idx
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// index along seqlen_q dimension
int
seqlen_q_idx
=
m_block
*
kBlockM
+
warp_m_idx
*
32
+
pv_lane_seq_idx
+
min_tile_m
*
16
;
if
(
seqlen_q_idx
<
params
.
seqlen_q
)
{
if
constexpr
(
Interleave2
)
{
/*contiguous 64 bytes storation*/
union_vec4_f16x2
<
SplitkvAccumType
>
v_data
;
v_data
.
f16x2
[
0
+
0
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
0
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
0
]);
v_data
.
f16x2
[
1
+
0
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
1
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
1
]);
v_data
.
f16x2
[
0
+
1
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
2
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
2
]);
v_data
.
f16x2
[
1
+
1
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
3
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
3
]);
int
pv_global_addr
=
seqlen_q_idx
*
o_row_stride
+
(
k_loop
+
warp_id
)
*
kBlockK
+
k_tile_idx
*
32
+
pv_lane_head_dim_idx
*
8
;
*
(
vec4_fp32
*
)(
o_ptr
+
pv_global_addr
)
=
v_data
.
f32
;
}
else
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
union_vec2_f16x2
<
SplitkvAccumType
>
data
;
int
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
data
.
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
mmac_id
].
f32x2
[
vec_index
]);
}
int
pv_global_addr
=
seqlen_q_idx
*
o_row_stride
+
(
k_loop
+
warp_id
)
*
kBlockK
+
k_tile_idx
*
32
+
pv_lane_head_dim_idx
*
4
+
min_tile_n
*
16
;
*
(
union_vec2_f16x2
<
SplitkvAccumType
>*
)(
o_ptr
+
pv_global_addr
)
=
data
;
}
}
}
}
}
}
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h
0 → 100644
View file @
a1eef562
#pragma once
#include "intrinsic_mls_ds.h"
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefetch_q_to_vgpr_gfx938_with_initialization
(
vec4_uint
q_addr
,
Element
*
q_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
int
warp_id
,
int
query_seqlen_stride
,
int
max_seq_q_offset
,
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
],
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDimV
/
kBlockK
][
4
])
{
flash
::
wait_all_warp_arrived
();
// prepare mls buffer resource registers
vec4_uint
q_srsrc
;
q_srsrc
[
2
]
=
query_seqlen_stride
;
q_srsrc
[
3
]
=
0
;
// total 16x576 f16s
// 16x128 f16s per wave first
constexpr
int
LOAD
=
4
;
constexpr
int
block32x16_bytes
=
32
*
16
*
sizeof
(
Element
);
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
LOAD
;
++
load_id
)
{
// lds address
int
lds_offset_bytes
=
(
load_id
*
WARP_NUM
+
warp_id
)
*
block32x16_bytes
;
// global offset
int
q_warp_offset
=
(
load_id
*
WARP_NUM
+
warp_id
)
*
32
;
// compute global address
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
q_warp_offset
*
sizeof
(
Element
));
// matrix load
__builtin_amdgcn_sched_barrier
(
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
// insert valus in def-use
attention_initialize
<
kHeadDimV
/
kBlockK
,
WARP_M
/
32
,
kBlockK
/
32
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
// fetch data from lds, from MID-th blocks
const
int
MID
=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
MID
;
++
load_id
)
{
// wait global data written to lds
flash
::
wait_buffer_data_arrived
<
true
/*sync*/
>
(
LOAD
-
load_id
-
1
);
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_NUM
;
++
i
)
{
DS_READ_MATRIX_32X16_B16
(
load_id
*
WARP_NUM
*
block32x16_bytes
+
i
*
block32x16_bytes
,
q_reg
[(
load_id
*
4
+
i
)
*
2
].
f16
,
true
);
}
}
// -------------------------------------------------------------------
// prefetch rest 16x64 loads
// 16x32 f16s 0-1 wave later
int
lds_offset_bytes
=
(
LOAD
*
WARP_NUM
+
warp_id
)
*
block32x16_bytes
;
int
real_warp_id
=
warp_id
>=
2
?
0
:
warp_id
;
int
q_warp_offset
=
(
LOAD
*
WARP_NUM
+
real_warp_id
)
*
32
;
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
q_warp_offset
*
sizeof
(
Element
));
__builtin_amdgcn_sched_barrier
(
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// continue from MID
#pragma unroll
for
(
int
load_id
=
MID
;
load_id
<
LOAD
;
++
load_id
)
{
// wait global data written to lds
flash
::
wait_buffer_data_arrived
<
true
/*sync*/
>
(
LOAD
-
load_id
-
1
+
MID
);
#pragma unroll
for
(
int
i
=
0
;
i
<
WARP_NUM
;
++
i
)
{
DS_READ_MATRIX_32X16_B16
(
load_id
*
WARP_NUM
*
block32x16_bytes
+
i
*
block32x16_bytes
,
q_reg
[(
load_id
*
4
+
i
)
*
2
].
f16
,
true
);
}
}
// wait global data written to lds
flash
::
wait_buffer_data_arrived
<
true
/*sync*/
>
(
0
);
// write last data into registers
DS_READ_MATRIX_32X16_B16
((
LOAD
*
WARP_NUM
+
0
)
*
block32x16_bytes
,
q_reg
[(
16
+
0
)
*
2
].
f16
,
true
);
DS_READ_MATRIX_32X16_B16
((
LOAD
*
WARP_NUM
+
1
)
*
block32x16_bytes
,
q_reg
[(
16
+
1
)
*
2
].
f16
,
true
);
// wait all data written to registers
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_acco_reduce.h
0 → 100644
View file @
a1eef562
#include "numeric_types.h"
template
<
int
REUSE_KV_TIMES
,
int
K_LOOP_COUNT
,
int
K_WARP_COUNT
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_acco_reduce
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
ElementAccum
*
acc_o_lds
,
int
seqlen_q
,
int
warp_id
,
int
lane_id
)
{
constexpr
int
kBlockK
=
K_WARP_COUNT
*
32
;
// when REUSE_KV not in templated, compute max reuse times
int
EVEN_REUSE_KV_TIMES
=
(
REUSE_KV_TIMES
>
0
)
?
((
REUSE_KV_TIMES
+
1
)
/
2
)
*
2
:
((
seqlen_q
+
1
)
/
2
)
*
2
;
int
HALF_REUSE_KV_TIMES
=
EVEN_REUSE_KV_TIMES
>>
1
;
int
q_seq_idx
=
(
lane_id
&
15
);
if
(
q_seq_idx
<
HALF_REUSE_KV_TIMES
)
{
// 除以 2, 是因为每个线程都会储存两行的数据, seq 方向上是 0,0,1,1,2,2,3,3,4,4,....,15,15
for
(
int
h_idx
=
0
;
h_idx
<
K_LOOP_COUNT
;
++
h_idx
)
{
// ####################################################################################################################################################
// 4 个 wave 分别把自己负责的 acc_o 计算结果写到 LDS 中
for
(
int
k_idx
=
0
;
k_idx
<
K_WARP_COUNT
;
++
k_idx
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half
int
lds_offset
=
(
warp_id
*
EVEN_REUSE_KV_TIMES
+
q_seq_idx
*
2
+
min_tile_m
)
*
kBlockK
+
k_idx
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
/*0~3*/
)
*
4
/*0~15*/
;
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
)
=
acc_o
[
h_idx
*
(
K_WARP_COUNT
+
k_idx
)
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
;
}
}
}
__syncthreads
();
// ####################################################################################################################################################
// 在 lds 中求和, 把 4 个 wave 写的 acc_o 的数据加起来
// 如果恰好是 4 个 wave, 则 4 个 wave 一起参与到 lds 操作, 每个 wave 操作 4 个元素中的一个
if
constexpr
(
WARP_NUM
==
4
)
{
for
(
int
k_idx
=
0
;
k_idx
<
K_WARP_COUNT
;
++
k_idx
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
lds_offset
=
(
q_seq_idx
*
2
+
min_tile_m
)
*
kBlockK
+
k_idx
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
)
*
4
+
warp_id
;
// 之前是一次性写了 4 个 Half 到 lds, 现在 4 个 wave 分别处理这 4 个位置的 acc_o reduce
float
acc_tmp_wave0
=
acc_o_lds
[
lds_offset
];
for
(
int
loop
=
1
;
loop
<
WARP_NUM
;
++
loop
)
{
acc_tmp_wave0
+=
acc_o_lds
[
lds_offset
+
loop
*
EVEN_REUSE_KV_TIMES
*
kBlockK
];
}
acc_o_lds
[
lds_offset
]
=
acc_tmp_wave0
;
}
}
}
}
// 不是恰好 4 个 wave, 则把 wave 0 单独拎出来做 lds reduce 操作
else
if
constexpr
(
WARP_NUM
>
1
)
{
if
(
warp_id
==
0
)
{
for
(
int
k_idx
=
0
;
k_idx
<
K_WARP_COUNT
;
++
k_idx
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
int
lds_offset
=
(
q_seq_idx
*
2
+
min_tile_m
)
*
kBlockK
+
k_idx
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
)
*
4
+
vec_idx
;
float
acc_tmp_wave0
=
acc_o_lds
[
lds_offset
];
for
(
int
loop
=
1
;
loop
<
WARP_NUM
;
++
loop
)
{
acc_tmp_wave0
+=
acc_o_lds
[
lds_offset
+
loop
*
EVEN_REUSE_KV_TIMES
*
kBlockK
];
}
acc_o_lds
[
lds_offset
]
=
acc_tmp_wave0
;
}
}
}
}
}
}
__syncthreads
();
// ####################################################################################################################################################
// 每个 wave 都从 LDS 获取最终的求和结果
for
(
int
k_idx
=
0
;
k_idx
<
K_WARP_COUNT
;
++
k_idx
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
lds_offset
=
(
q_seq_idx
*
2
+
min_tile_m
)
*
kBlockK
+
k_idx
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
)
*
4
;
acc_o
[
h_idx
*
(
K_WARP_COUNT
+
k_idx
)
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
);
}
}
}
__syncthreads
();
}
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_acco_reduce_tile16x32.h
0 → 100644
View file @
a1eef562
#include "numeric_types.h"
template
<
int
REUSE_KV_TIMES
,
int
K_LOOP_COUNT
,
int
K_WARP_COUNT
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
int
WARP_NUM
,
int
Padding
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_acco_reduce_tile16x32
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
ElementAccum
*
acc_o_lds
,
int
seqlen_q
,
int
warp_id
,
int
lane_id
)
{
constexpr
int
PREFETCH
=
WARP_NUM
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
K_LOOP_COUNT
;
k_loop
+=
PREFETCH
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
prefetch
=
0
;
prefetch
<
PREFETCH
;
++
prefetch
)
{
vec4_fp32
f32x4
=
acc_o
[
k_loop
+
prefetch
][
min_tile_n
*
2
].
f32
;
int
lds_write_offset
=
warp_id
*
2048
+
prefetch
*
2
*
16
*
16
+
min_tile_n
*
16
*
16
;
lds_write_offset
=
reinterpret_cast
<
size_t
>
(
acc_o_lds
+
lds_write_offset
+
lane_id
*
4
);
inlineasm_ds_write_b128
(
lds_write_offset
,
f32x4
);
}
}
union_vec4_fp32
data
[
2
][
WARP_NUM
];
constexpr
int
ds_bursts
=
PREFETCH
;
{
constexpr
int
min_tile_n
=
0
;
flash
::
wait_lds_data_arrived
<
true
>
((
1
-
min_tile_n
)
*
PREFETCH
);
#pragma unroll
for
(
int
neighbor
=
0
;
neighbor
<
PREFETCH
;
++
neighbor
)
{
int
lds_read_offset
=
reinterpret_cast
<
size_t
>
(
acc_o_lds
+
neighbor
*
2048
+
warp_id
*
2
*
16
*
16
+
min_tile_n
*
16
*
16
+
lane_id
*
4
);
inlineasm_ds_read_b128
(
lds_read_offset
,
data
[
min_tile_n
][
neighbor
].
f32
);
}
inline_vgpr2_init_zero
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
b64
[
0
]);
inline_vgpr2_init_zero
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
b64
[
1
]);
}
{
constexpr
int
min_tile_n
=
1
;
flash
::
wait_lds_data_arrived
<
true
>
((
1
-
min_tile_n
)
*
PREFETCH
+
ds_bursts
);
#pragma unroll
for
(
int
neighbor
=
0
;
neighbor
<
PREFETCH
;
++
neighbor
)
{
int
lds_read_offset
=
reinterpret_cast
<
size_t
>
(
acc_o_lds
+
neighbor
*
2048
+
warp_id
*
2
*
16
*
16
+
min_tile_n
*
16
*
16
+
lane_id
*
4
);
inlineasm_ds_read_b128
(
lds_read_offset
,
data
[
min_tile_n
][
neighbor
].
f32
);
}
inline_vgpr2_init_zero
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
b64
[
0
]);
inline_vgpr2_init_zero
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
b64
[
1
]);
}
{
constexpr
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
neighbor
=
0
;
neighbor
<
PREFETCH
;
++
neighbor
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
ds_bursts
-
1
-
neighbor
+
ds_bursts
);
inline_v_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
],
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
],
data
[
min_tile_n
][
neighbor
].
u64
[
0
]);
inline_v_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
],
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
],
data
[
min_tile_n
][
neighbor
].
u64
[
1
]);
}
}
{
constexpr
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
neighbor
=
0
;
neighbor
<
PREFETCH
;
++
neighbor
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
ds_bursts
-
1
-
neighbor
);
inline_v_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
],
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
],
data
[
min_tile_n
][
neighbor
].
u64
[
0
]);
inline_v_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
],
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
],
data
[
min_tile_n
][
neighbor
].
u64
[
1
]);
}
}
flash
::
wait_all_warp_arrived
();
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_epilogue.h
0 → 100644
View file @
a1eef562
#include "numeric_types.h"
template
<
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_epilugue_rescale_acco
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
])
{
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
K_LOOP_COUNT
;
++
pv_n_loop
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
K_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
ElementAccum
sum
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
ElementAccum
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
__float2
scale_pair
=
{
inv_sum
,
inv_sum
};
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
}
#else
for
(
int
vec_id
=
0
;
vec_id
<
4
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
f32
[
vec_id
]
*=
inv_sum
;
}
#endif
}
}
}
}
}
}
template
<
bool
Split
,
bool
Is_16x32
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_tp8_epilogue_store_softmax_lse
(
vec2_Accum
<
ElementAccum
>
scores_max
[
M_WARP_COUNT
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
],
ElementAccum
*
softmax_lse_ptr
,
ElementAccum
scale_softmax
,
int
warp_id
,
int
thread_id
,
int
lane_id
,
int
headdim_split_id
,
int
seqlen_q_limit
)
{
if
constexpr
(
Split
)
{
bool
write_ok
=
Is_16x32
?
(
thread_id
<
16
and
headdim_split_id
==
0
)
:
thread_id
<
16
;
if
(
write_ok
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row
=
Is_16x32
?
mi
*
32
+
lane_id
/*equal to lane_id & 15*/
+
min_tile_m
*
16
:
warp_id
*
M_WARP_COUNT
*
32
+
mi
*
32
+
thread_id
*
2
+
min_tile_m
;
if
(
row
<
seqlen_q_limit
)
{
softmax_lse_ptr
[
row
]
=
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
+
__logf
(
scores_sum
[
mi
].
f32
[
min_tile_m
]);
}
}
}
}
}
}
template
<
typename
Params
,
int
kHeadDimV
,
int
kHeadDimVSplit
,
bool
Split
,
typename
SplitkvAccumType
,
typename
ElementAccum
,
int
kBlockM
,
int
kBlockK
,
int
WARP_NUM
,
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
>
__forceinline__
__device__
void
mla_epilogue_store_output
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
Params
params
,
int
bidb
,
int
bidh
,
int
m_block
,
int
split_id
,
int
headdim_split_id
,
int
warp_id
,
int
lane_id
)
{
int
output_seqlen_stride
=
params
.
o_row_stride
;
const
int64_t
row_offset_o
=
bidb
*
int64_t
(
params
.
o_batch_stride
)
+
bidh
*
params
.
o_head_stride
+
headdim_split_id
*
kHeadDimVSplit
;
SplitkvAccumType
*
o_ptr
=
Split
?
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
oaccum_ptr
)
+
row_offset_o
+
/*which split*/
split_id
*
params
.
b
*
params
.
o_batch_stride
:
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
o_ptr
)
+
row_offset_o
;
auto
gO
=
prepare_for_buffer_load
<
kHeadDimV
,
SplitkvAccumType
,
false
/*USE_CACHE_SWIZZLE*/
>
(
o_ptr
);
int
pv_lane_seq_idx
=
(
lane_id
&
15
);
int
pv_lane_head_dim_idx
=
(
lane_id
>>
4
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
K_LOOP_COUNT
;
++
k_loop
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
M_WARP_COUNT
;
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
K_WARP_COUNT
;
++
k_tile_idx
)
{
// 获取第几个 32x32 tile
int
tile_32x32_id
=
k_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
warp_m_idx
*
K_WARP_COUNT
+
k_tile_idx
;
#pragma unroll 2
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// 当前 32x32 tile 的第几个 mmac
int
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
// seqlen_q 方向上的坐标
int
seqlen_q_idx
=
m_block
*
kBlockM
+
warp_m_idx
*
32
+
pv_lane_seq_idx
*
2
+
min_tile_m
;
if
constexpr
(
WARP_NUM
==
4
)
{
// for 4 waves, storation can be done togather, performance 4%
int
vec_index
=
warp_id
;
int64_t
pv_global_addr
=
seqlen_q_idx
*
output_seqlen_stride
+
k_loop
*
kBlockK
+
k_tile_idx
*
32
+
vec_index
*
8
+
pv_lane_head_dim_idx
*
2
+
min_tile_n
;
ElementAccum
data
=
acc_o
[
tile_32x32_id
][
mmac_id
].
f32
[
vec_index
];
if
(
seqlen_q_idx
<
params
.
seqlen_q
)
{
o_ptr
[
pv_global_addr
]
=
DownCast
<
ElementAccum
,
SplitkvAccumType
>
(
data
);
}
}
else
{
// non-4-waves should use this, but lead to performance drop when 4 waves per SIMD
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
++
vec_index
)
{
int64_t
pv_global_addr
=
seqlen_q_idx
*
output_seqlen_stride
+
k_loop
*
kBlockK
+
k_tile_idx
*
32
+
vec_index
*
8
+
pv_lane_head_dim_idx
*
2
+
min_tile_n
;
ElementAccum
data
=
acc_o
[
tile_32x32_id
][
mmac_id
].
f32
[
vec_index
];
if
(
seqlen_q_idx
<
params
.
seqlen_q
)
{
o_ptr
[
pv_global_addr
]
=
DownCast
<
ElementAccum
,
SplitkvAccumType
>
(
data
);
}
}
}
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)"
);
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_epilogue_tile16x32.h
0 → 100644
View file @
a1eef562
#include "numeric_types.h"
template
<
bool
Split
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_epilogue_store_max_sum_tile16x32
(
vec2_Accum
<
ElementAccum
>
scores_max
[
M_WARP_COUNT
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
],
ElementAccum
*
scores_max_ptr
,
ElementAccum
*
scores_sum_ptr
,
ElementAccum
scale_softmax
,
int
warp_id
,
int
thread_id
,
int
lane_id
,
int
headdim_split_id
,
int
seqlen_q_limit
)
{
#ifdef FA_DEBUG_SUM_MAX
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
true
;
#else
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
false
;
#endif
if
constexpr
(
Split
or
ALLOW_WRITE_SUM_MAX
)
{
if
(
headdim_split_id
==
0
)
{
// 因为 split-D 使用同样的 QK, 计算得到同样的 scores_sum/scores_max 会写多遍, 可能会有数据冲突, 所以强制只写一遍
if
(
thread_id
<
16
)
{
// 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row
=
/*warp_id * WARP_M + */
mi
*
32
+
lane_id
/*equal to lane_id & 15*/
+
min_tile_m
*
16
;
if
(
row
<
seqlen_q_limit
)
{
scores_sum_ptr
[
row
]
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
scores_max_ptr
[
row
]
=
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
;
}
}
}
}
}
}
}
template
<
typename
Params
,
int
kHeadDimV
,
int
kHeadDimVSplit
,
bool
Split
,
typename
SplitkvAccumType
,
typename
ElementAccum
,
int
kBlockM
,
int
kBlockK
,
int
WARP_NUM
,
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
>
__forceinline__
__device__
void
mla_epilogue_store_output_tile16x32
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
Params
params
,
int
bidb
,
int
bidh
,
int
m_block
,
int
split_id
,
int
headdim_split_id
,
int
warp_id
,
int
lane_id
)
{
int
output_seqlen_stride
=
params
.
o_row_stride
;
const
int64_t
row_offset_o
=
bidb
*
int64_t
(
params
.
o_batch_stride
)
+
bidh
*
params
.
o_head_stride
+
headdim_split_id
*
kHeadDimVSplit
;
SplitkvAccumType
*
o_ptr
=
Split
?
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
oaccum_ptr
)
+
row_offset_o
+
/*which split*/
split_id
*
params
.
b
*
params
.
o_batch_stride
:
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
o_ptr
)
+
row_offset_o
;
int
pv_lane_seq_idx
=
(
lane_id
&
15
);
int
pv_lane_head_dim_idx
=
(
lane_id
>>
4
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
K_LOOP_COUNT
;
k_loop
+=
WARP_NUM
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
M_WARP_COUNT
;
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
K_WARP_COUNT
;
++
k_tile_idx
)
{
int
tile_32x32_id
=
k_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
warp_m_idx
*
K_WARP_COUNT
+
k_tile_idx
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
seqlen_q_idx
=
m_block
*
kBlockM
+
warp_m_idx
*
32
+
pv_lane_seq_idx
+
min_tile_m
*
16
;
if
(
seqlen_q_idx
<
params
.
seqlen_q
)
{
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
++
vec_index
)
{
vec2_Element
<
SplitkvAccumType
>
data
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
data
[
min_tile_n
]
=
DownCast
<
ElementAccum
,
SplitkvAccumType
,
true
>
(
acc_o
[
tile_32x32_id
][
mmac_id
].
f32
[
vec_index
]);
}
int64_t
pv_global_addr
=
seqlen_q_idx
*
output_seqlen_stride
+
(
k_loop
+
warp_id
)
*
kBlockK
+
k_tile_idx
*
32
+
vec_index
*
8
+
pv_lane_head_dim_idx
*
2
;
*
(
vec2_Element
<
SplitkvAccumType
>*
)(
o_ptr
+
pv_global_addr
)
=
data
;
}
}
}
}
}
}
}
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_prefix_prefill.h
0 → 100644
View file @
a1eef562
#include "numeric_types.h"
#include "intrinsic.h"
#include "wait.h"
#include "flash.h"
using
namespace
flash
;
template
<
int
WARP_M
,
int
kHeadDimVSplit
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_initialize
(
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
vec4_Accum
<
ElementAccum
>
acc_o
[
WARP_M
/
16
][
kHeadDimVSplit
/
16
]
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
scores_max
[
m_idx
]
=
-
INFINITY
;
scores_sum
[
m_idx
]
=
0.
f
;
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
pv_tile
=
0
;
pv_tile
<
kHeadDimVSplit
/
16
;
++
pv_tile
)
{
acc_o
[
m_idx
][
pv_tile
].
b64
[
0
]
=
0x0
;
acc_o
[
m_idx
][
pv_tile
].
b64
[
1
]
=
0x0
;
}
}
}
template
<
int
kBlockM
,
int
WARP_M
,
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
mla_prefix_prefill_fetch_q_to_vgpr
(
union_vec4_f16x2
<
Element
>
qv_regs
[
WARP_M
/
16
][
8
],
union_vec4_f16x2
<
Element
>
q_regs
[
WARP_M
/
16
],
Element
*
qv_ptr
,
Element
*
q_ptr
,
int
m_block
,
int
warp_id_row
,
int
warp_id_col
,
int
lane_id
,
int
qv_row_stride
,
int
q_row_stride
,
int
actual_seqlen_q
)
{
constexpr
bool
IS_8_WAVES
=
WARP_NUM
==
8
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
8
;
++
load_loop
)
{
int
qv_row
=
min
(
actual_seqlen_q
-
1
-
m_block
*
kBlockM
,
m_idx
*
(
IS_8_WAVES
?
64
:
WARP_M
)
+
warp_id_row
*
16
+
(
lane_id
&
15
));
int
qv_col
=
(
lane_id
>>
4
)
*
8
+
warp_id_col
*
32
+
load_loop
*
64
;
int
qv_buffer_offset
=
qv_row
*
qv_row_stride
+
qv_col
;
qv_regs
[
m_idx
][
load_loop
]
=
*
(
union_vec4_f16x2
<
Element
>*
)(
qv_ptr
+
qv_buffer_offset
);
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
int
q_row
=
min
(
actual_seqlen_q
-
1
-
m_block
*
kBlockM
,
m_idx
*
(
IS_8_WAVES
?
64
:
WARP_M
)
+
warp_id_row
*
16
+
(
lane_id
&
15
));
int
q_col
=
(
lane_id
>>
4
)
*
8
+
warp_id_col
*
32
;
int
q_buffer_offset
=
q_row
*
q_row_stride
+
q_col
;
q_regs
[
m_idx
]
=
*
(
union_vec4_f16x2
<
Element
>*
)(
q_ptr
+
q_buffer_offset
);
}
}
template
<
int
kBlockN
,
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
mla_prefix_prefill_prefetch_k_rope_to_lds
(
Element
*
k_rope_lds
,
vec4_uint
k_buffer
,
int
warp_id
,
int
lane_id
,
int
k_row_stride
,
int
seqlen_kv_limit
)
{
if
constexpr
(
WARP_NUM
==
8
)
{
int
warp_id_row
=
warp_id
&
3
;
int
warp_id_col
=
warp_id
>>
2
;
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
2
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
64
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
k_row_stride
+
k_col
;
int
lds_write_offset
=
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
// 8 * 4 * 16 * 32 * sizeof(fp16) = 32KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
k_rope_lds
,
k_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
constexpr
int
K_LOAD_REQUESTS
=
kBlockN
/
(
16
*
2
);
int
warp_id_row
=
warp_id
>>
1
;
int
warp_id_col
=
warp_id
&
1
;
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
K_LOAD_REQUESTS
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
32
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
k_row_stride
+
k_col
;
int
lds_write_offset
=
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
// 4 * 4 * 16 * 32 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
k_rope_lds
,
k_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
int
WARP_NUM
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_compute_fwd_qk_rope
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
union_vec4_f16x2
<
Element
>
q_regs
[
WARP_M
/
16
],
vec4_uint
k_buffer
,
Element
*
k_rope_lds
,
int
warp_id
,
int
lane_id
,
int
k_row_stride
,
int
seqlen_kv_limit
)
{
if
constexpr
(
WARP_NUM
==
8
)
{
// mla_prefetch_k_rope_to_lds<kBlockN, Element>(k_rope_lds, k_buffer, warp_id, lane_id, k_row_stride, seqlen_kv_limit);
wait_buffer_data_arrived
<
true
>
(
0
);
int
warp_id_col
=
warp_id
>>
2
;
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
(
n_loop
>>
2
)
*
8
*
16
*
32
+
(
n_loop
&
3
)
*
16
*
32
+
warp_id_col
*
4
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
k_rope_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
/*sync*/
>
(
kBlockN
/
16
-
n_loop
-
1
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
q_regs
[
m_idx
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
q_regs
[
m_idx
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
__syncthreads
();
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
// mla_prefetch_k_rope_to_lds<kBlockN, Element>(k_rope_lds, k_buffer, warp_id, lane_id, k_row_stride, seqlen_kv_limit);
wait_buffer_data_arrived
<
true
>
(
0
);
int
warp_id_col
=
warp_id
&
1
;
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
n_loop
*
2
*
16
*
32
+
warp_id_col
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
k_rope_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
/*sync*/
>
(
kBlockN
/
16
-
n_loop
-
1
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
q_regs
[
m_idx
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
q_regs
[
m_idx
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
__syncthreads
();
}
}
template
<
int
kBlockN
,
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
mla_prefix_prefill_prefetch_k_nope_to_lds
(
Element
*
v_lds
,
vec4_uint
v_buffer
,
int
warp_id
,
int
lane_id
,
int
v_row_stride
,
int
seqlen_kv_limit
)
{
if
constexpr
(
WARP_NUM
==
8
)
{
constexpr
int
PREFETCH_K_BLOCKS
=
2
;
constexpr
int
K_LOAD_REQUESTS
=
kBlockN
/
(
16
*
4
);
// 16 * 4 = 64
int
warp_id_row
=
warp_id
&
3
;
int
warp_id_col
=
warp_id
>>
2
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH_K_BLOCKS
;
load_id
+=
2
)
{
#pragma unroll
for
(
int
depth
=
0
;
depth
<
2
;
++
depth
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
K_LOAD_REQUESTS
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
64
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
(
load_id
+
depth
)
*
64
+
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
v_row_stride
+
k_col
;
int
lds_write_offset
=
depth
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
// 2 * 2 * 8 * 16 * 32 * sizeof(fp16) = 32KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
}
}
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
__syncthreads
();
constexpr
int
K_LOAD_REQUESTS
=
kBlockN
/
(
16
*
2
);
int
warp_id_row
=
warp_id
>>
1
;
int
warp_id_col
=
warp_id
&
1
;
int
stage_id
=
0
;
constexpr
int
load_id
=
0
;
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
K_LOAD_REQUESTS
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
32
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
load_id
*
64
+
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
v_row_stride
+
k_col
;
int
lds_write_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
// 4 * 4 * 16 * 32 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
int
WARP_NUM
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_compute_fwd_qk_nope
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
union_vec4_f16x2
<
Element
>
qv_regs
[
WARP_M
/
16
][
8
],
vec4_uint
v_buffer
,
Element
*
v_lds
,
vec4_uint
k_buffer
,
Element
*
k_rope_lds
,
int
warp_id
,
int
lane_id
,
int
v_row_stride
,
int
k_row_stride
,
int
seqlen_kv_limit
)
{
if
constexpr
(
WARP_NUM
==
8
)
{
constexpr
int
PREFETCH_K_BLOCKS
=
2
;
constexpr
int
K_LOAD_REQUESTS
=
kBlockN
/
(
16
*
4
);
int
warp_id_row
=
warp_id
&
3
;
int
warp_id_col
=
warp_id
>>
2
;
// prefetch_k_nope_to_lds<Element>(v_lds, v_buffer, warp_id, lane_id, v_row_stride, seqlen_kv_limit);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
inline_vgpr4_init_zero
(
s_reg
[
m_idx
][
n_loop
]);
}
}
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH_K_BLOCKS
;
load_id
+=
2
)
{
#pragma unroll
for
(
int
depth
=
0
;
depth
<
2
;
++
depth
)
{
wait_buffer_data_arrived
<
true
>
((
2
-
depth
-
1
)
*
K_LOAD_REQUESTS
);
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
depth
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
(
n_loop
>>
2
)
*
WARP_NUM
*
16
*
32
+
(
n_loop
&
3
)
*
16
*
32
+
warp_id_col
*
4
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
v_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
>
(
kBlockN
/
16
-
n_loop
-
1
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
+
depth
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
+
depth
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
}
}
asm
volatile
(
"s_barrier
\n
"
);
// 上面在读 lds, 下面在写 lds, 有数据冲突的隐患
// 提前预取 k_rope 部分的数据, 注意 lds 部分重叠
mla_prefix_prefill_prefetch_k_rope_to_lds
<
kBlockN
,
WARP_NUM
,
Element
>
(
k_rope_lds
,
k_buffer
,
warp_id
,
lane_id
,
k_row_stride
,
seqlen_kv_limit
);
// 接着做剩下的内容
if
constexpr
(
true
)
{
int
stage_id
=
0
;
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
K_LOAD_REQUESTS
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
64
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
PREFETCH_K_BLOCKS
*
64
+
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
v_row_stride
+
k_col
;
int
lds_write_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
}
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
PREFETCH_K_BLOCKS
+
1
;
load_id
<
8
;
load_id
+=
1
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
K_LOAD_REQUESTS
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
64
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
load_id
*
64
+
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
v_row_stride
+
k_col
;
int
lds_write_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
stage_id
^=
1
;
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
(
n_loop
>>
2
)
*
WARP_NUM
*
16
*
32
+
(
n_loop
&
3
)
*
16
*
32
+
warp_id_col
*
4
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
v_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
>
(
kBlockN
/
16
-
n_loop
-
1
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// rest
{
constexpr
int
load_id
=
8
;
wait_buffer_data_arrived
<
true
>
(
0
);
stage_id
^=
1
;
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
stage_id
*
K_LOAD_REQUESTS
*
8
*
16
*
32
+
(
n_loop
>>
2
)
*
8
*
16
*
32
+
(
n_loop
&
3
)
*
16
*
32
+
warp_id_col
*
4
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
v_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
>
(
kBlockN
/
16
-
n_loop
-
1
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
constexpr
int
K_LOAD_REQUESTS
=
kBlockN
/
(
16
*
2
);
int
warp_id_row
=
warp_id
>>
1
;
int
warp_id_col
=
warp_id
&
1
;
int
stage_id
=
0
;
// mla_prefetch_k_nope_to_lds<kBlockN, Element>(v_lds, v_buffer, warp_id, lane_id, v_row_stride, seqlen_kv_limit);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
inline_vgpr4_init_zero
(
s_reg
[
m_idx
][
n_loop
]);
}
}
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
1
;
load_id
<
8
;
++
load_id
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
K_LOAD_REQUESTS
;
++
load_loop
)
{
int
k_row
=
min
(
seqlen_kv_limit
-
1
,
load_loop
*
32
+
warp_id_row
*
16
+
(
lane_id
>>
2
));
int
k_col
=
load_id
*
64
+
warp_id_col
*
32
+
(
lane_id
&
3
)
*
8
;
int
k_buffer_offset
=
k_row
*
v_row_stride
+
k_col
;
int
lds_write_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
load_loop
*
WARP_NUM
*
16
*
32
+
warp_id
*
16
*
32
;
// 4 * 4 * 16 * 32 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
k_buffer_offset
);
}
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
stage_id
^=
1
;
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
n_loop
*
2
*
16
*
32
+
warp_id_col
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
v_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
>
(
kBlockN
/
16
-
n_loop
-
1
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
__syncthreads
();
}
// 预取 rope 部分的 K 数据, 注意 k_rope_lds 和 k_lds 的重叠关系
mla_prefix_prefill_prefetch_k_rope_to_lds
<
kBlockN
,
WARP_NUM
,
Element
>
(
k_rope_lds
,
k_buffer
,
warp_id
,
lane_id
,
k_row_stride
,
seqlen_kv_limit
);
{
int
load_id
=
8
;
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
stage_id
^=
1
;
union_vec4_f16x2
<
Element
>
k_regs
[
kBlockN
/
16
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_wave_offset
=
stage_id
*
K_LOAD_REQUESTS
*
WARP_NUM
*
16
*
32
+
n_loop
*
2
*
16
*
32
+
warp_id_col
*
16
*
32
;
int
lds_tx_offset
=
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
8
;
inlineasm_ds_read_b128
(
reinterpret_cast
<
size_t
>
(
v_lds
+
lds_wave_offset
+
lds_tx_offset
),
k_regs
[
n_loop
]);
}
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
wait_lds_data_arrived
<
false
>
(
kBlockN
/
16
-
n_loop
-
1
);
// 准备做 mmac
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
0
],
k_regs
[
n_loop
].
f16x4
[
0
],
s_reg
[
m_idx
][
n_loop
].
f32
);
s_reg
[
m_idx
][
n_loop
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
qv_regs
[
m_idx
][
load_id
-
1
].
f16x4
[
1
],
k_regs
[
n_loop
].
f16x4
[
1
],
s_reg
[
m_idx
][
n_loop
].
f32
);
}
}
__syncthreads
();
}
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_combine_s_reg_of_2waves
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
ElementAccum
*
s_reg_lds
,
int
warp_id
,
int
lane_id
)
{
constexpr
bool
IS_8_WAVES
=
WARP_NUM
==
8
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
lds_write_offset
=
n_loop
*
WARP_NUM
*
(
64
*
4
)
+
warp_id
*
64
*
4
+
lane_id
*
4
;
*
(
vec4_fp32
*
)(
s_reg_lds
+
lds_write_offset
)
=
s_reg
[
m_idx
][
n_loop
].
f32
;
}
__syncthreads
();
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
int
warp_id_symmetry
=
IS_8_WAVES
?
((
warp_id
>=
4
)
?
warp_id
-
4
:
warp_id
+
4
)
:
((
warp_id
&
1
)
?
warp_id
-
1
:
warp_id
+
1
);
int
lds_load_offset
=
n_loop
*
WARP_NUM
*
(
64
*
4
)
+
warp_id_symmetry
*
64
*
4
+
lane_id
*
4
;
vec4_Accum
<
ElementAccum
>
symmetry_data
=
*
(
vec4_Accum
<
ElementAccum
>*
)(
s_reg_lds
+
lds_load_offset
);
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]
=
__builtin_hcu_pk_add_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
0
],
symmetry_data
.
u64
[
0
]);
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]
=
__builtin_hcu_pk_add_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
1
],
symmetry_data
.
u64
[
1
]);
}
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
kHeadDimVSplit
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_compute_fwd_softmax
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
ElementAccum
scale_softmax_log2
,
vec4_Accum
<
ElementAccum
>
acc_o
[
WARP_M
/
16
][
kHeadDimVSplit
/
16
])
{
ElementAccum
scores_max_cur
[
WARP_M
/
16
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
scores_max_cur
[
m_idx
]
=
scores_max
[
m_idx
];
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
scores_max_cur
[
m_idx
]
=
max
(
scores_max_cur
[
m_idx
],
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
]);
}
}
scores_max_cur
[
m_idx
]
=
max
(
scores_max_cur
[
m_idx
],
__shfl_xor_tmp
(
scores_max_cur
[
m_idx
],
32
));
scores_max_cur
[
m_idx
]
=
max
(
scores_max_cur
[
m_idx
],
__shfl_xor_tmp
(
scores_max_cur
[
m_idx
],
16
));
}
// 做 softmax
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
__float2
max_scaled
;
max_scaled
[
0
]
=
scores_max_cur
[
m_idx
]
==
-
INFINITY
?
0.
f
:
-
scores_max_cur
[
m_idx
]
*
scale_softmax_log2
;
max_scaled
[
1
]
=
max_scaled
[
0
];
__float2
scale_softmax_log2_pair
;
scale_softmax_log2_pair
[
0
]
=
scale_softmax_log2
;
scale_softmax_log2_pair
[
1
]
=
scale_softmax_log2
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]
=
__builtin_hcu_pk_fma_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
0
],
scale_softmax_log2_pair
,
max_scaled
);
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]
=
__builtin_hcu_pk_fma_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
1
],
scale_softmax_log2_pair
,
max_scaled
);
s_reg
[
m_idx
][
n_loop
].
f32
[
0
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
0
]);
s_reg
[
m_idx
][
n_loop
].
f32
[
1
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
1
]);
s_reg
[
m_idx
][
n_loop
].
f32
[
2
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
2
]);
s_reg
[
m_idx
][
n_loop
].
f32
[
3
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
3
]);
}
}
// 求和
ElementAccum
scores_sum_cur
[
WARP_M
/
16
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
__float2
scores_sum_pair
;
scores_sum_pair
[
0
]
=
0
;
scores_sum_pair
[
1
]
=
0
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
scores_sum_pair
=
__builtin_hcu_pk_add_f32
(
scores_sum_pair
,
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]);
scores_sum_pair
=
__builtin_hcu_pk_add_f32
(
scores_sum_pair
,
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]);
}
scores_sum_cur
[
m_idx
]
=
scores_sum_pair
[
0
]
+
scores_sum_pair
[
1
];
scores_sum_cur
[
m_idx
]
=
scores_sum_cur
[
m_idx
]
+
__shfl_xor
(
scores_sum_cur
[
m_idx
],
32
);
scores_sum_cur
[
m_idx
]
=
scores_sum_cur
[
m_idx
]
+
__shfl_xor
(
scores_sum_cur
[
m_idx
],
16
);
}
// 放缩
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
__float2
scores_scale
;
scores_scale
[
0
]
=
__llvm_exp2_f32
(
__llvm_fma_f32
(
scores_max
[
m_idx
],
scale_softmax_log2
,
/*max_scaled[0]*/
-
scores_max_cur
[
m_idx
]
*
scale_softmax_log2
));
scores_scale
[
1
]
=
scores_scale
[
0
];
scores_sum
[
m_idx
]
*=
scores_scale
[
0
];
#pragma unroll
for
(
int
pv_tile
=
0
;
pv_tile
<
kHeadDimVSplit
;
++
pv_tile
)
{
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
],
scores_scale
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
],
scores_scale
);
}
}
// update max/sum
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
scores_sum
[
m_idx
]
+=
scores_sum_cur
[
m_idx
];
scores_max
[
m_idx
]
=
scores_max_cur
[
m_idx
];
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_apply_mask
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
int
lane_id
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
)
{
constexpr
bool
IS_8_WAVES
=
WARP_NUM
==
8
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
int
col_idx_limit_right
=
max_seqlen_k
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_offset_
+
n_loop
*
16
+
vec_idx
*
4
+
(
lane_id
>>
4
);
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
]
=
(
col_idx
>=
col_idx_limit_right
)
?
-
INFINITY
:
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
];
}
}
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_apply_mtp_mask
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
int
lane_id
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
ngroups
,
const
int
mtp
)
{
constexpr
bool
IS_8_WAVES
=
WARP_NUM
==
8
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
int
row_idx
=
row_idx_offset_
+
(
IS_8_WAVES
?
m_idx
*
64
:
m_idx
*
WARP_M
)
+
(
lane_id
&
15
);
int
row_in_mtp
=
row_idx
/
ngroups
;
int
col_idx_limit_right
=
min
(
max_seqlen_k
,
row_in_mtp
+
max_seqlen_k
-
mtp
);
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_offset_
+
n_loop
*
16
+
vec_idx
*
4
+
(
lane_id
>>
4
);
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
];
}
}
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_apply_causal_mask
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][(
kBlockN
/
16
)],
int
lane_id
,
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
)
{
constexpr
bool
IS_8_WAVES
=
WARP_NUM
==
8
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
int
row_idx
=
row_idx_offset_
+
(
IS_8_WAVES
?
m_idx
*
64
:
m_idx
*
WARP_M
)
+
(
lane_id
&
15
);
int
col_idx_limit_right
=
min
(
max_seqlen_k
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_offset_
+
n_loop
*
16
+
vec_idx
*
4
+
(
lane_id
>>
4
);
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
];
}
}
}
}
template
<
int
kBlockN
,
int
WARP_M
,
typename
ElementAccum
,
typename
Element
>
__forceinline__
__device__
void
mla_prefix_prefill_cvt_dtype
(
vec4_Accum
<
ElementAccum
>
s_reg
[
WARP_M
/
16
][
kBlockN
/
16
],
union_vec2_f16x2
<
Element
>
p_reg
[
WARP_M
/
16
][
kBlockN
/
16
])
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
p_reg
[
m_idx
][
n_loop
].
f16x2
[
vec_idx
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
s_reg
[
m_idx
][
n_loop
].
f32x2
[
vec_idx
]);
}
#else
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
p_reg
[
m_idx
][
n_loop
].
f16
[
vec_idx
]
=
DownCast
<
ElementAccum
,
Element
,
false
>
(
s_reg
[
m_idx
][
n_loop
].
f32
[
vec_idx
]);
}
#endif
}
}
}
template
<
int
PREFETCH_V_BLOCKS
,
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
mla_prefix_prefill_prefetch_v_to_lds
(
vec4_uint
v_buffer
,
Element
*
v_lds
,
int
v_row_stride
,
int
warp_id
,
int
lane_id
,
int
seqlen_kv_limit
)
{
if
constexpr
(
WARP_NUM
==
8
)
{
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
PREFETCH_V_BLOCKS
;
n_loop
+=
2
)
{
#pragma unroll
for
(
int
depth
=
0
;
depth
<
2
;
++
depth
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
2
;
++
load_loop
)
{
int
v_row
=
min
(
seqlen_kv_limit
-
1
,
(
n_loop
+
depth
)
*
16
+
(
lane_id
>>
2
));
int
v_col
=
load_loop
*
8
*
32
+
warp_id
*
32
+
(
lane_id
&
3
)
*
8
;
int
v_buffer_offset
=
v_row
*
v_row_stride
+
v_col
;
int
lds_write_offset
=
depth
*
2
*
WARP_NUM
*
512
+
load_loop
*
WARP_NUM
*
512
+
warp_id
*
512
;
// 2 * 2 * 8 * 512 * sizeof(half) = 32KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
v_buffer_offset
);
}
}
}
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
__syncthreads
();
constexpr
int
V_LOAD_REQUESTS
=
PREFETCH_V_BLOCKS
;
// union
int
warp_id_col
=
warp_id
&
1
;
int
stage_id
=
0
;
constexpr
int
n_loop
=
0
;
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
V_LOAD_REQUESTS
;
++
load_loop
)
{
int
v_row
=
min
(
seqlen_kv_limit
-
1
,
n_loop
*
16
+
(
lane_id
>>
2
));
int
v_col
=
load_loop
*
4
*
32
+
warp_id
*
32
+
(
lane_id
&
3
)
*
8
;
int
v_buffer_offset
=
v_row
*
v_row_stride
+
v_col
;
int
lds_write_offset
=
stage_id
*
V_LOAD_REQUESTS
*
WARP_NUM
*
512
+
load_loop
*
WARP_NUM
*
512
+
warp_id
*
512
;
// 4 * 4 * 512 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
v_buffer_offset
);
}
}
}
template
<
bool
PREFETCH_K
,
int
PREFETCH_V_BLOCKS
,
int
kBlockN
,
int
WARP_M
,
int
WARP_NUM
,
int
kHeadDimVSplit
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_compute_fwd_pv
(
vec4_Accum
<
ElementAccum
>
acc_o
[
WARP_M
/
16
][
kHeadDimVSplit
/
16
],
union_vec2_f16x2
<
Element
>
p_reg
[
WARP_M
/
16
][
kBlockN
/
16
],
vec4_uint
v_buffer
,
Element
*
v_lds
,
int
warp_id
,
int
lane_id
,
int
v_row_stride
,
int
seqlen_kv_limit
,
int
v_buffer_offset
)
{
if
constexpr
(
WARP_NUM
==
8
)
{
wait_buffer_data_arrived
<
true
>
(
0
);
constexpr
int
V_LOAD_REQUESTS
=
2
;
int
warp_id_col
=
warp_id
>>
2
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
PREFETCH_V_BLOCKS
;
n_loop
+=
2
)
{
#pragma unroll
for
(
int
depth
=
0
;
depth
<
2
;
++
depth
)
{
// lds -> vgprs
union_vec4_f16x2
<
Element
>
v_regs
[
kHeadDimVSplit
/
32
];
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
int
v_load_base_offset
=
depth
*
V_LOAD_REQUESTS
*
WARP_NUM
*
512
+
warp_id_col
*
8
*
512
+
v_tile
*
512
;
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
int
v_load_offset
=
v_load_base_offset
+
i
*
8
*
32
+
(
lane_id
>>
4
)
*
32
+
(
lane_id
&
15
)
*
2
;
inline_ds_read2_b32_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
v_lds
+
v_load_offset
),
v_regs
[
v_tile
].
f16x4
[
i
],
64
);
}
}
// pv mmac
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
wait_lds_data_arrived
<
false
>
((
kHeadDimVSplit
/
32
-
v_tile
-
1
)
*
2
);
// v interleave into vgprs
union_vec4_f16x2
<
Element
>
v_composed
;
v_composed
.
f16x4
[
0
]
=
make_vec4_f16
<
Element
>
(
v_regs
[
v_tile
].
f16
[
0
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
1
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
2
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
3
*
2
+
0
]);
v_composed
.
f16x4
[
1
]
=
make_vec4_f16
<
Element
>
(
v_regs
[
v_tile
].
f16
[
0
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
1
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
2
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
3
*
2
+
1
]);
// pv mmac
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
+
depth
].
f16x4
,
v_composed
.
f16x4
[
0
],
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
);
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
+
depth
].
f16x4
,
v_composed
.
f16x4
[
1
],
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
);
}
}
}
}
asm
volatile
(
"s_barrier
\n
"
);
// 上面在读 lds, 下面在写 lds, 有数据冲突的隐患
// 做没预取的部分, 还需要重新取数据
if
constexpr
(
true
)
{
int
stage_id
=
0
;
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
V_LOAD_REQUESTS
;
++
load_loop
)
{
int
v_row
=
min
(
seqlen_kv_limit
-
1
,
PREFETCH_V_BLOCKS
*
16
+
(
lane_id
>>
2
));
int
v_col
=
load_loop
*
8
*
32
+
warp_id
*
32
+
(
lane_id
&
3
)
*
8
;
int
v_buffer_offset
=
v_row
*
v_row_stride
+
v_col
;
int
lds_write_offset
=
stage_id
*
V_LOAD_REQUESTS
*
WARP_NUM
*
512
+
load_loop
*
WARP_NUM
*
512
+
warp_id
*
512
;
// 2 * 2 * 8 * 512 * sizeof(half) = 32KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
v_buffer_offset
);
}
}
stage_id
^=
1
;
#pragma unroll
for
(
int
n_loop
=
PREFETCH_V_BLOCKS
+
1
;
n_loop
<
kBlockN
/
16
;
n_loop
+=
1
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
V_LOAD_REQUESTS
;
++
load_loop
)
{
int
v_row
=
min
(
seqlen_kv_limit
-
1
,
n_loop
*
16
+
(
lane_id
>>
2
));
int
v_col
=
load_loop
*
8
*
32
+
warp_id
*
32
+
(
lane_id
&
3
)
*
8
;
int
v_buffer_offset
=
v_row
*
v_row_stride
+
v_col
;
int
lds_write_offset
=
stage_id
*
V_LOAD_REQUESTS
*
WARP_NUM
*
512
+
load_loop
*
WARP_NUM
*
512
+
warp_id
*
512
;
// 2 * 2 * 8 * 512 * sizeof(half) = 32KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
v_buffer_offset
);
}
stage_id
^=
1
;
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
// lds -> vgprs
union_vec4_f16x2
<
Element
>
v_regs
[
kHeadDimVSplit
/
32
];
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
int
v_load_base_offset
=
stage_id
*
V_LOAD_REQUESTS
*
WARP_NUM
*
512
+
warp_id_col
*
8
*
512
+
v_tile
*
512
;
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
int
v_load_offset
=
v_load_base_offset
+
i
*
8
*
32
+
(
lane_id
>>
4
)
*
32
+
(
lane_id
&
15
)
*
2
;
inline_ds_read2_b32_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
v_lds
+
v_load_offset
),
v_regs
[
v_tile
].
f16x4
[
i
],
64
);
}
}
// pv mmac
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
wait_lds_data_arrived
<
false
>
((
kHeadDimVSplit
/
32
-
v_tile
-
1
)
*
2
);
// v interleave into vgprs
union_vec4_f16x2
<
Element
>
v_composed
;
v_composed
.
f16x4
[
0
]
=
make_vec4_f16
<
Element
>
(
v_regs
[
v_tile
].
f16
[
0
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
1
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
2
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
3
*
2
+
0
]);
v_composed
.
f16x4
[
1
]
=
make_vec4_f16
<
Element
>
(
v_regs
[
v_tile
].
f16
[
0
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
1
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
2
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
3
*
2
+
1
]);
// pv mmac
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
0
],
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
);
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
1
],
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
// rest
{
constexpr
int
n_loop
=
kBlockN
/
16
;
stage_id
^=
1
;
wait_buffer_data_arrived
<
true
>
(
0
);
// lds -> vgprs
union_vec4_f16x2
<
Element
>
v_regs
[
kHeadDimVSplit
/
32
];
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
int
v_load_base_offset
=
stage_id
*
V_LOAD_REQUESTS
*
WARP_NUM
*
512
+
warp_id_col
*
8
*
512
+
v_tile
*
512
;
#pragma unroll
for
(
int
i
=
0
;
i
<
2
;
++
i
)
{
int
v_load_offset
=
v_load_base_offset
+
i
*
8
*
32
+
(
lane_id
>>
4
)
*
32
+
(
lane_id
&
15
)
*
2
;
inline_ds_read2_b32_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
v_lds
+
v_load_offset
),
v_regs
[
v_tile
].
f16x4
[
i
],
64
);
}
}
// pv mmac
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
wait_lds_data_arrived
<
false
>
((
kHeadDimVSplit
/
32
-
v_tile
-
1
)
*
2
);
// v interleave into vgprs
union_vec4_f16x2
<
Element
>
v_composed
;
v_composed
.
f16x4
[
0
]
=
make_vec4_f16
<
Element
>
(
v_regs
[
v_tile
].
f16
[
0
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
1
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
2
*
2
+
0
],
v_regs
[
v_tile
].
f16
[
3
*
2
+
0
]);
v_composed
.
f16x4
[
1
]
=
make_vec4_f16
<
Element
>
(
v_regs
[
v_tile
].
f16
[
0
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
1
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
2
*
2
+
1
],
v_regs
[
v_tile
].
f16
[
3
*
2
+
1
]);
// pv mmac
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
0
],
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
);
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
1
],
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
constexpr
int
V_LOAD_REQUESTS
=
4
;
// mla_prefetch_v_to_lds<V_LOAD_REQUESTS, Element>(v_buffer, v_lds, v_row_stride, warp_id, lane_id, seqlen_kv_limit);
int
stage_id
=
1
;
int
warp_id_col
=
warp_id
&
1
;
#pragma unroll
for
(
int
n_loop
=
1
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#pragma unroll
for
(
int
load_loop
=
0
;
load_loop
<
V_LOAD_REQUESTS
;
++
load_loop
)
{
int
v_row
=
min
(
seqlen_kv_limit
-
1
,
n_loop
*
16
+
(
lane_id
>>
2
));
int
v_col
=
load_loop
*
4
*
32
+
warp_id
*
32
+
(
lane_id
&
3
)
*
8
;
int
v_buffer_offset
=
v_row
*
v_row_stride
+
v_col
;
int
lds_write_offset
=
stage_id
*
4
*
4
*
512
+
load_loop
*
4
*
512
+
warp_id
*
512
;
// 4 * 4 * 512 * sizeof(fp16) = 16KB
safe_inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
v_lds
,
v_buffer
,
lds_write_offset
,
0
,
v_buffer_offset
);
}
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
stage_id
^=
1
;
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
// lds -> vgprs
union_vec4_f16x2
<
Element
>
v_regs
;
int
v_load_base_offset
=
stage_id
*
4
*
4
*
512
+
warp_id_col
*
8
*
512
+
v_tile
*
512
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
int
v_load_offset
=
v_load_base_offset
+
i
*
4
*
32
+
(
lane_id
>>
4
)
*
32
+
(
lane_id
&
15
)
*
2
;
v_regs
.
f16x2
[
i
]
=
*
(
vec2_Element
<
Element
>*
)(
v_lds
+
v_load_offset
);
}
// v regs interleave
union_vec4_f16x2
<
Element
>
v_composed
;
v_composed
.
f16x4
[
0
]
=
make_vec4_f16
<
Element
>
(
v_regs
.
f16
[
0
*
2
+
0
],
v_regs
.
f16
[
1
*
2
+
0
],
v_regs
.
f16
[
2
*
2
+
0
],
v_regs
.
f16
[
3
*
2
+
0
]);
v_composed
.
f16x4
[
1
]
=
make_vec4_f16
<
Element
>
(
v_regs
.
f16
[
0
*
2
+
1
],
v_regs
.
f16
[
1
*
2
+
1
],
v_regs
.
f16
[
2
*
2
+
1
],
v_regs
.
f16
[
3
*
2
+
1
]);
// pv mmac
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
0
],
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
);
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
1
],
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
);
}
}
__syncthreads
();
}
{
if
constexpr
(
PREFETCH_K
)
{
vec4_uint
k_rope_buffer
=
v_buffer
;
*
(
int64_t
*
)
&
k_rope_buffer
+=
v_buffer_offset
;
mla_prefix_prefill_prefetch_k_nope_to_lds
<
kBlockN
,
WARP_NUM
,
Element
>
(
v_lds
,
k_rope_buffer
,
warp_id
,
lane_id
,
v_row_stride
,
seqlen_kv_limit
-
kBlockN
);
wait_buffer_data_arrived
<
true
>
(
kBlockN
/
(
16
*
2
));
}
else
{
wait_buffer_data_arrived
<
true
>
(
0
);
}
constexpr
int
n_loop
=
kBlockN
/
16
;
stage_id
^=
1
;
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
// lds -> vgprs
union_vec4_f16x2
<
Element
>
v_regs
;
int
v_load_base_offset
=
stage_id
*
4
*
4
*
512
+
warp_id_col
*
8
*
512
+
v_tile
*
512
;
#pragma unroll
for
(
int
i
=
0
;
i
<
4
;
++
i
)
{
int
v_load_offset
=
v_load_base_offset
+
i
*
4
*
32
+
(
lane_id
>>
4
)
*
32
+
(
lane_id
&
15
)
*
2
;
v_regs
.
f16x2
[
i
]
=
*
(
vec2_Element
<
Element
>*
)(
v_lds
+
v_load_offset
);
}
// v vgpr interleave
union_vec4_f16x2
<
Element
>
v_composed
;
v_composed
.
f16x4
[
0
]
=
make_vec4_f16
<
Element
>
(
v_regs
.
f16
[
0
*
2
+
0
],
v_regs
.
f16
[
1
*
2
+
0
],
v_regs
.
f16
[
2
*
2
+
0
],
v_regs
.
f16
[
3
*
2
+
0
]);
v_composed
.
f16x4
[
1
]
=
make_vec4_f16
<
Element
>
(
v_regs
.
f16
[
0
*
2
+
1
],
v_regs
.
f16
[
1
*
2
+
1
],
v_regs
.
f16
[
2
*
2
+
1
],
v_regs
.
f16
[
3
*
2
+
1
]);
// pv mmac
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
0
],
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
);
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
=
mmac
<
Element
,
ElementAccum
>
(
p_reg
[
m_idx
][
n_loop
-
1
].
f16x4
,
v_composed
.
f16x4
[
1
],
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
);
}
}
__syncthreads
();
}
}
}
template
<
int
kBlockM
,
int
WARP_M
,
int
WARP_NUM
,
int
kHeadDimVSplit
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_rescale_acc_o
(
vec4_Accum
<
ElementAccum
>
acc_o
[
WARP_M
/
16
][
kHeadDimVSplit
/
16
],
ElementAccum
*
scores_max_ptr
,
ElementAccum
*
scores_sum_ptr
,
ElementAccum
*
softmax_lse_ptr
,
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
ElementAccum
scale_softmax
,
int64_t
row_offset_lse
,
int
m_block
,
int
warp_id
,
int
warp_id_row
,
int
lane_id
,
int
actual_seqlen_q
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
ElementAccum
sum
=
scores_sum
[
m_idx
];
ElementAccum
lse
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
INFINITY
:
__llvm_fma_f32
(
scores_max
[
m_idx
],
scale_softmax
,
__logf
(
sum
));
if
constexpr
(
WARP_NUM
==
8
)
{
if
(
lane_id
<
16
and
warp_id
<
4
)
{
int
lse_offset
=
warp_id
*
16
+
lane_id
;
if
(
lse_offset
<
actual_seqlen_q
-
m_block
*
kBlockM
)
{
scores_max_ptr
[
row_offset_lse
+
lse_offset
]
=
scores_max
[
m_idx
]
*
scale_softmax
;
scores_sum_ptr
[
row_offset_lse
+
lse_offset
]
=
scores_sum
[
m_idx
];
softmax_lse_ptr
[
row_offset_lse
+
lse_offset
]
=
lse
;
}
}
}
else
if
constexpr
(
WARP_NUM
==
4
)
{
if
(
lane_id
<
16
and
((
warp_id
&
1
)
==
0
))
{
int
lse_offset
=
m_idx
*
WARP_M
+
warp_id_row
*
16
+
lane_id
;
if
(
lse_offset
<
actual_seqlen_q
-
m_block
*
kBlockM
)
{
scores_max_ptr
[
row_offset_lse
+
lse_offset
]
=
scores_max
[
m_idx
]
*
scale_softmax
;
scores_sum_ptr
[
row_offset_lse
+
lse_offset
]
=
scores_sum
[
m_idx
];
softmax_lse_ptr
[
row_offset_lse
+
lse_offset
]
=
lse
;
}
}
}
// 放缩 acc_o
__float2
inv_sum
;
inv_sum
[
0
]
=
1.0
f
/
sum
;
inv_sum
[
1
]
=
inv_sum
[
0
];
#pragma unroll
for
(
int
pv_tile
=
0
;
pv_tile
<
kHeadDimVSplit
/
16
;
++
pv_tile
)
{
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
],
inv_sum
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
],
inv_sum
);
}
}
}
template
<
int
kBlockM
,
int
WARP_M
,
int
WARP_NUM
,
int
kHeadDimVSplit
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_prefix_prefill_store_output
(
vec4_Accum
<
ElementAccum
>
acc_o
[
WARP_M
/
16
][
kHeadDimVSplit
/
16
],
void
*
__restrict__
o_raw_ptr
,
int64_t
row_offset_o
,
int
m_block
,
int
warp_id_row
,
int
warp_id_col
,
int
lane_id
,
int
o_row_stride
,
int
actual_seqlen_q
)
{
Element
*
o_ptr
=
reinterpret_cast
<
Element
*>
(
o_raw_ptr
)
+
row_offset_o
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
v_tile
=
0
;
v_tile
<
kHeadDimVSplit
/
32
;
++
v_tile
)
{
int
row_idx
=
(
WARP_NUM
==
8
?
m_idx
*
64
:
m_idx
*
WARP_M
)
+
warp_id_row
*
16
+
(
lane_id
&
15
);
if
(
m_block
*
kBlockM
+
row_idx
<
actual_seqlen_q
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
vec2_Element
<
Element
>
data
;
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
data
[
mmac_id
]
=
DownCast
<
ElementAccum
,
Element
,
true
>
(
acc_o
[
m_idx
][
v_tile
*
2
+
mmac_id
].
f32
[
vec_idx
]);
}
#else
data
=
DownCastPairNoPack
<
ElementAccum
,
Element
>
(
acc_o
[
m_idx
][
v_tile
*
2
+
0
].
f32
[
vec_idx
],
acc_o
[
m_idx
][
v_tile
*
2
+
1
].
f32
[
vec_idx
]);
#endif
int
col_idx
=
warp_id_col
*
256
+
v_tile
*
32
+
vec_idx
*
8
+
(
lane_id
>>
4
)
*
2
;
int64_t
write_offset
=
row_idx
*
int64_t
(
o_row_stride
)
+
col_idx
;
*
(
vec2_Element
<
Element
>*
)(
o_ptr
+
write_offset
)
=
data
;
}
}
}
}
}
\ No newline at end of file
csrc/gfx93/prefill/sparse/dsa_mls/legacy/include/mla/mla_pv_gemm_prefetch_k.h
0 → 100644
View file @
a1eef562
#include "mla_pv_gemm_utils.h"
template
<
int
K_LOOP_COUNT
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
M_WARP_COUNT
,
int
PV_N_WARP_COUNT
,
int
PV_K_WARP_COUNT
,
int
STAGES
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
mla_pv_gemm_prefetch_k
(
vec4_uint
v_addr
,
vec4_uint
k_addr
,
Element
*
v_lds
,
Element
*
k_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[
M_WARP_COUNT
*
PV_K_WARP_COUNT
][
4
],
vec4_Accum
<
ElementAccum
>
pv_reg
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
kvcache_seqlen_stride
,
int
max_seq_kv_offset
=-
1
)
{
constexpr
int
WARP_K
=
PV_K_WARP_COUNT
*
32
;
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockN
==
PV_N_WARP_COUNT
*
32
,
"Error: kBlockN in mla_pv_gemm_prefetch_k must be WARP_N * 32"
);
union_vec2_f16x2
<
Element
>
v_reg
[
STAGES
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
][
4
];
// 预先计算一些公共表达式
int
lane_id
=
threadIdx
.
x
&
63
;
int
laneid_shfl_2
=
lane_id
>>
2
;
// 0 ~ 15, 4 个线程读取一行
int
laneid_shfl_3
=
lane_id
>>
3
;
// 0 ~ 7, 8 个线程读取一行
int
laneid_shfl_4
=
lane_id
>>
4
;
// 0 ~ 3, 16 个线程读取一行
int
laneid_shfl_5
=
lane_id
>>
5
;
// 0 ~ 1, lds 读取时, 8x32的数据按照线程 [0, 16, 0, 16, 32, 48, 32, 48] 来读取, 每 32 个线程读取一个 4x32
constexpr
int
NEXT_DWORD_OFFSET
=
32
;
// 8x32 的数据, 一个 wave 每个线程读 4 个 half, 即 2 个 dword, 使用 ds_read2_b32 指令, 按照上面的读取方式, 第二个 dword 偏移 32 个 dword
#if defined(USE_BUFFER_LOAD_DWORDX4)
constexpr
int
READ_ONCE_LINES
=
16
;
// 一个 warp 每次读几行数据, loadx4, 每个线程读取 8 个 Half, 每行 32 个 Half 需要 32 / 8 = 4 个线程, 所以一个 wave 64 线程会读取 16 行
constexpr
int
READ_ONCE_COUNT
=
READ_ONCE_LINES
*
32
;
// 一个 warp 每次 load 多少数据, 16x32
constexpr
int
V_LDS_LOAD_NUM
=
kBlockN
*
WARP_K
/
READ_ONCE_COUNT
;
// 一个 warp 一共要发几次读取请求
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
;
// 一个 warp 一共要发几次读取请求
constexpr
int
READ_ELEMENT_COUNT
=
8
;
// 每个线程一次读取几个 Half
int
v_lane_headdim_n_idx
=
lane_id
&
3
;
// 当前 lane 负责这个 warp 的第几个 dwordx2
int
base
=
(
laneid_shfl_2
&
0xc
);
// 第几个 4 线程组的最小id
int
tail
=
(
laneid_shfl_2
&
0x3
);
// 4 线程组中的第几个线程
int
v_lane_seq_k_idx
=
base
+
(
tail
&
1
)
*
2
+
(
tail
>>
1
);
// global -> lds, seqlen 方向的坐标
int
v_ds_read_offset
=
(
laneid_shfl_5
*
4
+
(
laneid_shfl_4
&
1
))
*
32
+
(
lane_id
&
15
)
*
2
;
// 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto
BUFFER_LOAD_FUNC
=
&
inline_buffer_load_dwordx4_lds
<
Element
,
2
>
;
#else
constexpr
int
READ_ONCE_LINES
=
4
;
constexpr
int
READ_ONCE_COUNT
=
READ_ONCE_LINES
*
32
;
constexpr
int
V_LDS_LOAD_NUM
=
(
kBlockN
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
;
constexpr
int
READ_ELEMENT_COUNT
=
2
;
int
v_lane_headdim_n_idx
=
lane_id
&
15
;
int
v_lane_seq_k_idx
=
(
laneid_shfl_4
&
1
)
*
2
+
laneid_shfl_5
;
// 0, 1, 2, 3 ---> 0, 2, 1, 3
int
v_ds_read_offset
=
(
laneid_shfl_5
*
4
+
(
laneid_shfl_4
&
1
))
*
32
+
(
lane_id
&
15
)
*
2
;
// 一次读写 8x32, 0-31 线程读取前面 4x32, 32-63 读取后面 4x32, 4x32按照线程 [0, 16, 0, 16] 这种方式来读取
auto
BUFFER_LOAD_FUNC
=
&
inline_buffer_load_dword_lds
<
Element
,
2
>
;
#endif
// each wave need 2 32x32 lds space
v_lds
=
v_lds
+
warp_id
*
STAGES
*
WARP_K
*
kBlockN
;
int
stage_id
=
(
STAGES
==
2
)
?
1
:
0
;
constexpr
int
N_LOOP_START
=
(
STAGES
==
2
)
?
1
:
0
;
for
(
int
n_loop
=
N_LOOP_START
;
n_loop
<
K_LOOP_COUNT
;
++
n_loop
)
{
int
v_block_buffer_load_global_offset
=
warp_id
*
WARP_K
*
kvcache_seqlen_stride
+
n_loop
*
kBlockN
;
for
(
int
load
=
0
;
load
<
V_LOAD_REQUESTS
;
++
load
)
{
int
v_warp_buffer_load_k_id
=
(
load
+
warp_id
)
%
V_LOAD_REQUESTS
;
int
v_warp_buffer_load_lds_offset
=
load
*
READ_ONCE_COUNT
;
int
v_gvoffset_s
=
v_block_buffer_load_global_offset
/
2
;
int
v_gvoffset_v
=
(
v_lane_headdim_n_idx
*
READ_ELEMENT_COUNT
+
min
(
v_lane_seq_k_idx
+
load
*
READ_ONCE_LINES
,
max_seq_kv_offset
-
1
)
*
kvcache_seqlen_stride
)
/
2
;
int
v_lds_offset
=
v_warp_buffer_load_lds_offset
/
2
;
BUFFER_LOAD_FUNC
(
v_lds
+
stage_id
*
WARP_K
*
kBlockN
,
v_addr
,
v_lds_offset
,
v_gvoffset_s
,
v_gvoffset_v
);
}
if
constexpr
(
STAGES
==
2
)
stage_id
^=
1
;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int
precompute_v_lds_offset
[
4
];
vec2_Element
<
Element
>
*
v_lds_v2fp16
=
(
vec2_Element
<
Element
>
*
)(
v_lds
);
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
precompute_v_lds_offset
[
vec_idx
]
=
reinterpret_cast
<
size_t
>
(
v_lds_v2fp16
)
+
((
stage_id
*
WARP_K
*
kBlockN
+
seq_idx
*
32
*
kBlockN
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
)
*
4
/*4 bytes per dword*/
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
#ifdef USE_PINGPANG_BUFFER
if
constexpr
(
STAGES
==
2
)
{
buffer_load_lds_dwordx1_wait_nosync
<
V_LOAD_REQUESTS
>
();
}
else
if
constexpr
(
STAGES
==
1
)
{
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
}
#else
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
#endif
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
#pragma unroll
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
inline_ds_read2_b32_no_wait_bytes
(
precompute_v_lds_offset
[
vec_idx
],
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
,
NEXT_DWORD_OFFSET
);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element
<
Element
>
v_vgprs
[
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
][
2
];
{
constexpr
int
min_tile_k
=
0
;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element
<
Element
>
p_vgprs
[
2
];
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
p_vgprs
[
min_tile_m
]
=
make_vec4_f16
(
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
/*vec_idx*/
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
],
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
);
}
asm
volatile
(
"s_setprio 1"
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
v_tile_id
=
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
PV_K_WARP_COUNT
+
k_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
]
=
vec4_Element
<
Element
>
{
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
/*vec_idx*/
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
]
};
}
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
M_WARP_COUNT
;
++
m_idx
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
n_loop_idx
=
(
STAGES
==
2
)
?
n_loop
-
1
:
n_loop
;
int
pv_tile_id
=
n_loop_idx
*
M_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
M_WARP_COUNT
+
m_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac
<
Element
,
ElementAccum
>
(
p_vgprs
[
min_tile_m
],
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
asm
volatile
(
"s_setprio 0"
);
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr
int
min_tile_k
=
1
;
vec4_Element
<
Element
>
p_vgprs
[
2
];
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
p_vgprs
[
min_tile_m
]
=
make_vec4_f16
(
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
],
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
);
}
asm
volatile
(
"s_setprio 1"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
v_tile_id
=
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
PV_K_WARP_COUNT
+
k_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
]
=
vec4_Element
<
Element
>
{
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
]
};
}
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
M_WARP_COUNT
;
++
m_idx
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
n_loop_idx
=
(
STAGES
==
2
)
?
n_loop
-
1
:
n_loop
;
int
pv_tile_id
=
n_loop_idx
*
M_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
M_WARP_COUNT
+
m_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac
<
Element
,
ElementAccum
>
(
p_vgprs
[
min_tile_m
],
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
asm
volatile
(
"s_setprio 0"
);
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
if
constexpr
(
STAGES
==
2
)
{
int
n_loop
=
K_LOOP_COUNT
-
1
;
stage_id
^=
1
;
// 把 ds_read 之前的一些计算挪到 wait 之前, 等待数据返回
int
precompute_v_lds_offset
[
4
];
vec2_Element
<
Element
>
*
v_lds_v2fp16
=
(
vec2_Element
<
Element
>
*
)(
v_lds
);
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
precompute_v_lds_offset
[
vec_idx
]
=
reinterpret_cast
<
size_t
>
(
v_lds_v2fp16
)
+
((
stage_id
*
WARP_K
*
kBlockN
+
(
seq_idx
*
32
*
kBlockN
)
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
)
*
4
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
#ifdef USE_PINGPANG_BUFFER
if
constexpr
(
STAGES
==
2
)
{
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
}
else
if
constexpr
(
STAGES
==
1
)
{
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
}
#else
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
#endif
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
#pragma unroll
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
inline_ds_read2_b32_no_wait_bytes
(
precompute_v_lds_offset
[
vec_idx
],
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
,
NEXT_DWORD_OFFSET
);
}
}
}
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
vec4_Element
<
Element
>
v_vgprs
[
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
][
2
];
{
constexpr
int
min_tile_k
=
0
;
// 先把 p 寄存器需要的数据 v_pack 在一起
vec4_Element
<
Element
>
p_vgprs
[
2
];
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
p_vgprs
[
min_tile_m
]
=
make_vec4_f16
(
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
/*vec_idx*/
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
],
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
);
}
asm
volatile
(
"s_setprio 1"
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
v_tile_id
=
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
PV_K_WARP_COUNT
+
k_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
]
=
vec4_Element
<
Element
>
{
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
/*vec_idx*/
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
]
};
}
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
M_WARP_COUNT
;
++
m_idx
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
n_loop_idx
=
n_loop
;
int
pv_tile_id
=
n_loop_idx
*
M_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
M_WARP_COUNT
+
m_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac
<
Element
,
ElementAccum
>
(
p_vgprs
[
min_tile_m
],
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
asm
volatile
(
"s_setprio 0"
);
}
// ds 和 vgpr 之间的 ping-pang buffer
{
constexpr
int
min_tile_k
=
1
;
vec4_Element
<
Element
>
p_vgprs
[
2
];
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
p_vgprs
[
min_tile_m
]
=
make_vec4_f16
(
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
/*vec_idx*/
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
],
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
],
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
);
}
asm
volatile
(
"s_setprio 1"
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
v_tile_id
=
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
PV_K_WARP_COUNT
+
k_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
]
=
vec4_Element
<
Element
>
{
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
/*vec_idx*/
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
0
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
0
][
min_tile_n
],
v_reg
[
v_tile_id
][
1
+
min_tile_k
*
2
].
f16x2
[
1
][
min_tile_n
]
};
}
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
M_WARP_COUNT
;
++
m_idx
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
PV_K_WARP_COUNT
;
++
k_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
PV_N_WARP_COUNT
;
++
n_idx
)
{
int
n_loop_idx
=
n_loop
;
int
pv_tile_id
=
n_loop_idx
*
M_WARP_COUNT
*
PV_N_WARP_COUNT
+
n_idx
*
M_WARP_COUNT
+
m_idx
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac
<
Element
,
ElementAccum
>
(
p_vgprs
[
min_tile_m
],
v_vgprs
[
n_idx
*
PV_K_WARP_COUNT
+
k_idx
][
/*min_tile_k * 2 + */
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
asm
volatile
(
"s_setprio 0"
);
// asm volatile("s_barrier ; sync before load in the coming round");
}
}
__syncthreads
();
// here, K/V use more lds, and thus reuse togather, need sync
}
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment