Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
518a5f4d
Commit
518a5f4d
authored
Jun 09, 2026
by
hly
Browse files
import aicc-master-dev
parent
c2a1b310
Changes
131
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2370 additions
and
138 deletions
+2370
-138
csrc/flash_attn_hg/include/flash_singleton.h
csrc/flash_attn_hg/include/flash_singleton.h
+8
-9
csrc/flash_attn_hg/include/fwd/fwd_epilogue.h
csrc/flash_attn_hg/include/fwd/fwd_epilogue.h
+92
-40
csrc/flash_attn_hg/include/fwd/gfx92a/fwd_epilogue_gfx92a.h
csrc/flash_attn_hg/include/fwd/gfx92a/fwd_epilogue_gfx92a.h
+85
-0
csrc/flash_attn_hg/include/fwd/gfx92a/pv_gemm_prefetch_k_mls_ds_gfx92a.h
..._hg/include/fwd/gfx92a/pv_gemm_prefetch_k_mls_ds_gfx92a.h
+298
-0
csrc/flash_attn_hg/include/fwd/gfx92a/qk_gemm_prefetch_v_mls_ds_gfx92a.h
..._hg/include/fwd/gfx92a/qk_gemm_prefetch_v_mls_ds_gfx92a.h
+445
-0
csrc/flash_attn_hg/include/fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h
..._attn_hg/include/fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h
+96
-0
csrc/flash_attn_hg/include/fwd/gfx92a/softmax_gfx92a.h
csrc/flash_attn_hg/include/fwd/gfx92a/softmax_gfx92a.h
+132
-0
csrc/flash_attn_hg/include/fwd/gfx938/fp8_epilogue.h
csrc/flash_attn_hg/include/fwd/gfx938/fp8_epilogue.h
+149
-0
csrc/flash_attn_hg/include/fwd/gfx938/fp8_pv_gemm_prefetch_k_mls_ds.h
...ttn_hg/include/fwd/gfx938/fp8_pv_gemm_prefetch_k_mls_ds.h
+81
-0
csrc/flash_attn_hg/include/fwd/gfx938/fp8_pv_gemm_utils_mls_ds.h
...ash_attn_hg/include/fwd/gfx938/fp8_pv_gemm_utils_mls_ds.h
+49
-0
csrc/flash_attn_hg/include/fwd/gfx938/fp8_qk_gemm_prefetch_v_mls_ds.h
...ttn_hg/include/fwd/gfx938/fp8_qk_gemm_prefetch_v_mls_ds.h
+419
-0
csrc/flash_attn_hg/include/fwd/gfx938/fp8_qk_gemm_utils_mls_ds.h
...ash_attn_hg/include/fwd/gfx938/fp8_qk_gemm_utils_mls_ds.h
+159
-0
csrc/flash_attn_hg/include/fwd/gfx938/fp8_softmax_gfx938.h
csrc/flash_attn_hg/include/fwd/gfx938/fp8_softmax_gfx938.h
+302
-0
csrc/flash_attn_hg/include/fwd/gfx938/fwd_epilogue_gfx938.h
csrc/flash_attn_hg/include/fwd/gfx938/fwd_epilogue_gfx938.h
+6
-4
csrc/flash_attn_hg/include/fwd/gfx938/pv_gemm_prefetch_k_mls_ds.h
...sh_attn_hg/include/fwd/gfx938/pv_gemm_prefetch_k_mls_ds.h
+5
-17
csrc/flash_attn_hg/include/fwd/gfx938/pv_gemm_utils_mls_ds.h
csrc/flash_attn_hg/include/fwd/gfx938/pv_gemm_utils_mls_ds.h
+2
-4
csrc/flash_attn_hg/include/fwd/gfx938/qk_gemm_prefetch_v_mls_ds.h
...sh_attn_hg/include/fwd/gfx938/qk_gemm_prefetch_v_mls_ds.h
+6
-17
csrc/flash_attn_hg/include/fwd/gfx938/qk_gemm_utils_mls_ds.h
csrc/flash_attn_hg/include/fwd/gfx938/qk_gemm_utils_mls_ds.h
+8
-12
csrc/flash_attn_hg/include/fwd/gfx938/softmax_gfx938.h
csrc/flash_attn_hg/include/fwd/gfx938/softmax_gfx938.h
+5
-5
csrc/flash_attn_hg/include/fwd/softmax.h
csrc/flash_attn_hg/include/fwd/softmax.h
+23
-30
No files found.
csrc/flash_attn_hg/include/flash_singleton.h
View file @
518a5f4d
...
...
@@ -9,6 +9,7 @@ __attribute__((weak)) int getArch() {
auto
hipResult
=
hipGetDeviceProperties
(
&
props
,
0
);
std
::
string
gcn_arch_name
(
props
.
gcnArchName
);
gcn_arch_name
=
gcn_arch_name
.
substr
(
3
,
3
);
if
(
gcn_arch_name
==
"92a"
)
gcn_arch_name
=
"930"
;
int
gcn_arch
=
std
::
stoi
(
gcn_arch_name
);
return
gcn_arch
;
}
...
...
@@ -38,13 +39,8 @@ private:
DeviceProperties
()
{
// 可以在这里给内部变量赋初始值
hipDeviceProp_t
props
;
auto
hipResult
=
hipGetDeviceProperties
(
&
props
,
0
);
#ifdef ROCM_5_7
this
->
gcn_arch
=
props
.
gcnArch
;
#else
std
::
string
gcn_arch_name
(
props
.
gcnArchName
);
this
->
gcn_arch
=
std
::
stoi
(
gcn_arch_name
.
substr
(
3
,
3
));
#endif
this
->
cu_count
=
props
.
multiProcessorCount
;
this
->
gcn_arch
=
getArch
();
const
char
*
fa_debug
=
std
::
getenv
(
"FA_DEBUG"
);
bool
do_fa_debug
=
fa_debug
!=
nullptr
;
...
...
@@ -55,16 +51,19 @@ private:
const
size_t
q_smem_size
=
run_new_mls
?
least_required_size
:
Kernel_traits
::
q_smem_size
;
const
size_t
k_smem_size
=
run_new_mls
?
least_required_size
:
Kernel_traits
::
k_smem_size
*
2
;
const
size_t
v_smem_size
=
run_new_mls
?
least_required_size
:
Kernel_traits
::
v_smem_size
*
2
;
if
(
gcn_arch
==
928
or
gcn_arch
==
936
or
gcn_arch
==
938
)
{
if
(
gcn_arch
==
928
or
gcn_arch
==
936
or
gcn_arch
==
938
or
gcn_arch
==
946
)
{
this
->
lds_size
=
run_new_mls
?
std
::
max
(
q_smem_size
,
std
::
max
(
v_smem_size
,
k_smem_size
))
:
std
::
max
(
q_smem_size
,
v_smem_size
+
k_smem_size
);
}
else
if
(
gcn_arch
==
930
)
{
this
->
lds_size
=
32
*
1024
;
}
if
(
do_fa_debug
and
std
::
strcmp
(
fa_debug
,
"2"
))
{
printf
(
"gcn_arch: %d
\n
q_smem_size: %ld
\n
k_smem_size: %ld
\n
v_smem_size: %ld
\n
shared memory size: %ld
\n
cu count: %d
\n
"
,
this
->
gcn_arch
,
q_smem_size
,
k_smem_size
,
v_smem_size
,
this
->
lds_size
,
this
->
cu_count
);
}
}
else
if
constexpr
(
Func
==
FAFUNC
::
BACKWARD
)
{
this
->
lds_size
=
32
*
1024
;
if
(
this
->
gcn_arch
>=
936
&&
Kernel_traits
::
kHeadDim
=
=
128
){
if
(
this
->
gcn_arch
==
936
)
{
if
(
this
->
gcn_arch
>=
936
&&
Kernel_traits
::
kHeadDim
<
=
128
){
if
(
this
->
gcn_arch
==
936
||
this
->
gcn_arch
==
938
)
{
this
->
lds_size
=
21
*
1024
;
}
else
{
this
->
lds_size
=
16
*
1024
;
...
...
csrc/flash_attn_hg/include/fwd/fwd_epilogue.h
View file @
518a5f4d
#include "numeric_types.h"
#include "intrinsic.h"
template
<
int
WARP_M
,
int
kBlockK
,
int
kHeadDimV
,
typename
ElementAccum
>
__forceinline__
__device__
void
fwd_apply_attention_sink
(
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
],
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
],
const
ElementAccum
scale_softmax
,
const
float
sink_value
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
const
ElementAccum
old_scaled_max
=
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
;
const
ElementAccum
new_scaled_max
=
max
(
old_scaled_max
,
ElementAccum
(
sink_value
));
const
ElementAccum
old_rescale
=
__expf
(
old_scaled_max
-
new_scaled_max
);
scores_sum
[
mi
].
f32
[
min_tile_m
]
=
scores_sum
[
mi
].
f32
[
min_tile_m
]
*
old_rescale
+
__expf
(
ElementAccum
(
sink_value
)
-
new_scaled_max
);
scores_max
[
mi
].
f32
[
min_tile_m
]
=
new_scaled_max
/
scale_softmax
;
__float2
old_rescale_pair
=
{
old_rescale
,
old_rescale
};
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockK
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
kHeadDimV
/
kBlockK
);
++
pv_n_loop
)
{
const
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
ni
*
(
WARP_M
/
32
)
+
mi
;
#if defined(__gfx936__) || defined(__gfx938__)
#pragma unroll
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
],
old_rescale_pair
);
}
#else
#pragma unroll
for
(
int
vec_id
=
0
;
vec_id
<
4
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
f32
[
vec_id
]
*=
old_rescale
;
}
#endif
}
}
}
}
}
}
template
<
int
WARP_M
,
int
kBlockK
,
int
kHeadDimV
,
bool
Is_dropout
,
typename
ElementAccum
>
__forceinline__
__device__
void
fwd_epilugue_rescale_acco
(
...
...
@@ -28,9 +72,9 @@ __forceinline__ __device__ void fwd_epilugue_rescale_acco(
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
kHeadDimV
/
kBlockK
);
++
pv_n_loop
)
{
const
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
ni
*
(
WARP_M
/
32
)
+
mi
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
...
...
@@ -108,14 +152,51 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
constexpr
bool
Is_Interleaved_
=
Is_Interleaved
and
kHeadDimV
==
128
;
#else
constexpr
bool
Is_Interleaved_
=
Is_Interleaved
;
#endif
if
constexpr
(
Is_Interleaved_
)
{
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__)
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M
/
32
);
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK
/
32
);
++
k_tile_idx
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
#pragma unroll 2
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
pv_tile_id
=
k_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
;
const
int
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
int
seqlen_q_offset
=
warp_id
*
WARP_M
+
warp_m_idx
*
32
+
min_tile_m
*
16
+
pv_lane_seq_idx
;
// prepare for store
int
s_offset
=
k_tile_idx
*
32
+
min_tile_n
*
16
;
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
k_loop
*
kBlockK
+
pv_lane_head_dim_idx
*
4
;
union_vec2_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
// convert float -> bf16/fp16
v_data
.
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
pv_tile_id
][
mmac_id
].
f32x2
[
vec_index
]);
}
if
constexpr
(
not
Is_even_MN
)
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
}
else
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
}
}
}
}
}
// brace, to control vgpr usage
#else
// simulate mmac-4interleave via lds
// todo: lds bank conflicts, vgpr spills
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll
...
...
@@ -128,14 +209,17 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
int
s_offset
=
k_loop
*
kBlockK
;
int
seqlen_q_offset
=
(
warp_id
*
WARP_M
+
warp_m_idx
*
32
+
pv_lane_seq_idx
*
2
+
min_tile_m
);
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
pv_lane_head_dim_idx
*
8
;
// prepare vgprs
union_vec4_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
++
vec_index
)
{
// convert float -> bf16/fp16
constexpr
bool
is_bf16
=
std
::
is_same
<
Element
,
bhalf_t
>::
value
;
v_data
.
f16x2
[
vec_index
][
0
]
=
DownCast
<
ElementAccum
,
Element
,
is_bf16
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
]);
v_data
.
f16x2
[
vec_index
][
1
]
=
DownCast
<
ElementAccum
,
Element
,
is_bf16
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]);
}
// try interleave
auto
lds
=
(
__attribute__
((
address_space
(
3
)))
float
*
)(
0
);
int
lds_write_offset
=
(
warp_id
*
512
+
pv_lane_seq_idx
*
16
+
pv_lane_head_dim_idx
*
4
+
pv_lane_seq_idx
*
4
)
*
4
;
__builtin_amdgcn_sched_barrier
(
0
);
...
...
@@ -148,6 +232,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
// write to global memory
if
constexpr
(
Is_even_MN
)
{
*
(
vec4_fp32
*
)(
o_ptr
+
v_offset
+
s_offset
+
k_tile_idx
*
32
)
=
v_data
.
f32
;
}
else
{
...
...
@@ -159,42 +244,9 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
}
}
#else
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M
/
32
);
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK
/
32
);
++
k_tile_idx
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
#pragma unroll 2
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
pv_tile_id
=
k_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
;
const
int
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
int
seqlen_q_offset
=
warp_id
*
WARP_M
+
warp_m_idx
*
32
+
min_tile_m
*
16
+
pv_lane_seq_idx
;
int
s_offset
=
k_tile_idx
*
32
+
min_tile_n
*
16
;
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
k_loop
*
kBlockK
+
pv_lane_head_dim_idx
*
4
;
union_vec2_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
v_data
.
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
pv_tile_id
][
mmac_id
].
f32x2
[
vec_index
]);
}
if
constexpr
(
not
Is_even_MN
)
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
}
else
{
*
(
union_vec2_f16x2
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
)
=
v_data
;
}
}
}
}
}
}
#endif
}
else
{
auto
gO
=
prepare_for_buffer_load
<
kHeadDimV
,
Element
,
TcpSwizzle
>
(
o_ptr
);
auto
o_resource
=
prepare_for_buffer_load
<
kHeadDimV
,
Element
,
TcpSwizzle
>
(
o_ptr
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll
...
...
@@ -234,7 +286,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output(
}
// write to global memory
if
constexpr
(
Is_even_MN
)
{
inline_buffer_store_dword
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gO
,
s_offset
,
/* immediate integer */
s_offset_constexpr
);
inline_buffer_store_dword
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
o_resource
,
s_offset
,
/* immediate integer */
s_offset_constexpr
);
}
else
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
*
(
vec2_Element
<
Element
>*
)(
o_ptr
+
v_offset
+
s_offset
+
s_offset_constexpr
)
=
v_data
;
...
...
csrc/flash_attn_hg/include/fwd/gfx92a/fwd_epilogue_gfx92a.h
0 → 100644
View file @
518a5f4d
#include "numeric_types.h"
#include "intrinsic.h"
template
<
int
kHeadDimV
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
bool
Is_even_MN
,
bool
Is_Interleaved
,
bool
TcpSwizzle
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fwd_epilogue_store_output_mls_gfx92a
(
Element
*
o_ptr
,
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
],
int
m_block
,
int
warp_id
,
int
lane_id
,
int
seqlen_o_stride
,
int
seqlen_q_limit
)
{
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
// MLS gfx92a PV accumulators are laid out as 4-interleaved rows. Keep
// this store path private to the MLS gfx92a kernels so the generic fwd
// epilogue can continue to serve the legacy FA_FWD_NO_MLS path unchanged.
if
constexpr
(
false
)
{
}
else
{
auto
gO
=
prepare_for_buffer_load
<
kHeadDimV
,
Element
,
TcpSwizzle
>
(
o_ptr
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M
/
32
);
++
warp_m_idx
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK
/
32
);
++
k_tile_idx
)
{
#pragma unroll 2
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
++
vec_index
)
{
int
tile32x32_id
=
k_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
;
int
s_offset
=
k_loop
*
kBlockK
;
int
s_offset_constexpr
=
k_tile_idx
*
32
+
vec_index
*
8
;
int
seqlen_q_offset
=
warp_id
*
WARP_M
+
warp_m_idx
*
32
+
pv_lane_seq_idx
+
min_tile_m
*
16
;
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
pv_lane_head_dim_idx
*
2
;
vec2_Element
<
Element
>
v_data
;
if
constexpr
(
std
::
is_same
<
Element
,
bhalf_t
>::
value
)
{
*
(
vec2_Element
<
Element
>*
)
&
v_data
=
DownCastPairNoPack
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
],
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]
);
}
else
if
constexpr
(
std
::
is_same
<
Element
,
half_t
>::
value
)
{
#ifdef USE_CVT_PKRTZ_FP16_FP32
*
(
vec2_Element
<
Element
>*
)
&
v_data
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
],
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]
);
#else
v_data
[
0
]
=
DownCast
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
]);
v_data
[
1
]
=
DownCast
<
ElementAccum
,
Element
>
(
acc_o
[
tile32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]);
#endif
}
if
constexpr
(
Is_even_MN
)
{
inline_buffer_store_dword
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gO
,
s_offset
,
s_offset_constexpr
);
}
else
{
if
(
m_block
*
kBlockM
+
seqlen_q_offset
<
seqlen_q_limit
)
{
inline_buffer_store_dword
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gO
,
s_offset
,
s_offset_constexpr
);
}
}
}
}
}
}
}
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
}
template
<
int
kHeadDimV
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
bool
Is_even_MN
,
bool
Is_Interleaved
,
bool
TcpSwizzle
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fwd_epilogue_store_output_gfx92a
(
Element
*
o_ptr
,
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
],
int
m_block
,
int
warp_id
,
int
lane_id
,
int
seqlen_o_stride
,
int
seqlen_q_limit
)
{
fwd_epilogue_store_output_mls_gfx92a
<
kHeadDimV
,
kBlockM
,
kBlockK
,
WARP_M
,
Is_even_MN
,
Is_Interleaved
,
TcpSwizzle
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
warp_id
,
lane_id
,
seqlen_o_stride
,
seqlen_q_limit
);
}
csrc/flash_attn_hg/include/fwd/gfx92a/pv_gemm_prefetch_k_mls_ds_gfx92a.h
0 → 100644
View file @
518a5f4d
#include "fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h"
#include "static_switch.h"
template
<
bool
PREFETCH_K
,
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
pv_gemm_prefetch_k_mls_ds_gfx92a
(
vec4_uint
v_ptr
,
vec4_uint
k_ptr
,
Element
*
v_lds
,
Element
*
k_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
],
vec4_Accum
<
ElementAccum
>
pv_reg
[(
kHeadDimV
/
kBlockN
)
*
(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
*
kBlockN
/
(
WARP_M
*
WARP_N
);
constexpr
int
WARP_K
=
32
;
constexpr
int
READ_ONCE_COUNT
=
32
*
32
;
constexpr
int
V_LDS_LOAD_NUM
=
(
kHeadDimV
*
WARP_K
)
/
READ_ONCE_COUNT
;
constexpr
int
V_LOAD_REQUESTS
=
V_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockM
>=
WARP_M
,
"Error: pv gemm kBlockM must be equal or greater than WARP_M"
);
static_assert
(
kBlockN
==
WARP_N
,
"Error: pv gemm kBlockN must be equal to WARP_N"
);
static_assert
(
WARP_K
==
32
and
"Error: To simplify, only WARP_K = 32 is supported!"
);
static_assert
(
WARP_M
==
32
and
"Error: To simplify, only WARP_M = 32 is supported!"
);
static_assert
(
WARP_N
==
32
and
"Error: To simplify, only WARP_N = 32 is supported!"
);
// Prepare V regs
union_vec4_f16x2
<
Element
>
v_reg
[
STAGES
*
(
32
*
WARP_N
)
/
(
32
*
32
)
*
2
];
// Prepare V lds offset
int
v_lds_base
=
0
;
// reinterpret_cast<size_t>(v_lds); // ===> 性能下降 ?
// Prepare MLS buffer resource sregs
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_ptr
[
0
];
v_srsrc
[
1
]
=
v_ptr
[
1
];
v_srsrc
[
2
]
=
seqlen_v_stride
;
// stride
v_srsrc
[
3
]
=
0
;
int
lds_stage_id
=
1
;
// Main loop across blockN(128) among seqlenkv
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockK
/
WARP_K
);
++
n_loop
)
{
// Do k-dim interleave for next mmac
#if defined(__gfx92a__)
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
0
+
0
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
0
+
1
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
1
+
0
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
1
+
1
].
f16x4
);
#endif
// MLS dispatch
if
constexpr
(
Is_even_MN
)
{
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_ptr
+
(
n_loop
*
WARP_K
*
seqlen_v_stride
+
warp_id
*
32
)
*
ELEMENT_BYTES
);
v_srsrc
[
3
]
=
0x20000
;
}
else
{
int
nm_filter_max
=
n_loop
*
WARP_K
+
32
-
max_seq_kv_offset
;
int
real_mls_loop
=
nm_filter_max
>=
32
?
0
:
n_loop
;
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_loop
*
WARP_K
+
32
-
max_seq_kv_offset
);
v_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x20000
;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_ptr
+
(
real_mls_loop
*
WARP_K
*
seqlen_v_stride
+
warp_id
*
32
)
*
ELEMENT_BYTES
);
}
int
lds_write_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_write_offset
,
0
);
// Wait buffer
lds_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
// DS dispatch
int
lds_load_offset
=
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
// Wait ds_mpermute
#if defined(__gfx92a__)
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
#endif
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// DS dispatch
int
lds_load_offset
=
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
{
constexpr
int
min_tile_k
=
0
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
}
// Prefetch K
if
constexpr
(
PREFETCH_K
)
{
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockK
,
kBlockN
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
max_seq_kv_offset
);
}
{
constexpr
int
n_loop
=
4
;
// Do k-dim interleave for next mmac
#if defined(__gfx92a__)
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
0
+
0
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
0
+
1
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
1
+
0
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
n_loop
-
1
][
2
*
1
+
1
].
f16x4
);
#endif
lds_stage_id
^=
1
;
int
stage_id
=
0
;
// Wait buffer
if
constexpr
(
PREFETCH_K
)
{
flash
::
wait_buffer_data_arrived
<
true
>
(
V_LOAD_REQUESTS
);
}
else
{
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
}
// DS dispatch
int
lds_load_offset
=
(
0
/*k_loop*/
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
// Wait ds_mpermute
#if defined(__gfx92a__)
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
#endif
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDimV
/
kBlockN
);
++
k_loop
)
{
// DS dispatch
int
lds_load_offset
=
(
k_loop
*
32
*
32
+
lds_stage_id
*
WARP_K
*
kHeadDimV
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_ALT2
(
lds_load_offset
,
v_reg
[
stage_id
*
2
+
0
].
f16
,
v_reg
[
stage_id
*
2
+
1
].
f16
,
false
/*transpose*/
);
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
stage_id
^=
1
;
{
constexpr
int
min_tile_k
=
0
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
constexpr
int
min_tile_k
=
1
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// last mmac
{
constexpr
int
min_tile_k
=
0
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
constexpr
int
min_tile_k
=
1
;
flash
::
raise_priority
(
1
);
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
pv_tile_id
=
(
kHeadDimV
/
kBlockN
)
-
1
;
int
v_tile_id
=
stage_id
*
2
+
min_tile_k
;
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
n_loop
-
1
][
2
*
min_tile_k
+
min_tile_m
].
f16x4
,
v_reg
[
v_tile_id
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
flash
::
lower_priority
();
}
}
}
csrc/flash_attn_hg/include/fwd/gfx92a/qk_gemm_prefetch_v_mls_ds_gfx92a.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "fwd/gfx938/pv_gemm_utils_mls_ds.h"
#include "fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h"
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_gfx92a_2TG
(
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
32
and
"To simplify, only WARP_M = 32 is supported!"
);
static_assert
(
WARP_N
==
32
and
"To simplify, only WARP_N = 32 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
k_lds_load_num
=
WARP_N
*
kHeadDim
/
(
32
*
32
);
constexpr
int
K_LOAD_REQUESTS
=
k_lds_load_num
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
// 准备 K 寄存器
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
WARP_N
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// 计算 K lds 起始偏移量
int
k_lds_base
=
reinterpret_cast
<
size_t
>
(
k_lds
);
// here, v_mov_b64 can be applied
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
32
)
*
(
kBlockN
/
32
);
++
i
)
{
// for kBlockN = 64, only wave 0 get the right QK gemm results
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0
;
}
}
}
flash
::
lower_priority
();
// MLS
vec4_uint
k_srsrc
;
k_srsrc
[
2
]
=
seqlen_k_stride
;
// stride
k_srsrc
[
3
]
=
0
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
(
kBlockN
/
WARP_N
);
++
n_loop
)
{
// Wait global data
flash
::
wait_buffer_data_arrived
<
true
>
(
kBlockN
/
WARP_N
-
n_loop
-
1
);
// DS
int
stage_id
=
0
;
{
constexpr
int
k_loop
=
0
;
int
lds_load_offset
=
k_lds_base
+
(
n_loop
*
WARP_N
*
kHeadDim
+
k_loop
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
stage_id
^=
1
;
#pragma unroll
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
// DS
int
lds_load_offset
=
k_lds_base
+
(
n_loop
*
WARP_N
*
kHeadDim
+
k_loop
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
flash
::
raise_priority
();
// MMAC
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
lower_priority
();
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
flash
::
raise_priority
();
// last mmac
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
lower_priority
();
}
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__) || defined(__gfx946__)
prefetch_v_to_lds_mls_ds
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
2
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
max_seq_k_offset
);
#else
#endif
}
}
// qk_gemm
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
,
bool
Is_even_MN
>
__forceinline__
__device__
void
qk_gemm_prefetch_v_mls_ds_gfx92a
(
vec4_uint
k_ptr
,
vec4_uint
v_ptr
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
seqlen_k_stride
,
int
seqlen_v_stride
,
int
max_seq_k_offset
=
0
)
{
// Simplify
static_assert
(
kBlockK
==
32
and
"To simplify, only kBlockK = 32 is supported!"
);
static_assert
(
WARP_M
==
32
and
"To simplify, only WARP_M = 32 is supported!"
);
static_assert
(
WARP_N
==
32
and
"To simplify, only WARP_N = 32 is supported!"
);
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
k_lds_load_num
=
WARP_N
*
kHeadDim
/
(
32
*
32
);
constexpr
int
K_LOAD_REQUESTS
=
k_lds_load_num
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
// Prepare regs for k
union_vec4_f16x2
<
Element
>
k_reg
[
STAGES
*
(
WARP_N
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// Zero-initialize s_reg
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
kBlockN
==
128
)
{
inline_vgpr4_init_zero_4x4x4
(
s_reg
);
}
else
{
for
(
int
i
=
0
;
i
<
(
WARP_M
/
32
)
*
(
kBlockN
/
32
);
++
i
)
{
// for kBlockN = 64, only wave 0 get the right QK gemm results
for
(
int
j
=
0
;
j
<
4
;
++
j
)
{
s_reg
[
i
][
j
].
u64
[
0
]
=
0
;
s_reg
[
i
][
j
].
u64
[
1
]
=
0
;
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// Prepare MLS buffer resource sregs
vec4_uint
k_srsrc
;
k_srsrc
[
2
]
=
seqlen_k_stride
;
// stride
k_srsrc
[
3
]
=
0
;
int
n_stage_id
=
1
;
#pragma unroll
for
(
int
n_loop
=
1
;
n_loop
<
(
kBlockN
/
WARP_N
);
++
n_loop
)
{
// MLS dispatch
const
bool
has_tail
=
max_seq_k_offset
%
kBlockN
!=
0
;
const
int
nm_filter_max
=
n_loop
*
WARP_N
+
32
-
max_seq_k_offset
;
const
int
k_load_loop
=
has_tail
&&
nm_filter_max
>=
32
?
0
:
n_loop
;
const
int
nm_filter
=
inline_min_max
<
0
,
31
>
(
k_load_loop
*
WARP_N
+
32
-
max_seq_k_offset
);
const
int
__nm_filter
=
__builtin_amdgcn_readfirstlane
(
nm_filter
);
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_ptr
+
(
k_load_loop
*
WARP_N
*
seqlen_k_stride
+
warp_id
*
32
)
*
ELEMENT_BYTES
);
k_srsrc
[
3
]
=
has_tail
?
__nm_filter
<<
8
:
0
;
// set only once
int
lds_write_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// sync lds usage when ping-pong
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_write_offset
,
0
);
// Wait MLS
n_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
K_LOAD_REQUESTS
);
// DS dispatch
{
constexpr
int
k_loop
=
0
;
int
lds_load_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim
+
k_loop
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
// DS dispatch
int
lds_load_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim
+
k_loop
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
asm
volatile
(
"s_setprio 2"
);
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
asm
volatile
(
"s_setprio 0"
);
}
stage_id
^=
1
;
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
asm
volatile
(
"s_setprio 2"
);
// MMAC
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
asm
volatile
(
"s_setprio 0"
);
}
{
// Wait MLS
constexpr
int
n_loop
=
4
;
n_stage_id
^=
1
;
int
stage_id
=
0
;
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// DS dispatch
{
int
k_loop
=
0
;
int
lds_load_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim
+
k_loop
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
}
stage_id
^=
1
;
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
// DS dispatch
int
lds_load_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim
+
k_loop
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
k_reg
[
stage_id
*
2
].
f16
,
k_reg
[
stage_id
*
2
+
1
].
f16
,
true
);
// Wait DS
flash
::
wait_lds_data_arrived
<
false
>
(
3
);
// MMAC
asm
volatile
(
"s_setprio 2"
);
stage_id
^=
1
;
{
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
2
);
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
(
STAGES
==
2
)
?
k_loop
-
1
:
k_loop
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
asm
volatile
(
"s_setprio 0"
);
}
stage_id
^=
1
;
flash
::
wait_lds_data_arrived
<
false
>
(
1
);
// MMAC
asm
volatile
(
"s_setprio 2"
);
// flash::raise_priority 性能下降严重, 157.8 -> 148.2 tflops, strange, 需要看汇编
{
// 对比汇编, 差异就在于单独的 s_setprio 会被胡乱调度到 mmac 中间, 但这样跑出来却性能更高; 强行加 scheduled barrier 跑出来性能更低;
int
min_tile_n
=
0
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
flash
::
wait_lds_data_arrived
<
false
>
(
0
);
{
int
min_tile_n
=
1
;
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
int
k_loop_idx
=
kHeadDim
/
kBlockK
-
1
;
int
q_tile_id
=
k_loop_idx
*
2
+
min_tile_m
;
int
k_tile_id
=
stage_id
*
2
+
min_tile_n
;
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
k_tile_id
].
f16x4
[
min_tile_k
],
s_reg
[
n_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
asm
volatile
(
"s_setprio 0"
);
// flash::lower_priority 性能下降 0.6 tflops 左右
}
if
constexpr
(
STAGES
==
2
)
{
prefetch_v_to_lds_mls_ds
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
2
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
max_seq_k_offset
);
}
}
// qk_gemm
csrc/flash_attn_hg/include/fwd/gfx92a/qk_gemm_utils_mls_ds_gfx92a.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "fwd/gfx938/qk_gemm_utils_mls_ds.h"
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_q_to_vgpr_mls_ds_gfx92a
(
vec4_uint
q_ptr
,
Element
*
q_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
int
warp_id
,
int
seqlen_q_stride
,
int
max_seq_q_offset
=
0
)
{
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
Q_LDS_LOAD_NUM
=
kBlockM
*
kBlockK
/
(
32
*
32
);
constexpr
int
Q_LOAD_REQUESTS
=
Q_LDS_LOAD_NUM
/
WARP_NUM
;
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
int
q_lds_base
=
reinterpret_cast
<
size_t
>
(
q_lds
);
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
vec4_uint
q_srsrc
;
q_srsrc
[
2
]
=
seqlen_q_stride
;
const
int
q_row
=
warp_id
*
32
;
const
int
nm_filter
=
q_row
+
32
-
max_seq_q_offset
;
const
bool
has_tail
=
max_seq_q_offset
%
kBlockM
!=
0
;
const
int
q_load_row
=
has_tail
&&
nm_filter
>=
32
?
0
:
q_row
;
// gfx92a has a 5-bit MLS filter field, so never encode 32.
const
int
q_filter
=
inline_min_max
<
0
,
31
>
(
q_load_row
+
32
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
has_tail
?
q_filter
<<
8
:
0
;
int
stage_id
=
0
;
{
int
k_loop
=
0
;
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_ptr
+
(
k_loop
*
kBlockK
+
q_load_row
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
int
lds_offset
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
inline_matrix_load_32x32_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
}
stage_id
^=
1
;
#pragma unroll
for
(
int
k_loop
=
1
;
k_loop
<
(
kHeadDim
/
kBlockK
);
++
k_loop
)
{
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_ptr
+
(
k_loop
*
kBlockK
+
q_load_row
*
seqlen_q_stride
)
*
ELEMENT_BYTES
);
int
lds_offset
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
inline_matrix_load_32x32_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
stage_id
^=
1
;
buffer_load_lds_dwordx1_wait
<
Q_LOAD_REQUESTS
>
();
__builtin_amdgcn_sched_barrier
(
0
);
int
lds_load_offset
=
q_lds_base
+
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
q_reg
[(
k_loop
-
1
)
*
2
].
f16
,
q_reg
[(
k_loop
-
1
)
*
2
+
1
].
f16
,
true
);
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
}
{
stage_id
^=
1
;
buffer_load_lds_dwordx1_wait
<
0
>
();
__builtin_amdgcn_sched_barrier
(
0
);
constexpr
int
k_loop
=
kHeadDim
/
kBlockK
-
1
;
int
lds_load_offset
=
q_lds_base
+
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
q_reg
[
k_loop
*
2
].
f16
,
q_reg
[
k_loop
*
2
+
1
].
f16
,
true
);
}
__builtin_amdgcn_s_waitcnt
(
0
);
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
}
template
<
int
kHeadDim
,
int
kBlockN
,
int
kBlockK
,
int
WARP_NUM
,
int
WARP_N
,
typename
Element
,
bool
Is_even_MN
>
__forceinline__
__device__
void
prefetch_k_to_lds_mls_ds_gfx92a
(
vec4_uint
k_ptr
,
Element
*
k_lds
,
int
warp_id
,
int
seqlen_k_stride
,
int
max_seq_k_offset
=
0
)
{
flash
::
wait_all_warp_arrived
();
constexpr
int
ELEMENT_BYTES
=
sizeof
(
Element
);
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
WARP_N
;
++
n_loop
)
{
vec4_uint
k_srsrc
;
k_srsrc
[
2
]
=
seqlen_k_stride
;
const
bool
has_tail
=
max_seq_k_offset
%
kBlockN
!=
0
;
const
int
nm_filter_max
=
n_loop
*
WARP_N
+
32
-
max_seq_k_offset
;
const
int
k_load_loop
=
has_tail
&&
nm_filter_max
>=
32
?
0
:
n_loop
;
const
int
nm_filter
=
inline_min_max
<
0
,
31
>
(
k_load_loop
*
WARP_N
+
32
-
max_seq_k_offset
);
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_ptr
+
(
k_load_loop
*
WARP_N
*
seqlen_k_stride
+
warp_id
*
32
)
*
ELEMENT_BYTES
);
k_srsrc
[
3
]
=
has_tail
?
nm_filter
<<
8
:
0
;
int
lds_offset
=
(
n_loop
*
WARP_N
*
kHeadDim
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
__builtin_amdgcn_sched_barrier
(
0
);
}
csrc/flash_attn_hg/include/fwd/gfx92a/softmax_gfx92a.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "philox.cuh"
#include "../utils.h"
using
namespace
flash
;
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
>
inline
__device__
void
apply_mask_gfx92a
(
DataType
tensor
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
*
4
;
if
(
col_idx
>=
max_seqlen_k
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
}
}
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
>
inline
__device__
void
apply_mask_causal_gfx92a
(
DataType
tensor
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
*
4
;
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
bool
HasWSLeft
=
true
,
typename
DataType
,
int
WARP_M
,
int
WARP_N
>
inline
__device__
void
apply_mask_local_gfx92a
(
DataType
tensor
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
col_idx_limit_left
=
std
::
max
(
0
,
row_idx
+
1
+
max_seqlen_k
-
max_seqlen_q
-
window_size_left
);
const
int
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_idx
+
max_seqlen_k
-
max_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
*
4
;
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
||
(
HasWSLeft
&&
col_idx
<
(
col_idx_limit_left
-
1
)))
?
-
INFINITY
:
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
typename
DataType
,
int
WARP_M
,
int
WARP_N
>
inline
__device__
void
apply_alibi_gfx92a
(
DataType
tensor
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
float
g_alibi
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
*
4
;
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
+=
g_alibi
*
(
col_idx
-
row_idx
);
}
}
}
}
}
}
\ No newline at end of file
csrc/flash_attn_hg/include/fwd/gfx938/fp8_epilogue.h
0 → 100644
View file @
518a5f4d
#include "numeric_types.h"
#include "intrinsic.h"
__forceinline__
__device__
float
fp8_attention_sink_load
(
const
void
*
s_aux_ptr
,
int
s_aux_type
,
int
head_idx
)
{
if
(
s_aux_type
==
1
)
{
return
reinterpret_cast
<
const
float
*>
(
s_aux_ptr
)[
head_idx
];
}
else
if
(
s_aux_type
==
2
)
{
return
UpCast
<
half_t
,
float
>
(
reinterpret_cast
<
const
half_t
*>
(
s_aux_ptr
)[
head_idx
]);
}
else
{
return
UpCast
<
BFloat16
,
float
>
(
reinterpret_cast
<
const
BFloat16
*>
(
s_aux_ptr
)[
head_idx
]);
}
}
template
<
int
kHeadDim
,
int
WARP_M
,
int
WARP_N
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_attention_sink_apply
(
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDim
/
32
][
WARP_M
/
16
][
WARP_N
/
16
],
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
ElementAccum
softmax_scale
,
float
sink_value
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
const
ElementAccum
old_scaled_max
=
scores_max
[
m_idx
]
*
softmax_scale
;
const
ElementAccum
new_scaled_max
=
max
(
old_scaled_max
,
ElementAccum
(
sink_value
));
const
ElementAccum
old_rescale
=
__expf
(
old_scaled_max
-
new_scaled_max
);
scores_sum
[
m_idx
]
=
scores_sum
[
m_idx
]
*
old_rescale
+
__expf
(
ElementAccum
(
sink_value
)
-
new_scaled_max
);
scores_max
[
m_idx
]
=
new_scaled_max
/
softmax_scale
;
__float2
old_rescale_pair
;
old_rescale_pair
[
0
]
=
old_rescale
;
old_rescale_pair
[
1
]
=
old_rescale
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kHeadDim
/
32
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
0
],
old_rescale_pair
);
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
1
],
old_rescale_pair
);
}
}
}
}
template
<
bool
AssumeValidRows
,
int
kHeadDim
,
int
WARP_M
,
int
WARP_N
,
bool
StoreLSE
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_epilogue_rescale_acc_o
(
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDim
/
32
][
WARP_M
/
16
][
WARP_N
/
16
],
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
ElementAccum
lse
[
WARP_M
/
16
],
ElementAccum
softmax_scale
,
ElementAccum
v_descale
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
ElementAccum
sum
=
scores_sum
[
m_idx
];
if
constexpr
(
StoreLSE
)
{
lse
[
m_idx
]
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
INFINITY
:
__llvm_fma_f32
(
scores_max
[
m_idx
],
softmax_scale
,
__logf
(
sum
));
}
ElementAccum
total_rescale
;
if
constexpr
(
AssumeValidRows
)
{
total_rescale
=
v_descale
/
sum
;
}
else
{
total_rescale
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
0.
f
:
v_descale
/
sum
;
}
__float2
total_scale_pair
;
total_scale_pair
[
0
]
=
total_rescale
;
total_scale_pair
[
1
]
=
total_rescale
;
// __float2 inv_sum_pair;
// inv_sum_pair[0] = 1.0f / sum;
// inv_sum_pair[1] = inv_sum_pair[0];
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kHeadDim
/
32
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
0
],
total_scale_pair
);
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
k_loop
][
m_idx
][
n_idx
].
u64
[
1
],
total_scale_pair
);
}
}
}
}
template
<
bool
Is_even_MN
,
int
WARP_M
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_epilogue_store_lse
(
// ElementAccum* scores_max_ptr,
// ElementAccum* scores_sum_ptr,
ElementAccum
*
softmax_lse_ptr
,
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
ElementAccum
lse
[
WARP_M
/
16
],
int
row_offset_lse
,
/*(bidb * h + bidh) * actual_seqlen_q*/
int
actual_seqlen_q
,
int
wave_row_offset
,
/*m_block * kBlockM + warp_id * WARP_M*/
int
lane_id
)
{
if
(
lane_id
<
16
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
int
lse_row_id
=
row_offset_lse
+
wave_row_offset
+
((
lane_id
&
15
)
>>
2
)
*
8
+
m_idx
*
4
+
(
lane_id
&
3
);;
// scores_max_ptr[lse_row_id] = scores_max[m_idx];
// scores_sum_ptr[lse_row_id] = scores_sum[m_idx];
if
(
lse_row_id
-
row_offset_lse
<
actual_seqlen_q
){
softmax_lse_ptr
[
lse_row_id
]
=
lse
[
m_idx
];
}
}
}
}
template
<
bool
Is_even_MN
,
int
kBlockM
,
int
kHeadDim
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_epilogue_store_output
(
Element
*
acc_o_ptr
,
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDim
/
32
][
WARP_M
/
16
][
WARP_N
/
16
],
int
m_block
,
int
warp_id
,
int
lane_id
,
int
o_row_stride
,
int
actual_seqlen_q
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
pv_loop
=
0
;
pv_loop
<
kHeadDim
/
32
;
++
pv_loop
)
{
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
int
row_idx
=
warp_id
*
WARP_M
+
((
lane_id
&
15
)
>>
2
)
*
8
+
m_idx
*
4
+
(
lane_id
&
3
);
int
col_idx
=
pv_loop
*
32
+
mmac_id
*
16
+
(
lane_id
>>
4
)
*
4
;
int
offset
=
row_idx
*
o_row_stride
+
col_idx
;
union_vec2_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
v_data
.
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
f32x2
[
vec_index
]);
// v_data = __builtin_hcu_cvt_pk_f16_f32(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2], acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2+1], false/*clamp*/, 0/*omod*/);
// v_data[0] = DownCast<ElementAccum, Float16>(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2]);
// v_data[1] = DownCast<ElementAccum, Float16>(acc_o[pv_loop][m_idx][mmac_id].f32[vec_index*2+1]);
}
if
constexpr
(
Is_even_MN
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
acc_o_ptr
+
offset
)
=
v_data
;
}
else
if
(
m_block
*
kBlockM
+
row_idx
<
actual_seqlen_q
)
{
*
(
union_vec2_f16x2
<
Element
>*
)(
acc_o_ptr
+
offset
)
=
v_data
;
}
// *(vec4_fp32*)(acc_o_ptr + offset) = acc_o[pv_loop][m_idx][mmac_id].f32;
}
}
}
}
csrc/flash_attn_hg/include/fwd/gfx938/fp8_pv_gemm_prefetch_k_mls_ds.h
0 → 100644
View file @
518a5f4d
#include "fp8_qk_gemm_utils_mls_ds.h"
#include "static_switch.h"
// PrefetchK=false 版本:不 prefetch K,不需要额外参数
template
<
bool
PrefetchK
,
bool
Is_even_MN
,
int
kHeadDimQK
,
int
kHeadDimV
,
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_pv_gemm_and_prefetch_k
(
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDimV
/
32
][
WARP_M
/
16
][
WARP_N
/
16
],
union_vec32_fp8
p_reg
[
WARP_M
/
16
],
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDimV
/
32
],
int8_t
*
v_lds
,
Element
*&
k_ptr
,
int8_t
*
k_lds
,
int
warp_id
,
int
k_row_stride
,
int
max_seq_kv_offset
)
{
// 等待从 lds 的数据返回
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
PrefetchK
)
{
k_ptr
+=
kBlockN
*
k_row_stride
;
fp8_prefetch_k_to_lds
<
Is_even_MN
,
kHeadDimQK
,
WARP_N
,
Element
>
(
k_ptr
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_kv_offset
);
}
// mmac stream
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
k_loop
+=
1
)
{
#pragma unroll
for
(
int
pv_loop
=
0
;
pv_loop
<
kHeadDimV
/
32
;
++
pv_loop
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
p_reg
[
m_idx
].
i8x8
[
k_loop
],
v_regs
[
k_loop
][
pv_loop
].
i8x8
[
mmac_id
],
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
f32
);
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
// PrefetchK=true 版本:在 PV MMAC 期间 prefetch 下一块 K(paged KV)
template
<
bool
Is_even_MN
,
int
kHeadDimQK
,
int
kHeadDimV
,
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_pv_gemm_and_prefetch_k_paged
(
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDimV
/
32
][
WARP_M
/
16
][
WARP_N
/
16
],
union_vec32_fp8
p_reg
[
WARP_M
/
16
],
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDimV
/
32
],
int8_t
*
v_lds
,
Element
*
k_ptr_next
,
int8_t
*
k_lds
,
int
warp_id
,
int
k_row_stride
,
int
max_seq_kv_offset_next
)
{
// 等待从 lds 的数据返回
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// Prefetch 下一块 K 到 LDS(与 MMAC 重叠)
__builtin_amdgcn_sched_barrier
(
0
);
fp8_prefetch_k_to_lds
<
Is_even_MN
,
kHeadDimQK
,
WARP_N
,
Element
>
(
k_ptr_next
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_kv_offset_next
);
__builtin_amdgcn_sched_barrier
(
0
);
// mmac stream
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
k_loop
+=
1
)
{
#pragma unroll
for
(
int
pv_loop
=
0
;
pv_loop
<
kHeadDimV
/
32
;
++
pv_loop
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
p_reg
[
m_idx
].
i8x8
[
k_loop
],
v_regs
[
k_loop
][
pv_loop
].
i8x8
[
mmac_id
],
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
f32
);
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
csrc/flash_attn_hg/include/fwd/gfx938/fp8_pv_gemm_utils_mls_ds.h
0 → 100644
View file @
518a5f4d
#pragma once // prepare for prefetch V in qk gemm
#include "intrinsic_mls_ds.h"
template
<
bool
Is_even_MN
,
int
kBlockN
,
int
kHeadDim
,
int
WARP_N
,
typename
Element
>
__forceinline__
__device__
void
fp8_prefetch_v_to_lds
(
Element
*
v_ptr
,
int8_t
*
v_lds
,
int
warp_id
,
int
v_row_stride
,
int
max_seq_kv_offset
)
{
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
256
);
// 准备 MLS 寄存器, 填充 stride
vec4_uint
v_root
=
prepare_for_matrix_load
<
kHeadDim
,
Element
>
(
v_ptr
);
vec4_uint
v_srsrc
;
v_srsrc
[
0
]
=
v_root
[
0
];
v_srsrc
[
1
]
=
v_root
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
// stride
v_srsrc
[
3
]
=
0x00000
;
// 4 个 wave 直接全量预取
int
v_lds_write_bytes
=
warp_id
*
WARP_N
*
kHeadDim
*
sizeof
(
Element
);
// 每次读取 32x128 的数据
// tile1: 行 [warp_id*32, warp_id*32+16)
// 整个 tile 被 filter 时保留一行合法 V,避免 0 * NaN 污染 PV。
int
nm_filter_warp0_tile1
=
inline_min_max
<
0
,
16
>
(
16
-
max_seq_kv_offset
);
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
16
-
max_seq_kv_offset
);
v_srsrc
[
0
]
=
(
nm_filter
==
16
)
?
v_root
[
0
]
:
v_root
[
0
]
+
(
warp_id
*
2
)
*
16
*
v_row_stride
*
sizeof
(
Element
);
nm_filter
=
(
nm_filter
==
16
)
?
min
(
nm_filter_warp0_tile1
,
15
)
:
nm_filter
;
v_srsrc
[
3
]
=
v_srsrc
[
3
]
+
((
max_seq_kv_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
flash
::
wait_all_warp_arrived
();
__builtin_hcu_matrix_load_128X16_b8
(
v_srsrc
,
v_lds
+
v_lds_write_bytes
,
0
,
true
,
false
,
false
,
false
,
false
);
if
constexpr
(
kHeadDim
==
256
)
{
v_srsrc
[
0
]
+=
128
*
sizeof
(
Element
);
__builtin_hcu_matrix_load_128X16_b8
(
v_srsrc
,
v_lds
+
v_lds_write_bytes
+
4096
,
0
,
true
,
false
,
false
,
false
,
false
);
}
// tile2: 行 [warp_id*32+16, warp_id*32+32)
int
nm_filter_warp0_tile2
=
inline_min_max
<
0
,
16
>
(
32
-
max_seq_kv_offset
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
32
-
max_seq_kv_offset
);
v_srsrc
[
0
]
=
(
nm_filter
==
16
)
?
v_root
[
0
]
:
v_root
[
0
]
+
(
warp_id
*
2
+
1
)
*
16
*
v_row_stride
*
sizeof
(
Element
);
nm_filter
=
(
nm_filter
==
16
)
?
min
(
nm_filter_warp0_tile2
,
15
)
:
nm_filter
;
v_srsrc
[
3
]
=
((
max_seq_kv_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
v_srsrc
,
v_lds
+
v_lds_write_bytes
+
(
128
*
16
>>
1
),
0
,
true
,
false
,
false
,
false
,
false
);
if
constexpr
(
kHeadDim
==
256
)
{
v_srsrc
[
0
]
+=
128
*
sizeof
(
Element
);
__builtin_hcu_matrix_load_128X16_b8
(
v_srsrc
,
v_lds
+
v_lds_write_bytes
+
4096
+
(
128
*
16
>>
1
),
0
,
true
,
false
,
false
,
false
,
false
);
}
}
csrc/flash_attn_hg/include/fwd/gfx938/fp8_qk_gemm_prefetch_v_mls_ds.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "fp8_pv_gemm_utils_mls_ds.h"
// #define USE_DS_READ_B128_FOR_INTERLEAVE4
template
<
int
kBlockN
,
int
kHeadDim
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_qk_gemm
(
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
union_vec16_fp8
q_regs
[
WARP_M
/
16
][
kHeadDim
/
64
],
int8_t
*
k_lds
)
{
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
192
||
kHeadDim
==
256
);
constexpr
int
kLdsHeadDimStride
=
kHeadDim
==
192
?
256
:
kHeadDim
;
int
tx
=
threadIdx
.
x
;
int
lane_id
=
tx
&
63
;
int
row
=
(
lane_id
&
15
)
>>
1
;
int
col
=
lane_id
>>
4
;
int
col_swizzle
=
(
row
+
col
)
&
3
;
// 等待 K 的数据都写到 lds 了, 4 wave 同步
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier
\n
"
);
// hint: pv mmac 太短, 不够隐藏这一段时延, 可以考虑把 vmcnt 拆细一点, 先等一部分数据回来计算也可以
__builtin_amdgcn_sched_barrier
(
0
);
// __syncthreads();
if
constexpr
(
true
)
{
// 直接从 lds 读数据, 看看 lds 的数据排布
union_vec16_fp8
k_regs
[
kBlockN
/
WARP_N
][
WARP_N
/
16
][
kHeadDim
/
64
];
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
0
,
k_regs
[
k_loop
][
0
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
1024
,
k_regs
[
k_loop
][
1
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
2048
,
k_regs
[
k_loop
][
0
][
1
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
3072
,
k_regs
[
k_loop
][
1
][
1
].
i32x4
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kLdsHeadDimStride
;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0, k_regs[k_loop][0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024, k_regs[k_loop][1][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[k_loop][0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[k_loop][1][1].i32x4, true/*transpose*/)
#pragma unroll
for
(
int
h_idx
=
0
;
h_idx
<
kHeadDim
/
64
;
++
h_idx
)
{
const
int
h_offset
=
(
kHeadDim
==
192
&&
h_idx
==
2
)
?
3
*
2048
:
h_idx
*
2048
;
k_regs
[
k_loop
][
0
][
h_idx
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
h_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
k_loop
][
1
][
h_idx
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
h_offset
+
1024
,
0
,
3
,
1
,
0
);
}
#endif
}
// init s_reg
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
s_reg
[
k_loop
][
m_idx
][
n_idx
]);
}
}
}
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(%0)
\n
"
::
"B"
((
kBlockN
/
WARP_N
-
k_loop
-
1
)
*
4
));
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier
(
0
);
// ======================================================== QK mmac ======================================================================
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
#pragma unroll
for
(
int
h_idx
=
0
;
h_idx
<
kHeadDim
/
64
;
++
h_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
m_idx
][
h_idx
].
i8x8
[
0
],
k_regs
[
k_loop
][
n_idx
][
h_idx
].
i8x8
[
0
],
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
);
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
m_idx
][
h_idx
].
i8x8
[
1
],
k_regs
[
k_loop
][
n_idx
][
h_idx
].
i8x8
[
1
],
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
);
}
}
}
}
}
else
if
constexpr
(
false
and
WARP_M
==
32
and
WARP_N
==
32
and
kBlockN
==
128
and
kHeadDim
==
128
)
{
union_vec16_fp8
k_regs
[
WARP_N
/
16
][
kHeadDim
/
64
];
{
constexpr
int
k_loop
=
0
;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
0
,
k_regs
[
0
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
1024
,
k_regs
[
1
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
2048
,
k_regs
[
0
][
1
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
3072
,
k_regs
[
1
][
1
].
i32x4
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kHeadDim
;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024, k_regs[1][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs
[
0
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
1024
,
0
,
3
,
1
,
0
);
k_regs
[
0
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
2048
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
3072
,
0
,
3
,
1
,
0
);
#endif
// init s_reg
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
2
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
s_reg
[
k_loop
][
m_idx
][
n_idx
]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)
\n
"
);
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// 寄存器不够全量预取, 则预取一部分
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
0
+
WARP_N
*
kHeadDim
,
k_regs
[
0
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
1024
+
WARP_N
*
kHeadDim
,
k_regs
[
1
][
0
].
i32x4
);
#else
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4, true/*transpose*/)
k_regs
[
0
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
1024
,
0
,
3
,
1
,
0
);
#endif
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
}
{
constexpr
int
k_loop
=
1
;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
2048
,
k_regs
[
0
][
1
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
3072
,
k_regs
[
1
][
1
].
i32x4
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kHeadDim
;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs
[
0
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
2048
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
3072
,
0
,
3
,
1
,
0
);
#endif
// init s_reg
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
2
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
s_reg
[
k_loop
][
m_idx
][
n_idx
]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)
\n
"
);
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// 寄存器不够全量预取, 则预取一部分
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
0
+
WARP_N
*
kHeadDim
,
k_regs
[
0
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
1024
+
WARP_N
*
kHeadDim
,
k_regs
[
1
][
0
].
i32x4
);
#else
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4, true/*transpose*/)
k_regs
[
0
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
1024
,
0
,
3
,
1
,
0
);
#endif
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
}
{
constexpr
int
k_loop
=
2
;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
2048
,
k_regs
[
0
][
1
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
3072
,
k_regs
[
1
][
1
].
i32x4
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kHeadDim
;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs
[
0
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
2048
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
3072
,
0
,
3
,
1
,
0
);
#endif
// init s_reg
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
2
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
s_reg
[
k_loop
][
m_idx
][
n_idx
]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)
\n
"
);
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
// 寄存器不够全量预取, 则预取一部分
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
0
+
WARP_N
*
kHeadDim
,
k_regs
[
0
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
1024
+
WARP_N
*
kHeadDim
,
k_regs
[
1
][
0
].
i32x4
);
#else
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0 + WARP_N * kHeadDim, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024 + WARP_N * kHeadDim, k_regs[1][0].i32x4, true/*transpose*/)
k_regs
[
0
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
1024
,
0
,
3
,
1
,
0
);
#endif
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
}
{
constexpr
int
k_loop
=
3
;
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
2048
,
k_regs
[
0
][
1
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
3072
,
k_regs
[
1
][
1
].
i32x4
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kHeadDim
;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs
[
0
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
2048
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
3072
,
0
,
3
,
1
,
0
);
#endif
// init s_reg
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
2
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
s_reg
[
k_loop
][
m_idx
][
n_idx
]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)
\n
"
);
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
}
}
else
if
constexpr
(
false
)
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
// 直接从 lds 读数据, 看看 lds 的数据排布
union_vec16_fp8
k_regs
[
WARP_N
/
16
][
kHeadDim
/
64
];
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
0
,
k_regs
[
0
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
1024
,
k_regs
[
1
][
0
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
2048
,
k_regs
[
0
][
1
].
i32x4
);
inline_ds_read_b128_no_wait_bytes
(
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_load_offset
+
3072
,
k_regs
[
1
][
1
].
i32x4
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kHeadDim
;
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 0, k_regs[0][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 1024, k_regs[1][0].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 2048, k_regs[0][1].i32x4, true/*transpose*/)
// DS_READ_MATRIX_64x16_B8(k_lds_load_offset + 3072, k_regs[1][1].i32x4, true/*transpose*/)
k_regs
[
0
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
1024
,
0
,
3
,
1
,
0
);
k_regs
[
0
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
2048
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
3072
,
0
,
3
,
1
,
0
);
#endif
// init s_reg
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
2
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
2
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
s_reg
[
k_loop
][
m_idx
][
n_idx
]);
}
}
// ======================================================== QK mmac ======================================================================
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(2)
\n
"
);
// asm volatile("s_waitcnt lgkmcnt(0)\n");
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
0
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
0
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
0
],
k_regs
[
1
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
0
].
i8x8
[
1
],
k_regs
[
1
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
0
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
1
][
0
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
0
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
0
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
0
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
0
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
0
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
0
],
k_regs
[
1
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
s_reg
[
k_loop
][
1
][
1
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
1
][
1
].
i8x8
[
1
],
k_regs
[
1
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
1
][
1
].
f32
);
}
}
else
{
// 寄存器占用更少的写法
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
// init s_reg
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
0
]
=
0
;
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
1
]
=
0
;
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
2
]
=
0
;
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
3
]
=
0
;
}
}
// 直接从 lds 读数据, 看看 lds 的数据排布
union_vec16_fp8
k_regs
[
WARP_N
/
16
][
kHeadDim
/
64
];
// 分两次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
k_lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
k_loop
*
WARP_N
*
kHeadDim
;
k_regs
[
0
][
0
].
i32x4
=
*
(
vec4_int32
*
)(
k_lds
+
k_lds_load_offset
+
0
);
k_regs
[
1
][
0
].
i32x4
=
*
(
vec4_int32
*
)(
k_lds
+
k_lds_load_offset
+
1024
/*ds fmt 0, dmft1 */
);
k_regs
[
0
][
1
].
i32x4
=
*
(
vec4_int32
*
)(
k_lds
+
k_lds_load_offset
+
2048
);
k_regs
[
1
][
1
].
i32x4
=
*
(
vec4_int32
*
)(
k_lds
+
k_lds_load_offset
+
3072
);
#else
int
k_lds_load_offset
=
k_loop
*
WARP_N
*
kHeadDim
;
k_regs
[
0
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
0
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
1024
,
0
,
3
,
1
,
0
);
k_regs
[
0
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
2048
,
0
,
3
,
1
,
0
);
k_regs
[
1
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
k_lds
+
k_lds_load_offset
+
3072
,
0
,
3
,
1
,
0
);
#endif
// ======================================================== QK mmac ======================================================================
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
m_idx
][
0
].
i8x8
[
0
],
k_regs
[
n_idx
][
0
].
i8x8
[
0
],
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
);
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
m_idx
][
0
].
i8x8
[
1
],
k_regs
[
n_idx
][
0
].
i8x8
[
1
],
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
);
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
m_idx
][
1
].
i8x8
[
0
],
k_regs
[
n_idx
][
1
].
i8x8
[
0
],
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
);
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_regs
[
m_idx
][
1
].
i8x8
[
1
],
k_regs
[
n_idx
][
1
].
i8x8
[
1
],
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
);
}
}
}
}
}
csrc/flash_attn_hg/include/fwd/gfx938/fp8_qk_gemm_utils_mls_ds.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "intrinsic_mls_ds.h"
#include "intrinsic_mls_ds_b8.h"
#define USE_MLS_128B_REQUEST
template
<
bool
Is_even_MN
,
int
kHeadDim
,
int
WARP_M
,
typename
Element
>
__forceinline__
__device__
void
fp8_prefetch_q_to_lds
(
Element
*
q_ptr
,
int8_t
*
q_lds
,
int
warp_id
,
int
q_row_stride
,
int
max_seq_q_offset
)
{
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
192
||
kHeadDim
==
256
);
// 准备 MLS 寄存器, 填充 stride
vec4_uint
q_root
=
prepare_for_matrix_load
<
128
,
Element
>
(
q_ptr
);
vec4_uint
q_srsrc
;
q_srsrc
[
0
]
=
q_root
[
0
];
q_srsrc
[
1
]
=
q_root
[
1
];
q_srsrc
[
2
]
=
q_row_stride
;
// stride
q_srsrc
[
3
]
=
0x40000
;
// [17: 18], interleave 4
// 计算 lds 写入地址
constexpr
int
kLdsHeadDimStride
=
kHeadDim
==
192
?
256
:
kHeadDim
;
int
q_lds_offset
=
warp_id
*
WARP_M
*
kLdsHeadDimStride
*
sizeof
(
Element
);
int
q_lds_write_bytes
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
q_lds_offset
;
// 计算 global 读取地址
q_srsrc
[
0
]
=
q_root
[
0
]
+
(
warp_id
*
WARP_M
)
*
q_row_stride
*
sizeof
(
Element
);
//边界判断
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
16
-
max_seq_q_offset
);
// q_srsrc[3] = q_srsrc[3] + max_seq_q_offset % 128 == 0 ? 0: nm_filter << 8; // set only once
q_srsrc
[
3
]
=
0x40000
+
((
max_seq_q_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
// set only once
// printf("nm_filter is %d, max_seq_q_osffset is %d\n", max_seq_q_offset % 128 == 0 ? 0: nm_filter << 8, max_seq_q_offset);
// 启动 mls 读取
#ifdef USE_MLS_128B_REQUEST
// inline_matrix_load_128x32_b8_lds_rearrange<0, 1>(q_lds, q_srsrc, q_lds_offset/*lds bytes*/, 0/*matrix_offset, 0 or 16*/);
__builtin_hcu_matrix_load_128X16_b8
(
q_srsrc
,
q_lds
+
q_lds_offset
,
0
,
true
,
false
,
false
,
false
,
false
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
32
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
0x40000
+
((
max_seq_q_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
q_srsrc
,
q_lds
+
q_lds_offset
+
512
,
16
,
true
,
false
,
false
,
false
,
false
);
if
constexpr
(
kHeadDim
>
128
)
{
constexpr
int
kSecondMlsLoadHeadOffset
=
kHeadDim
==
192
?
64
:
128
;
q_srsrc
[
0
]
=
q_root
[
0
]
+
(
warp_id
*
WARP_M
)
*
q_row_stride
*
sizeof
(
Element
)
+
kSecondMlsLoadHeadOffset
*
sizeof
(
Element
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
16
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
0x40000
+
((
max_seq_q_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
q_srsrc
,
q_lds
+
q_lds_offset
+
4096
,
0
,
true
,
false
,
false
,
false
,
false
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
32
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
0x40000
+
((
max_seq_q_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
q_srsrc
,
q_lds
+
q_lds_offset
+
4608
,
16
,
true
,
false
,
false
,
false
,
false
);
}
#else
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
q_lds
,
q_srsrc
,
q_lds_write_bytes
/*lds bytes*/
,
0
/*matrix_offset, 0 or 16*/
);
q_srsrc
[
0
]
=
q_srsrc
[
0
]
+
64
*
sizeof
(
Element
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
q_lds
,
q_srsrc
,
q_lds_write_bytes
+
2048
/*lds bytes*/
,
0
/*matrix_offset, 0 or 16*/
);
// Q 部分可以考虑 128x16 或者非 4-interleave 形式
#endif
__builtin_amdgcn_sched_barrier
(
0
);
}
// #define USE_DS_READ_B128_FOR_INTERLEAVE4
template
<
int
kHeadDim
,
int
WARP_M
,
typename
Element
>
__forceinline__
__device__
void
load_q_from_lds_to_vgpr
(
union_vec16_fp8
q_regs
[
WARP_M
/
16
][
kHeadDim
/
64
],
int8_t
*
q_lds
,
int
warp_id
,
int
lane_id
)
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
192
||
kHeadDim
==
256
);
constexpr
int
kLdsHeadDimStride
=
kHeadDim
==
192
?
256
:
kHeadDim
;
// lds 写到两个地方去了, 注意是 rearrange, 所以跳 1K; transpose 跳 2K
// MLS0: [0: 512) 和 [1024, 1536)
// MLS1: [512: 1024) 和 [1536, 2048)
// 分 4 次读取寄存器, 第一次是 [0, 1024), 即 16x64 的内容, 每个线程读取 16 个 fp8
#ifdef USE_DS_READ_B128_FOR_INTERLEAVE4
int
row
=
(
lane_id
&
15
)
>>
1
;
int
col
=
lane_id
>>
4
;
int
col_swizzle
=
(
row
+
col
)
&
3
;
int
lds_load_offset
=
row
*
128
+
col_swizzle
*
16
+
(
lane_id
&
1
)
*
64
+
warp_id
*
WARP_M
*
kHeadDim
;
q_regs
[
0
][
0
].
i32x4
=
*
(
vec4_int32
*
)(
q_lds
+
lds_load_offset
+
0
);
q_regs
[
1
][
0
].
i32x4
=
*
(
vec4_int32
*
)(
q_lds
+
lds_load_offset
+
1024
/*ds fmt 0, dmft1 */
);
q_regs
[
0
][
1
].
i32x4
=
*
(
vec4_int32
*
)(
q_lds
+
lds_load_offset
+
2048
/*ds fmt 0, dmft1 */
);
q_regs
[
1
][
1
].
i32x4
=
*
(
vec4_int32
*
)(
q_lds
+
lds_load_offset
+
3072
/*ds fmt 0, dmft1 */
);
#else
#pragma unroll
for
(
int
h_idx
=
0
;
h_idx
<
kHeadDim
/
64
;
++
h_idx
)
{
const
int
h_offset
=
(
kHeadDim
==
192
&&
h_idx
==
2
)
?
3
*
2048
:
h_idx
*
2048
;
q_regs
[
0
][
h_idx
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
q_lds
+
h_offset
+
0
+
warp_id
*
WARP_M
*
kLdsHeadDimStride
,
0
,
3
,
1
,
0
);
q_regs
[
1
][
h_idx
].
i32x4
=
__builtin_hcu_ds_read_matrix_trans_format_u8
(
q_lds
+
h_offset
+
1024
+
warp_id
*
WARP_M
*
kLdsHeadDimStride
,
0
,
3
,
1
,
0
);
}
#endif
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
Is_even_MN
,
int
kHeadDim
,
int
WARP_N
,
typename
Element
>
__forceinline__
__device__
void
fp8_prefetch_k_to_lds
(
Element
*
k_ptr
,
int8_t
*
k_lds
,
int
warp_id
,
int
k_row_stride
,
int
max_seq_kv_offset
)
{
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
192
||
kHeadDim
==
256
);
// 准备 MLS 寄存器, 填充 stride
vec4_uint
k_root
=
prepare_for_matrix_load
<
kHeadDim
,
Element
>
(
k_ptr
);
vec4_uint
k_srsrc
;
k_srsrc
[
0
]
=
k_root
[
0
];
k_srsrc
[
1
]
=
k_root
[
1
];
k_srsrc
[
2
]
=
k_row_stride
;
// stride
k_srsrc
[
3
]
=
0x40000
;
// [17: 18], interleave 4
// 计算 lds 写入地址
constexpr
int
kLdsHeadDimStride
=
kHeadDim
==
192
?
256
:
kHeadDim
;
int
k_lds_offset
=
warp_id
*
WARP_N
*
kLdsHeadDimStride
*
sizeof
(
Element
);
int
k_lds_write_bytes
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
k_lds_offset
;
// 计算 global 读取地址
k_srsrc
[
0
]
=
k_root
[
0
]
+
warp_id
*
32
*
k_row_stride
*
sizeof
(
Element
);
//边界判断
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
16
-
max_seq_kv_offset
);
k_srsrc
[
3
]
=
k_srsrc
[
3
]
+
((
max_seq_kv_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
// 同步所有warp,确保srsrc参数准备完毕后再发起MLS load
flash
::
wait_all_warp_arrived
();
// 启动 mls 读取
#ifdef USE_MLS_128B_REQUEST
__builtin_hcu_matrix_load_128X16_b8
(
k_srsrc
,
k_lds
+
k_lds_offset
,
0
,
true
,
false
,
false
,
false
,
false
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
32
-
max_seq_kv_offset
);
k_srsrc
[
3
]
=
0x40000
+
((
max_seq_kv_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
k_srsrc
,
k_lds
+
k_lds_offset
+
512
,
16
,
true
,
false
,
false
,
false
,
false
);
if
constexpr
(
kHeadDim
>
128
)
{
constexpr
int
kSecondMlsLoadHeadOffset
=
kHeadDim
==
192
?
64
:
128
;
k_srsrc
[
0
]
=
k_root
[
0
]
+
warp_id
*
32
*
k_row_stride
*
sizeof
(
Element
)
+
kSecondMlsLoadHeadOffset
*
sizeof
(
Element
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
16
-
max_seq_kv_offset
);
k_srsrc
[
3
]
=
0x40000
+
((
max_seq_kv_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
k_srsrc
,
k_lds
+
k_lds_offset
+
4096
,
0
,
true
,
false
,
false
,
false
,
false
);
nm_filter
=
inline_min_max
<
0
,
16
>
(
32
*
warp_id
+
32
-
max_seq_kv_offset
);
k_srsrc
[
3
]
=
0x40000
+
((
max_seq_kv_offset
%
128
==
0
)
?
0
:
(
nm_filter
<<
8
));
__builtin_hcu_matrix_load_128X16_b8
(
k_srsrc
,
k_lds
+
k_lds_offset
+
4608
,
16
,
true
,
false
,
false
,
false
,
false
);
}
#else
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
k_lds
,
k_srsrc
,
k_lds_write_bytes
/*lds bytes*/
,
0
/*matrix_offset, 0 or 16*/
);
k_srsrc
[
0
]
=
k_srsrc
[
0
]
+
64
*
sizeof
(
Element
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
k_lds
,
k_srsrc
,
k_lds_write_bytes
+
2048
/*lds bytes*/
,
0
/*matrix_offset, 0 or 16*/
);
#endif
}
csrc/flash_attn_hg/include/fwd/gfx938/fp8_softmax_gfx938.h
0 → 100644
View file @
518a5f4d
#pragma once
#include "philox.cuh"
#include "../utils.h"
using
namespace
flash
;
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_apply_mask
(
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
int
max_seq_kv_offset
,
int
wave_col_offset
,
int
lane_id
)
{
__builtin_amdgcn_sched_barrier
(
0
);
const
int
col_base
=
wave_col_offset
+
(
lane_id
>>
4
)
*
8
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
const
int
k_offset
=
k_loop
*
WARP_N
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
const
int
n_base
=
col_base
+
n_idx
*
4
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
]
=
(
n_base
+
k_offset
+
vec_idx
>=
max_seq_kv_offset
)
?
-
INFINITY
:
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
];
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_apply_causal_mask
(
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
int
actual_seqlen_q
,
int
actual_seqlen_k
,
int
wave_row_offset
,
int
wave_col_offset
,
int
lane_id
)
{
__builtin_amdgcn_sched_barrier
(
0
);
const
int
row_base
=
wave_row_offset
+
((
lane_id
&
15
)
>>
2
)
*
8
+
(
lane_id
&
3
);
const
int
col_base
=
wave_col_offset
+
(
lane_id
>>
4
)
*
8
;
const
int
causal_limit
=
actual_seqlen_k
-
actual_seqlen_q
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
const
int
row_idx
=
row_base
+
m_idx
*
4
;
const
int
col_limit
=
min
(
actual_seqlen_k
,
row_idx
+
causal_limit
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
const
int
k_offset
=
k_loop
*
WARP_N
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
const
int
n_base
=
col_base
+
n_idx
*
4
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
]
=
(
n_base
+
k_offset
+
vec_idx
>
col_limit
)
?
-
INFINITY
:
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
];
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_apply_local_mask
(
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
int
actual_seqlen_q
,
int
actual_seqlen_k
,
int
wave_row_offset
,
int
wave_col_offset
,
int
window_size_left
,
int
window_size_right
,
int
lane_id
)
{
__builtin_amdgcn_sched_barrier
(
0
);
const
int
row_base
=
wave_row_offset
+
((
lane_id
&
15
)
>>
2
)
*
8
+
(
lane_id
&
3
);
const
int
col_base
=
wave_col_offset
+
(
lane_id
>>
4
)
*
8
;
const
bool
has_ws_left
=
window_size_left
>=
0
;
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
const
int
row_idx
=
row_base
+
m_idx
*
4
;
const
int
col_limit_left
=
max
(
0
,
row_idx
+
1
+
actual_seqlen_k
-
actual_seqlen_q
-
window_size_left
);
const
int
col_limit_right
=
min
(
actual_seqlen_k
,
row_idx
+
actual_seqlen_k
-
actual_seqlen_q
+
window_size_right
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
const
int
k_offset
=
k_loop
*
WARP_N
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
const
int
n_base
=
col_base
+
n_idx
*
4
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
n_base
+
k_offset
+
vec_idx
;
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
]
=
(
col_idx
>
col_limit_right
||
(
has_ws_left
&&
col_idx
<
col_limit_left
-
1
))
?
-
INFINITY
:
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
];
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_qk_descale
(
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
__float2
qk_descale
)
{
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
u64
[
vec_idx
]
=
__builtin_hcu_pk_mul_f32
(
s_reg
[
k_loop
][
m_idx
][
n_idx
].
u64
[
vec_idx
],
qk_descale
);
// s_reg[k_loop][m_idx][n_idx].u64[vec_idx] = s_reg[k_loop][m_idx][n_idx].u64[vec_idx] * qk_descale;
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
bool
AssumeValidRows
,
int
kHeadDim
,
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
int
WARP_K
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_softmax_and_schedule_v
(
/*softmax module related args*/
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDim
/
32
][
WARP_M
/
16
][
WARP_N
/
16
],
ElementAccum
softmax_scale_log2
,
/*scheduled modules related args*/
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDim
/
32
],
int8_t
*
v_lds
)
{
// ======================================================== Max ======================================================================
ElementAccum
scores_max_cur
[
WARP_M
/
16
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
ElementAccum
max_value
=
scores_max
[
m_idx
];
// 当前线程遍历 4 个 32x32x32 mmac 输出的 f32x4
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
max_value
=
max
(
max_value
,
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
]);
}
}
}
// 这一行比较 0, 16, 32, 48 号线程的数据
max_value
=
max
(
max_value
,
__shfl_xor_tmp
(
max_value
,
32
));
max_value
=
max
(
max_value
,
__shfl_xor_tmp
(
max_value
,
16
));
// 赋值给最终的最大值
scores_max_cur
[
m_idx
]
=
max_value
;
}
// ========================================== softmax ===============================================
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
__float2
max_scaled_pair
;
if
constexpr
(
AssumeValidRows
)
{
max_scaled_pair
[
0
]
=
-
scores_max_cur
[
m_idx
]
*
softmax_scale_log2
;
}
else
{
max_scaled_pair
[
0
]
=
scores_max_cur
[
m_idx
]
==
-
INFINITY
?
0.
f
:
-
scores_max_cur
[
m_idx
]
*
softmax_scale_log2
;
}
max_scaled_pair
[
1
]
=
max_scaled_pair
[
0
];
__float2
softmax_scale_log2_pair
=
{
softmax_scale_log2
,
softmax_scale_log2
};
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
u64
[
vec_idx
]
=
__builtin_hcu_pk_fma_f32
(
s_reg
[
k_loop
][
m_idx
][
n_idx
].
u64
[
vec_idx
],
softmax_scale_log2_pair
,
max_scaled_pair
);
}
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
[
vec_idx
]);
}
}
}
}
// ========================================== Sum ===============================================
ElementAccum
scores_sum_cur
[
WARP_M
/
16
];
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
vec2_Accum
<
ElementAccum
>
sum_pair
;
sum_pair
.
data
=
0.0
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
sum_pair
.
u64
=
__builtin_hcu_pk_add_f32
(
sum_pair
.
u64
,
s_reg
[
k_loop
][
m_idx
][
n_idx
].
u64
[
0
]);
sum_pair
.
u64
=
__builtin_hcu_pk_add_f32
(
sum_pair
.
u64
,
s_reg
[
k_loop
][
m_idx
][
n_idx
].
u64
[
1
]);
}
}
scores_sum_cur
[
m_idx
]
=
sum_pair
.
f32
[
0
]
+
sum_pair
.
f32
[
1
];
scores_sum_cur
[
m_idx
]
=
scores_sum_cur
[
m_idx
]
+
__shfl_xor_tmp
(
scores_sum_cur
[
m_idx
],
32
);
scores_sum_cur
[
m_idx
]
=
scores_sum_cur
[
m_idx
]
+
__shfl_xor_tmp
(
scores_sum_cur
[
m_idx
],
16
);
}
// 更新 scores_sum, scores_max
// 这段代码放在这是因为即将下发的大量 ds 指令, 会跟 __shfl_xor 抢带宽, 导致时延太高
// ElementAccum exp_rescale[WARP_M / 16];
// #pragma unroll
// for (int m_idx = 0; m_idx < WARP_M / 16; ++m_idx) {
// exp_rescale[m_idx] = __llvm_exp2_f32((scores_max[m_idx] - scores_max_cur[m_idx]) * softmax_scale_log2);
// scores_max[m_idx] = scores_max_cur[m_idx];
// scores_sum[m_idx] = __llvm_fma_f32(scores_sum[m_idx], exp_rescale[m_idx], scores_sum_cur[m_idx]);
// }
// ========================================== schedule V ===============================================
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
k_loop
+=
1
)
{
// 用 ds_read_matrix 从 lds 读取数据到寄存器
int8_t
*
lds_load_ptr
=
v_lds
+
k_loop
*
WARP_M
*
kHeadDim
*
sizeof
(
Element
);
v_regs
[
k_loop
][
0
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
,
0
,
2
,
2
,
0
);
v_regs
[
k_loop
][
1
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
32
,
0
,
2
,
2
,
0
);
v_regs
[
k_loop
][
2
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
128
*
16
,
0
,
2
,
2
,
0
);
v_regs
[
k_loop
][
3
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
128
*
16
+
32
,
0
,
2
,
2
,
0
);
if
constexpr
(
kHeadDim
==
256
)
{
v_regs
[
k_loop
][
4
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
4096
,
0
,
2
,
2
,
0
);
v_regs
[
k_loop
][
5
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
4096
+
32
,
0
,
2
,
2
,
0
);
v_regs
[
k_loop
][
6
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
4096
+
128
*
16
,
0
,
2
,
2
,
0
);
v_regs
[
k_loop
][
7
].
i32x4
=
__builtin_hcu_ds_read_matrix_format_u8
(
lds_load_ptr
+
4096
+
128
*
16
+
32
,
0
,
2
,
2
,
0
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
// hint: 这里考虑只发一部分的 ds_read_matrix 指令出去, 一面堵住
// ========================================== rescale ===============================================
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
m_idx
+=
1
)
{
if
(
scores_sum
[
m_idx
]
!=
0.
f
&&
scores_max
[
m_idx
]
<
scores_max_cur
[
m_idx
])
{
__float2
scores_scale_pair
;
float
max_diff
;
if
constexpr
(
AssumeValidRows
)
{
max_diff
=
scores_max
[
m_idx
]
-
scores_max_cur
[
m_idx
];
}
else
{
// Fix: 当 scores_max 和 scores_max_cur 都是 -INFINITY 时,(-INF) - (-INF) = NaN
// 这种情况发生在某些 query 行完全没有有效的 KV 可以 attend 时
max_diff
=
(
scores_max
[
m_idx
]
==
-
INFINITY
||
scores_max_cur
[
m_idx
]
==
-
INFINITY
)
?
0.
f
:
(
scores_max
[
m_idx
]
-
scores_max_cur
[
m_idx
]);
}
scores_scale_pair
[
0
]
=
__llvm_exp2_f32
(
max_diff
*
softmax_scale_log2
);
scores_scale_pair
[
1
]
=
scores_scale_pair
[
0
];
scores_sum
[
m_idx
]
*=
scores_scale_pair
[
0
];
// 放缩 acc_o
#pragma unroll
for
(
int
pv_loop
=
0
;
pv_loop
<
kHeadDim
/
WARP_N
;
++
pv_loop
)
{
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
WARP_K
/
16
;
++
mmac_id
)
{
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
u64
[
0
],
scores_scale_pair
);
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
pv_loop
][
m_idx
][
mmac_id
].
u64
[
1
],
scores_scale_pair
);
}
}
}
}
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
scores_max
[
m_idx
]
=
scores_max_cur
[
m_idx
];
scores_sum
[
m_idx
]
+=
scores_sum_cur
[
m_idx
];
}
}
template
<
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_cvt_f32_to_fp8
(
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
],
union_vec32_fp8
p_reg
[
WARP_M
/
16
]
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
__builtin_hcu_cvt_pk4_fp8_f32
<
Element
>
(
s_reg
[
k_loop
][
m_idx
][
n_idx
].
f32
,
p_reg
[
m_idx
].
i32
[
k_loop
*
2
+
n_idx
]);
}
}
}
}
csrc/flash_attn_hg/include/fwd/gfx938/fwd_epilogue_gfx938.h
View file @
518a5f4d
...
...
@@ -13,11 +13,11 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
int
seqlen_o_stride
,
int
seqlen_q_limit
)
{
static_assert
(
Is_Interleaved
and
"For fwd_epilogue_store_output_gfx938, mmac must be 4interleave"
);
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
static_assert
(
Is_Interleaved
and
"For fwd_epilogue_store_output_gfx938, mmac must be 4interleave"
);
if
constexpr
(
TailTile16
==
2
)
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
...
...
@@ -46,7 +46,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
}
}
}
}
}
// brace, to control vgpr usage
}
else
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDimV
/
kBlockK
);
++
k_loop
)
{
...
...
@@ -61,11 +61,13 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
const
int
pv_tile_id
=
k_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
warp_m_idx
*
(
kBlockK
/
32
)
+
k_tile_idx
;
const
int
mmac_id
=
min_tile_m
+
min_tile_n
*
2
;
int
seqlen_q_offset
=
warp_id
*
WARP_M
+
warp_m_idx
*
32
+
min_tile_m
*
16
+
pv_lane_seq_idx
;
// prepare for store
int
s_offset
=
k_tile_idx
*
32
+
min_tile_n
*
16
;
int
v_offset
=
seqlen_q_offset
*
seqlen_o_stride
+
k_loop
*
kBlockK
+
pv_lane_head_dim_idx
*
4
;
union_vec2_f16x2
<
Element
>
v_data
;
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
2
;
++
vec_index
)
{
// convert float -> bf16/fp16
v_data
.
f16x2
[
vec_index
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
acc_o
[
pv_tile_id
][
mmac_id
].
f32x2
[
vec_index
]);
}
if
constexpr
(
not
Is_even_MN
)
{
...
...
@@ -79,7 +81,7 @@ __forceinline__ __device__ void fwd_epilogue_store_output_gfx938(
}
}
}
}
}
// brace, to control vgpr usage
}
__builtin_amdgcn_sched_barrier
(
0
);
}
\ No newline at end of file
csrc/flash_attn_hg/include/fwd/gfx938/pv_gemm_prefetch_k_mls_ds.h
View file @
518a5f4d
...
...
@@ -59,10 +59,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
else
if
constexpr
(
kHeadDimV
==
192
)
{
int
warp_id_m
=
warp_id
%
2
;
// w0 w2
int
warp_id_n
=
warp_id
/
2
;
// w1 w3
...
...
@@ -76,10 +73,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
16
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
// DS
lds_stage_id
^=
1
;
...
...
@@ -165,10 +159,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
}
}
...
...
@@ -241,10 +232,7 @@ __forceinline__ __device__ void pv_gemm_prefetch_k_mls_ds(
}
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDimV_OPT
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
16
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
}
lds_stage_id
^=
1
;
...
...
csrc/flash_attn_hg/include/fwd/gfx938/pv_gemm_utils_mls_ds.h
View file @
518a5f4d
...
...
@@ -34,9 +34,7 @@ __forceinline__ __device__ void prefetch_v_to_lds_mls_ds(
int
lds_offset
=
(
lds_stage_id
*
WARP_K
*
kHeadDim_OPT
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
// 防止写 v lds 和读 k lds 冲突, qk 可能有的 warp 没结束
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
lds_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
lds_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
csrc/flash_attn_hg/include/fwd/gfx938/qk_gemm_prefetch_v_mls_ds.h
View file @
518a5f4d
...
...
@@ -79,10 +79,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int
lds_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim_
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
else
if
constexpr
(
kHeadDim
==
192
)
{
int
warp_id_m
=
warp_id
/
2
;
int
warp_id_n
=
warp_id
%
2
;
...
...
@@ -95,10 +92,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int
lds_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim_
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
// Wait MLS
...
...
@@ -178,10 +172,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int
lds_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim_
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
}
}
...
...
@@ -242,10 +233,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
int
lds_offset
=
(
n_stage_id
*
WARP_N
*
kHeadDim_
+
warp_id
*
32
*
16
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
16
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x16_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
// Wait MLS
...
...
@@ -361,7 +349,7 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
flash
::
wait_all_warp_arrived
();
if
constexpr
(
STAGES
==
2
)
{
#if defined(__gfx938__)
// 有的 prefetch v 写到了 mha 主 kernel 代码里
#if defined(__gfx938__)
|| defined(__gfx946__) || (defined(__gfx92a__) && defined(YY_USE_MPERMUTE))
// 有的 prefetch v 写到了 mha 主 kernel 代码里
prefetch_v_to_lds_mls_ds
<
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
TailTile16
,
Element
,
Is_even_MN
>
(
v_ptr
,
v_lds
,
warp_id
,
seqlen_v_stride
,
max_seq_k_offset
);
#else
...
...
@@ -369,3 +357,4 @@ __forceinline__ __device__ void qk_gemm_prefetch_v_mls_ds(
}
}
// qk_gemm
csrc/flash_attn_hg/include/fwd/gfx938/qk_gemm_utils_mls_ds.h
View file @
518a5f4d
...
...
@@ -35,10 +35,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
q_srsrc
[
3
]
=
max_seq_q_offset
%
kBlockM
==
0
?
0
:
nm_filter
<<
8
;
// set only once
}
int
lds_offset
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
}
stage_id
^=
1
;
...
...
@@ -50,10 +47,7 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
q_srsrc
[
3
]
=
max_seq_q_offset
%
kBlockM
==
0
?
0
:
nm_filter
<<
8
;
// set only once
}
int
lds_offset
=
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset
,
0
);
stage_id
^=
1
;
// DS
...
...
@@ -63,6 +57,8 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
int
lds_load_offset
=
q_lds_base
+
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
#ifdef __gfx938__
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
q_reg
[(
k_loop
-
1
)
*
2
].
f16
,
q_reg
[(
k_loop
-
1
)
*
2
+
1
].
f16
,
true
);
#elif defined(__gfx946__) || defined(__gfx92a__)
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
q_reg
[(
k_loop
-
1
)
*
2
].
f16
,
q_reg
[(
k_loop
-
1
)
*
2
+
1
].
f16
,
true
);
#endif
// __syncthreads();
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
...
...
@@ -77,6 +73,8 @@ __forceinline__ __device__ void prefetch_q_to_vgpr_mls_ds(
int
lds_load_offset
=
q_lds_base
+
(
stage_id
*
kBlockM
*
kBlockK
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
#ifdef __gfx938__
DS_READ_MATRIX_32X32_B16
(
lds_load_offset
,
q_reg
[
k_loop
*
2
].
f16
,
q_reg
[
k_loop
*
2
+
1
].
f16
,
true
);
#elif defined(__gfx946__) || defined(__gfx92a__)
DS_READ_MATRIX_32X32_B16_GFX946
(
lds_load_offset
,
q_reg
[
k_loop
*
2
].
f16
,
q_reg
[
k_loop
*
2
+
1
].
f16
,
true
);
#endif
}
__builtin_amdgcn_s_waitcnt
(
0
);
...
...
@@ -114,9 +112,7 @@ __forceinline__ __device__ void prefetch_k_to_lds_mls_ds(
}
int
lds_offset
=
(
stage_id
*
WARP_N
*
kHeadDim_
+
warp_id
*
32
*
32
)
*
ELEMENT_BYTES
;
flash
::
wait_all_warp_arrived
();
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset
,
0
);
}
csrc/flash_attn_hg/include/fwd/gfx938/softmax_gfx938.h
View file @
518a5f4d
...
...
@@ -171,10 +171,10 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
mi
+
ni
*
(
WARP_M
/
32
);
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
acc_o
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scores_scale_pair
);
...
...
@@ -200,8 +200,8 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
reduce_sum
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
>
(
scores
,
scores_sum_cur
);
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum
[
mi
].
u64
=
hcu_pk_add_f32
(
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
scores_sum
[
mi
].
u64
=
__builtin_
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
...
...
@@ -210,7 +210,7 @@ inline __device__ void softmax_rescale_o_gfx938(DataType0 scores[(WARP_N / 32) *
scores_sum
[
mi
].
f32
[
1
]
+=
scores_sum_cur
[
mi
].
f32
[
1
];
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__))
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
)
inlineasm_fa_v_mov_b64
(
scores_max
[
mi
].
u64
,
scores_max_cur
[
mi
].
u64
...
...
csrc/flash_attn_hg/include/fwd/softmax.h
View file @
518a5f4d
...
...
@@ -228,7 +228,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
++
m_idx
)
{
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
summary
[
m_idx
*
2
].
u64
=
0x0
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
...
...
@@ -236,7 +236,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -262,7 +262,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
summary_cur
[
m_idx
*
2
].
u64
=
summary
[
m_idx
*
2
].
u64
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
++
n_idx
)
{
...
...
@@ -270,7 +270,7 @@ __device__ inline void thread_reduce_sum(const DataType0 tensor[(WARP_M / 32) *
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
// mmac min_tile is 16*16, a warp is 64 thread
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary_cur
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -372,15 +372,14 @@ inline __device__ void scale_apply_exp2(DataType0 tensor[(WARP_M / 32) * (WARP_N
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
qk_tile_id
=
mi
+
ni
*
(
WARP_M
/
32
);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
hcu_pk_fma_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_fma_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scale_pair
,
neg_max_scaled_pair
);
}
asm
volatile
(
"s_nop 0"
:::
"memory"
);
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
qk_tile_id
][
mmac_id
].
f32
[
vec_idx
]);
}
...
...
@@ -418,6 +417,7 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
?
scores_max_cur
[
mi
*
2
].
f32
[
min_tile_m
]
:
(
scores_max_cur
[
mi
*
2
].
f32
[
min_tile_m
]
==
-
INFINITY
?
0.0
f
:
scores_max_cur
[
mi
*
2
].
f32
[
min_tile_m
]);
// optimization from flash-attention-4
if
(
IsInference
or
scores_max
[
mi
*
2
].
f32
[
min_tile_m
]
<
scores_max_cur_reg
)
{
float
scores_scale
=
__llvm_exp2_f32
((
scores_max
[
mi
*
2
].
f32
[
min_tile_m
]
-
scores_max_cur_reg
)
*
softmax_scale_log2
);
scores_sum
[
mi
*
2
].
f32
[
min_tile_m
]
*=
scores_scale
;
...
...
@@ -428,13 +428,17 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
(
K
/
kBlockK
);
pv_n_loop
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockK
/
32
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
// max * log_2(e)) This allows the compiler to use the ffma
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
pv_tile_id
=
pv_n_loop
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
+
mi
+
ni
*
(
WARP_M
/
32
);
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
acc_o
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_tile_id
][
mmac_id
].
u64
[
vec_idx
],
scores_scale_pair
);
...
...
@@ -460,25 +464,17 @@ inline __device__ void softmax_rescale_o(DataType0 scores[(WARP_N / 32) * (WARP_
reduce_sum
<
true
,
DataType0
,
DataType1
,
WARP_M
,
WARP_N
>
(
scores
,
scores_sum_cur
);
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum
[
mi
].
u64
=
hcu_pk_add_f32
(
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
scores_sum
[
mi
].
u64
=
__builtin_
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
#else
// for perf-model, add listed below will be optimized as v_fmac_f32, leading to incorrect results
#else
scores_sum
[
mi
].
f32
[
0
]
+=
scores_sum_cur
[
mi
].
f32
[
0
];
scores_sum
[
mi
].
f32
[
1
]
+=
scores_sum_cur
[
mi
].
f32
[
1
];
#endif
#if defined(USE_V_MOV_B64) && (defined(__gfx936__) || defined(__gfx938__))
inlineasm_fa_v_mov_b64
(
scores_max
[
mi
].
u64
,
scores_max_cur
[
mi
].
u64
);
#else
scores_max
[
mi
].
f32
[
0
]
=
scores_max_cur
[
mi
].
f32
[
0
];
scores_max
[
mi
].
f32
[
1
]
=
scores_max_cur
[
mi
].
f32
[
1
];
#endif
}
}
};
...
...
@@ -496,7 +492,12 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#if defined(__gfx938__)
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
#else
if
constexpr
(
IsInference
)
{
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPairNoPack
<
float
,
Element
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
0
],
...
...
@@ -507,6 +508,7 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
1
]
);
}
else
{
// For training, higher precision is needed
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
0
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
]
=
DownCast
<
float
,
Element
,
false
>
(
...
...
@@ -516,15 +518,6 @@ inline __device__ void convert_pk_type(union_vec2_f16x2<Element> p_reg[(WARP_M /
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
1
]);
}
#else
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
0
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
0
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
0
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
0
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
1
]);
p_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f16
[
min_tile_k
*
2
+
1
]
=
DownCast
<
float
,
Element
,
false
>
(
s_reg
[
n_idx
*
(
WARP_M
/
32
)
+
m_idx
][
1
*
2
+
min_tile_m
].
f32
[
min_tile_k
*
2
+
1
]);
#endif
}
}
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment