Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
518a5f4d
Commit
518a5f4d
authored
Jun 09, 2026
by
hly
Browse files
import aicc-master-dev
parent
c2a1b310
Changes
131
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
705 additions
and
120 deletions
+705
-120
csrc/flash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v_3stage.h
...tn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v_3stage.h
+1
-1
csrc/flash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v_tile16x32.h
...hg/include/kvcache/kvcache_qk_gemm_prefetch_v_tile16x32.h
+1
-1
csrc/flash_attn_hg/include/kvcache/kvcache_softmax.h
csrc/flash_attn_hg/include/kvcache/kvcache_softmax.h
+13
-14
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_pv_gemm_gfx938.h
...h_attn_hg/include/mla/gfx938/f16_mla_tp8_pv_gemm_gfx938.h
+210
-0
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_pv_gemm_utils_gfx938.h
..._hg/include/mla/gfx938/f16_mla_tp8_pv_gemm_utils_gfx938.h
+53
-0
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_qk_gemm_gfx938.h
...h_attn_hg/include/mla/gfx938/f16_mla_tp8_qk_gemm_gfx938.h
+270
-0
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_qk_gemm_utils_gfx938.h
..._hg/include/mla/gfx938/f16_mla_tp8_qk_gemm_utils_gfx938.h
+52
-0
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_acco_reduce_gfx938.h
...h_attn_hg/include/mla/gfx938/fp8_mla_acco_reduce_gfx938.h
+2
-2
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_epilogue_gfx938.h
...lash_attn_hg/include/mla/gfx938/fp8_mla_epilogue_gfx938.h
+2
-2
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_softmax_gfx938.h
...flash_attn_hg/include/mla/gfx938/fp8_mla_softmax_gfx938.h
+2
-2
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h
...nclude/mla/gfx938/fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h
+16
-16
csrc/flash_attn_hg/include/mla/gfx938/mla_epilogue_tile16x32_lit.h
...h_attn_hg/include/mla/gfx938/mla_epilogue_tile16x32_lit.h
+50
-13
csrc/flash_attn_hg/include/mla/gfx938/mla_pv_gemm_prefetch_k_mls_ds.h
...ttn_hg/include/mla/gfx938/mla_pv_gemm_prefetch_k_mls_ds.h
+4
-13
csrc/flash_attn_hg/include/mla/gfx938/mla_pv_gemm_utils_mls_ds.h
...ash_attn_hg/include/mla/gfx938/mla_pv_gemm_utils_mls_ds.h
+1
-4
csrc/flash_attn_hg/include/mla/gfx938/mla_qk_gemm_prefetch_v_mls_ds.h
...ttn_hg/include/mla/gfx938/mla_qk_gemm_prefetch_v_mls_ds.h
+4
-10
csrc/flash_attn_hg/include/mla/gfx938/mla_qk_gemm_utils_mls_ds.h
...ash_attn_hg/include/mla/gfx938/mla_qk_gemm_utils_mls_ds.h
+3
-9
csrc/flash_attn_hg/include/mla/gfx938/mla_softmax_gfx938.h
csrc/flash_attn_hg/include/mla/gfx938/mla_softmax_gfx938.h
+15
-16
csrc/flash_attn_hg/include/mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h
...attn_hg/include/mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h
+3
-9
csrc/flash_attn_hg/include/mla/mla_epilogue.h
csrc/flash_attn_hg/include/mla/mla_epilogue.h
+2
-2
csrc/flash_attn_hg/include/mla/mla_epilogue_tile16x32.h
csrc/flash_attn_hg/include/mla/mla_epilogue_tile16x32.h
+1
-6
No files found.
csrc/flash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v_3stage.h
View file @
518a5f4d
...
...
@@ -36,7 +36,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_3stage(
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
if
constexpr
(
M_MMAC_COUNT
==
1
)
{
inline_vgpr4_init_zero_1x2x4
(
s_reg
);
}
else
{
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v_tile16x32.h
View file @
518a5f4d
...
...
@@ -28,7 +28,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_tile16x32(
int
laneid_shfl_4
=
lane_id
>>
4
;
int
laneid_and_15
=
lane_id
&
15
;
#if defined(__gfx936__) || defined(__gfx938__)
// >= bmz
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
int
qk_lane_m_idx
=
lane_id
>>
2
;
int
qk_lane_head_dim_idx
=
(
lane_id
&
3
)
<<
2
;
auto
BUFFER_LOAD_FUNC
=
&
inline_buffer_load_dwordx4_lds
<
Element
,
2
>
;
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_softmax.h
View file @
518a5f4d
...
...
@@ -218,7 +218,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
if
(
zero_init
==
true
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
summary
[
m_idx
*
2
].
u64
=
0x0
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
...
...
@@ -227,7 +227,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -254,7 +254,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
summary_cur
[
m_idx
*
2
].
u64
=
summary
[
m_idx
*
2
].
u64
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
...
...
@@ -262,7 +262,7 @@ __device__ inline void kvcache_thread_reduce_sum(const DataType0 tensor[(WARP_M
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary_cur
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -362,16 +362,15 @@ inline __device__ void kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M / 32) *
// min tile is 32 * 32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_fma_f32
(
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_fma_f32
(
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
],
scale_pair
,
neg_max_scaled_pair
);
}
asm
volatile
(
"s_nop 0"
:::
"memory"
);
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
...
...
@@ -448,10 +447,10 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
acc_o
[
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
(
mi
+
ni
*
(
WARP_M
/
32
))][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
(
mi
+
ni
*
(
WARP_M
/
32
))][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
(
mi
+
ni
*
(
WARP_M
/
32
))][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
],
scores_scale_pair
);
...
...
@@ -503,8 +502,8 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
#pragma unroll
for
(
int
warp_loop
=
1
;
warp_loop
<
WARP_NUM
;
warp_loop
++
)
{
__float2
other_warp_sum
=
*
(
__float2
*
)(
sum_lds
+
warp_loop
*
WARP_M
+
mi
*
32
+
lane_id
*
2
);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum
=
hcu_pk_add_f32
(
cur_wave_sum
,
other_warp_sum
);
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
cur_wave_sum
=
__builtin_
hcu_pk_add_f32
(
cur_wave_sum
,
other_warp_sum
);
#else
cur_wave_sum
[
0
]
+=
other_warp_sum
[
0
];
cur_wave_sum
[
1
]
+=
other_warp_sum
[
1
];
...
...
@@ -528,8 +527,8 @@ inline __device__ void kvcache_softmax_rescale_o(DataType0 scores[(WARP_N / 32)
}
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum
[
mi
].
u64
=
hcu_pk_add_f32
(
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
scores_sum
[
mi
].
u64
=
__builtin_
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
...
...
@@ -558,7 +557,7 @@ inline __device__ void kvcache_convert_pk_type(union_vec2_f16x2<Element> p_reg[(
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__) || defined(__gfx92a__)
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
...
...
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_pv_gemm_gfx938.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "f16_mla_tp8_pv_gemm_utils_gfx938.h"
template
<
int
K_LOOP_COUNT
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
M_WARP_COUNT
,
int
PV_N_WARP_COUNT
,
int
PV_K_WARP_COUNT
,
int
STAGES
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
f16_mla_tp8_pv_gemm_gfx938
(
vec4_uint
v_addr
,
vec4_uint
k_addr
,
Element
*
v_lds
,
Element
*
k_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[
M_WARP_COUNT
*
PV_K_WARP_COUNT
][
4
],
vec4_Accum
<
ElementAccum
>
pv_reg
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
kvcache_seqlen_stride
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_K
=
PV_K_WARP_COUNT
*
32
;
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockN
==
PV_N_WARP_COUNT
*
32
,
"Error: kBlockN in kvcache_pv_gemm_prefetch_k must be WARP_N * 32"
);
static_assert
(
M_WARP_COUNT
==
1
,
"for gfx938, only WARP_M = 32 is supported yet!"
);
static_assert
(
PV_N_WARP_COUNT
==
1
,
"for gfx938, only WARP_N = 32 is supported yet!"
);
static_assert
(
PV_K_WARP_COUNT
==
1
,
"for gfx938, only WARP_K = 32 is supported yet!"
);
constexpr
int
V_LOAD_REQUESTS
=
(
WARP_K
*
kBlockN
)
/
(
32
*
32
);
// 准备寄存器, 每次加载 32x32 的 half 用于 mmac 计算, 每个线程持有 16 个 half, 因此是 8 * 2, 一列有 8 个 half, 有两列
union_vec4_f16x2
<
Element
>
v_reg
[
1
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
*
2
];
// 准备 MLS 的 resource 寄存器
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
kvcache_seqlen_stride
;
// stride
// 防止与多 wave reduce max 需要的 lds 冲突
__syncthreads
();
int
stage_id
=
(
STAGES
==
2
)
?
1
:
0
;
// 一次加载多批数据
constexpr
int
N_LOOP_STEP
=
(
STAGES
==
2
)
?
2
:
1
;
constexpr
int
N_LOOP_START
=
(
STAGES
==
2
)
?
K_LOOP_COUNT
-
N_LOOP_STEP
*
2
:
K_LOOP_COUNT
-
1
;
constexpr
int
N_LOOP_END
=
0
;
for
(
int
n_loop
=
N_LOOP_START
;
n_loop
>=
N_LOOP_END
;
n_loop
-=
N_LOOP_STEP
)
{
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
N_LOOP_STEP
;
++
prefetch_id
)
{
// 计算当前 wave 当前加载的 32x32 block 的偏移字节数
int
v_mls_warp_global_offset
=
(
n_loop
+
prefetch_id
)
*
kBlockN
*
sizeof
(
Element
);
// 计算当前 wave 写入 lds 的偏移地址(注意 v_lds 相较于 smem 的偏移量)
int
v_mls_lds_warp_offset
=
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
prefetch_id
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
sizeof
(
Element
);
// 计算当前 wave 读取数据的起始偏移字节数
int
v_mls_loop_global_offset
;
// = warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
// 计算 MLS 读取数据的 global 地址, 判断边界
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
v_mls_loop_global_offset
=
real_mls_warp_id
*
WARP_K
*
kvcache_seqlen_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc
[
3
]
=
max_seq_kv_offset
%
kBlockN
==
0
?
0
:
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_mls_loop_global_offset
+
v_mls_warp_global_offset
);
__builtin_amdgcn_sched_barrier
(
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_mls_lds_warp_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
// 等待 MLS 数据回来
if
constexpr
(
N_LOOP_STEP
==
2
)
{
buffer_load_lds_dwordx1_wait_nosync
<
3
*
V_LOAD_REQUESTS
>
();
}
else
if
constexpr
(
N_LOOP_STEP
==
1
and
STAGES
==
2
)
{
buffer_load_lds_dwordx1_wait_nosync
<
V_LOAD_REQUESTS
>
();
}
else
if
constexpr
(
N_LOOP_STEP
==
1
and
STAGES
==
1
)
{
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
}
__builtin_amdgcn_sched_barrier
(
0
);
// 切换到 load 轮次
if
constexpr
(
STAGES
==
2
)
{
stage_id
^=
1
;
}
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
2
/*bytes*/
;
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// hint: multiple prefetching can be applied here
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
::
"B"
(
2
-
min_tile_k
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
int
pv_tile_id
=
(
STAGES
==
2
)
?
n_loop
+
2
:
n_loop
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
asm
volatile
(
"s_setprio 0"
);
}
// ============================================================================================================
// 处理预取的第二段数据
if
constexpr
(
N_LOOP_STEP
==
2
)
{
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_dwordx1_wait_nosync
<
2
*
V_LOAD_REQUESTS
>
();
__builtin_amdgcn_sched_barrier
(
0
);
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
1
/*prefetch_id*/
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
2
/*bytes*/
;
__builtin_amdgcn_sched_barrier
(
0
);
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
::
"B"
(
2
-
min_tile_k
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
int
pv_tile_id
=
(
STAGES
==
2
)
?
n_loop
+
3
:
n_loop
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
asm
volatile
(
"s_setprio 0"
);
}
}
}
if
constexpr
(
STAGES
==
2
)
{
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_dwordx1_wait_nosync
<
1
*
V_LOAD_REQUESTS
>
();
__builtin_amdgcn_sched_barrier
(
0
);
int
n_loop
=
N_LOOP_END
-
N_LOOP_STEP
;
// 切换
stage_id
^=
1
;
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
2
/*bytes*/
;
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
::
"B"
(
2
-
min_tile_k
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
int
pv_tile_id
=
n_loop
+
N_LOOP_STEP
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
asm
volatile
(
"s_setprio 0"
);
}
// ============================================================================================================
// 处理预取的第二段数据
if
constexpr
(
N_LOOP_STEP
==
2
)
{
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_dwordx1_wait_nosync
<
0
*
V_LOAD_REQUESTS
>
();
__builtin_amdgcn_sched_barrier
(
0
);
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
1
/*prefetch_id*/
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
2
/*bytes*/
;
__builtin_amdgcn_sched_barrier
(
0
);
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// 拆成两段, 一共发出去 8 条 ds_read 指令, 前 4 条指令 perm, 用来做 mmac,然后再等待后面 4 条指令的返回, 类似
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)"
::
"B"
(
2
-
min_tile_k
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
int
pv_tile_id
=
n_loop
+
N_LOOP_STEP
+
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
asm
volatile
(
"s_setprio 0"
);
}
}
}
__syncthreads
();
// here, K/V use more lds, and thus reuse togather, need sync
}
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_pv_gemm_utils_gfx938.h
0 → 100644
View file @
518a5f4d
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
WARP_K
,
int
stage_id
,
int
WARP_NUM
,
typename
Element
,
int
STAGES
>
__forceinline__
__device__
void
f16_mla_tp8_prefetch_v_to_lds_gfx938
(
vec4_uint
v_addr
,
Element
*
v_lds
,
int
warp_id
,
int
kvcache_seqlen_stride
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
V_LOAD_REQUESTS
=
(
WARP_K
*
kBlockN
)
/
(
32
*
32
);
constexpr
int
N_LOOP_STEP
=
2
;
// 准备 MLS 的 resource 寄存器
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
kvcache_seqlen_stride
;
// stride
// 从倒数第 2 个 block 开始读取
int
n_loop
=
kHeadDim
/
kBlockN
-
N_LOOP_STEP
;
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
N_LOOP_STEP
;
++
prefetch_id
)
{
// 计算当前 wave 当前加载的 32x32 block 的偏移字节数
int
v_mls_warp_global_offset
=
(
n_loop
+
prefetch_id
)
*
kBlockN
*
sizeof
(
Element
);
// 计算当前 wave 写入 lds 的偏移地址(注意 v_lds 相较于 smem 的偏移量)
int
v_mls_lds_warp_offset
=
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
prefetch_id
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
sizeof
(
Element
);
// 计算当前 wave 读取数据的起始偏移字节数
int
v_mls_loop_global_offset
;
// = warp_id * WARP_K * kvcache_seqlen_stride * sizeof(Element);
// 计算 MLS 读取数据的 global 地址, 判断边界
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
v_mls_loop_global_offset
=
real_mls_warp_id
*
WARP_K
*
kvcache_seqlen_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
v_srsrc
[
3
]
=
max_seq_kv_offset
%
kBlockN
==
0
?
0
:
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_mls_loop_global_offset
+
v_mls_warp_global_offset
);
__builtin_amdgcn_sched_barrier
(
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_mls_lds_warp_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
\ No newline at end of file
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_qk_gemm_gfx938.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "f16_mla_tp8_pv_gemm_utils_gfx938.h"
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
WARP_NUM
,
int
STAGES
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
f16_mla_tp8_qk_gemm_gfx938
(
vec4_uint
q_addr
,
vec4_uint
k_addr
,
vec4_uint
v_addr
,
Element
*
q_lds
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
int
warp_id
,
int
kcache_seqlen_stride
,
int
vcache_seqlen_stride
,
int
max_seq_k_offset
=
0
)
{
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported! otherwise, restore q_warp_buffer_load_k_id and so on"
);
constexpr
int
K_LOAD_REQUESTS
=
(
WARP_N
/
32
)
*
(
kBlockK
/
32
);
// 分配 k 计算 mmac 需要的寄存器资源
// 一次加载 32x32 个 half, 每个线程持有 16 个 half
union_vec4_f16x2
<
Element
>
k_reg
[
1
*
(
WARP_N
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 初始化 s
uint64_t
pk_zero
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
WARP_N
/
WARP_N
)
*
(
WARP_M
/
32
);
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
s_reg
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
pk_zero
;
s_reg
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
pk_zero
;
}
}
}
// 准备 MLS resource 寄存器
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
kcache_seqlen_stride
;
int
stage_id
=
0
;
constexpr
int
K_LOOP_START
=
(
STAGES
==
2
)
?
2
:
0
;
if
constexpr
(
STAGES
==
2
)
stage_id
^=
1
;
for
(
int
k_loop
=
K_LOOP_START
;
k_loop
<
(
kHeadDim
/
kBlockK
);
k_loop
+=
2
)
{
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
2
;
++
prefetch_id
)
{
// 计算当前 wave 写到 lds 的起始地址
int
k_lds_stage_offset
=
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
prefetch_id
)
*
K_LOAD_REQUESTS
*
(
32
*
32
);
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int
k_mls_loop_global_offset
=
(
k_loop
+
prefetch_id
)
*
kBlockK
*
sizeof
(
Element
);
// 计算当前 wave 从 global 读取数据的起始地址
int
k_mls_warp_global_offset
;
// = warp_id * WARP_N * kcache_seqlen_stride * sizeof(Element);
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_N
+
32
-
max_seq_k_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
k_mls_warp_global_offset
=
real_mls_warp_id
*
WARP_N
*
kcache_seqlen_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_N
+
32
-
max_seq_k_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc
[
3
]
=
nm_filter
<<
8
;
}
// 根据偏移计算 global load 的字节偏移数
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
k_mls_loop_global_offset
+
k_mls_warp_global_offset
);
int
lds_offset_bytes
=
k_lds_stage_offset
*
2
/*half -> bytes*/
;
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset_bytes
,
0
);
}
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
STAGES
==
2
)
{
buffer_load_lds_dwordx1_wait_nosync
<
3
*
K_LOAD_REQUESTS
>
();
}
else
{
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
}
__builtin_amdgcn_sched_barrier
(
0
);
// __builtin_amdgcn_sched_barrier(0);
if
constexpr
(
STAGES
==
2
)
stage_id
^=
1
;
// 加载上一次 MLS 写到 lds 的数据到寄存器
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
)
*
K_LOAD_REQUESTS
*
(
32
*
32
)
*
sizeof
(
Element
)
/*half -> bytes*/
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
k_reg
[
0
].
f16
,
k_reg
[
1
].
f16
,
true
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)
\n
"
::
"B"
(
2
-
min_tile_n
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
32
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
32
;
++
n_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
kBlockK
/
32
;
++
head_dim_idx
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
2
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
+
(
head_dim_idx
*
(
WARP_M
/
32
)
+
m_idx
)
*
2
+
min_tile_m
;
int
k_tile_id
=
(
head_dim_idx
*
(
WARP_N
/
32
)
+
n_idx
)
*
2
+
min_tile_n
;
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
// =============================================================================================================
{
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
STAGES
==
2
)
{
buffer_load_lds_dwordx1_wait_nosync
<
2
*
K_LOAD_REQUESTS
>
();
}
else
{
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
}
__builtin_amdgcn_sched_barrier
(
0
);
// 加载上一次 MLS 写到 lds 的数据到寄存器
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
1
)
*
K_LOAD_REQUESTS
*
(
32
*
32
)
*
sizeof
(
Element
)
/*half -> bytes*/
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
k_reg
[
0
].
f16
,
k_reg
[
1
].
f16
,
true
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)
\n
"
::
"B"
(
2
-
min_tile_n
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
32
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
32
;
++
n_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
kBlockK
/
32
;
++
head_dim_idx
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
+
(
head_dim_idx
*
(
WARP_M
/
32
)
+
m_idx
)
*
2
+
min_tile_m
;
int
k_tile_id
=
(
head_dim_idx
*
(
WARP_N
/
32
)
+
n_idx
)
*
2
+
min_tile_n
;
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
if
constexpr
(
STAGES
==
2
)
{
// 等待 MLS 数据回来
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_dwordx1_wait_nosync
<
1
>
();
__builtin_amdgcn_sched_barrier
(
0
);
// 切换到上一次 lds 被写入的轮次
stage_id
^=
1
;
// 从 lds 加载最后一部分数据
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
)
*
K_LOAD_REQUESTS
*
(
32
*
32
)
*
sizeof
(
Element
)
/*half -> bytes*/
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
k_reg
[
0
].
f16
,
k_reg
[
1
].
f16
,
true
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)
\n
"
::
"B"
(
2
-
min_tile_n
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
32
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
32
;
++
n_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
kBlockK
/
32
;
++
head_dim_idx
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
2
;
int
q_tile_id
=
k_loop_idx
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
+
(
head_dim_idx
*
(
WARP_M
/
32
)
+
m_idx
)
*
2
+
min_tile_m
;
int
k_tile_id
=
(
head_dim_idx
*
(
WARP_N
/
32
)
+
n_idx
)
*
2
+
min_tile_n
;
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
// ==========================================================================
{
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_dwordx1_wait_nosync
<
0
>
();
__builtin_amdgcn_sched_barrier
(
0
);
// 从 lds 加载最后一部分数据
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
1
)
*
K_LOAD_REQUESTS
*
(
32
*
32
)
*
sizeof
(
Element
)
/*half -> bytes*/
;
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
k_reg
[
0
].
f16
,
k_reg
[
1
].
f16
,
true
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)
\n
"
::
"B"
(
2
-
min_tile_n
-
1
));
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
32
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
32
;
++
n_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
kBlockK
/
32
;
++
head_dim_idx
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
+
(
head_dim_idx
*
(
WARP_M
/
32
)
+
m_idx
)
*
2
+
min_tile_m
;
int
k_tile_id
=
(
head_dim_idx
*
(
WARP_N
/
32
)
+
n_idx
)
*
2
+
min_tile_n
;
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
__syncthreads
();
// qk gemm 等待最后一次计算需要的数据之前, 可以先把需要的 V load 指令发下去;
if
constexpr
(
STAGES
>
1
)
{
f16_mla_tp8_prefetch_v_to_lds_gfx938
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
32
/*WARP_K*/
,
0
,
WARP_NUM
,
Element
,
STAGES
>
(
v_addr
,
v_lds
,
warp_id
,
vcache_seqlen_stride
,
max_seq_k_offset
);
}
}
// qk_gemm
csrc/flash_attn_hg/include/mla/gfx938/f16_mla_tp8_qk_gemm_utils_gfx938.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
template
<
int
kBlockK
,
int
WARP_N
,
typename
Element
,
int
STAGES
,
int
WARP_NUM
>
__forceinline__
__device__
void
f16_mla_tp8_prefetch_k_to_lds_gfx938
(
vec4_uint
k_addr
,
Element
*
k_lds
,
int
warp_id
,
int
kvcache_seqlen_stride
,
int
max_seq_k_offset
=
0
)
{
// 准备 MLS 寄存器
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
kvcache_seqlen_stride
;
// pingpong buffer 的第一阶段
int
stage_id
=
0
;
// kHeadDim 方向上的第几个 32x32 块
int
k_loop
=
0
;
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
2
;
++
prefetch_id
)
{
// 计算当前 wave 写到 lds 的起始地址
int
k_lds_stage_offset
=
(
warp_id
*
STAGES
*
2
+
stage_id
*
2
+
prefetch_id
)
*
(
WARP_N
/
32
)
*
(
kBlockK
/
32
)
*
(
32
*
32
);
// 计算当前 wave 沿着 kHeadDim 方向循环读取的起始地址, 读到第几个 32x32 块了
int
k_mls_loop_global_offset
=
(
k_loop
+
prefetch_id
)
*
kBlockK
*
sizeof
(
Element
);
// 计算当前 wave 从 global 读取数据的起始地址
int
k_mls_warp_global_offset
;
// = warp_id * WARP_N * kvcache_seqlen_stride;
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_N
+
32
-
max_seq_k_offset
;
// 判断是否有 warp 取空数据
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// 如果取空数据, 938 不支持, 退化到取 warp 0 的数据
k_mls_warp_global_offset
=
real_mls_warp_id
*
WARP_N
*
kvcache_seqlen_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_N
+
32
-
max_seq_k_offset
);
// 如果取空数据, 使用 warp 0 的 nm_filter 值
k_srsrc
[
3
]
=
nm_filter
<<
8
;
}
// 根据偏移计算 global load 的字节偏移数
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
k_mls_loop_global_offset
+
k_mls_warp_global_offset
);
int
lds_offset_bytes
=
k_lds_stage_offset
*
2
/*half -> bytes*/
;
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
\ No newline at end of file
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_acco_reduce_gfx938.h
View file @
518a5f4d
...
...
@@ -104,8 +104,8 @@ __forceinline__ __device__ void fp8_mla_acco_reduce_tile16x32(
data
.
f32
[
1
]
=
acc_o_lds
[
neighbor
*
2048
+
warp_id
*
2
*
16
*
16
+
min_tile_n
*
16
*
16
+
lane_id
+
1
*
64
];
data
.
f32
[
2
]
=
acc_o_lds
[
neighbor
*
2048
+
warp_id
*
2
*
16
*
16
+
min_tile_n
*
16
*
16
+
lane_id
+
2
*
64
];
data
.
f32
[
3
]
=
acc_o_lds
[
neighbor
*
2048
+
warp_id
*
2
*
16
*
16
+
min_tile_n
*
16
*
16
+
lane_id
+
3
*
64
];
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
]
=
hcu_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
],
data
.
u64
[
0
]);
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
]
=
hcu_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
],
data
.
u64
[
1
]);
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
]
=
__builtin_
hcu_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
0
],
data
.
u64
[
0
]);
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
]
=
__builtin_
hcu_pk_add_f32
(
acc_o
[
k_loop
+
0
][
min_tile_n
*
2
].
u64
[
1
],
data
.
u64
[
1
]);
}
}
__syncthreads
();
...
...
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_epilogue_gfx938.h
View file @
518a5f4d
...
...
@@ -22,9 +22,9 @@ __forceinline__ __device__ void fp8_mla_epilugue_rescale_acco_gfx938(
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
...
...
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_softmax_gfx938.h
View file @
518a5f4d
...
...
@@ -75,8 +75,8 @@ inline __device__ void fp8_mla_apply_descale_gfx938(DataType tensor[M_WARP_COUNT
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
],
qk_descale
);
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
],
qk_descale
);
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
__builtin_
hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
],
qk_descale
);
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
__builtin_
hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
],
qk_descale
);
}
}
}
...
...
csrc/flash_attn_hg/include/mla/gfx938/fp8_mla_tp8_pv_gemm_prefetch_k_gfx938.h
View file @
518a5f4d
...
...
@@ -88,16 +88,16 @@ __forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
// fp8 -> f32
vec2_fp32
v_f32x2
[
4
];
// 8 fp8 -> 8 f32, for 1 mmac
v_f32x2
[
0
]
=
hcu_cvt_pk_f32_fp8
<
0
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]);
v_f32x2
[
1
]
=
hcu_cvt_pk_f32_fp8
<
2
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]);
v_f32x2
[
2
]
=
hcu_cvt_pk_f32_fp8
<
0
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]);
v_f32x2
[
3
]
=
hcu_cvt_pk_f32_fp8
<
2
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]);
v_f32x2
[
0
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]
,
false
/*word_sel*/
);
v_f32x2
[
1
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]
,
true
/*word_sel*/
);
v_f32x2
[
2
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]
,
false
/*word_sel*/
);
v_f32x2
[
3
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]
,
true
/*word_sel*/
);
// f32 -> fp16
union_vec4_f16x2
<
P_Element
>
v_f16x8
;
v_f16x8
.
f16x2
[
0
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
]);
v_f16x8
.
f16x2
[
1
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
]);
v_f16x8
.
f16x2
[
2
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
]);
v_f16x8
.
f16x2
[
3
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
]);
v_f16x8
.
f16x2
[
0
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
1
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
2
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
3
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
// mmac_16x16x16, 4 fp16
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
...
...
@@ -151,16 +151,16 @@ __forceinline__ __device__ void fp8_mla_tp8_pv_gemm_prefetch_k_gfx938(
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
// fp8 -> f32
vec2_fp32
v_f32x2
[
4
];
// 8 fp8 -> 8 f32, for 1 mmac
v_f32x2
[
0
]
=
hcu_cvt_pk_f32_fp8
<
0
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]);
v_f32x2
[
1
]
=
hcu_cvt_pk_f32_fp8
<
2
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]);
v_f32x2
[
2
]
=
hcu_cvt_pk_f32_fp8
<
0
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]);
v_f32x2
[
3
]
=
hcu_cvt_pk_f32_fp8
<
2
>
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]);
v_f32x2
[
0
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]
,
false
/*word_sel*/
);
v_f32x2
[
1
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
]
,
true
/*word_sel*/
);
v_f32x2
[
2
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]
,
false
/*word_sel*/
);
v_f32x2
[
3
]
=
__builtin_
hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
]
,
true
/*word_sel*/
);
// f32 -> fp16
union_vec4_f16x2
<
P_Element
>
v_f16x8
;
v_f16x8
.
f16x2
[
0
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
]);
v_f16x8
.
f16x2
[
1
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
]);
v_f16x8
.
f16x2
[
2
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
]);
v_f16x8
.
f16x2
[
3
]
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
]);
v_f16x8
.
f16x2
[
0
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
1
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
2
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
3
]
=
__builtin_
hcu_cvt_pk_f16_f32
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
// mmac_16x16x16, 4 fp16
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
...
...
csrc/flash_attn_hg/include/mla/gfx938/mla_epilogue_tile16x32_lit.h
View file @
518a5f4d
...
...
@@ -34,9 +34,9 @@ __forceinline__ __device__ void prefill_mla_epilugue_rescale_acco(
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
kHeadDimV
/
kBlockK
);
++
pv_n_loop
)
{
const
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
+
mi
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
...
...
@@ -115,8 +115,35 @@ __forceinline__ __device__ void prefill_mla_epilogue_store_output(
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
if
constexpr
(
Is_Interleaved
)
{
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
union_vec2_f16x2
<
Element
>
acc_o_fp16
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
))
*
(
kBlockK
/
32
)][
2
*
M_MMAC_COUNT
];
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll 2
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
;
if
constexpr
(
M_MMAC_COUNT
==
2
)
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
else
mmac_id
=
min_tile_n
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
// convert float -> bf16/fp16
acc_o_fp16
[
k_loop
][
mmac_id
].
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
k_loop
][
mmac_id
].
f32x2
[
vec_index
]);
}
ds_mpermute_kdim_for_mmac
(
acc_o_fp16
[
k_loop
][
mmac_id
].
f32
);
}
}
}
#endif
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
flash
::
wait_lds_data_arrived
<
false
>
((
kHeadDimV
/
kBlockK
-
k_loop
-
1
)
*
2
*
2
);
#endif
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
warp_m_idx
)
{
#pragma unroll
...
...
@@ -137,6 +164,15 @@ __forceinline__ __device__ void prefill_mla_epilogue_store_output(
// prepare for store
int
s_offset
=
k_tile_idx
*
32
+
min_tile_n
*
16
;
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
k_loop
*
kBlockK
+
pv_lane_head_dim_idx
*
4
;
#if defined(__gfx92a__) && defined(YY_USE_MPERMUTE)
if
constexpr
(
not
Is_even_MN
)
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
acc_o_fp16
[
k_loop
][
mmac_id
];
}
}
else
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
acc_o_fp16
[
k_loop
][
mmac_id
];
}
#else
union_vec2_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
...
...
@@ -150,6 +186,7 @@ __forceinline__ __device__ void prefill_mla_epilogue_store_output(
}
else
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
#endif
}
}
}
...
...
csrc/flash_attn_hg/include/mla/gfx938/mla_pv_gemm_prefetch_k_mls_ds.h
View file @
518a5f4d
...
...
@@ -59,10 +59,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
// DS
...
...
@@ -136,10 +133,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// 预防有 warp 还没算完7,还在读 v lds, 若是此时写 v lds,则 data cover
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
}
stage_id
^=
1
;
...
...
@@ -200,10 +194,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds_576_512(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
lds_stage_id
^=
1
;
...
...
csrc/flash_attn_hg/include/mla/gfx938/mla_pv_gemm_utils_mls_ds.h
View file @
518a5f4d
...
...
@@ -29,9 +29,6 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds_576_512(
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDim_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// 防止写 v lds 和读 q lds k lds 冲突, qk 可能有的 warp 没结束
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
csrc/flash_attn_hg/include/mla/gfx938/mla_qk_gemm_prefetch_v_mls_ds.h
View file @
518a5f4d
...
...
@@ -106,10 +106,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
q_srsrc
[
3
]
=
max_seq_q_offset
%
kBlockM
==
0
?
0
:
nm_filter
<<
8
;
int
lds_offset
=
(
q_stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
1
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
if
(
k_even
)
{
k_stage_id
^=
1
;
...
...
@@ -122,10 +119,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
k_srsrc
[
3
]
=
(
max_seq_k_offset
%
kBlockN
==
0x0
?
0
:
nm_filter
)
<<
8
;
int
lds_offset
=
(
k_stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
}
...
...
@@ -317,7 +311,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds_576_512(
}
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
prefetch_v_to_lds_mls_ds_576_512
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
max_seq_k_offset
);
#else
...
...
csrc/flash_attn_hg/include/mla/gfx938/mla_qk_gemm_utils_mls_ds.h
View file @
518a5f4d
...
...
@@ -36,10 +36,7 @@ __forceinline__ __device__ void prefetch_q_to_lds_mls_ds_576_512(
int
lds_offset
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
16
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// pvgemm 完成后会发射q,k的预取,避免有的warp还没完成,即规避读V写Q/K,造成数据覆盖
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
1
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
}
}
...
...
@@ -71,9 +68,6 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds_576_512(
}
int
lds_offset
=
(
stage_id
*
WARP_N
*
kHeadDim_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
\ No newline at end of file
csrc/flash_attn_hg/include/mla/gfx938/mla_softmax_gfx938.h
View file @
518a5f4d
...
...
@@ -13,13 +13,13 @@ struct PrefillMlaAllreduce {
DataType
res
;
if
constexpr
(
std
::
is_same
<
DataType
,
union_vec2_fp32
>::
value
)
{
if
constexpr
(
std
::
is_same
<
Operator
,
SumOp
<
float
>
>::
value
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
res
.
f32
[
0
]
=
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
res
.
f32
[
1
]
=
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
);
x
.
u64
=
hcu_pk_add_f32
(
x
.
u64
,
res
.
u64
);
x
.
u64
=
__builtin_
hcu_pk_add_f32
(
x
.
u64
,
res
.
u64
);
res
.
f32
[
0
]
=
__shfl_xor_tmp
(
x
.
f32
[
0
],
16
);
res
.
f32
[
1
]
=
__shfl_xor_tmp
(
x
.
f32
[
1
],
16
);
res
.
u64
=
hcu_pk_add_f32
(
res
.
u64
,
x
.
u64
);
res
.
u64
=
__builtin_
hcu_pk_add_f32
(
res
.
u64
,
x
.
u64
);
#else
x
.
f32
[
0
]
=
x
.
f32
[
0
]
+
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
x
.
f32
[
1
]
=
x
.
f32
[
1
]
+
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
);
...
...
@@ -100,7 +100,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
summary
[
m_idx
*
2
].
u64
=
0x0
;
}
else
{
...
...
@@ -113,7 +113,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
if
constexpr
(
M_MMAC_COUNT
==
2
){
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -146,7 +146,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
summary_cur
[
m_idx
*
2
].
u64
=
summary
[
m_idx
*
2
].
u64
;
}
else
{
...
...
@@ -159,7 +159,7 @@ __device__ inline void prefill_mla_thread_reduce_sum(const DataType0 tensor[(WAR
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary_cur
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -273,15 +273,14 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / (16 * M_MMAC_
mmac_id
=
min_tile_n
;
}
int
qk_tile_id
=
mi
+
ni
*
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
hcu_pk_fma_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_fma_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scale_pair
,
neg_max_scaled_pair
);
}
asm
volatile
(
"s_nop 0"
:::
"memory"
);
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]);
}
...
...
@@ -340,10 +339,10 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
}
else
{
mmac_id
=
min_tile_n
;
}
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scores_scale_pair
);
...
...
@@ -372,9 +371,9 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
reduce_sum
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
,
M_MMAC_COUNT
>
(
scores
,
scores_sum_cur
);
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
(
16
*
M_MMAC_COUNT
));
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
scores_sum
[
mi
].
u64
=
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
=
__builtin_
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
...
...
@@ -390,7 +389,7 @@ inline __device__ void prefill_mla_softmax_rescale_o(DataType0 scores[(WARP_N /
}
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__))
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
inlineasm_fa_v_mov_b64
(
scores_max
[
mi
].
u64
,
...
...
@@ -423,7 +422,7 @@ inline __device__ void prefill_mla_convert_pk_type(union_vec2_f16x2<Element> p_r
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
M_MMAC_COUNT
==
2
)
{
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
...
...
csrc/flash_attn_hg/include/mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h
View file @
518a5f4d
...
...
@@ -33,10 +33,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
q_warp_offset
*
sizeof
(
Element
));
// matrix load
__builtin_amdgcn_sched_barrier
(
0
);
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset_bytes
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
1
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
...
...
@@ -63,10 +60,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_gfx938_with_initializatio
int
q_warp_offset
=
(
LOAD
*
WARP_NUM
+
real_warp_id
)
*
32
;
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
q_warp_offset
*
sizeof
(
Element
));
__builtin_amdgcn_sched_barrier
(
0
);
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset_bytes
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
1
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
// continue from MID
...
...
csrc/flash_attn_hg/include/mla/mla_epilogue.h
View file @
518a5f4d
...
...
@@ -20,9 +20,9 @@ __forceinline__ __device__ void mla_epilugue_rescale_acco(
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
...
...
csrc/flash_attn_hg/include/mla/mla_epilogue_tile16x32.h
View file @
518a5f4d
...
...
@@ -15,12 +15,7 @@ __forceinline__ __device__ void mla_epilogue_store_max_sum_tile16x32(
int
headdim_split_id
,
int
seqlen_q_limit
)
{
#ifdef FA_DEBUG_SUM_MAX
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
true
;
#else
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
false
;
#endif
if
constexpr
(
Split
or
ALLOW_WRITE_SUM_MAX
)
{
if
constexpr
(
Split
)
{
if
(
headdim_split_id
==
0
)
{
// 因为 split-D 使用同样的 QK, 计算得到同样的 scores_sum/scores_max 会写多遍, 可能会有数据冲突, 所以强制只写一遍
if
(
thread_id
<
16
)
{
// 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment