Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
518a5f4d
Commit
518a5f4d
authored
Jun 09, 2026
by
hly
Browse files
import aicc-master-dev
parent
c2a1b310
Changes
131
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
2597 additions
and
233 deletions
+2597
-233
csrc/flash_attn_hg/include/mla/mla_prefix_prefill.h
csrc/flash_attn_hg/include/mla/mla_prefix_prefill.h
+12
-13
csrc/flash_attn_hg/include/mla/mla_pv_gemm_prefetch_k_tile16x32.h
...sh_attn_hg/include/mla/mla_pv_gemm_prefetch_k_tile16x32.h
+4
-4
csrc/flash_attn_hg/include/mla/mla_qk_gemm_prefetch_v_tile16x32.h
...sh_attn_hg/include/mla/mla_qk_gemm_prefetch_v_tile16x32.h
+1
-1
csrc/flash_attn_hg/include/mla/mla_qk_gemm_utils_tile16x32.h
csrc/flash_attn_hg/include/mla/mla_qk_gemm_utils_tile16x32.h
+3
-3
csrc/flash_attn_hg/include/mla/mla_softmax.h
csrc/flash_attn_hg/include/mla/mla_softmax.h
+13
-14
csrc/flash_attn_hg/include/numeric_types.h
csrc/flash_attn_hg/include/numeric_types.h
+1
-2
csrc/flash_attn_hg/src/flash_bwd_launch_template.h
csrc/flash_attn_hg/src/flash_bwd_launch_template.h
+19
-10
csrc/flash_attn_hg/src/flash_fwd_b16_fa.h
csrc/flash_attn_hg/src/flash_fwd_b16_fa.h
+634
-63
csrc/flash_attn_hg/src/flash_fwd_b16_mla.h
csrc/flash_attn_hg/src/flash_fwd_b16_mla.h
+26
-44
csrc/flash_attn_hg/src/flash_fwd_b16_pa.h
csrc/flash_attn_hg/src/flash_fwd_b16_pa.h
+287
-19
csrc/flash_attn_hg/src/flash_fwd_b8_fa.h
csrc/flash_attn_hg/src/flash_fwd_b8_fa.h
+573
-1
csrc/flash_attn_hg/src/flash_fwd_b8_mla.h
csrc/flash_attn_hg/src/flash_fwd_b8_mla.h
+2
-2
csrc/flash_attn_hg/src/flash_fwd_b8_pa.h
csrc/flash_attn_hg/src/flash_fwd_b8_pa.h
+521
-13
csrc/flash_attn_hg/src/flash_fwd_launch_template.h
csrc/flash_attn_hg/src/flash_fwd_launch_template.h
+200
-4
csrc/flash_attn_hg/src/flash_fwd_launch_template_pa.h
csrc/flash_attn_hg/src/flash_fwd_launch_template_pa.h
+249
-36
csrc/flash_attn_hg/src/flash_fwd_reduce.h
csrc/flash_attn_hg/src/flash_fwd_reduce.h
+4
-4
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_bf16.cpp
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_bf16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_fp16.cpp
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_fp16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp
.../src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp
+12
-0
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp
.../src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp
+12
-0
No files found.
csrc/flash_attn_hg/include/mla/mla_prefix_prefill.h
View file @
518a5f4d
...
...
@@ -428,8 +428,8 @@ __forceinline__ __device__ void mla_prefix_prefill_combine_s_reg_of_2waves(vec4_
:
((
warp_id
&
1
)
?
warp_id
-
1
:
warp_id
+
1
);
int
lds_load_offset
=
n_loop
*
WARP_NUM
*
(
64
*
4
)
+
warp_id_symmetry
*
64
*
4
+
lane_id
*
4
;
vec4_Accum
<
ElementAccum
>
symmetry_data
=
*
(
vec4_Accum
<
ElementAccum
>*
)(
s_reg_lds
+
lds_load_offset
);
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]
=
hcu_pk_add_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
0
],
symmetry_data
.
u64
[
0
]);
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]
=
hcu_pk_add_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
1
],
symmetry_data
.
u64
[
1
]);
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]
=
__builtin_
hcu_pk_add_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
0
],
symmetry_data
.
u64
[
0
]);
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]
=
__builtin_
hcu_pk_add_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
1
],
symmetry_data
.
u64
[
1
]);
}
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
...
...
@@ -471,9 +471,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scale_softmax_log2_pair
[
1
]
=
scale_softmax_log2
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]
=
hcu_pk_fma_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
0
],
scale_softmax_log2_pair
,
max_scaled
);
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]
=
hcu_pk_fma_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
1
],
scale_softmax_log2_pair
,
max_scaled
);
asm
volatile
(
"s_nop 0"
:::
"memory"
);
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]
=
__builtin_hcu_pk_fma_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
0
],
scale_softmax_log2_pair
,
max_scaled
);
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]
=
__builtin_hcu_pk_fma_f32
(
s_reg
[
m_idx
][
n_loop
].
u64
[
1
],
scale_softmax_log2_pair
,
max_scaled
);
s_reg
[
m_idx
][
n_loop
].
f32
[
0
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
0
]);
s_reg
[
m_idx
][
n_loop
].
f32
[
1
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
1
]);
s_reg
[
m_idx
][
n_loop
].
f32
[
2
]
=
__llvm_exp2_f32
(
s_reg
[
m_idx
][
n_loop
].
f32
[
2
]);
...
...
@@ -489,8 +488,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scores_sum_pair
[
1
]
=
0
;
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
scores_sum_pair
=
hcu_pk_add_f32
(
scores_sum_pair
,
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]);
scores_sum_pair
=
hcu_pk_add_f32
(
scores_sum_pair
,
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]);
scores_sum_pair
=
__builtin_
hcu_pk_add_f32
(
scores_sum_pair
,
s_reg
[
m_idx
][
n_loop
].
u64
[
0
]);
scores_sum_pair
=
__builtin_
hcu_pk_add_f32
(
scores_sum_pair
,
s_reg
[
m_idx
][
n_loop
].
u64
[
1
]);
}
scores_sum_cur
[
m_idx
]
=
scores_sum_pair
[
0
]
+
scores_sum_pair
[
1
];
scores_sum_cur
[
m_idx
]
=
scores_sum_cur
[
m_idx
]
+
__shfl_xor
(
scores_sum_cur
[
m_idx
],
32
);
...
...
@@ -505,8 +504,8 @@ __forceinline__ __device__ void mla_prefix_prefill_compute_fwd_softmax(
scores_sum
[
m_idx
]
*=
scores_scale
[
0
];
#pragma unroll
for
(
int
pv_tile
=
0
;
pv_tile
<
kHeadDimVSplit
;
++
pv_tile
)
{
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
]
=
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
],
scores_scale
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
]
=
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
],
scores_scale
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
],
scores_scale
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
],
scores_scale
);
}
}
// update max/sum
...
...
@@ -607,7 +606,7 @@ __forceinline__ __device__ void mla_prefix_prefill_cvt_dtype(
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_loop
=
0
;
n_loop
<
kBlockN
/
16
;
++
n_loop
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
++
vec_idx
)
{
p_reg
[
m_idx
][
n_loop
].
f16x2
[
vec_idx
]
=
DownCastPair
<
ElementAccum
,
Element
>
(
s_reg
[
m_idx
][
n_loop
].
f32x2
[
vec_idx
]);
...
...
@@ -932,8 +931,8 @@ __forceinline__ __device__ void mla_prefix_prefill_rescale_acc_o(
inv_sum
[
1
]
=
inv_sum
[
0
];
#pragma unroll
for
(
int
pv_tile
=
0
;
pv_tile
<
kHeadDimVSplit
/
16
;
++
pv_tile
)
{
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
]
=
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
],
inv_sum
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
]
=
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
],
inv_sum
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
0
],
inv_sum
);
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
m_idx
][
pv_tile
].
u64
[
1
],
inv_sum
);
}
}
}
...
...
@@ -962,7 +961,7 @@ __forceinline__ __device__ void mla_prefix_prefill_store_output(
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
vec2_Element
<
Element
>
data
;
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
data
[
mmac_id
]
=
DownCast
<
ElementAccum
,
Element
,
true
>
(
acc_o
[
m_idx
][
v_tile
*
2
+
mmac_id
].
f32
[
vec_idx
]);
...
...
csrc/flash_attn_hg/include/mla/mla_pv_gemm_prefetch_k_tile16x32.h
View file @
518a5f4d
...
...
@@ -74,7 +74,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
precompute_v_lds_offset
[
vec_idx
]
=
reinterpret_cast
<
size_t
>
(
v_lds_v2fp16
)
+
(
(
stage_id
*
WARP_K
*
kBlockN
+
seq_idx
*
32
*
kBlockN
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
)
*
4
;
precompute_v_lds_offset
[
vec_idx
]
=
(
stage_id
*
WARP_K
*
kBlockN
+
seq_idx
*
32
*
kBlockN
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
;
}
}
}
...
...
@@ -97,7 +97,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
inline_ds_read2_b32_no_wait_bytes
(
precompute_v_lds_offset
[
vec_idx
],
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
,
NEXT_DWORD_OFFSET
);
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
v_lds_v2fp16
+
precompute_v_lds_offset
[
vec_idx
],
0
,
NEXT_DWORD_OFFSET
,
false
);
}
}
}
...
...
@@ -220,7 +220,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
precompute_v_lds_offset
[
vec_idx
]
=
reinterpret_cast
<
size_t
>
(
v_lds_v2fp16
)
+
(
(
stage_id
*
WARP_K
*
kBlockN
+
(
seq_idx
*
32
*
kBlockN
)
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
)
*
4
;
precompute_v_lds_offset
[
vec_idx
]
=
(
stage_id
*
WARP_K
*
kBlockN
+
(
seq_idx
*
32
*
kBlockN
)
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
;
}
}
}
...
...
@@ -243,7 +243,7 @@ __forceinline__ __device__ void mla_pv_gemm_prefetch_k_tile16x32(
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
inline_ds_read2_b32_no_wait_bytes
(
precompute_v_lds_offset
[
vec_idx
],
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
,
NEXT_DWORD_OFFSET
);
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
v_lds_v2fp16
+
precompute_v_lds_offset
[
vec_idx
],
0
,
NEXT_DWORD_OFFSET
,
false
);
}
}
}
...
...
csrc/flash_attn_hg/include/mla/mla_qk_gemm_prefetch_v_tile16x32.h
View file @
518a5f4d
...
...
@@ -27,7 +27,7 @@ __forceinline__ __device__ void mla_qk_gemm_prefetch_v_tile16x32(
int
laneid_shfl_4
=
lane_id
>>
4
;
int
laneid_and_15
=
lane_id
&
15
;
#if defined(__gfx936__) || defined(__gfx938__) // >= bmz
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
// >= bmz
int
qk_lane_m_idx
=
lane_id
>>
2
;
int
qk_lane_head_dim_idx
=
(
lane_id
&
3
)
<<
2
;
auto
BUFFER_LOAD_FUNC
=
&
inline_buffer_load_dwordx4_lds
<
Element
,
2
>
;
...
...
csrc/flash_attn_hg/include/mla/mla_qk_gemm_utils_tile16x32.h
View file @
518a5f4d
...
...
@@ -13,7 +13,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_tile16x32(
int
warp_id
,
int
query_seqlen_stride
,
int
max_seq_q_offset
=-
1
)
{
#if defined(__gfx928__)
#if defined(__gfx928__)
|| defined(__gfx92a__)
constexpr
int
Q_LOAD_REQUESTS
=
(
kBlockM
*
kBlockK
>>
1
/*16x32 tile*/
)
*
M_MMAC_COUNT
/
(
4
*
32
*
WARP_NUM
);
constexpr
int
SEQUENCE_READ
=
M_MMAC_COUNT
;
constexpr
int
READ_ONCE_LINES
=
4
;
...
...
@@ -99,7 +99,7 @@ __forceinline__ __device__ void mla_prefetch_q_to_vgpr_tile16x32(
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
#elif defined(__gfx936__) || defined(__gfx938__)
#elif defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
int
lane_id
=
threadIdx
.
x
&
63
;
int
laneid_shfl_4
=
lane_id
>>
4
;
int
laneid_and_15
=
lane_id
&
15
;
...
...
@@ -143,7 +143,7 @@ __forceinline__ __device__ void mla_prefetch_k_to_lds_tile16x32(
// 预先计算一些表达式
int
lane_id
=
threadIdx
.
x
&
63
;
// lane id, 0-63
#if defined(__gfx936__) || defined(__gfx938__) // >= bmz
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
// >= bmz
int
qk_lane_m_idx
=
lane_id
>>
2
;
int
qk_lane_head_dim_idx
=
(
lane_id
&
3
)
<<
2
;
auto
BUFFER_LOAD_FUNC
=
&
inline_buffer_load_dwordx4_lds
<
Element
,
2
>
;
...
...
csrc/flash_attn_hg/include/mla/mla_softmax.h
View file @
518a5f4d
...
...
@@ -117,7 +117,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
if
(
zero_init
==
true
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
M_WARP_COUNT
;
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
summary
[
m_idx
*
2
].
u64
=
0x0
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
N_WARP_COUNT
;
++
n_idx
)
{
...
...
@@ -125,7 +125,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
M_WARP_COUNT
][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
M_WARP_COUNT
][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -151,7 +151,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
M_WARP_COUNT
;
++
m_idx
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
summary_cur
[
m_idx
*
2
].
u64
=
summary
[
m_idx
*
2
].
u64
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
N_WARP_COUNT
;
++
n_idx
)
{
...
...
@@ -159,7 +159,7 @@ __device__ inline void mla_thread_reduce_sum(const DataType0 tensor[M_WARP_COUNT
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
M_WARP_COUNT
][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
M_WARP_COUNT
][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary_cur
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -258,15 +258,14 @@ inline __device__ void mla_scale_apply_exp2(DataType0 tensor[M_WARP_COUNT * N_WA
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
vec_idx
++
)
{
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_fma_f32
(
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_fma_f32
(
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
],
scale_pair
,
neg_max_scaled_pair
);
}
asm
volatile
(
"s_nop 0"
:::
"memory"
);
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
}
...
...
@@ -347,10 +346,10 @@ inline __device__ void mla_softmax_rescale_o(
int
loop_id
=
(
pv_n_loop
*
K_WARP_COUNT
+
ni
)
*
M_WARP_COUNT
+
mi
;
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
vec_idx
++
)
{
acc_o
[
loop_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_mul_f32
(
acc_o
[
loop_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
loop_id
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
],
scores_scale_pair
);
...
...
@@ -401,8 +400,8 @@ inline __device__ void mla_softmax_rescale_o(
#pragma unroll
for
(
int
warp_loop
=
1
;
warp_loop
<
WARP_NUM
;
++
warp_loop
)
{
__float2
other_warp_sum
=
*
(
__float2
*
)(
sum_lds
+
warp_loop
*
WARP_M
+
mi
*
32
+
lane_id
*
2
);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum
=
hcu_pk_add_f32
(
cur_wave_sum
,
other_warp_sum
);
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
cur_wave_sum
=
__builtin_
hcu_pk_add_f32
(
cur_wave_sum
,
other_warp_sum
);
#else
cur_wave_sum
[
0
]
+=
other_warp_sum
[
0
];
cur_wave_sum
[
1
]
+=
other_warp_sum
[
1
];
...
...
@@ -425,8 +424,8 @@ inline __device__ void mla_softmax_rescale_o(
}
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum
[
mi
].
u64
=
hcu_pk_add_f32
(
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
scores_sum
[
mi
].
u64
=
__builtin_
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
...
...
@@ -454,7 +453,7 @@ inline __device__ void mla_convert_pk_type(union_vec2_f16x2<Element> p_reg[M_WAR
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__) || defined(__gfx92a__)
p_reg
[
n_idx
*
M_WARP_COUNT
+
m_idx
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
n_idx
*
M_WARP_COUNT
+
m_idx
][
0
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
p_reg
[
n_idx
*
M_WARP_COUNT
+
m_idx
][
1
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
...
...
csrc/flash_attn_hg/include/numeric_types.h
View file @
518a5f4d
...
...
@@ -79,8 +79,6 @@ union union_vec_fp32 {
union
union_vec4_uint
{
unsigned
long
long
u64
[
2
];
// 128 bits
uint4
u32
;
vec4_int
i32
;
vec4_uint
v32
;
uint8_t
u8
[
16
];
};
...
...
@@ -261,3 +259,4 @@ __forceinline__ __device__ vec4_Element<bhalf_t> make_vec4_f16(bhalf_t a, bhalf_
// return {*(unsigned short*)(&a), *(unsigned short*)(&b), *(unsigned short*)(&c), *(unsigned short*)(&d)};
#endif
}
csrc/flash_attn_hg/src/flash_bwd_launch_template.h
View file @
518a5f4d
...
...
@@ -13,8 +13,10 @@
template
<
bool
Clear_dQaccum
=
true
,
bool
Is_even_MN
,
class
Element
,
class
ElementAccumType
,
int
kBlockM_
,
int
kBlockN_
,
int
WARP_M_
,
int
WARP_N_
,
int
kHeadDim_
,
int
STAGES_
,
bool
USE_BSHD_LAYOUT
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_bwd_dot_do_o_kernel
(
Params
params
)
{
#if defined(__gfx938__)
compute_dot_do_o_gfx938
<
true
,
Is_even_MN
,
Element
,
ElementAccumType
,
kBlockM_
,
kBlockN_
,
WARP_M_
,
WARP_N_
,
kHeadDim_
,
STAGES_
,
USE_BSHD_LAYOUT
>
(
params
);
#if defined(__gfx946__)
// compute_dot_do_o_gfx946<true, Is_even_MN, Element, ElementAccumType, kBlockM_, kBlockN_, WARP_M_, WARP_N_, kHeadDim_, STAGES_, USE_BSHD_LAYOUT>(params);
#elif defined(__gfx938__)
compute_dot_do_o
<
true
,
Is_even_MN
,
Element
,
ElementAccumType
,
kBlockM_
,
kBlockN_
,
WARP_M_
,
WARP_N_
,
kHeadDim_
,
STAGES_
,
USE_BSHD_LAYOUT
>
(
params
);
#else
compute_dot_do_o
<
true
,
Is_even_MN
,
Element
,
ElementAccumType
,
kBlockM_
,
kBlockN_
,
WARP_M_
,
WARP_N_
,
kHeadDim_
,
STAGES_
,
USE_BSHD_LAYOUT
>
(
params
);
#endif
...
...
@@ -27,8 +29,10 @@ __global__ void __launch_bounds__(256,1) flash_attention_dv_dk_bwd_kernel(Param
const
int
bidb
=
bidbh
/
params
.
h
;
const
int
bidh
=
bidbh
%
params
.
h
;
const
int
n_block
=
blockIdx
.
y
;
#if defined(__gfx938__)
compute_dk_dv_1colblock_gfx938
<
Element
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_first
,
Is_last
,
Seq_parallel
,
kBlockM_
,
kBlockN_
,
K
,
K_v
,
kBlockK_
,
WARP_M_
,
WARP_N_
,
USE_BSHD_LAYOUT
>
(
params
,
bidb
,
bidh
,
n_block
);
#if defined(__gfx946__)
// compute_dk_dv_1colblock_gfx946<Element, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, USE_BSHD_LAYOUT>(params, bidb, bidh, n_block);
#elif defined(__gfx938__)
compute_dk_dv_1colblock
<
Element
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_first
,
Is_last
,
Seq_parallel
,
kBlockM_
,
kBlockN_
,
K
,
K_v
,
kBlockK_
,
WARP_M_
,
WARP_N_
,
USE_BSHD_LAYOUT
>
(
params
,
bidb
,
bidh
,
n_block
);
#else
compute_dk_dv_1colblock
<
Element
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_first
,
Is_last
,
Seq_parallel
,
kBlockM_
,
kBlockN_
,
K
,
K_v
,
kBlockK_
,
WARP_M_
,
WARP_N_
,
USE_BSHD_LAYOUT
>
(
params
,
bidb
,
bidh
,
n_block
);
#endif
...
...
@@ -42,8 +46,10 @@ __global__ void __launch_bounds__(256,1) flash_attention_dq_bwd_kernel(Params p
const
int
m_actual_block
=
(
params
.
seqlen_q
+
kBlockM_
-
1
)
/
kBlockM_
;
const
int
m_block
=
m_actual_block
-
1
-
blockIdx
.
y
;
#if defined(__gfx938__)
compute_dq_1colblock_gfx938
<
ElementType
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_first
,
Is_last
,
Seq_parallel
,
kBlockM_
,
kBlockN_
,
K
,
K_v
,
kBlockK_
,
WARP_M_
,
WARP_N_
,
STAGES
,
USE_BSHD_LAYOUT
>
(
params
,
bidb
,
bidh
,
m_block
);
#if defined(__gfx946__)
// compute_dq_1colblock_gfx946<ElementType, float, Is_dropout, Is_causal, Is_local, Is_even_MN, Is_even_K, Is_first, Is_last, Seq_parallel, kBlockM_, kBlockN_, K, K_v, kBlockK_, WARP_M_, WARP_N_, STAGES, USE_BSHD_LAYOUT>(params, bidb, bidh, m_block);
#elif defined(__gfx938__)
compute_dq_1colblock
<
ElementType
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_first
,
Is_last
,
Seq_parallel
,
kBlockM_
,
kBlockN_
,
K
,
K_v
,
kBlockK_
,
WARP_M_
,
WARP_N_
,
STAGES
,
USE_BSHD_LAYOUT
>
(
params
,
bidb
,
bidh
,
m_block
);
#else
compute_dq_1colblock
<
ElementType
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_first
,
Is_last
,
Seq_parallel
,
kBlockM_
,
kBlockN_
,
K
,
K_v
,
kBlockK_
,
WARP_M_
,
WARP_N_
,
STAGES
,
USE_BSHD_LAYOUT
>
(
params
,
bidb
,
bidh
,
m_block
);
#endif
...
...
@@ -80,7 +86,8 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms) {
// std::cout<<"USE_BSHD_LAYOUT="<<USE_BSHD_LAYOUT<<std::endl;
hipStream_t
stream
=
NULL
;
const
bool
is_even_MN
=
((
params
.
seqlen_k
%
kBlockN_
)
==
0
)
&&
((
params
.
seqlen_q
)
%
kBlockM_
==
0
)
&&
params
.
cu_seqlens_q
==
nullptr
;
// Even-MN must be computed with the same tile shape as each launched kernel.
const
bool
is_even_MN_dot
=
((
params
.
seqlen_k
%
kBlockN_
)
==
0
)
&&
((
params
.
seqlen_q
%
kBlockM_
)
==
0
)
&&
params
.
cu_seqlens_q
==
nullptr
;
//is_even_K指headdim是否是32的整数倍,否则需要进行边界判断
const
bool
is_even_K
=
params
.
d
==
K
;
...
...
@@ -109,7 +116,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms) {
// flash_attention_bwd: 34.9 ms
// flash_bwd_convert_dq_kernel: 0.9 ms
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
_dot
,
IsEvenMNConst
,
[
&
]
{
flash_bwd_dot_do_o_kernel
<
true
,
IsEvenMNConst
,
Element
,
ElementAccumType
,
kBlockM_
,
kBlockN_
,
WARP_M_
,
WARP_N_
,
K_v
,
STAGES
,
USE_BSHD_LAYOUT
>
<<<
grid_m
,
kMThreads
,
0
,
stream
>>>
(
params
);
});
...
...
@@ -148,6 +155,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms) {
constexpr
int
dk_dv_kBlockK
=
32
;
constexpr
int
dk_dv_WARP_M
=
32
;
constexpr
int
dk_dv_WARP_N
=
32
;
const
bool
is_even_MN_dk_dv
=
((
params
.
seqlen_k
%
dk_dv_kBlockN
)
==
0
)
&&
((
params
.
seqlen_q
%
dk_dv_kBlockM
)
==
0
)
&&
params
.
cu_seqlens_q
==
nullptr
;
dim3
dimBlock
;
int
maxBlockThreads
=
512
;
dimBlock
.
x
=
min
((
dk_dv_kBlockN
)
/
(
dk_dv_WARP_N
)
*
64
,
maxBlockThreads
);
...
...
@@ -164,7 +172,7 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms) {
// dim3 grid_n(gridDimx, params.h, params.b);
dim3
grid_n
(
params
.
se_balance_cnt
,
gridDimx
,
(
params
.
h
*
params
.
b
/
params
.
se_balance_cnt
));
// printf("flash_attention_dv_dk_bwd_kernel : grid(%d, %d, %d) | block(%d, %d, %d)\n", grid_n.x, grid_n.y, grid_n.z, dimBlock.x, dimBlock.y, dimBlock.z);
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
_dk_dv
,
IsEvenMNConst
,
[
&
]
{
flash_attention_dv_dk_bwd_kernel
<
Element
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
IsEvenMNConst
,
true
,
Is_first
,
Is_last
,
Seq_parallel
,
dk_dv_kBlockM
,
dk_dv_kBlockN
,
K
,
K_v
,
dk_dv_kBlockK
,
dk_dv_WARP_M
,
dk_dv_WARP_N
,
USE_BSHD_LAYOUT
>
<<<
grid_n
,
dimBlock
,
sharedMemSize
,
stream
>>>
(
params
);
});
...
...
@@ -176,12 +184,13 @@ void run_flash_bwd_seqk_parallel(Flash_bwd_params ¶ms) {
constexpr
int
dq_kBlockK
=
32
;
constexpr
int
dq_WARP_M
=
32
;
constexpr
int
dq_WARP_N
=
32
;
const
bool
is_even_MN_dq
=
((
params
.
seqlen_k
%
dq_kBlockN
)
==
0
)
&&
((
params
.
seqlen_q
%
dq_kBlockM
)
==
0
)
&&
params
.
cu_seqlens_q
==
nullptr
;
int
dq_kMThreads
=
(
dq_kBlockM
+
dq_WARP_M
-
1
)
/
dq_WARP_M
*
64
;
const
int
num_m_block_dq
=
(
params
.
seqlen_q
+
dq_kBlockM
-
1
)
/
dq_kBlockM
;
// dim3 grid_m(num_m_block_dq, params.h, params.b);
dim3
grid_m
(
params
.
se_balance_cnt
,
num_m_block_dq
,
(
params
.
h
*
params
.
b
/
params
.
se_balance_cnt
));
// printf("flash_attention_dq_bwd_kernel : grid(%d, %d, %d) | block(%d, %d, %d)\n", grid_m.x, grid_m.y, grid_m.z, dq_kMThreads, 1, 1);
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_even_MN
_dq
,
IsEvenMNConst
,
[
&
]
{
flash_attention_dq_bwd_kernel
<
Element
,
float
,
Is_dropout
,
Is_causal
,
Is_local
,
IsEvenMNConst
,
true
,
Is_first
,
Is_last
,
Seq_parallel
,
dq_kBlockM
,
dq_kBlockN
,
K
,
K_v
,
dq_kBlockK
,
dq_WARP_M
,
dq_WARP_N
,
2
,
USE_BSHD_LAYOUT
>
<<<
grid_m
,
dq_kMThreads
,
sharedMemSize
,
stream
>>>
(
params
);
});
...
...
csrc/flash_attn_hg/src/flash_fwd_b16_fa.h
View file @
518a5f4d
...
...
@@ -247,6 +247,11 @@ inline __device__ void compute_attn_mha_1rowblock(const Params ¶ms, const in
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
reinterpret_cast
<
const
float
*>
(
params
.
s_aux_ptr
)[
bidh
];
fwd_apply_attention_sink
<
WARP_M
,
kBlockK
,
kHeadDimV
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
sink_value
);
}
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
Is_dropout
&&
Is_training
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
int
lane_id
=
threadIdx
.
x
&
63
;
...
...
@@ -288,7 +293,7 @@ inline __device__ void compute_attn(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_attn_mha_prefix_prefill_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
__bidh
,
const
int
m_block
,
const
int
WARP_ID
)
{
inline
__device__
void
compute_attn_mha_prefix_prefill_1rowblock
(
const
Params
&
params
,
const
int
bidb
,
const
int
__bidh
,
const
int
m_block
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
...
...
@@ -330,17 +335,17 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
}
// 计算数据跨度
int
seqlen_q_stride
=
(
Layout
==
1
)
?
params
.
q_row_stride
:
params
.
q_row_stride
;
int
seqlen_k_stride
=
(
Layout
==
1
)
?
params
.
k_row_stride
:
params
.
k_row_stride
;
int
seqlen_v_stride
=
(
Layout
==
1
)
?
params
.
v_row_stride
:
params
.
v_row_stride
;
int
seqlen_o_stride
=
(
Layout
==
1
)
?
params
.
o_row_stride
:
params
.
o_row_stride
;
int
seqlen_q_stride
=
params
.
q_row_stride
;
int
seqlen_k_stride
=
params
.
k_row_stride
;
int
seqlen_v_stride
=
params
.
v_row_stride
;
int
seqlen_o_stride
=
params
.
o_row_stride
;
int64_t
row_offset_q
,
row_offset_k
,
row_offset_v
,
row_offset_o
;
int64_t
row_offset_lse
;
// 获取页表信息
const
int
page_block_size
=
params
.
page_block_size
;
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
block_table_idx
=
n_block_min
;
const
int
block_table_offset
=
0
;
const
int
block_table_idx
=
n_block_min
*
kBlockN
/
page_block_size
;
const
int
block_table_offset
=
n_block_min
*
kBlockN
-
block_table_idx
*
page_block_size
;
if
constexpr
(
Layout
==
1
)
{
/*bshd layout, lse is num_heads, total_q*/
row_offset_q
=
(
binfo
.
sum_s_q
+
m_block
*
kBlockM
)
*
int64_t
(
seqlen_q_stride
)
+
params
.
q_head_stride
*
bidh
;
row_offset_k
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
k_batch_stride
)
+
block_table_offset
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
...
...
@@ -361,9 +366,9 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
auto
gV
=
prepare_for_buffer_load
<
kHeadDimV
>
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
);
// attention 变体: Alibi
float
g
A
libi
;
float
g
_a
libi
;
if
constexpr
(
Has_alibi
)
{
g
A
libi
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
g
_a
libi
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
}
// attention 插件: Dropout
...
...
@@ -372,9 +377,9 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
union_vec2_uint
warp_idx_for_dropout
;
if
constexpr
(
Is_dropout
)
{
rand_seed
=
params
.
rand_seed
;
rand_offset
=
params
.
rand_offset
+
((
bidb
*
params
.
h
+
bidh
)
<<
6
)
+
threadIdx
.
x
&
63
;
/* 参考官方写法 offset(offset + (bid * nheads + hid) * 32 + tid % 32) */
p_dropout_in_8bits_value
=
params
.
p_dropout_in_uint8_t
&
0xffffffff
;
/*hcu 不支持 16bit 和 8bit 的比较指令*/
warp_idx_for_dropout
.
u32
.
x
=
1
*
m_block
*
(
kBlockM
/
32
)
/*前面几个 block 累积的 warp 数目, 这里不直接填 WARP_M, 参照 NV 的写法*/
+
WARP_ID
/*当前 block 内的 warp id*/
;
rand_offset
=
params
.
rand_offset
+
((
bidb
*
params
.
h
+
bidh
)
<<
6
)
+
threadIdx
.
x
&
63
;
p_dropout_in_8bits_value
=
params
.
p_dropout_in_uint8_t
&
0xffffffff
;
warp_idx_for_dropout
.
u32
.
x
=
1
*
m_block
*
(
kBlockM
/
32
)
+
warp_id
;
// Save seed and offset for backward, before any early exiting. Otherwise the 0-th thread block might exit early and no one saves the rng states.
if
(
m_block
==
0
and
bidb
==
0
and
bidh
==
0
and
threadIdx
.
x
==
0
)
{
params
.
rng_state
[
0
]
=
rand_seed
;
...
...
@@ -383,21 +388,17 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
}
// 预取 Q 的数据到寄存器
vec2_Element
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
][
4
];
// ds_read mini size is 32 * 32,2 is seq, 4 is head dim
Is_even_MN
?
prefetch_q_to_vgpr
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
gQ
,
q_lds
,
q_reg
,
WARP_ID
,
seqlen_q_stride
)
:
prefetch_q_to_vgpr
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
gQ
,
q_lds
,
q_reg
,
WARP_ID
,
seqlen_q_stride
,
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
));
vec2_Element
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
][
4
];
prefetch_q_to_vgpr
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
gQ
,
q_lds
,
q_reg
,
warp_id
,
seqlen_q_stride
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
/***************************************************************************************************************************/
/***************************************************************************************************************************/
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
];
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
];
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
];
attention_initialize
<
kHeadDimV
/
kBlockK
,
WARP_M
/
32
,
kBlockK
/
32
,
2
/*M_MMAC_COUNT*/
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
/***************************************************************************************************************************/
/***************************************************************************************************************************/
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr
bool
PREFETCH_K
=
false
;
constexpr
bool
PREFETCH_K
=
false
;
// true;
constexpr
bool
Aggressive
=
(
kHeadDim
==
128
or
kHeadDim
==
64
);
auto
QK_GEMM_FUNC
=
Aggressive
?
&
qk_gemm_prefetch_v_headdim128
<
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockN
,
kBlockK
,
WARP_M
,
WARP_N
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
...
...
@@ -406,47 +407,36 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
auto
PV_GEMM_FUNC
=
Aggressive
?
&
pv_gemm_prefetch_k_headdim128
<
PREFETCH_K
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
:
&
pv_gemm_prefetch_k
<
PREFETCH_K
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
// mask 循环中不需要做 prefetch K, 因此 prefetch K 固定为 false
auto
PV_GEMM_FUNC_IN_MASK
=
Aggressive
?
&
pv_gemm_prefetch_k_headdim128
<
false
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
:
&
pv_gemm_prefetch_k
<
false
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
if
constexpr
(
PREFETCH_K
)
{
prefetch_k_to_lds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
gK
,
k_lds
,
warp_id
,
seqlen_k_stride
,
binfo
.
actual_seqlen_k
-
n_block_min
*
kBlockN
);
}
// constexpr int n_masking_steps = (!Is_causal && !Is_local) ? 1 : flash::ceil_div(kBlockM, kBlockN);
// These are the iterations where we don't need masking on S
for
(
int
n_block_loop
=
n_block_min
;
n_block_loop
<
n_block_max
/*n_block_max - n_masking_steps*/
;
++
n_block_loop
)
{
const
int
seqlen_kv_limit
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
// c mini tile is 32 * 32
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
WARP_N
)][
4
];
if
constexpr
(
STAGES
>
1
)
{
Is_even_MN
?
prefetch_k_to_lds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
gK
,
k_lds
,
WARP_ID
,
seqlen_k_stride
)
:
prefetch_k_to_lds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
gK
,
k_lds
,
WARP_ID
,
seqlen_k_stride
,
seqlen_kv_limit
);
if
constexpr
(
not
PREFETCH_K
)
{
prefetch_k_to_lds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
gK
,
k_lds
,
warp_id
,
seqlen_k_stride
,
seqlen_kv_limit
);
}
Is_even_MN
?
QK_GEMM_FUNC
(
gQ
,
gK
,
gV
,
q_lds
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
WARP_ID
,
seqlen_k_stride
,
seqlen_v_stride
,
0
)
:
QK_GEMM_FUNC
(
gQ
,
gK
,
gV
,
q_lds
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
WARP_ID
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_kv_limit
);
QK_GEMM_FUNC
(
gQ
,
gK
,
gV
,
q_lds
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_kv_limit
);
if
constexpr
(
Has_alibi
)
{
apply_alibi
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
n_block_loop
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
WARP_ID
*
WARP_M
,
binfo
.
actual_seqlen_q
,
g
A
libi
);
apply_alibi
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
n_block_loop
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
binfo
.
actual_seqlen_q
,
g
_a
libi
);
}
if
constexpr
(
!
Is_causal
&&
!
Is_local
)
{
if
constexpr
(
!
Is_even_MN
)
{
apply_mask
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
seqlen_kv_limit
);
}
}
else
{
if
constexpr
(
Is_local
)
{
apply_mask_local
<
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
n_block_loop
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
WARP_ID
*
WARP_M
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
apply_mask_local
<
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
n_block_loop
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
}
else
if
constexpr
(
Is_causal
)
{
apply_mask_causal
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
n_block_loop
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
WARP_ID
*
WARP_M
,
binfo
.
actual_seqlen_q
);
}
apply_mask_causal
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
n_block_loop
*
kBlockN
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
binfo
.
actual_seqlen_q
);
}
softmax_rescale_o
<
false
,
Is_causal
||
Is_local
,
vec4_Accum
<
ElementAccum
>
,
vec2_Accum
<
ElementAccum
>
,
kHeadDimV
,
kBlockK
,
WARP_M
,
kBlockN
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
...
...
@@ -457,12 +447,11 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
}
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
];
// convertType: float2half
convert_pk_type
<
WARP_M
,
kBlockN
,
Element
,
ElementAccum
,
true
/*IsInference*/
>
(
p_reg
,
s_reg
);
Is_even_MN
?
PV_GEMM_FUNC
(
gV
,
gK
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
WARP_ID
,
seqlen_k_stride
,
seqlen_v_stride
,
0
)
:
PV_GEMM_FUNC
(
gV
,
gK
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
WARP_ID
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_kv_limit
);
if
constexpr
(
not
PREFETCH_K
)
{
PV_GEMM_FUNC
(
gV
,
gK
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_kv_limit
);
}
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
...
...
@@ -472,21 +461,31 @@ inline __device__ void compute_attn_mha_prefix_prefill_1rowblock(const Params &p
const
int
offset_diff
=
block_table_offset_next
-
block_table_offset_cur
;
*
(
int64_t
*
)
&
gK
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
k_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
k_row_stride
))
*
sizeof
(
Element
);
if
constexpr
(
PREFETCH_K
)
{
PV_GEMM_FUNC
(
gV
,
gK
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_kv_limit
);
}
*
(
int64_t
*
)
&
gV
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
v_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
v_row_stride
))
*
sizeof
(
Element
);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
reinterpret_cast
<
const
float
*>
(
params
.
s_aux_ptr
)[
bidh
];
fwd_apply_attention_sink
<
WARP_M
,
kBlockK
,
kHeadDimV
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
sink_value
);
}
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
false
/*Is_dropout*/
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
int
lane_id
=
threadIdx
.
x
&
63
;
if
(
params
.
softmax_lse_ptr
!=
nullptr
)
{
fwd_epilogue_store_lse
<
WARP_M
,
Is_even_MN
,
SplitD
,
false
/*Is_Interleaved*/
,
ElementAccum
>
(
lse
,
params
.
softmax_lse_ptr
,
row_offset_lse
,
WARP_ID
,
lane_id
,
0
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
fwd_epilogue_store_lse
<
WARP_M
,
Is_even_MN
,
SplitD
,
false
/*Is_Interleaved*/
,
ElementAccum
>
(
lse
,
params
.
softmax_lse_ptr
,
row_offset_lse
,
warp_id
,
lane_id
,
0
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
/**************************************************************************************************************************************/
Element
*
o_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
;
fwd_epilogue_store_output
<
kHeadDimV
,
kBlockM
,
kBlockK
,
WARP_M
,
Is_even_MN
,
false
/*Is_Interleaves*/
,
false
/*TcpSwizzle*/
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
WARP_ID
,
lane_id
,
seqlen_o_stride
,
binfo
.
actual_seqlen_q
);
fwd_epilogue_store_output
<
kHeadDimV
,
kBlockM
,
kBlockK
,
WARP_M
,
Is_even_MN
,
false
/*Is_Interleaves*/
,
false
/*TcpSwizzle*/
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
warp_id
,
lane_id
,
seqlen_o_stride
,
binfo
.
actual_seqlen_q
);
}
...
...
@@ -674,6 +673,11 @@ inline __device__ void compute_attn_mha_padding_mask_1rowblock(const Params &par
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
reinterpret_cast
<
const
float
*>
(
params
.
s_aux_ptr
)[
bidh
];
fwd_apply_attention_sink
<
WARP_M
,
kBlockK
,
kHeadDimV
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
sink_value
);
}
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
false
/*Is_dropout && Is_training*/
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
int
lane_id
=
threadIdx
.
x
&
63
;
...
...
@@ -840,6 +844,11 @@ inline __device__ void compute_attn_mha_attn_mask_1rowblock(const Params ¶ms
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
reinterpret_cast
<
const
float
*>
(
params
.
s_aux_ptr
)[
bidh
];
fwd_apply_attention_sink
<
WARP_M
,
kBlockK
,
kHeadDimV
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
sink_value
);
}
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
false
/*Is_dropout && Is_training*/
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
int
lane_id
=
threadIdx
.
x
&
63
;
...
...
@@ -1107,6 +1116,11 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params ¶ms, c
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
reinterpret_cast
<
const
float
*>
(
params
.
s_aux_ptr
)[
bidh
];
fwd_apply_attention_sink
<
WARP_M
,
kBlockK
,
kHeadDimPVCompute
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
sink_value
);
}
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimPVCompute
,
Is_dropout
&&
Is_training
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
constexpr
bool
Is_Interleave
=
true
;
...
...
@@ -1123,7 +1137,7 @@ inline __device__ void compute_attn_mha_1rowblock_gfx938(const Params ¶ms, c
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_attn_gfx938
(
const
Params
&
params
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
constexpr
bool
Do_lpt
=
Is_causal
;
const
int
bidh
=
Do_lpt
?
blockIdx
.
x
:
blockIdx
.
y
;
...
...
@@ -1194,8 +1208,8 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
// 获取页表信息
const
int
page_block_size
=
params
.
page_block_size
;
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
block_table_idx
=
n_block_min
;
const
int
block_table_offset
=
0
;
const
int
block_table_idx
=
n_block_min
*
kBlockN
/
page_block_size
;
const
int
block_table_offset
=
n_block_min
*
kBlockN
-
block_table_idx
*
page_block_size
;
if
constexpr
(
Layout
==
1
)
{
/*bshd layout, lse is num_heads, total_q*/
row_offset_q
=
(
binfo
.
sum_s_q
+
m_block
*
kBlockM
)
*
int64_t
(
seqlen_q_stride
)
+
params
.
q_head_stride
*
bidh
;
row_offset_k
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
k_batch_stride
)
+
block_table_offset
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
...
...
@@ -1351,10 +1365,12 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
}
// Attention: mask, causal mask, local mask
if
constexpr
(
Is_local
)
{
apply_mask_
causal_
gfx938
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq
_q
,
binfo
.
actual_seqlen_q
);
}
else
if
constexpr
(
Is_
caus
al
)
{
if
constexpr
(
!
Is_causal
&&
!
Is_local
)
{
apply_mask_gfx938
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq
kv
);
}
else
if
constexpr
(
Is_
loc
al
)
{
apply_mask_local_gfx938
<
/*HasWSLeft=*/
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
}
else
if
constexpr
(
Is_causal
)
{
apply_mask_causal_gfx938
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
...
...
@@ -1388,6 +1404,11 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
reinterpret_cast
<
const
float
*>
(
params
.
s_aux_ptr
)[
bidh
];
fwd_apply_attention_sink
<
WARP_M
,
kBlockK
,
kHeadDimV
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
sink_value
);
}
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
Is_dropout
&&
Is_training
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
constexpr
bool
Is_Interleave
=
true
;
...
...
@@ -1401,19 +1422,569 @@ inline __device__ void compute_attn_prefix_prefill_1rowblock_gfx938(const Params
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_prefix_prefill_gfx938_kernel
(
const
Params
params
)
{
////////////////////////////////////////////////////////////////////////////////////////////////////
// GFX92A kernels
////////////////////////////////////////////////////////////////////////////////////////////////////
#if defined(__gfx938__)
const
int
bidh
=
blockIdx
.
x
;
#include "fwd/gfx92a/qk_gemm_prefetch_v_mls_ds_gfx92a.h"
#include "fwd/gfx92a/pv_gemm_prefetch_k_mls_ds_gfx92a.h"
#include "fwd/gfx92a/softmax_gfx92a.h"
#include "fwd/gfx92a/fwd_epilogue_gfx92a.h"
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_attn_mha_1rowblock_gfx92a
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
std
::
conditional_t
<
Is_even_MN
,
int32_t
,
int64_t
>
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockK
=
Kernel_traits
::
kBlockK
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDimV
=
Kernel_traits
::
kHeadDimV
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
WARP_M
=
Kernel_traits
::
kWaveM
;
constexpr
int
WARP_N
=
Kernel_traits
::
kWaveN
;
constexpr
int
STAGES
=
Kernel_traits
::
STAGES
;
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
WARP_K
=
32
;
const
int
bidb
=
blockIdx
.
y
;
// 获取当前 TG 处理的任务大小
using
BlockInfoType
=
flash
::
BlockInfo
<
Is_Varlen
,
false
/*Is_Kvcache*/
,
false
/*USE_BSHD_LAYOUT*/
>
;
const
BlockInfoType
binfo
(
params
,
bidb
);
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
// 处理边界
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
||
bidh
>=
params
.
h
/*border judgement*/
)
return
;
int
warp_offset_in_seq_q
=
m_block
*
kBlockM
+
warp_id
*
WARP_M
;
int
m_block
=
gridDim
.
z
-
1
-
blockIdx
.
z
;
flash
::
compute_attn_prefix_prefill_1rowblock_gfx938
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
// 分配 lds
extern
__shared__
Element
smem
[];
Element
*
q_lds
=
(
Element
*
)
&
(
smem
);
Element
*
k_lds
=
q_lds
;
Element
*
v_lds
=
k_lds
;
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
constexpr
(
Is_causal
||
Is_local
)
{
n_block_max
=
std
::
min
(
n_block_max
,
flash
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
,
kBlockN
));
}
// 计算数据跨度
int
seqlen_q_stride
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_o_stride
;
index_t
row_offset_q
,
row_offset_k
,
row_offset_v
,
row_offset_o
;
int
row_offset_lse
;
int
headdim_split_id
=
0
;
fwd_prologue_compute_offset
<
Layout
,
kBlockM
,
kBlockN
,
kHeadDim
,
kHeadDimV
,
kHeadDimV
,
0
/*SplitD*/
,
Is_even_MN
,
false
/*Is_PaddingMask*/
,
Params
,
decltype
(
binfo
),
decltype
(
row_offset_q
)
>
(
seqlen_q_stride
,
seqlen_k_stride
,
seqlen_v_stride
,
seqlen_o_stride
,
row_offset_q
,
row_offset_k
,
row_offset_v
,
row_offset_o
,
row_offset_lse
,
headdim_split_id
,
bidb
,
bidh
,
bidh
,
m_block
,
n_block_min
,
binfo
,
params
);
#if 0
if (int(threadIdx.x) == 0) {
printf("bidb: %d | bidh: %d | actual_seqlen_q: %d | actual_seqlen_k: %d | n_block_max: %d | row_offset_q: %d | row_offset_k: %d | row_offset_v: %d | row_offset_o: %d | seqlen_q_stride: %d | seqlen_k_stride: %d | seqlen_v_stride: %d\n",
bidb, bidh, binfo.actual_seqlen_q, binfo.actual_seqlen_k, n_block_max, row_offset_q, row_offset_k, row_offset_v, row_offset_o, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride);
}
#endif
// 根据起始数据偏移量准备 Q/K/V 的 buffer resource 寄存器
auto
q_ptr
=
prepare_for_matrix_load
<
kHeadDim
>
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
);
auto
k_ptr
=
prepare_for_matrix_load
<
kHeadDim
>
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
);
auto
v_ptr
=
prepare_for_matrix_load
<
kHeadDimV
>
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
);
// attention 插件: Alibi
float
g_alibi
;
if
constexpr
(
Has_alibi
)
{
g_alibi
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
}
// attention 插件: Dropout
unsigned
long
long
rand_seed
,
rand_offset
;
uint32_t
p_dropout_in_8bits_value
;
union_vec2_uint
warp_idx_for_dropout
;
if
constexpr
(
Is_dropout
and
Is_training
)
{
rand_seed
=
params
.
rand_seed
;
rand_offset
=
params
.
rand_offset
+
((
bidb
*
params
.
h
+
bidh
)
<<
6
)
+
(
threadIdx
.
x
&
63
);
p_dropout_in_8bits_value
=
params
.
p_dropout_in_uint8_t
&
0xffffffff
;
warp_idx_for_dropout
.
u32
.
x
=
1
*
m_block
*
(
kBlockM
/
32
)
/* 前面几个 block 累积的 warp 数目, 这里不直接填 WARP_M, 参照 NV 的写法*/
+
warp_id
/*当前 block 内的 warp id*/
;
if
(
Is_training
and
m_block
==
0
and
bidb
==
0
and
bidh
==
0
and
threadIdx
.
x
==
0
)
{
params
.
rng_state
[
0
]
=
rand_seed
;
params
.
rng_state
[
1
]
=
rand_offset
;
}
}
// 预取 Q 的数据到寄存器
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
];
prefetch_q_to_vgpr_mls_ds_gfx92a
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
q_reg
,
warp_id
,
seqlen_q_stride
,
Is_even_MN
?
0
:
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
// apply causal mask 的步骤和 no causal mask 的步骤分开算
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
:
flash
::
ceil_div
(
kBlockM
,
kBlockN
);
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr
bool
PREFETCH_K
=
Is_even_MN
and
kHeadDim
==
128
and
kHeadDimV
==
128
;
constexpr
bool
ALLOW_PREFETCH
=
(
STAGES
>
1
);
// 客观上决定是否开启 prefetch
if
constexpr
(
PREFETCH_K
and
ALLOW_PREFETCH
)
{
if
(
n_block_min
<
n_block_max
-
n_masking_steps
)
{
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
Is_even_MN
?
0
:
binfo
.
actual_seqlen_k
-
n_block_min
*
kBlockN
);
}
}
/***************************************************************************************************************************/
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
];
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
];
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
];
attention_initialize
<
kHeadDimV
/
kBlockK
,
WARP_M
/
32
,
kBlockK
/
32
,
2
/*M_MMAC_COUNT*/
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
/***************************************************************************************************************************/
auto
QK_GEMM_FUNC
=
&
qk_gemm_prefetch_v_mls_ds_gfx92a
<
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockN
,
kBlockK
,
WARP_M
,
WARP_N
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
auto
PV_GEMM_FUNC
=
&
pv_gemm_prefetch_k_mls_ds_gfx92a
<
PREFETCH_K
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
auto
PV_GEMM_FUNC_IN_MASK
=
&
pv_gemm_prefetch_k_mls_ds_gfx92a
<
false
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
// Mainloop, 主循环, 不做 causal mask 的部分
for
(
int
n_block_loop
=
n_block_min
;
n_block_loop
<
n_block_max
-
n_masking_steps
;
++
n_block_loop
)
{
flash
::
raise_priority
();
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
;
int
warp_seqkv_limit
=
Is_even_MN
?
0
:
binfo
.
actual_seqlen_k
-
warp_offset_in_seqkv
;
// 预取 K 的数据到 lds
if
constexpr
(
not
PREFETCH_K
)
{
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
warp_seqkv_limit
);
}
// 准备 QK gemm 输出的寄存器
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
WARP_N
)][
4
];
// QK gemm
QK_GEMM_FUNC
(
k_ptr
,
v_ptr
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// Attention 变体 alibi
if
constexpr
(
Has_alibi
)
{
apply_alibi_gfx938
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
g_alibi
);
}
// Attention 变体 local mask
if
constexpr
(
Is_local
)
{
apply_mask_local_gfx938
<
/*HasWSLeft=*/
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o
<
false
,
Is_causal
||
Is_local
,
vec4_Accum
<
ElementAccum
>
,
vec2_Accum
<
ElementAccum
>
,
kHeadDimV
,
kBlockK
,
WARP_M
,
kBlockN
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Attention 变体 dropout
if
constexpr
(
Is_dropout
and
Is_training
)
{
warp_idx_for_dropout
.
u32
.
y
=
n_block_loop
*
(
kBlockN
/
WARP_N
);
apply_dropout
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
,
kNWarps
,
Is_even_MN
>
(
s_reg
,
warp_seqkv_limit
,
0
,
rand_seed
,
rand_offset
,
p_dropout_in_8bits_value
,
warp_idx_for_dropout
,
params
.
dropout_debug_count
);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
];
convert_pk_type
<
WARP_M
,
kBlockN
,
Element
,
ElementAccum
>
(
p_reg
,
s_reg
);
// 偏移 K 指针, 提前偏移准备预取 K
*
(
uint64_t
*
)
&
k_ptr
+=
kBlockN
*
params
.
k_row_stride
*
sizeof
(
Element
);
// PV gemm
PV_GEMM_FUNC
(
v_ptr
,
k_ptr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// 偏移 V 指针
*
(
uint64_t
*
)
&
v_ptr
+=
kBlockN
*
params
.
v_row_stride
*
sizeof
(
Element
);
}
// prefetch K 的话, 最后一次多取了一段 K, 为了不影响后续的操作, 需要同步等待
if
constexpr
(
PREFETCH_K
)
{
buffer_load_lds_dwordx1_wait
<
0
>
();
}
/***************************************************************************************************************************/
// Rest loop, 做 causal mask 的部分
int
n_block_loop
=
max
(
n_block_max
-
n_masking_steps
,
n_block_min
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
++
n_block_loop
)
{
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
;
int
warp_seqkv_limit
=
Is_even_MN
?
0
:
binfo
.
actual_seqlen_k
-
warp_offset_in_seqkv
;
// 预取 K 的数据到 lds
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
warp_seqkv_limit
);
// 准备 QK gemm 输出的寄存器
vec4_Accum
<
ElementAccum
>
s_reg
[(
kBlockN
/
32
)
*
(
WARP_M
/
32
)][
4
];
// QK gemm
QK_GEMM_FUNC
(
k_ptr
,
v_ptr
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// 偏移 K 指针, 提前偏移准备预取 K
*
(
uint64_t
*
)
&
k_ptr
+=
kBlockN
*
params
.
k_row_stride
*
sizeof
(
Element
);
// Attention 变体 alibi
if
constexpr
(
Has_alibi
)
{
apply_alibi_gfx938
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
g_alibi
);
}
// Attention: mask, causal mask, local mask
if
constexpr
(
!
Is_causal
&&
!
Is_local
)
{
if
constexpr
(
!
Is_even_MN
)
{
apply_mask_gfx92a
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_seqkv_limit
);
}
}
else
{
if
constexpr
(
Is_causal
)
{
apply_mask_causal_gfx92a
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
);
}
else
if
constexpr
(
Is_local
)
{
apply_mask_local_gfx938
<
/*HasWSLeft=*/
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
}
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o
<
false
,
Is_causal
||
Is_local
,
vec4_Accum
<
ElementAccum
>
,
vec2_Accum
<
ElementAccum
>
,
kHeadDimV
,
kBlockK
,
WARP_M
,
kBlockN
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Attention 变体 dropout
if
constexpr
(
Is_dropout
and
Is_training
)
{
warp_idx_for_dropout
.
u32
.
y
=
n_block_loop
*
(
kBlockN
/
WARP_N
);
apply_dropout
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
,
kNWarps
,
Is_even_MN
>
(
s_reg
,
warp_seqkv_limit
,
0
,
rand_seed
,
rand_offset
,
p_dropout_in_8bits_value
,
warp_idx_for_dropout
,
params
.
dropout_debug_count
);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
];
convert_pk_type
<
WARP_M
,
kBlockN
,
Element
,
ElementAccum
>
(
p_reg
,
s_reg
);
// PV gemm
PV_GEMM_FUNC_IN_MASK
(
v_ptr
,
k_ptr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// 偏移 V 指针
*
(
uint64_t
*
)
&
v_ptr
+=
kBlockN
*
params
.
v_row_stride
*
sizeof
(
Element
);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
Is_dropout
&&
Is_training
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
constexpr
bool
Is_Interleave
=
true
;
constexpr
bool
Is_Output_Interleave
=
false
;
int
lane_id
=
threadIdx
.
x
&
63
;
if
(
params
.
softmax_lse_ptr
!=
nullptr
)
{
fwd_epilogue_store_lse
<
WARP_M
,
Is_even_MN
,
false
/*SplitD*/
,
Is_Interleave
,
ElementAccum
>
(
lse
,
params
.
softmax_lse_ptr
,
row_offset_lse
,
warp_id
,
lane_id
,
0
,
Is_even_MN
?
0
:
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
/**************************************************************************************************************************************/
Element
*
o_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
;
fwd_epilogue_store_output_mls_gfx92a
<
kHeadDimV
,
kBlockM
,
kBlockK
,
WARP_M
,
Is_even_MN
,
Is_Output_Interleave
,
false
/*TcpSwizzle*/
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
warp_id
,
lane_id
,
seqlen_o_stride
,
binfo
.
actual_seqlen_q
);
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_attn_gfx92a
(
const
Params
&
params
)
{
#if defined(__gfx92a__)
constexpr
bool
Do_lpt
=
Is_causal
;
const
int
bidh
=
Do_lpt
?
blockIdx
.
x
:
blockIdx
.
y
;
const
int
bidb
=
Do_lpt
?
blockIdx
.
y
:
blockIdx
.
z
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
int
m_block
=
Do_lpt
?
gridDim
.
z
-
1
-
blockIdx
.
z
:
blockIdx
.
x
;
flash
::
compute_attn_mha_1rowblock_gfx92a
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_Varlen
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_attn_prefix_prefill_1rowblock_gfx92a
(
const
Params
&
params
,
const
int
bidb
,
const
int
__bidh
,
const
int
m_block
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
constexpr
int
WARP_K
=
32
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockK
=
Kernel_traits
::
kBlockK
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDimV
=
Kernel_traits
::
kHeadDimVSplit
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
WARP_M
=
Kernel_traits
::
kWaveM
;
constexpr
int
WARP_N
=
Kernel_traits
::
kWaveN
;
constexpr
int
STAGES
=
Kernel_traits
::
STAGES
;
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
constexpr
int
SplitD
=
Kernel_traits
::
SplitD
;
constexpr
int
kHeadDimVOrigin
=
Kernel_traits
::
kHeadDimV
;
// 获取 splitD 结果
const
int
bidh
=
__bidh
/
SplitD
;
// 获取当前 TG 处理的任务大小
// const flash::BlockInfo</*Varlen=*/!Is_even_MN, false/*Is_kvcache*/> binfo(params, bidb);
flash
::
SafeDecodeBlockInfo
binfo
;
binfo
.
set_params
<
Params
,
/*Is_Q_varlen=*/
true
,
/*Is_K_Cumulative=*/
false
>
(
params
,
bidb
);
// 处理边界
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
int
warp_offset_in_seq_q
=
m_block
*
kBlockM
+
warp_id
*
WARP_M
;
// 分配 lds
extern
__shared__
Element
smem
[];
Element
*
q_lds
=
(
Element
*
)
&
(
smem
);
Element
*
k_lds
=
q_lds
;
Element
*
v_lds
=
k_lds
;
// 计算当前任务沿着 seqlenKV 方向的 block 起始位置和终止位置
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
constexpr
(
Is_causal
||
Is_local
)
{
n_block_max
=
std
::
min
(
n_block_max
,
flash
::
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
params
.
window_size_right
,
kBlockN
));
}
// 计算数据跨度
int
seqlen_q_stride
=
params
.
q_row_stride
;
int
seqlen_k_stride
=
params
.
k_row_stride
;
int
seqlen_v_stride
=
params
.
v_row_stride
;
int
seqlen_o_stride
=
params
.
o_row_stride
;
int64_t
row_offset_q
,
row_offset_k
,
row_offset_v
,
row_offset_o
;
int64_t
row_offset_lse
;
// 获取页表信息
const
int
page_block_size
=
params
.
page_block_size
;
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
kv_start
=
n_block_min
*
kBlockN
;
const
int
block_table_idx
=
kv_start
/
page_block_size
;
const
int
block_table_offset
=
kv_start
-
block_table_idx
*
page_block_size
;
if
constexpr
(
Layout
==
1
)
{
/*bshd layout, lse is num_heads, total_q*/
row_offset_q
=
(
binfo
.
sum_s_q
+
m_block
*
kBlockM
)
*
int64_t
(
seqlen_q_stride
)
+
params
.
q_head_stride
*
bidh
;
row_offset_k
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
k_batch_stride
)
+
block_table_offset
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
row_offset_v
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
v_batch_stride
)
+
block_table_offset
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
row_offset_lse
=
bidh
*
params
.
total_q
+
binfo
.
sum_s_q
+
m_block
*
kBlockM
;
row_offset_o
=
binfo
.
sum_s_q
*
int64_t
(
params
.
o_head_stride
)
*
params
.
h
+
params
.
o_head_stride
*
bidh
+
m_block
*
kBlockM
*
seqlen_o_stride
;
}
#if 0
if (int(threadIdx.x) == 0) {
printf("bidb: %d | bidh: %d | actual_seqlen_q: %d | actual_seqlen_k: %d | n_block_max: %d | row_offset_q: %ld | row_offset_k: %ld | row_offset_v: %ld | row_offset_o: %ld | seqlen_q_stride: %d | seqlen_k_stride: %d | seqlen_v_stride: %d\n",
bidb, bidh, binfo.actual_seqlen_q, binfo.actual_seqlen_k, n_block_max, row_offset_q, row_offset_k, row_offset_v, row_offset_o, seqlen_q_stride, seqlen_k_stride, seqlen_v_stride);
}
#endif
// 根据起始数据偏移量准备 Q/K/V 的资源寄存器
auto
q_ptr
=
prepare_for_matrix_load
<
kHeadDim
>
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
);
auto
k_ptr
=
prepare_for_matrix_load
<
kHeadDim
>
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
);
auto
v_ptr
=
prepare_for_matrix_load
<
kHeadDimV
>
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
);
// attention 变体: Alibi
float
g_alibi
;
if
constexpr
(
Has_alibi
)
{
g_alibi
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
alibi_slopes_ptr
)[
bidb
*
params
.
alibi_slopes_batch_stride
+
bidh
]
/
params
.
scale_softmax
;
}
// attention 插件: Dropout
unsigned
long
long
rand_seed
,
rand_offset
;
uint32_t
p_dropout_in_8bits_value
;
union_vec2_uint
warp_idx_for_dropout
;
if
constexpr
(
Is_dropout
and
Is_training
)
{
rand_seed
=
params
.
rand_seed
;
rand_offset
=
params
.
rand_offset
+
((
bidb
*
params
.
h
+
bidh
)
<<
6
)
+
(
threadIdx
.
x
&
63
);
p_dropout_in_8bits_value
=
params
.
p_dropout_in_uint8_t
&
0xffffffff
;
warp_idx_for_dropout
.
u32
.
x
=
1
*
m_block
*
(
kBlockM
/
32
)
/* 前面几个 block 累积的 warp 数目, 这里不直接填 WARP_M, 参照 NV 的写法*/
+
warp_id
/*当前 block 内的 warp id*/
;
if
(
Is_training
and
m_block
==
0
and
bidb
==
0
and
bidh
==
0
and
threadIdx
.
x
==
0
)
{
params
.
rng_state
[
0
]
=
rand_seed
;
params
.
rng_state
[
1
]
=
rand_offset
;
}
}
// 预取 Q 的数据到寄存器
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
];
prefetch_q_to_vgpr_mls_ds_gfx92a
<
kHeadDim
,
kBlockM
,
kBlockK
,
WARP_M
,
Element
,
Is_even_MN
>
(
q_ptr
,
q_lds
,
q_reg
,
warp_id
,
seqlen_q_stride
,
Is_even_MN
?
0
:
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
// apply causal mask 的步骤和 no causal mask 的步骤分开算
// prefix prefill 目前没分开算, 明确边界的情况下也可以分开算, 性能会有提升
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
:
min
(
n_block_max
,
flash
::
ceil_div
(
kBlockM
,
kBlockN
)
+
1
);
// 是否做 prefetch K, PV 结束后, prefetch K 有风险
constexpr
bool
PREFETCH_K
=
Is_even_MN
and
kHeadDim
==
128
and
kHeadDimV
==
128
;
constexpr
bool
ALLOW_PREFETCH
=
(
STAGES
>
1
);
// 客观上决定是否开启 prefetch
if
constexpr
(
PREFETCH_K
and
ALLOW_PREFETCH
)
{
if
(
n_block_min
<
n_block_max
-
n_masking_steps
)
{
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
Is_even_MN
?
0
:
binfo
.
actual_seqlen_k
-
n_block_min
*
kBlockN
);
}
}
/***************************************************************************************************************************/
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
];
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
];
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)][
4
];
attention_initialize
<
kHeadDimV
/
kBlockK
,
WARP_M
/
32
,
kBlockK
/
32
,
2
/*M_MMAC_COUNT*/
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
/***************************************************************************************************************************/
auto
QK_GEMM_FUNC
=
&
qk_gemm_prefetch_v_mls_ds_gfx92a
<
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockN
,
kBlockK
,
WARP_M
,
WARP_N
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
auto
PV_GEMM_FUNC
=
&
pv_gemm_prefetch_k_mls_ds_gfx92a
<
PREFETCH_K
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
auto
PV_GEMM_FUNC_IN_MASK
=
&
pv_gemm_prefetch_k_mls_ds_gfx92a
<
false
,
kHeadDim
,
kHeadDimV
,
kBlockM
,
kBlockK
,
kBlockN
,
WARP_M
,
kBlockK
,
STAGES
,
Element
,
ElementAccum
,
Is_even_MN
>
;
// Mainloop, 主循环, 不做 causal mask 的部分
for
(
int
n_block_loop
=
n_block_min
;
n_block_loop
<
n_block_max
-
n_masking_steps
;
++
n_block_loop
)
{
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
;
int
warp_seqkv_limit
=
Is_even_MN
?
0
:
binfo
.
actual_seqlen_k
-
warp_offset_in_seqkv
;
// 预取 K 的数据到 lds
if
constexpr
(
not
PREFETCH_K
)
{
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
warp_seqkv_limit
);
}
// 准备 QK gemm 输出的寄存器
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
WARP_N
)][
4
];
// QK gemm
QK_GEMM_FUNC
(
k_ptr
,
v_ptr
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// Attention 变体 alibi
if
constexpr
(
Has_alibi
)
{
apply_alibi_gfx92a
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
g_alibi
);
}
// Attention: mask, causal mask, local mask
if
constexpr
(
Is_local
)
{
apply_mask_local_gfx92a
<
/*HasWSLeft=*/
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o
<
false
,
Is_causal
||
Is_local
,
vec4_Accum
<
ElementAccum
>
,
vec2_Accum
<
ElementAccum
>
,
kHeadDimV
,
kBlockK
,
WARP_M
,
kBlockN
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Attention 变体 dropout
if
constexpr
(
Is_dropout
and
Is_training
)
{
warp_idx_for_dropout
.
u32
.
y
=
n_block_loop
*
(
kBlockN
/
WARP_N
);
apply_dropout
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
,
kNWarps
,
Is_even_MN
>
(
s_reg
,
warp_seqkv_limit
,
0
,
rand_seed
,
rand_offset
,
p_dropout_in_8bits_value
,
warp_idx_for_dropout
,
params
.
dropout_debug_count
);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
];
convert_pk_type
<
WARP_M
,
kBlockN
,
Element
,
ElementAccum
>
(
p_reg
,
s_reg
);
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
const
int
block_table_idx_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
table_diff
=
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
];
const
int
offset_diff
=
block_table_offset_next
-
block_table_offset_cur
;
if
constexpr
(
PREFETCH_K
)
{
*
(
int64_t
*
)
&
k_ptr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
k_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
k_row_stride
))
*
sizeof
(
Element
);
}
// PV gemm
PV_GEMM_FUNC
(
v_ptr
,
k_ptr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
if
constexpr
(
not
PREFETCH_K
)
{
*
(
int64_t
*
)
&
k_ptr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
k_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
k_row_stride
))
*
sizeof
(
Element
);
}
*
(
int64_t
*
)
&
v_ptr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
v_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
v_row_stride
))
*
sizeof
(
Element
);
}
// prefetch K 的话, 最后一次多取了一段 K, 为了不影响后续的操作, 需要同步等待
if
constexpr
(
PREFETCH_K
)
{
buffer_load_lds_dwordx1_wait
<
0
>
();
}
/***************************************************************************************************************************/
// Rest loop, 做 causal mask 的部分
int
n_block_loop
=
max
(
n_block_max
-
n_masking_steps
,
n_block_min
);
// #pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
++
n_block_loop
)
{
// 计算当前 loop 下 seqlen_kv 的数据起始位置和边界
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
;
int
warp_seqkv_limit
=
Is_even_MN
?
0
:
binfo
.
actual_seqlen_k
-
warp_offset_in_seqkv
;
// 预取 K 的数据到 lds
if
constexpr
(
true
)
{
prefetch_k_to_lds_mls_ds
<
kHeadDim
,
kBlockN
,
kBlockK
,
WARP_NUM
,
WARP_N
,
Element
,
Is_even_MN
>
(
k_ptr
,
k_lds
,
warp_id
,
seqlen_k_stride
,
warp_seqkv_limit
);
}
// 准备 QK gemm 输出的寄存器
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
WARP_N
)][
4
];
// QK gemm
QK_GEMM_FUNC
(
k_ptr
,
v_ptr
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// Attention 变体 alibi
if
constexpr
(
Has_alibi
)
{
apply_alibi_gfx92a
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
g_alibi
);
}
// Attention: mask, causal mask, local mask
if
constexpr
(
!
Is_causal
&&
!
Is_local
)
{
if
constexpr
(
!
Is_even_MN
)
{
apply_mask_gfx92a
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_seqkv_limit
);
}
}
else
if
constexpr
(
Is_causal
)
{
apply_mask_causal_gfx92a
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
);
}
else
if
constexpr
(
Is_local
)
{
apply_mask_local_gfx92a
<
/*HasWSLeft=*/
Is_local
,
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
>
(
s_reg
,
warp_offset_in_seqkv
,
binfo
.
actual_seqlen_k
,
warp_offset_in_seq_q
,
binfo
.
actual_seqlen_q
,
params
.
window_size_left
,
params
.
window_size_right
);
}
// 对 QK 输出做 softmax, 以及重放缩 acc_o/scores_sum
softmax_rescale_o
<
false
,
Is_causal
||
Is_local
,
vec4_Accum
<
ElementAccum
>
,
vec2_Accum
<
ElementAccum
>
,
kHeadDimV
,
kBlockK
,
WARP_M
,
kBlockN
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
params
.
scale_softmax_log2
);
// Attention 变体 dropout
if
constexpr
(
Is_dropout
and
Is_training
)
{
warp_idx_for_dropout
.
u32
.
y
=
n_block_loop
*
(
kBlockN
/
WARP_N
);
apply_dropout
<
vec4_Accum
<
ElementAccum
>
,
WARP_M
,
kBlockN
,
kNWarps
,
Is_even_MN
>
(
s_reg
,
warp_seqkv_limit
,
0
,
rand_seed
,
rand_offset
,
p_dropout_in_8bits_value
,
warp_idx_for_dropout
,
params
.
dropout_debug_count
);
}
// softmax(QK) f32 -> f16
union_vec2_f16x2
<
Element
>
p_reg
[(
WARP_M
/
32
)
*
(
kBlockN
/
32
)][
4
];
convert_pk_type
<
WARP_M
,
kBlockN
,
Element
,
ElementAccum
>
(
p_reg
,
s_reg
);
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
const
int
block_table_idx_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
table_diff
=
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
];
const
int
offset_diff
=
block_table_offset_next
-
block_table_offset_cur
;
// PV gemm
PV_GEMM_FUNC_IN_MASK
(
v_ptr
,
k_ptr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
seqlen_k_stride
,
seqlen_v_stride
,
warp_seqkv_limit
);
// 偏移 V 指针
*
(
int64_t
*
)
&
k_ptr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
k_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
k_row_stride
))
*
sizeof
(
Element
);
*
(
int64_t
*
)
&
v_ptr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
v_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
v_row_stride
))
*
sizeof
(
Element
);
}
/**********************************************************************************************************************************/
// Epilogue: rescale acco
vec2_Accum
<
ElementAccum
>
lse
[
WARP_M
/
32
];
fwd_epilugue_rescale_acco
<
WARP_M
,
kBlockK
,
kHeadDimV
,
Is_dropout
&&
Is_training
,
ElementAccum
>
(
acc_o
,
lse
,
scores_max
,
scores_sum
,
params
.
scale_softmax
,
params
.
rp_dropout
);
/**************************************************************************************************************************************/
constexpr
bool
Is_Interleave
=
true
;
int
lane_id
=
threadIdx
.
x
&
63
;
if
(
params
.
softmax_lse_ptr
!=
nullptr
)
{
fwd_epilogue_store_lse
<
WARP_M
,
Is_even_MN
,
false
/*SplitD*/
,
Is_Interleave
,
ElementAccum
>
(
lse
,
params
.
softmax_lse_ptr
,
row_offset_lse
,
warp_id
,
lane_id
,
0
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
/**************************************************************************************************************************************/
Element
*
o_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
;
fwd_epilogue_store_output_mls_gfx92a
<
kHeadDimV
,
kBlockM
,
kBlockK
,
WARP_M
,
Is_even_MN
,
false
/*Is_Interleave*/
,
false
/*TcpSwizzle*/
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
warp_id
,
lane_id
,
seqlen_o_stride
,
binfo
.
actual_seqlen_q
);
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_prefix_prefill_gfx938_kernel
(
const
Params
params
)
{
#if defined(__gfx938__) || defined(__gfx946__)
const
int
bidh
=
blockIdx
.
x
;
const
int
bidb
=
blockIdx
.
y
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
int
m_block
=
gridDim
.
z
-
1
-
blockIdx
.
z
;
flash
::
compute_attn_prefix_prefill_1rowblock_gfx938
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_prefix_prefill_gfx92a_kernel
(
const
Params
params
)
{
#if defined(__gfx92a__)
const
int
bidh
=
blockIdx
.
x
;
const
int
bidb
=
blockIdx
.
y
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
int
m_block
=
gridDim
.
z
-
1
-
blockIdx
.
z
;
flash
::
compute_attn_prefix_prefill_1rowblock_gfx92a
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
#endif
}
...
...
csrc/flash_attn_hg/src/flash_fwd_b16_mla.h
View file @
518a5f4d
...
...
@@ -358,7 +358,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_tile16x32(const Params
/**********************************************************************************************************************************/
// 主循环, 沿着 seqlenKV 维度, 每次 4 个 wave 共同计算一个 kBLOCKN
const
int
n_block_min
=
0
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
const
int
n_block_max
=
(
Split
and
!
MLA_FIX_NUM_SPLITS
)
?
ceil_div
(
Partition_Size
,
kBlockN
)
:
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
#else
const
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
// temp workaround, unroll partition size for zd may lead to wrong results
...
...
@@ -451,15 +451,15 @@ inline __device__ void compute_attn_splitkv_mla(const Params ¶ms) {
////////////////////////////////////////////////////////////////////////////////////////////////////
#include "kvcache/gfx938/kvcache_qk_gemm_prefetch_v_gfx938.h"
#include "kvcache/gfx938/kvcache_pv_gemm_prefetch_k_gfx938.h"
#include "kvcache/gfx938/kvcache_softmax_gfx938.h"
#include "kvcache/gfx938/kvcache_epilogue_gfx938.h"
#include "kvcache/kvcache_acco_reduce_tile16x32.h"
#include "kvcache/kvcache_epilogue.h"
#include "mla/gfx938/fp8_mla_acco_reduce_gfx938.h"
#include "mla/gfx938/mla_tp8_qk_gemm_utils_gfx938.h"
#include "mla/gfx938/mla_tp8_epilogue_gfx938.h"
#include "mla/gfx938/f16_mla_tp8_qk_gemm_utils_gfx938.h"
#include "mla/gfx938/f16_mla_tp8_qk_gemm_gfx938.h"
#include "mla/gfx938/f16_mla_tp8_pv_gemm_gfx938.h"
// For FlashMLA, codes almostly copy codes from paged_attention with a few differences.
// Kernel codes listed below can be customized alone if neccessary.
// sgpr: 75, vgpr: 240 | base sgpr: 80, vgpr 254
...
...
@@ -546,7 +546,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
bool
is_thread0
=
threadIdx
.
x
==
0
;
if
(
is_thread0
)
{
inline_utcl2_warmup_dword
(
k_addr
);
//
inline_utcl2_warmup_dword(k_addr);
}
// splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum
...
...
@@ -584,11 +584,11 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
+
warp_id
*
WARP_N
;
int
warp_seqkv_limit
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
kvcache
_prefetch_k_to_lds_gfx938
<
kBlockK
,
WARP_N
,
Element
,
STAGES
,
WARP_NUM
>
(
k_addr
,
k_lds
,
warp_id
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
f16_mla_tp8
_prefetch_k_to_lds_gfx938
<
kBlockK
,
WARP_N
,
Element
,
STAGES
,
WARP_NUM
>
(
k_addr
,
k_lds
,
warp_id
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
vec4_Accum
<
ElementAccum
>
s_reg
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
];
kvcache_qk_gemm_prefetch_v
_gfx938
<
kHeadDim
,
kHeadDimVSplit
,
kBlockM
,
WARP_N
,
kBlockK
,
WARP_M
,
WARP_N
,
WARP_NUM
,
STAGES
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
f16_mla_tp8_qk_gemm
_gfx938
<
kHeadDim
,
kHeadDimVSplit
,
kBlockM
,
WARP_N
,
kBlockK
,
WARP_M
,
WARP_N
,
WARP_NUM
,
STAGES
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
q_addr
,
k_addr
,
v_addr
,
q_lds
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
kcache_seqlen_stride
,
vcache_seqlen_stride
,
warp_seqkv_limit
);
if
constexpr
(
Is_causal
)
{
...
...
@@ -603,7 +603,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
union_vec2_f16x2
<
Element
>
p_reg
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
];
mla_convert_pk_type
<
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
p_reg
,
s_reg
);
kvcache_pv_gemm_prefetch_k
_gfx938
<
K_LOOP_COUNT
,
kBlockM
,
kBlockK
,
kBlockN
,
M_WARP_COUNT
,
K_WARP_COUNT
/*kBlockK*/
,
N_WARP_COUNT
/*WARP_N*/
,
STAGES
,
WARP_NUM
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
f16_mla_tp8_pv_gemm
_gfx938
<
K_LOOP_COUNT
,
kBlockM
,
kBlockK
,
kBlockN
,
M_WARP_COUNT
,
K_WARP_COUNT
/*kBlockK*/
,
N_WARP_COUNT
/*WARP_N*/
,
STAGES
,
WARP_NUM
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
v_addr
,
k_addr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
vcache_seqlen_stride
,
warp_seqkv_limit
);
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
params
.
page_block_size
;
...
...
@@ -646,7 +646,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_mla_gfx938(const Params &p
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Split
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_mla_gfx938
(
const
Params
&
params
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
// The block index for the head.
const
int
bidh
=
Split
?
blockIdx
.
z
%
params
.
h
:
blockIdx
.
y
;
// batch x num_head, num_head first
...
...
@@ -786,7 +786,7 @@ inline __device__ void flash_fwd_mla_prefix_prefill_kernel_base(const Params par
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
bool
Is_prefix
,
bool
Is_causal
,
typename
Element
,
typename
ElementAccum
,
typename
Params
>
__global__
void
__launch_bounds__
(
512
,
1
)
flash_fwd_mla_prefix_prefill_fix_kernel
(
const
Params
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
int
q_blocks
=
params
.
q_blocks
;
for
(
int
loop
=
blockIdx
.
x
;
loop
<
params
.
total_blocks
;
loop
+=
params
.
cu_count
)
{
int
m_block
=
loop
%
q_blocks
;
...
...
@@ -818,7 +818,7 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefix_prefill_fix_kerne
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
bool
Is_prefix
,
bool
Is_causal
,
typename
Element
,
typename
ElementAccum
,
typename
Params
>
__global__
void
__launch_bounds__
(
512
,
1
)
flash_fwd_mla_fix_kernel
(
const
Params
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
int
q_blocks
=
params
.
q_blocks
;
for
(
int
loop
=
blockIdx
.
x
;
loop
<
params
.
total_blocks
;
loop
+=
params
.
cu_count
)
{
int
m_block
=
loop
%
q_blocks
;
...
...
@@ -976,7 +976,7 @@ inline __device__ void flash_fwd_mla_fast_prefix_prefill_kernel_base(const Param
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
bool
Is_prefix
,
bool
Is_causal
,
typename
Element
,
typename
ElementAccum
,
typename
Params
>
__global__
void
__launch_bounds__
(
512
,
1
)
flash_fwd_mla_prefix_prefill_kernel
(
const
Params
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
const
int
q_blocks
=
params
.
q_blocks
;
for
(
int
m_block
=
0
;
m_block
<
q_blocks
;
++
m_block
)
{
// 获取当前任务
...
...
@@ -996,7 +996,7 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefix_prefill_kernel(co
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockM
,
int
kBlockN
,
bool
Is_prefix
,
bool
Is_causal
,
typename
Element
,
typename
ElementAccum
,
typename
Params
>
__global__
void
__launch_bounds__
(
512
,
1
)
flash_fwd_mla_kernel
(
const
Params
params
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
// 获取当前任务
const
int
q_blocks
=
params
.
q_blocks
;
for
(
int
m_block
=
0
;
m_block
<
q_blocks
;
++
m_block
)
{
...
...
@@ -1132,22 +1132,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum
[
i
].
f32
[
0
]
=
0
;
}
uint64_t
pk_zero
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
;
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#if defined(__gfx936__)
acc_o
[
i
][
min_tile_n
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
0
);
acc_o
[
i
][
min_tile_n
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
0
);
#elif defined(__gfx938__)
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
].
u64
[
0
])
:
);
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
].
u64
[
1
])
:
);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
acc_o
[
i
][
min_tile_n
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
acc_o
[
i
][
min_tile_n
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
#else
acc_o
[
i
][
min_tile_n
].
f32
[
0
]
=
0
;
acc_o
[
i
][
min_tile_n
].
f32
[
1
]
=
0
;
...
...
@@ -1391,22 +1385,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_prefill_kernel_gfx938(co
scores_sum
[
i
].
f32
[
0
]
=
0
;
}
uint64_t
pk_zero
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
;
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#if defined(__gfx936__)
acc_o
[
i
][
min_tile_n
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
0
);
acc_o
[
i
][
min_tile_n
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
0
);
#elif defined(__gfx938__)
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
].
u64
[
0
])
:
);
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
].
u64
[
1
])
:
);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
acc_o
[
i
][
min_tile_n
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
acc_o
[
i
][
min_tile_n
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
#else
acc_o
[
i
][
min_tile_n
].
f32
[
0
]
=
0
;
acc_o
[
i
][
min_tile_n
].
f32
[
1
]
=
0
;
...
...
@@ -1662,22 +1650,16 @@ __global__ void __launch_bounds__(512, 1) flash_fwd_mla_decode_kernel_gfx938(con
scores_sum
[
i
].
f32
[
0
]
=
0
;
}
uint64_t
pk_zero
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
;
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#if defined(__gfx936__)
acc_o
[
i
][
min_tile_n
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
0
);
acc_o
[
i
][
min_tile_n
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
0
);
#elif defined(__gfx938__)
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
].
u64
[
0
])
:
);
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
].
u64
[
1
])
:
);
#if defined(__gfx936__) || defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
acc_o
[
i
][
min_tile_n
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
acc_o
[
i
][
min_tile_n
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
#else
acc_o
[
i
][
min_tile_n
].
f32
[
0
]
=
0
;
acc_o
[
i
][
min_tile_n
].
f32
[
1
]
=
0
;
...
...
csrc/flash_attn_hg/src/flash_fwd_b16_pa.h
100644 → 100755
View file @
518a5f4d
...
...
@@ -89,7 +89,7 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv(const Params ¶ms,
const
int64_t
row_offset_q
=
Is_Varlen
?
binfo
.
sum_s_q
*
ngroups
*
query_seqlen_stride
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
query_seqlen_stride
:
bidb
*
int64_t
(
params
.
q_batch_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
query_seqlen_stride
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
constexpr
bool
USE_CACHE_SWIZZLE
=
false
;
#else
constexpr
bool
USE_CACHE_SWIZZLE
=
true
;
// for gfx928, cache swizzle have significant influence
...
...
@@ -292,7 +292,7 @@ inline __device__ void compute_attn_splitkv(const Params ¶ms) {
#include "kvcache/kvcache_softmax_tile16x32.h"
#include "kvcache/kvcache_acco_reduce_tile16x32.h"
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv_tile16x32
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
...
...
@@ -385,7 +385,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &pa
:
bidb
*
int64_t
(
params
.
q_batch_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
query_seqlen_stride
;
// 准备读取数据的 buffer resource 寄存器
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
constexpr
bool
USE_CACHE_SWIZZLE
=
false
;
#else
constexpr
bool
USE_CACHE_SWIZZLE
=
true
;
// for gfx928, cache swizzle have significant influence
...
...
@@ -497,7 +497,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &pa
int
lane_id
=
thread_id
&
63
;
if
constexpr
(
WARP_NUM
>
1
)
{
int
reduced_q_len
=
Is_Varlen
?
params
.
seqlen_q
:
actual_seqlen_q
;
kvcache_acco_reduce_tile16x32
<
REUSE_KV_TIMES
,
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
WARP_NUM
,
4
/*Padding*/
,
ElementAccum
>
(
acc_o
,
acc_o_lds
,
reduced_q_len
,
warp_id
,
lane_id
);
kvcache_acco_reduce_tile16x32
<
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
WARP_NUM
,
4
/*Padding*/
,
ElementAccum
>
(
acc_o
,
acc_o_lds
,
reduced_q_len
,
warp_id
,
lane_id
);
}
/**********************************************************************************************************************************/
...
...
@@ -525,7 +525,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_tile16x32(const Params &pa
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_tile16x32
(
const
Params
&
params
)
{
// The block index for the head.
...
...
@@ -537,7 +537,7 @@ inline __device__ void compute_attn_splitkv_tile16x32(const Params ¶ms) {
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
flash
::
compute_attn_1rowblock_splitkv_tile16x32
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
*
128
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
flash
::
compute_attn_1rowblock_splitkv_tile16x32
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
*
128
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
}
...
...
@@ -550,7 +550,7 @@ inline __device__ void compute_attn_splitkv_tile16x32(const Params ¶ms) {
#include "kvcache/gfx938/kvcache_softmax_gfx938.h"
#include "kvcache/gfx938/kvcache_epilogue_gfx938.h"
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv_gfx938
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
...
...
@@ -574,9 +574,9 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params ¶m
binfo
.
set_params
<
Params
,
/*Is_Q_varlen=*/
Is_Varlen
,
/*Is_K_Cumulative=*/
false
>
(
params
,
bidb
);
// splitKV, 根据 split id 确定当前 split 在 seqlen_kv 上处理的长度
int
split_id
;
int
split_id
=
0
;
int
original_actual_seqlen_k
=
binfo
.
actual_seqlen_k
;
int
partition_size
;
int
partition_size
=
0
;
if
constexpr
(
Split
)
{
split_id
=
blockIdx
.
y
;
if
constexpr
(
Is_Varlen
)
{
...
...
@@ -642,10 +642,12 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params ¶m
bool
is_thread0
=
threadIdx
.
x
==
0
;
if
(
is_thread0
)
{
inline_utcl2_warmup_dword
(
q_addr
);
inline_utcl2_warmup_dword
(
k_addr
);
inline_utcl2_warmup_dword
(
v_addr
);
//
inline_utcl2_warmup_dword(q_addr);
//
inline_utcl2_warmup_dword(k_addr);
//
inline_utcl2_warmup_dword(v_addr);
}
// Keep warmup buffer loads out of the MLS vmcnt schedule below.
flash
::
wait_all_buffer_data_arrived
<
true
>
();
// splitkv, debug 场景下需要写出一些值, 例如 scores_max/scores_sum
int
row_offset_lse
;
...
...
@@ -746,12 +748,17 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params ¶m
int
thread_id
=
threadIdx
.
x
;
int
lane_id
=
thread_id
&
63
;
if
constexpr
(
WARP_NUM
>
1
)
{
kvcache_acco_reduce_tile16x32
<
REUSE_KV_TIMES
,
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
WARP_NUM
,
0
/*Padding*/
,
ElementAccum
>
(
acc_o
,
acc_o_lds
,
params
.
seqlen_q
,
warp_id
,
lane_id
);
kvcache_acco_reduce_tile16x32
<
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
WARP_NUM
,
0
/*Padding*/
,
ElementAccum
>
(
acc_o
,
acc_o_lds
,
params
.
seqlen_q
,
warp_id
,
lane_id
);
}
/**********************************************************************************************************************************/
// Epilogue, 收尾工作
// 收尾 1: 根据最后的归一化求和, 做 rescale
if
(
params
.
s_aux_ptr
!=
nullptr
&&
split_id
==
0
)
{
fp8_kvcache_apply_attention_sink_gfx938
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
s_aux_ptr
,
params
.
s_aux_type
,
bidh
,
params
.
h
,
ngroups
,
m_block
,
kBlockM
,
lane_id
,
params
.
scale_softmax
);
}
kvcache_epilugue_rescale_acco
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
acc_o
,
scores_sum
);
// 收尾 2: splitkv, 或者开启 debug 的情况下, 写出 scores_max, scores_sum
...
...
@@ -776,10 +783,10 @@ inline __device__ void compute_attn_1rowblock_splitkv_gfx938(const Params ¶m
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_gfx938
(
const
Params
&
params
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
// The block index for the head.
const
int
bidh
=
Split
?
blockIdx
.
z
%
params
.
h
:
blockIdx
.
y
;
// batch x num_head, num_head first
...
...
@@ -789,11 +796,272 @@ inline __device__ void compute_attn_splitkv_gfx938(const Params ¶ms) {
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
flash
::
compute_attn_1rowblock_splitkv_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
*
128
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
flash
::
compute_attn_1rowblock_splitkv_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
*
128
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// MLS-based Paged Attention, gfx92a
////////////////////////////////////////////////////////////////////////////////////////////////////
#include "kvcache/gfx92a/f16_kvcache_gfx92a.h"
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Is_monopolize
,
bool
Split
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv_gfx92a
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
using
SplitkvAccumType
=
typename
Kernel_traits
::
SplitkvAccumType
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockK
=
Kernel_traits
::
kBlockK
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDimV
=
Kernel_traits
::
kHeadDimV
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
WARP_M
=
Kernel_traits
::
kWaveM
;
constexpr
int
WARP_N
=
Kernel_traits
::
kWaveN
;
constexpr
int
STAGES
=
Kernel_traits
::
STAGES
;
constexpr
int
WARP_NUM
=
kBlockN
/
WARP_N
;
constexpr
int
kHeadDimVSplit
=
kHeadDimV
/
HEADDIM_V_SPLIT
;
// flash::BlockInfo</*Varlen=*/true, /*Is_Kvcache*/true> binfo(params, bidb);
flash
::
SafeDecodeBlockInfo
binfo
;
binfo
.
set_params
<
Params
,
/*Is_Q_varlen=*/
Is_Varlen
,
/*Is_K_Cumulative=*/
false
>
(
params
,
bidb
);
// SplitKV processing
int
split_id
;
int
original_actual_seqlen_k
=
binfo
.
actual_seqlen_k
;
int
partition_size
;
if
constexpr
(
Split
)
{
split_id
=
blockIdx
.
y
;
if
constexpr
(
Is_Varlen
)
{
partition_size
=
splitkv_get_partitionsize_of_fix_numsplits
(
binfo
.
actual_seqlen_k
,
params
.
num_splits
);
binfo
.
actual_seqlen_k
=
min
(
binfo
.
actual_seqlen_k
-
split_id
*
partition_size
,
partition_size
);
}
else
{
partition_size
=
params
.
partition_size
;
int
num_splits
=
max
(
1
,
floor_div
(
binfo
.
actual_seqlen_k
,
partition_size
));
binfo
.
actual_seqlen_k
=
(
split_id
==
num_splits
-
1
)
?
binfo
.
actual_seqlen_k
-
split_id
*
partition_size
:
partition_size
;
binfo
.
actual_seqlen_k
=
(
split_id
>=
num_splits
)
?
0
:
binfo
.
actual_seqlen_k
;
if
(
split_id
>=
num_splits
)
return
;
}
}
// acquire TG id
int
block_x
=
blockIdx
.
x
;
const
int
m_block
=
block_x
/
HEADDIM_V_SPLIT
;
const
int
headdim_split_id
=
block_x
&
(
HEADDIM_V_SPLIT
-
1
);
// Compute seqQ
int
ngroups
,
actual_seqlen_q
;
if
constexpr
(
Is_Varlen
)
{
ngroups
=
params
.
ngroups
;
actual_seqlen_q
=
binfo
.
actual_seqlen_q
*
ngroups
;
}
else
{
actual_seqlen_q
=
binfo
.
actual_seqlen_q
;
}
// Running boundaries
if
(
m_block
*
kBlockM
>=
actual_seqlen_q
||
binfo
.
actual_seqlen_k
<=
0
)
return
;
// Decide lsa usage
extern
__shared__
Element
smem
[];
Element
*
q_lds
=
reinterpret_cast
<
Element
*>
(
smem
);
Element
*
k_lds
=
reinterpret_cast
<
Element
*>
(
smem
);
Element
*
v_lds
=
Is_monopolize
?
k_lds
+
16384
:
k_lds
;
ElementAccum
*
acc_o_lds
=
reinterpret_cast
<
ElementAccum
*>
(
smem
);
ElementAccum
*
max_lds
=
acc_o_lds
+
1024
/*from 4096 bytes*/
;
// Acquire stride along seq dimension of q/k/v
int
query_seqlen_stride
=
params
.
q_row_stride
;
int
kcache_seqlen_stride
=
params
.
k_row_stride
;
int
vcache_seqlen_stride
=
params
.
v_row_stride
;
// Compute q and k/v block table address
int
page_block_size
=
params
.
page_block_size
;
int
this_split_seqlen_start
=
Split
?
split_id
*
partition_size
:
0
;
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
block_table
=
block_table
+
(
Split
?
ceil_div
(
this_split_seqlen_start
,
page_block_size
)
:
0
);
const
int
block_table_idx
=
0
;
const
int
block_table_offset
=
0
;
const
int64_t
row_offset_k
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
k_batch_stride
)
+
block_table_offset
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
int64_t
row_offset_v
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
v_batch_stride
)
+
block_table_offset
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
int64_t
row_offset_q
=
Is_Varlen
?
binfo
.
sum_s_q
*
ngroups
*
query_seqlen_stride
+
bidh
*
ngroups
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
query_seqlen_stride
:
bidb
*
int64_t
(
params
.
q_batch_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
query_seqlen_stride
;
// Prepare buffer resource for q/k/v
Element
*
q_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
;
auto
q_addr
=
prepare_for_buffer_load
<
kHeadDim
,
Element
,
false
>
(
q_ptr
);
auto
k_addr
=
prepare_for_buffer_load
<
kHeadDim
,
Element
,
false
>
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
);
auto
v_addr
=
prepare_for_buffer_load
<
kHeadDimV
,
Element
,
false
>
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
+
headdim_split_id
*
kHeadDimVSplit
);
// utcl2 warmup
if
constexpr
(
false
)
{
bool
is_thread0
=
threadIdx
.
x
==
0
;
if
(
is_thread0
)
{
inline_utcl2_warmup_dword
(
q_addr
);
inline_utcl2_warmup_dword
(
k_addr
);
inline_utcl2_warmup_dword
(
v_addr
);
}
}
// Compute lse/max/sum pointers of this TG
int
row_offset_lse
;
ElementAccum
*
scores_sum_ptr
;
ElementAccum
*
scores_max_ptr
;
ElementAccum
*
softmax_lse_ptr
;
if
constexpr
(
Split
)
{
int
row_offset_scores_split
;
if
constexpr
(
Is_Varlen
)
{
row_offset_lse
=
bidh
*
ngroups
*
params
.
total_q
+
binfo
.
sum_s_q
+
m_block
*
kBlockM
;
row_offset_scores_split
=
split_id
*
(
params
.
h
*
ngroups
*
params
.
total_q
);
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
+
row_offset_scores_split
;
}
else
{
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
row_offset_scores_split
=
split_id
*
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
);
scores_sum_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
scores_sum_ptr
)
+
row_offset_lse
+
row_offset_scores_split
;
scores_max_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
scores_max_ptr
)
+
row_offset_lse
+
row_offset_scores_split
;
}
}
else
{
if
constexpr
(
Is_Varlen
)
{
row_offset_lse
=
bidh
*
ngroups
*
params
.
total_q
+
binfo
.
sum_s_q
+
m_block
*
kBlockM
;
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
;
}
else
{
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
;
}
}
// hold q regs
constexpr
int
M_WARP_COUNT
=
WARP_M
/
32
;
constexpr
int
K_WARP_COUNT
=
kBlockK
/
32
;
constexpr
int
N_WARP_COUNT
=
WARP_N
/
32
;
constexpr
int
K_LOOP_COUNT
=
kHeadDimVSplit
/
kBlockK
;
constexpr
int
Q_LOAD_BLOCKS
=
STAGES
==
2
?
(
kHeadDim
/
kBlockK
)
:
1
;
union_vec4_f16x2
<
Element
>
q_reg
[
Q_LOAD_BLOCKS
*
M_WARP_COUNT
*
K_WARP_COUNT
*
2
];
// prefetch Q into vgprs, can be hide
gfx92a
::
kvcache_prefetch_q_to_vgpr
<
Is_Varlen
,
kHeadDim
,
kBlockK
,
WARP_M
,
WARP_NUM
,
M_MMAC_COUNT
,
Element
>
(
q_ptr
,
q_lds
,
q_reg
,
warp_id
,
query_seqlen_stride
,
params
.
q_head_stride
,
ngroups
,
actual_seqlen_q
-
m_block
*
kBlockM
);
// Initialize, scores_max/scores_max/acc_o
vec2_Accum
<
ElementAccum
>
scores_max
[
M_WARP_COUNT
];
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
];
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
];
attention_initialize
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
// Mainloop, along seqlenkv dimension, 4 warps computes attention of a kBlockN
const
int
n_block_min
=
0
;
const
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
int
n_block_loop
=
n_block_min
;
for
(;
n_block_loop
<
n_block_max
;
++
n_block_loop
)
{
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
+
warp_id
*
WARP_N
;
int
warp_seqkv_limit
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
constexpr
int
prefetchKLevel
=
4
;
constexpr
int
prefetchVLevel
=
Is_monopolize
?
4
:
2
;
constexpr
bool
prefetchK
=
Is_monopolize
;
if
constexpr
(
prefetchK
)
{
if
(
n_block_loop
==
n_block_min
)
gfx92a
::
kvcache_prefetch_k_to_lds
<
kBlockK
,
WARP_N
,
prefetchKLevel
,
Element
>
(
k_addr
,
k_lds
,
warp_id
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
}
else
{
gfx92a
::
kvcache_prefetch_k_to_lds
<
kBlockK
,
WARP_N
,
prefetchKLevel
,
Element
>
(
k_addr
,
k_lds
,
warp_id
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
}
vec4_Accum
<
ElementAccum
>
s_reg
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
];
gfx92a
::
kvcache_qk_gemm_prefetch_v
<
kHeadDim
,
kHeadDimVSplit
,
WARP_N
,
kBlockK
,
WARP_M
,
WARP_N
,
prefetchKLevel
,
prefetchVLevel
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
k_addr
,
v_addr
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
kcache_seqlen_stride
,
vcache_seqlen_stride
,
warp_seqkv_limit
);
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
const
int
block_table_idx_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
offset_diff
=
block_table_offset_next
-
block_table_offset_cur
;
int
table_diff
;
int
table_cur
,
table_next
;
if
constexpr
(
prefetchK
)
{
inline_global_load_dwordx1
(
table_cur
,
block_table_idx_cur
,
block_table
);
inline_global_load_dwordx1
(
table_next
,
block_table_idx_next
,
block_table
);
}
gfx92a
::
kvcache_prefetch_v_to_lds
<
kHeadDimV
,
kBlockK
,
kBlockK
,
STAGES
,
prefetchVLevel
,
Element
>
(
v_addr
,
v_lds
,
warp_id
,
vcache_seqlen_stride
,
warp_seqkv_limit
);
if
constexpr
(
Is_causal
)
{
gfx92a
::
kvcache_apply_mask_causal
<
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
,
Is_Varlen
>
(
s_reg
,
warp_offset_in_seqkv
+
this_split_seqlen_start
,
original_actual_seqlen_k
,
m_block
*
kBlockM
,
actual_seqlen_q
,
ngroups
,
params
.
mtp
,
params
.
layout
);
}
else
{
gfx92a
::
kvcache_apply_mask
<
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
>
(
s_reg
,
warp_seqkv_limit
,
warp_id
*
WARP_N
);
}
mla_softmax_rescale_o
<
Is_causal
,
ElementAccum
,
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
N_WARP_COUNT
,
WARP_NUM
,
M_MMAC_COUNT
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
max_lds
,
warp_id
,
params
.
scale_softmax_log2
);
union_vec2_f16x2
<
Element
>
p_reg
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
];
gfx92a
::
convert_attn_f32_to_f16
<
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
s_reg
,
p_reg
);
if
constexpr
(
prefetchK
)
{
flash
::
wait_buffer_data_arrived
<
true
/*can be false*/
>
(
prefetchVLevel
/*4 for hdim 128*/
);
table_diff
=
__builtin_amdgcn_readfirstlane
(
table_next
-
table_cur
);
}
else
{
table_diff
=
__builtin_amdgcn_readfirstlane
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
]);
}
*
(
int64_t
*
)
&
k_addr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
k_batch_stride
)
+
offset_diff
*
params
.
k_row_stride
)
*
sizeof
(
Element
);
gfx92a
::
kvcache_pv_gemm_prefetch_k
<
prefetchK
,
K_LOOP_COUNT
,
kBlockK
,
kBlockN
,
M_WARP_COUNT
,
K_WARP_COUNT
,
N_WARP_COUNT
,
STAGES
,
prefetchKLevel
,
prefetchVLevel
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
v_addr
,
k_addr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
vcache_seqlen_stride
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
*
(
int64_t
*
)
&
v_addr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
v_batch_stride
)
+
offset_diff
*
params
.
v_row_stride
)
*
sizeof
(
Element
);
}
// reduce pv results among 4 warps
int
thread_id
=
threadIdx
.
x
;
int
lane_id
=
thread_id
&
63
;
if
constexpr
(
WARP_NUM
>
1
)
{
gfx92a
::
kvcache_acco_reduce_tile16x32
<
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
WARP_NUM
,
ElementAccum
>
(
acc_o
,
acc_o_lds
,
params
.
seqlen_q
,
warp_id
,
lane_id
);
}
/**********************************************************************************************************************************/
// Epilogue 1: rescaling acc_o
kvcache_epilugue_rescale_acco
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
acc_o
,
scores_sum
);
// Epilogue 2: store lse / max / sum for splitkv reduction
if
constexpr
(
Is_Varlen
)
{
kvcache_epilogue_store_softmax_lse
<
Is_Varlen
,
true
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
softmax_lse_ptr
,
params
.
scale_softmax
,
warp_id
,
thread_id
,
lane_id
,
headdim_split_id
,
actual_seqlen_q
-
m_block
*
kBlockM
,
params
.
total_q
,
params
.
ngroups
);
}
else
{
kvcache_epilogue_store_max_sum
<
Split
,
true
/*Is_16x32*/
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
scores_max_ptr
,
scores_sum_ptr
,
params
.
scale_softmax
,
warp_id
,
thread_id
,
lane_id
,
headdim_split_id
,
actual_seqlen_q
-
m_block
*
kBlockM
);
}
// Epilogue 3: store acc_o into global memory
int64_t
row_offset_o
=
Is_Varlen
?
binfo
.
sum_s_q
*
ngroups
*
int64_t
(
params
.
o_row_stride
)
+
bidh
*
ngroups
*
params
.
o_head_stride
+
headdim_split_id
*
kHeadDimVSplit
+
(
Split
?
split_id
*
params
.
ngroups
*
int64_t
(
params
.
total_q
)
*
params
.
o_row_stride
:
0
)
:
bidb
*
int64_t
(
params
.
o_batch_stride
)
+
bidh
*
params
.
o_head_stride
+
headdim_split_id
*
kHeadDimVSplit
+
(
Split
?
split_id
*
params
.
b
*
params
.
o_batch_stride
:
0
);
gfx92a
::
kvcache_varlen_epilogue_store_output
<
Is_Varlen
,
Split
,
kBlockK
,
WARP_NUM
,
K_LOOP_COUNT
,
M_MMAC_COUNT
,
SplitkvAccumType
,
ElementAccum
>
(
acc_o
,
params
,
row_offset_o
,
actual_seqlen_q
-
m_block
*
kBlockM
,
warp_id
,
lane_id
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Is_monopolize
,
bool
Split
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_gfx92a
(
const
Params
&
params
)
{
#if defined(__gfx92a__)
// The block index for the head.
const
int
bidh
=
Split
?
blockIdx
.
z
%
params
.
h
:
blockIdx
.
y
;
// batch x num_head, num_head first
// The block index for the batch.
const
int
bidb
=
Split
?
blockIdx
.
z
/
params
.
h
:
blockIdx
.
z
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
flash
::
compute_attn_1rowblock_splitkv_gfx92a
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Is_monopolize
,
Split
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FMA-based Paged Attention
////////////////////////////////////////////////////////////////////////////////////////////////////
...
...
@@ -864,7 +1132,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_fma_kernel(Param
ElementAccum
scores_sum
=
0
;
// 准备必要的 lds
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
__shared__
ElementAccum
lds
[
4096
];
// 16384 bytes, allow 4 waves per simd
#else
__shared__
ElementAccum
lds
[
16384
];
// 65536 bytes, allow 1 waves per simd for zd
...
...
@@ -1095,7 +1363,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_mha_kernel(Params pa
ElementAccum
scores_sum
=
0
;
// 准备必要的 lds
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
__shared__
ElementAccum
lds
[
4096
];
// 16384 bytes, allow 4 waves per simd
#else
__shared__
ElementAccum
lds
[
16384
];
// 65536 bytes, allow 1 waves per simd for zd
...
...
csrc/flash_attn_hg/src/flash_fwd_b8_fa.h
100644 → 100755
View file @
518a5f4d
...
...
@@ -284,4 +284,576 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_int8_prefix_prefill_kernel(c
compute_attn_int8_prefix_prefill_1rowblock
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// GFX938 kernels
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
FP8_DEBUG
,
bool
Is_even_MN
,
int
kBlockM
,
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
typename
Element
>
__forceinline__
__device__
void
fp8_debug_p_reg
(
Element
*
p_reg_ptr
,
union_vec32_fp8
p_reg
[
WARP_M
/
16
],
int
bidb
,
int
bidh
,
int
h
,
int
actual_seqlen_q
,
int
actual_seqlen_k
,
int
max_seq_q_offset
,
int
max_seq_kv_offset
,
int
m_block
,
int
n_block_loop
,
int
warp_id
,
int
lane_id
)
{
if
constexpr
(
FP8_DEBUG
)
{
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
FP8_DEBUG
)
{
Element
*
p_reg_buffer
=
p_reg_ptr
+
(
bidb
*
h
+
bidh
)
*
actual_seqlen_q
*
actual_seqlen_k
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kBlockN
/
WARP_N
;
++
k_loop
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
int
row_pos
=
m_block
*
kBlockM
+
warp_id
*
WARP_M
+
((
lane_id
&
15
)
>>
2
)
*
8
+
m_idx
*
4
+
(
lane_id
&
3
);
int
col_pos
=
(
lane_id
>>
4
)
*
8
+
n_idx
*
4
+
k_loop
*
WARP_N
+
n_block_loop
*
kBlockN
;
*
(
int32_t
*
)(
p_reg_buffer
+
row_pos
*
actual_seqlen_k
+
col_pos
)
=
p_reg
[
m_idx
].
i32
[
k_loop
*
2
+
n_idx
];
}
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_waitcnt
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
#include "fwd/gfx938/fp8_qk_gemm_prefetch_v_mls_ds.h"
#include "fwd/gfx938/fp8_pv_gemm_prefetch_k_mls_ds.h"
#include "fwd/gfx938/fp8_softmax_gfx938.h"
#include "fwd/gfx938/fp8_epilogue.h"
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_fp8_attn_mha_1rowblock_gfx938
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element_k
=
typename
Kernel_traits
::
Element_k
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
index_t
=
typename
Kernel_traits
::
index_t
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockK
=
Kernel_traits
::
kBlockK
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDimV
=
Kernel_traits
::
kHeadDimV
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
WARP_M
=
Kernel_traits
::
kWaveM
;
constexpr
int
WARP_N
=
Kernel_traits
::
kWaveN
;
constexpr
int
WARP_K
=
32
;
constexpr
int
STAGES
=
Kernel_traits
::
STAGES
;
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
// 获取当前 TG 处理的任务大小
const
flash
::
BlockInfo
<
/*Varlen=*/
Is_Varlen
>
binfo
(
params
,
bidb
);
// 判断任务边界
int
max_seq_q_offset
=
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
<=
0
/* || bidh >= h*/
)
return
;
// 获取 wave id
// int __warp_id = threadIdx.x >> 6;
// int warp_id = __builtin_amdgcn_readfirstlane(__warp_id);
// 定义 lds, 128x128 个 fp8, 16384 bytes
// __shared__ int8_t lds[16384 + 4096 + 16384 + 4096];
extern
__shared__
int8_t
lds
[];
int8_t
*
q_lds
=
lds
+
0
;
int8_t
*
k_lds
=
lds
+
0
;
int8_t
*
v_lds
=
lds
+
0
;
// ========================================== 计算 offset ===========================================
int64_t
row_offset_q
,
row_offset_k
,
row_offset_v
,
row_offset_o
;
int64_t
row_offset_lse_base
;
if
constexpr
(
Is_Varlen
)
{
if
constexpr
(
Layout
==
1
)
{
/* bshd: q/o are [total_q, h, d] */
row_offset_q
=
(
int64_t
(
binfo
.
sum_s_q
)
+
m_block
*
kBlockM
)
*
int64_t
(
params
.
q_row_stride
)
+
params
.
q_head_stride
*
bidh
;
row_offset_k
=
int64_t
(
binfo
.
sum_s_k
)
*
int64_t
(
params
.
k_row_stride
)
+
int
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
row_offset_v
=
int64_t
(
binfo
.
sum_s_k
)
*
int64_t
(
params
.
v_row_stride
)
+
int
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
row_offset_o
=
int64_t
(
binfo
.
sum_s_q
)
*
int64_t
(
params
.
o_head_stride
)
*
params
.
h
+
params
.
o_head_stride
*
bidh
+
m_block
*
kBlockM
*
int64_t
(
params
.
o_row_stride
);
row_offset_lse_base
=
bidh
*
int64_t
(
params
.
total_q
)
+
binfo
.
sum_s_q
;
}
else
{
/* bhsd */
row_offset_q
=
int64_t
(
binfo
.
sum_s_q
)
*
int64_t
(
params
.
q_row_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
int64_t
(
params
.
q_row_stride
);
row_offset_k
=
int64_t
(
binfo
.
sum_s_k
)
*
int64_t
(
params
.
k_row_stride
)
+
int
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
row_offset_v
=
int64_t
(
binfo
.
sum_s_k
)
*
int64_t
(
params
.
v_row_stride
)
+
int
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
row_offset_o
=
int64_t
(
binfo
.
sum_s_q
)
*
int64_t
(
params
.
o_row_stride
)
+
bidh
*
params
.
o_head_stride
+
m_block
*
kBlockM
*
int64_t
(
params
.
o_row_stride
);
row_offset_lse_base
=
bidh
*
int64_t
(
params
.
total_q
)
+
binfo
.
sum_s_q
;
}
}
else
{
row_offset_q
=
bidb
*
int64_t
(
params
.
q_batch_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
int64_t
(
params
.
q_row_stride
);
row_offset_k
=
bidb
*
int64_t
(
params
.
k_batch_stride
)
+
int
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
row_offset_v
=
bidb
*
int64_t
(
params
.
v_batch_stride
)
+
int
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
row_offset_o
=
bidb
*
int64_t
(
params
.
o_batch_stride
)
+
bidh
*
params
.
o_head_stride
+
m_block
*
kBlockM
*
int64_t
(
params
.
o_row_stride
);
row_offset_lse_base
=
(
bidb
*
params
.
h
+
bidh
)
*
int64_t
(
binfo
.
actual_seqlen_q
);
}
Element_k
*
q_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
q_ptr
)
+
row_offset_q
;
Element_k
*
k_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
k_ptr
)
+
row_offset_k
;
Element_k
*
v_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
v_ptr
)
+
row_offset_v
;
ElementAccum
*
q_descale_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
q_descale_ptr
);
ElementAccum
*
k_descale_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
k_descale_ptr
);
ElementAccum
*
v_descale_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
v_descale_ptr
);
ElementAccum
q_descale
=
q_descale_ptr
[
0
];
ElementAccum
k_descale
=
k_descale_ptr
[
0
];
ElementAccum
qk_descale
=
q_descale
*
k_descale
;
ElementAccum
softmax_scale
=
params
.
scale_softmax
*
qk_descale
;
ElementAccum
softmax_scale_log2
=
params
.
scale_softmax_log2
*
qk_descale
;
ElementAccum
v_descale
=
v_descale_ptr
[
0
];
// acc_o_ptr = reinterpret_cast<ElementAccum*>(acc_o_ptr) + row_offset_o;
ElementAccum
*
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
);
Element_k
*
p_reg_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
p_ptr
);
Element
*
o_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
;
// ======================================================== 读取 Q ======================================================================
fp8_prefetch_q_to_lds
<
Is_even_MN
,
kHeadDim
,
WARP_M
,
Element_k
>
(
q_ptr
,
q_lds
,
warp_id
,
params
.
q_row_stride
,
max_seq_q_offset
);
// 计算解决 bank 冲突必须的一些变量
int
tx
=
threadIdx
.
x
;
int
lane_id
=
tx
&
63
;
// 准备存储最大值, 求和, acc_o 寄存器 等
ElementAccum
scores_max
[
WARP_M
/
16
];
ElementAccum
scores_sum
[
WARP_M
/
16
];
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDimV
/
32
][
WARP_M
/
16
][
WARP_N
/
16
];
fp8_attention_initialize
<
kHeadDimV
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
// 从 lds 读取 q 的数据, 不需要同步
union_vec16_fp8
q_regs
[
WARP_M
/
16
][
kHeadDim
/
64
];
load_q_from_lds_to_vgpr
<
kHeadDim
,
WARP_M
,
Element_k
>
(
q_regs
,
q_lds
,
warp_id
,
lane_id
);
// ======================================================== Prefetch K ======================================================================
fp8_prefetch_k_to_lds
<
Is_even_MN
,
kHeadDim
,
WARP_N
,
Element_k
>
(
k_ptr
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
binfo
.
actual_seqlen_k
);
// ======================================================== Mainloop ======================================================================
// 计算当前 block 计算任务的边界,带 causal mask 的场景可以少计算一些
int
n_block_min
=
0
;
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
constexpr
(
Is_causal
)
{
n_block_max
=
std
::
min
(
n_block_max
,
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
0
/*params.window_size_right*/
,
kBlockN
));
}
constexpr
int
n_masking_steps
=
(
!
Is_causal
/* && !Is_local*/
)
?
1
:
ceil_div
(
kBlockM
,
kBlockN
);
// 目前的场景可能需要限制 kBlockM == kBlockN, 主要是考虑到 prefetch K 的数据正确性
constexpr
bool
Assume_valid_rows
=
!
Is_local
&&
(
!
Is_causal
||
!
Is_Varlen
);
for
(
int
n_block_loop
=
n_block_min
;
n_block_loop
<
n_block_max
-
n_masking_steps
;
++
n_block_loop
)
{
// 计算 kv 的边界
int
max_seq_kv_offset
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
// ======================================================== QK gemm ======================================================================
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
];
fp8_qk_gemm
<
kBlockN
,
kHeadDim
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
q_regs
,
k_lds
);
// ========================================== load V ================================================
fp8_prefetch_v_to_lds
<
Is_even_MN
,
kBlockN
,
kHeadDimV
,
WARP_N
,
Element_k
>
(
v_ptr
,
v_lds
,
warp_id
,
params
.
v_row_stride
,
max_seq_kv_offset
);
// ======================================================== s_reg ======================================================================
// fp8_debug_s_reg<FP8_DEBUG, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, ElementAccum>(
// s_reg_ptr, s_reg, bidb, bidh, h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ======================================================== Softmax ======================================================================
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDimV
/
32
];
fp8_softmax_and_schedule_v
<
Assume_valid_rows
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
WARP_K
,
Element_k
,
ElementAccum
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
softmax_scale_log2
,
v_regs
,
v_lds
);
// ========================================== cvt ===============================================
union_vec32_fp8
p_reg
[
WARP_M
/
16
];
fp8_cvt_f32_to_fp8
<
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
p_reg
);
// ======================================================== p_reg ======================================================================
// fp8_debug_p_reg<1, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, Element_k>(
// p_reg_ptr, p_reg, bidb, bidh, params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ========================================== PV mmac ================================================
fp8_pv_gemm_and_prefetch_k
<
true
/*PrefetchK*/
,
Is_even_MN
,
kHeadDim
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
acc_o
,
p_reg
,
v_regs
,
v_lds
,
k_ptr
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
max_seq_kv_offset
-
kBlockN
);
// 计算 k, v 的偏移
v_ptr
+=
kBlockN
*
params
.
v_row_stride
;
}
// ========================================== Rest ===============================================
// 剩下的需要做 causal mask
int
n_block_loop
=
max
(
n_block_max
-
n_masking_steps
,
n_block_min
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
++
n_block_loop
)
{
// 计算 kv 的边界
int
max_seq_kv_offset
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
// ======================================================== QK gemm ======================================================================
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
];
fp8_qk_gemm
<
kBlockN
,
kHeadDim
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
q_regs
,
k_lds
);
// ========================================== load V ================================================
fp8_prefetch_v_to_lds
<
Is_even_MN
,
kBlockN
,
kHeadDimV
,
WARP_N
,
Element_k
>
(
v_ptr
,
v_lds
,
warp_id
,
params
.
v_row_stride
,
max_seq_kv_offset
);
// ======================================================== causal mask ==================================================================
if
constexpr
(
Is_causal
)
{
fp8_apply_causal_mask
<
kBlockN
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
s_reg
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
n_block_loop
*
kBlockN
,
lane_id
);
}
// ======================================================== s_reg ======================================================================
// fp8_debug_s_reg<FP8_DEBUG, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, ElementAccum>(
// s_reg_ptr, s_reg, bidb, bidh, h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ======================================================== mask ==================================================================
// 对齐 fp16 fwd:非 causal 的 rest loop 要屏蔽最后一个 partial KV tile 的越界列。
if
constexpr
(
!
Is_causal
&&
!
Is_local
)
{
if
constexpr
(
!
Is_even_MN
)
{
fp8_apply_mask
<
kBlockN
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
s_reg
,
max_seq_kv_offset
,
0
,
lane_id
);
}
}
// ======================================================== Softmax ======================================================================
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDimV
/
32
];
fp8_softmax_and_schedule_v
<
Assume_valid_rows
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
WARP_K
,
Element_k
,
ElementAccum
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
softmax_scale_log2
,
v_regs
,
v_lds
);
// ========================================== cvt ===============================================
union_vec32_fp8
p_reg
[
WARP_M
/
16
];
fp8_cvt_f32_to_fp8
<
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
p_reg
);
// ======================================================== p_reg ======================================================================
// fp8_debug_p_reg<0, Is_even_MN, kBlockM, kBlockN, WARP_M, WARP_N, Element_k>(
// p_reg_ptr, p_reg, bidb, bidh, params.h, binfo.actual_seqlen_q, binfo.actual_seqlen_k, max_seq_q_offset, max_seq_kv_offset, m_block, n_block_loop, warp_id, lane_id);
// ========================================== PV mmac ================================================
constexpr
bool
PrefetchK
=
n_masking_steps
>
1
;
fp8_pv_gemm_and_prefetch_k
<
PrefetchK
,
Is_even_MN
,
kHeadDim
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
acc_o
,
p_reg
,
v_regs
,
v_lds
,
k_ptr
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
max_seq_kv_offset
-
kBlockN
);
// 计算 k, v 的偏移
if
(
not
PrefetchK
)
{
k_ptr
+=
kBlockN
*
params
.
k_row_stride
;
}
v_ptr
+=
kBlockN
*
params
.
v_row_stride
;
}
// ========================================== rescale by scores_sum ==========================================
// 根据 scores_sum 对 acc_o 做缩放
ElementAccum
lse
[
WARP_M
/
16
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
fp8_attention_sink_load
(
params
.
s_aux_ptr
,
params
.
s_aux_type
,
bidh
);
fp8_attention_sink_apply
<
kHeadDimV
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
softmax_scale
,
sink_value
);
}
if
constexpr
(
Return_softmax
)
{
fp8_epilogue_rescale_acc_o
<
Assume_valid_rows
,
kHeadDimV
,
WARP_M
,
WARP_N
,
true
/*StoreLSE*/
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
lse
,
softmax_scale
,
v_descale
);
}
else
{
fp8_epilogue_rescale_acc_o
<
Assume_valid_rows
,
kHeadDimV
,
WARP_M
,
WARP_N
,
false
/*StoreLSE*/
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
lse
,
softmax_scale
,
v_descale
);
}
// ========================================== lse storation ==========================================
if
constexpr
(
Return_softmax
)
{
fp8_epilogue_store_lse
<
Is_even_MN
,
WARP_M
,
ElementAccum
>
(
softmax_lse_ptr
,
scores_max
,
scores_sum
,
lse
,
row_offset_lse_base
,
binfo
.
actual_seqlen_q
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
lane_id
);
}
// ========================================== Storation =============================================
fp8_epilogue_store_output
<
Is_even_MN
,
kBlockM
,
kHeadDimV
,
WARP_M
,
WARP_N
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
warp_id
,
lane_id
,
params
.
o_row_stride
,
binfo
.
actual_seqlen_q
);
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
bool
Is_GQA
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_fp8_attn_gfx938
(
const
Params
&
params
)
{
#if defined(__gfx938__) || defined(__gfx946__)
constexpr
bool
Do_lpt
=
Is_causal
and
Is_GQA
;
const
int
bidh
=
Do_lpt
?
blockIdx
.
x
:
blockIdx
.
y
;
const
int
bidb
=
Do_lpt
?
blockIdx
.
y
:
blockIdx
.
z
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
int
m_block
=
Do_lpt
?
gridDim
.
z
-
1
-
blockIdx
.
z
:
blockIdx
.
x
;
flash
::
compute_fp8_attn_mha_1rowblock_gfx938
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_Varlen
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
if
constexpr
(
Is_causal
and
!
Is_GQA
/*MHA causal mask*/
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
compute_fp8_attn_mha_1rowblock_gfx938
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_Varlen
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
gridDim
.
x
*
2
-
1
-
m_block
,
warp_id
);
}
#endif
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 Prefix Prefill (paged KV cache + varlen) for GFX938
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
,
typename
Params
>
inline
__device__
void
compute_fp8_attn_prefix_prefill_1rowblock_gfx938
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
m_block
,
const
int
warp_id
)
{
using
Element
=
typename
Kernel_traits
::
Element
;
using
Element_k
=
typename
Kernel_traits
::
Element_k
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockK
=
Kernel_traits
::
kBlockK
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDimV
=
Kernel_traits
::
kHeadDimV
;
constexpr
int
kNWarps
=
Kernel_traits
::
kNWarps
;
constexpr
int
WARP_M
=
Kernel_traits
::
kWaveM
;
constexpr
int
WARP_N
=
Kernel_traits
::
kWaveN
;
constexpr
int
WARP_K
=
32
;
constexpr
int
STAGES
=
Kernel_traits
::
STAGES
;
constexpr
int
WARP_NUM
=
kBlockM
/
WARP_M
;
// Varlen BlockInfo
const
flash
::
BlockInfo
<
true
/*Varlen*/
,
false
/*Is_kvcache*/
>
binfo
(
params
,
bidb
);
int
max_seq_q_offset
=
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
;
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
<=
0
)
return
;
// 定义 lds
extern
__shared__
int8_t
lds
[];
int8_t
*
q_lds
=
lds
+
0
;
int8_t
*
k_lds
=
lds
+
0
;
int8_t
*
v_lds
=
lds
+
0
;
// ========================================== 计算 offset (varlen + paged) ===========================================
const
int
page_block_size
=
params
.
page_block_size
;
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
int
n_block_min
=
0
;
if
constexpr
(
Is_local
)
{
n_block_min
=
max
(
0
,
(
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
-
params
.
window_size_left
)
/
kBlockN
);
}
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
constexpr
(
Is_causal
||
Is_local
)
{
const
int
window_size_right
=
Is_local
?
params
.
window_size_right
:
0
;
n_block_max
=
min
(
n_block_max
,
ceil_div
((
m_block
+
1
)
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
+
window_size_right
,
kBlockN
));
}
if
(
n_block_min
>=
n_block_max
)
return
;
const
int
first_block_table_idx
=
n_block_min
*
kBlockN
/
params
.
page_block_size
;
const
int
first_block_table_offset
=
n_block_min
*
kBlockN
-
first_block_table_idx
*
params
.
page_block_size
;
const
int
first_page
=
block_table
[
first_block_table_idx
];
int64_t
row_offset_q
,
row_offset_k
,
row_offset_v
,
row_offset_o
;
int
row_offset_lse
;
if
constexpr
(
Layout
==
1
)
{
/*bshd layout*/
row_offset_q
=
(
binfo
.
sum_s_q
+
m_block
*
kBlockM
)
*
int64_t
(
params
.
q_row_stride
)
+
params
.
q_head_stride
*
bidh
;
row_offset_k
=
int64_t
(
first_page
)
*
int64_t
(
params
.
k_batch_stride
)
+
first_block_table_offset
*
int64_t
(
params
.
k_row_stride
)
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
row_offset_v
=
int64_t
(
first_page
)
*
int64_t
(
params
.
v_batch_stride
)
+
first_block_table_offset
*
int64_t
(
params
.
v_row_stride
)
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
row_offset_o
=
binfo
.
sum_s_q
*
int64_t
(
params
.
o_head_stride
)
*
params
.
h
+
params
.
o_head_stride
*
bidh
+
m_block
*
kBlockM
*
params
.
o_row_stride
;
row_offset_lse
=
bidh
*
params
.
total_q
+
binfo
.
sum_s_q
;
}
else
{
/*bhsd layout*/
row_offset_q
=
binfo
.
sum_s_q
*
int64_t
(
params
.
q_row_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
params
.
q_row_stride
;
row_offset_k
=
int64_t
(
first_page
)
*
int64_t
(
params
.
k_batch_stride
)
+
first_block_table_offset
*
int64_t
(
params
.
k_row_stride
)
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
row_offset_v
=
int64_t
(
first_page
)
*
int64_t
(
params
.
v_batch_stride
)
+
first_block_table_offset
*
int64_t
(
params
.
v_row_stride
)
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
row_offset_o
=
binfo
.
sum_s_q
*
int64_t
(
params
.
o_row_stride
)
+
bidh
*
params
.
o_head_stride
+
m_block
*
kBlockM
*
params
.
o_row_stride
;
row_offset_lse
=
bidh
*
params
.
total_q
+
binfo
.
sum_s_q
;
}
// FP8 descale tensors are broadcast by taking the first scalar value.
// 使用原始指针 (FP8 prefetch函数内部会调用 prepare_for_matrix_load)
Element_k
*
q_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
q_ptr
)
+
row_offset_q
;
Element_k
*
k_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
k_ptr
)
+
row_offset_k
;
Element_k
*
v_ptr
=
reinterpret_cast
<
Element_k
*>
(
params
.
v_ptr
)
+
row_offset_v
;
ElementAccum
*
q_descale_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
q_descale_ptr
);
ElementAccum
*
k_descale_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
k_descale_ptr
);
ElementAccum
*
v_descale_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
v_descale_ptr
);
ElementAccum
q_descale
=
q_descale_ptr
[
0
];
ElementAccum
k_descale
=
k_descale_ptr
[
0
];
ElementAccum
qk_descale
=
q_descale
*
k_descale
;
ElementAccum
softmax_scale
=
params
.
scale_softmax
*
qk_descale
;
ElementAccum
softmax_scale_log2
=
params
.
scale_softmax_log2
*
qk_descale
;
ElementAccum
v_descale
=
v_descale_ptr
[
0
];
ElementAccum
*
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
);
Element
*
o_ptr
=
reinterpret_cast
<
Element
*>
(
params
.
o_ptr
)
+
row_offset_o
;
// ======================================================== 读取 Q ======================================================================
fp8_prefetch_q_to_lds
<
false
/*Is_even_MN*/
,
kHeadDim
,
WARP_M
,
Element_k
>
(
q_ptr
,
q_lds
,
warp_id
,
params
.
q_row_stride
,
max_seq_q_offset
);
int
lane_id
=
threadIdx
.
x
&
63
;
// 准备寄存器
ElementAccum
scores_max
[
WARP_M
/
16
];
ElementAccum
scores_sum
[
WARP_M
/
16
];
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDimV
/
32
][
WARP_M
/
16
][
WARP_N
/
16
];
fp8_attention_initialize
<
kHeadDimV
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
// 从 lds 读取 q 的数据
union_vec16_fp8
q_regs
[
WARP_M
/
16
][
kHeadDim
/
64
];
load_q_from_lds_to_vgpr
<
kHeadDim
,
WARP_M
,
Element_k
>
(
q_regs
,
q_lds
,
warp_id
,
lane_id
);
// ======================================================== Mainloop ======================================================================
int
n_masking_steps
=
1
;
if
constexpr
(
Is_causal
)
{
const
int
causal_start_col
=
m_block
*
kBlockM
+
binfo
.
actual_seqlen_k
-
binfo
.
actual_seqlen_q
;
const
int
first_mask_block
=
max
(
n_block_min
,
causal_start_col
/
kBlockN
);
n_masking_steps
=
n_block_max
-
first_mask_block
;
}
else
if
constexpr
(
Is_local
)
{
n_masking_steps
=
min
(
n_block_max
-
n_block_min
,
ceil_div
(
kBlockM
,
kBlockN
));
}
n_masking_steps
=
min
(
max
(
n_masking_steps
,
1
),
n_block_max
-
n_block_min
);
constexpr
bool
Assume_valid_rows
=
!
Is_local
;
// ======================================================== Prefetch 第一块 K ======================================================================
if
(
n_block_max
>
n_masking_steps
)
{
fp8_prefetch_k_to_lds
<
false
/*Is_even_MN*/
,
kHeadDim
,
WARP_N
,
Element_k
>
(
k_ptr
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
binfo
.
actual_seqlen_k
);
}
// ======================================================== 主循环:不需要 causal mask + Prefetch K ============================================================
for
(
int
n_block_loop
=
n_block_min
;
n_block_loop
<
n_block_max
-
n_masking_steps
;
++
n_block_loop
)
{
int
max_seq_kv_offset
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
// QK gemm(K 数据已在上一轮 prefetch 到 LDS)
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
];
fp8_qk_gemm
<
kBlockN
,
kHeadDim
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
q_regs
,
k_lds
);
// Prefetch V
fp8_prefetch_v_to_lds
<
false
/*Is_even_MN*/
,
kBlockN
,
kHeadDimV
,
WARP_N
,
Element_k
>
(
v_ptr
,
v_lds
,
warp_id
,
params
.
v_row_stride
,
max_seq_kv_offset
);
if
constexpr
(
Is_local
)
{
fp8_apply_local_mask
<
kBlockN
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
s_reg
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
n_block_loop
*
kBlockN
,
params
.
window_size_left
,
params
.
window_size_right
,
lane_id
);
}
// Softmax + 读取 V 到寄存器
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDimV
/
32
];
fp8_softmax_and_schedule_v
<
Assume_valid_rows
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
WARP_K
,
Element_k
,
ElementAccum
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
softmax_scale_log2
,
v_regs
,
v_lds
);
// cvt
union_vec32_fp8
p_reg
[
WARP_M
/
16
];
fp8_cvt_f32_to_fp8
<
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
p_reg
);
// PV MMAC + Prefetch 下一块 K(paged KV)
const
int
next_n_block_loop
=
n_block_loop
+
1
;
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
page_block_size
;
const
int
block_table_idx_next
=
next_n_block_loop
*
kBlockN
/
page_block_size
;
const
int
block_table_offset_next
=
next_n_block_loop
*
kBlockN
-
block_table_idx_next
*
page_block_size
;
const
int64_t
table_delta
=
int64_t
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
]);
const
int64_t
offset_delta
=
int64_t
(
block_table_offset_next
-
block_table_offset_cur
);
Element_k
*
k_ptr_next
=
k_ptr
+
table_delta
*
int64_t
(
params
.
k_batch_stride
)
+
offset_delta
*
int64_t
(
params
.
k_row_stride
);
const
int
max_seq_kv_offset_next
=
binfo
.
actual_seqlen_k
-
next_n_block_loop
*
kBlockN
;
fp8_pv_gemm_and_prefetch_k_paged
<
false
/*Is_even_MN*/
,
kHeadDim
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
acc_o
,
p_reg
,
v_regs
,
v_lds
,
k_ptr_next
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
max_seq_kv_offset_next
);
// 更新 K/V 指针
k_ptr
=
k_ptr_next
;
v_ptr
+=
table_delta
*
int64_t
(
params
.
v_batch_stride
)
+
offset_delta
*
int64_t
(
params
.
v_row_stride
);
}
// ======================================================== Masking 循环:需要 causal mask,不 Prefetch K ============================================================
int
n_block_loop
=
max
(
n_block_max
-
n_masking_steps
,
n_block_min
);
#pragma unroll
for
(
int
masking_step
=
0
;
masking_step
<
n_masking_steps
;
++
masking_step
,
++
n_block_loop
)
{
int
max_seq_kv_offset
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
// 如果主循环没有 prefetch(n_block_max <= n_masking_steps),需要在这里 prefetch K
if
(
masking_step
==
0
&&
n_block_max
<=
n_masking_steps
)
{
fp8_prefetch_k_to_lds
<
false
/*Is_even_MN*/
,
kHeadDim
,
WARP_N
,
Element_k
>
(
k_ptr
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
max_seq_kv_offset
);
}
// QK gemm
vec4_Accum
<
ElementAccum
>
s_reg
[
kBlockN
/
WARP_N
][
WARP_M
/
16
][
WARP_N
/
16
];
fp8_qk_gemm
<
kBlockN
,
kHeadDim
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
q_regs
,
k_lds
);
// Prefetch V
fp8_prefetch_v_to_lds
<
false
/*Is_even_MN*/
,
kBlockN
,
kHeadDimV
,
WARP_N
,
Element_k
>
(
v_ptr
,
v_lds
,
warp_id
,
params
.
v_row_stride
,
max_seq_kv_offset
);
// Mask
// 对齐 fp16 fwd:非 causal 的 rest loop 要屏蔽最后一个 partial KV tile 的越界列。
if
constexpr
(
!
Is_causal
&&
!
Is_local
)
{
fp8_apply_mask
<
kBlockN
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
s_reg
,
max_seq_kv_offset
,
0
,
lane_id
);
}
// Causal mask
if
constexpr
(
Is_local
)
{
fp8_apply_local_mask
<
kBlockN
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
s_reg
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
n_block_loop
*
kBlockN
,
params
.
window_size_left
,
params
.
window_size_right
,
lane_id
);
}
else
if
constexpr
(
Is_causal
)
{
fp8_apply_causal_mask
<
kBlockN
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
s_reg
,
binfo
.
actual_seqlen_q
,
binfo
.
actual_seqlen_k
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
n_block_loop
*
kBlockN
,
lane_id
);
}
// Softmax + 读取 V 到寄存器
union_vec16_fp8
v_regs
[
kBlockN
/
WARP_N
][
kHeadDimV
/
32
];
fp8_softmax_and_schedule_v
<
Assume_valid_rows
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
WARP_K
,
Element_k
,
ElementAccum
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
softmax_scale_log2
,
v_regs
,
v_lds
);
// cvt
union_vec32_fp8
p_reg
[
WARP_M
/
16
];
fp8_cvt_f32_to_fp8
<
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
s_reg
,
p_reg
);
const
int
next_n_block_loop
=
n_block_loop
+
1
;
if
(
next_n_block_loop
<
n_block_max
)
{
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
page_block_size
;
const
int
block_table_idx_next
=
next_n_block_loop
*
kBlockN
/
page_block_size
;
const
int
block_table_offset_next
=
next_n_block_loop
*
kBlockN
-
block_table_idx_next
*
page_block_size
;
const
int64_t
table_delta
=
int64_t
(
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
]);
const
int64_t
offset_delta
=
int64_t
(
block_table_offset_next
-
block_table_offset_cur
);
Element_k
*
k_ptr_next
=
k_ptr
+
table_delta
*
int64_t
(
params
.
k_batch_stride
)
+
offset_delta
*
int64_t
(
params
.
k_row_stride
);
const
int
max_seq_kv_offset_next
=
binfo
.
actual_seqlen_k
-
next_n_block_loop
*
kBlockN
;
fp8_pv_gemm_and_prefetch_k_paged
<
false
/*Is_even_MN*/
,
kHeadDim
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
acc_o
,
p_reg
,
v_regs
,
v_lds
,
k_ptr_next
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
max_seq_kv_offset_next
);
k_ptr
=
k_ptr_next
;
v_ptr
+=
table_delta
*
int64_t
(
params
.
v_batch_stride
)
+
offset_delta
*
int64_t
(
params
.
v_row_stride
);
}
else
{
fp8_pv_gemm_and_prefetch_k
<
false
/*PrefetchK*/
,
false
/*Is_even_MN*/
,
kHeadDim
,
kHeadDimV
,
kBlockN
,
WARP_M
,
WARP_N
,
Element_k
,
ElementAccum
>
(
acc_o
,
p_reg
,
v_regs
,
v_lds
,
k_ptr
,
k_lds
,
warp_id
,
params
.
k_row_stride
,
max_seq_kv_offset
);
}
}
// ========================================== rescale by scores_sum ==========================================
ElementAccum
lse
[
WARP_M
/
16
];
if
(
params
.
s_aux_ptr
!=
nullptr
)
{
const
float
sink_value
=
fp8_attention_sink_load
(
params
.
s_aux_ptr
,
params
.
s_aux_type
,
bidh
);
fp8_attention_sink_apply
<
kHeadDimV
,
WARP_M
,
WARP_N
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
softmax_scale
,
sink_value
);
}
if
constexpr
(
Return_softmax
)
{
fp8_epilogue_rescale_acc_o
<
Assume_valid_rows
,
kHeadDimV
,
WARP_M
,
WARP_N
,
true
/*StoreLSE*/
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
lse
,
softmax_scale
,
v_descale
);
}
else
{
fp8_epilogue_rescale_acc_o
<
Assume_valid_rows
,
kHeadDimV
,
WARP_M
,
WARP_N
,
false
/*StoreLSE*/
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
lse
,
softmax_scale
,
v_descale
);
}
// ========================================== lse storation (varlen) ==========================================
if
constexpr
(
Return_softmax
)
{
fp8_epilogue_store_lse
<
false
/*Is_even_MN*/
,
WARP_M
,
ElementAccum
>
(
softmax_lse_ptr
,
scores_max
,
scores_sum
,
lse
,
row_offset_lse
,
binfo
.
actual_seqlen_q
,
m_block
*
kBlockM
+
warp_id
*
WARP_M
,
lane_id
);
}
// ========================================== Storation =============================================
fp8_epilogue_store_output
<
false
/*Is_even_MN*/
,
kBlockM
,
kHeadDimV
,
WARP_M
,
WARP_N
,
Element
,
ElementAccum
>
(
o_ptr
,
acc_o
,
m_block
,
warp_id
,
lane_id
,
params
.
o_row_stride
,
binfo
.
actual_seqlen_q
);
}
template
<
typename
Kernel_traits
,
bool
Is_training
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fp8_fwd_prefix_prefill_kernel_gfx938
(
Flash_fwd_params
params
)
{
#if defined(__gfx938__) || defined(__gfx946__)
// LPT 调度:改变 blockIdx 到 m_block/bidh/bidb 的映射
// causal 模式:blockIdx.x = bidh, blockIdx.y = bidb, blockIdx.z 倒序 = m_block
// 非 causal 模式:blockIdx.x = m_block, blockIdx.y = bidh, blockIdx.z = bidb
constexpr
bool
Do_lpt
=
Is_causal
;
const
int
bidh
=
Do_lpt
?
blockIdx
.
x
:
blockIdx
.
y
;
const
int
bidb
=
Do_lpt
?
blockIdx
.
y
:
blockIdx
.
z
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
threadIdx
.
x
/
64
);
int
m_block
=
Do_lpt
?
gridDim
.
z
-
1
-
blockIdx
.
z
:
blockIdx
.
x
;
flash
::
compute_fp8_attn_prefix_prefill_1rowblock_gfx938
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_K
,
Return_softmax
,
Has_alibi
,
Layout
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
m_block
,
warp_id
);
#endif
}
}
// namespace flash
csrc/flash_attn_hg/src/flash_fwd_b8_mla.h
View file @
518a5f4d
...
...
@@ -210,7 +210,7 @@ inline __device__ void compute_attn_1rowblock_splitkv_fp8_mla_gfx938(const Param
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Split
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_fp8_mla_gfx938
(
const
Params
&
params
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
// The block index for the head.
const
int
bidh
=
Split
?
blockIdx
.
z
%
params
.
h
:
blockIdx
.
y
;
// batch x num_head, num_head first
...
...
@@ -252,7 +252,7 @@ __global__ void flash_mla_convert_query_to_fp8_kernel(
const
int
nope_row_stride
,
const
int
rope_row_stride
,
const
int
qheads
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
if
constexpr
(
persistent
)
{
for
(
int
bidb
=
blockIdx
.
x
;
bidb
<
total_blocks
;
bidb
+=
gridDim
.
x
)
{
// --------------------- nope -------------------------
...
...
csrc/flash_attn_hg/src/flash_fwd_b8_pa.h
View file @
518a5f4d
...
...
@@ -108,7 +108,7 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
const
int64_t
row_offset_q
=
bidb
*
params
.
q_batch_stride
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
query_seqlen_stride
;
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
constexpr
bool
USE_CACHE_SWIZZLE
=
false
;
#else
constexpr
bool
USE_CACHE_SWIZZLE
=
true
;
// for gfx928, cache swizzle have significant influence
...
...
@@ -166,7 +166,7 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
vec2_Accum
<
ElementAccum
>
scores_max
[
WARP_M
/
32
]
=
{
-
INFINITY
};
vec2_Accum
<
ElementAccum
>
scores_sum
[
WARP_M
/
32
]
=
{
0
};
// 由于当前编译器无法自动生成 v_mov_b64 指令, 主动用 builtin 还会被转译成 v_mov_b32, 因此用内联汇编控制
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
vec4_Accum
<
ElementAccum
>
acc_o
[(
kHeadDimV
/
kBlockK
)
*
((
WARP_M
/
32
)
*
(
kBlockK
/
32
))][
4
];
if
constexpr
(
kHeadDimV
==
128
)
{
// kHeadDim 128 是主要优化目标
if
constexpr
(
M_MMAC_COUNT
==
1
)
{
...
...
@@ -176,23 +176,15 @@ inline __device__ void compute_attn_mha_1rowblock_splitkv_int8(const Params &par
}
__builtin_amdgcn_sched_barrier
(
0
);
}
else
{
// 非 kHeaddim 128, 交给编译器后续的优化了
uint64_t
pk_zero
=
0
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
kHeadDimV
/
kBlockK
)
*
((
WARP_M
/
32
)
*
(
kBlockK
/
32
));
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#if defined(__gfx936__)
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
0
);
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
0
);
#elif defined(__gfx938__)
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
])
:
);
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
])
:
);
#endif
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
}
}
}
...
...
@@ -418,4 +410,520 @@ inline __device__ void compute_attn_splitkv_int8(const Params ¶ms) {
flash
::
compute_attn_mha_1rowblock_splitkv_int8
<
Kernel_traits
,
Is_training
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_K
,
Return_softmax
,
Has_alibi
,
Split
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
Flash_fwd_params
>
(
params
,
bidb
,
bidh
,
warp_id
);
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// MLS-based FP8 Paged Attention, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
REUSE_KV_TIMES
,
int
K_LOOP_COUNT
,
int
K_WARP_COUNT
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_kvcache_acco_reduce_compact_gfx938
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
ElementAccum
*
acc_o_lds
,
int
seqlen_q
,
int
warp_id
,
int
lane_id
)
{
constexpr
int
kReduceBlockK
=
32
;
constexpr
int
kReduceRows
=
M_WARP_COUNT
*
M_MMAC_COUNT
*
16
;
const
int
q_seq_idx
=
lane_id
&
15
;
const
int
lane_dim_offset
=
(
lane_id
>>
4
)
*
4
;
const
int
even_reuse_kv_times
=
(
REUSE_KV_TIMES
>
0
)
?
((
REUSE_KV_TIMES
+
1
)
/
2
)
*
2
:
((
seqlen_q
+
1
)
/
2
)
*
2
;
const
bool
is_valid_q_lane
=
q_seq_idx
<
even_reuse_kv_times
;
#pragma unroll
for
(
int
h_idx
=
0
;
h_idx
<
K_LOOP_COUNT
;
++
h_idx
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
K_WARP_COUNT
;
++
k_idx
)
{
if
(
is_valid_q_lane
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
M_WARP_COUNT
;
++
warp_m_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
row_idx
=
warp_m_idx
*
M_MMAC_COUNT
*
16
+
min_tile_m
*
16
+
q_seq_idx
;
const
int
lds_offset
=
(
warp_id
*
kReduceRows
+
row_idx
)
*
kReduceBlockK
+
min_tile_n
*
16
+
lane_dim_offset
;
const
int
tile_32x32_id
=
h_idx
*
M_WARP_COUNT
*
K_WARP_COUNT
+
k_idx
*
M_WARP_COUNT
+
warp_m_idx
;
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
)
=
acc_o
[
tile_32x32_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
;
}
}
}
}
__syncthreads
();
if
constexpr
(
WARP_NUM
>
1
)
{
if
(
warp_id
==
0
)
{
if
(
is_valid_q_lane
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
M_WARP_COUNT
;
++
warp_m_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
row_idx
=
warp_m_idx
*
M_MMAC_COUNT
*
16
+
min_tile_m
*
16
+
q_seq_idx
;
const
int
lds_offset
=
row_idx
*
kReduceBlockK
+
min_tile_n
*
16
+
lane_dim_offset
+
vec_idx
;
ElementAccum
acc_tmp
=
acc_o_lds
[
lds_offset
];
#pragma unroll
for
(
int
loop
=
1
;
loop
<
WARP_NUM
;
++
loop
)
{
acc_tmp
+=
acc_o_lds
[
lds_offset
+
loop
*
kReduceRows
*
kReduceBlockK
];
}
acc_o_lds
[
lds_offset
]
=
acc_tmp
;
}
}
}
}
}
}
}
__syncthreads
();
if
(
is_valid_q_lane
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
M_WARP_COUNT
;
++
warp_m_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
row_idx
=
warp_m_idx
*
M_MMAC_COUNT
*
16
+
min_tile_m
*
16
+
q_seq_idx
;
const
int
lds_offset
=
row_idx
*
kReduceBlockK
+
min_tile_n
*
16
+
lane_dim_offset
;
const
int
tile_32x32_id
=
h_idx
*
M_WARP_COUNT
*
K_WARP_COUNT
+
k_idx
*
M_WARP_COUNT
+
warp_m_idx
;
acc_o
[
tile_32x32_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
);
}
}
}
}
__syncthreads
();
}
}
}
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
fp8_kvcache_apply_mask_local_causal_gfx938
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
ngroups
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
8
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
logical_row
=
row_idx
/
ngroups
;
const
int
logical_q
=
max_seqlen_q
/
ngroups
;
const
int
col_idx_limit_left
=
max
(
0
,
logical_row
+
max_seqlen_k
-
logical_q
-
window_size_left
);
const
int
col_idx_limit_right
=
min
(
max_seqlen_k
,
logical_row
+
max_seqlen_k
-
logical_q
+
window_size_right
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
4
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
<
col_idx_limit_left
||
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
int
kHeadDim
,
int
kBlockM
,
int
WARP_M
,
int
M_MMAC_COUNT
,
typename
Element
>
__forceinline__
__device__
void
fp8_mha_prefetch_q_to_vgpr_gfx938
(
vec4_uint
q_addr
,
Element
*
q_lds
,
union_vec16_fp8
q_reg
[
M_MMAC_COUNT
][
kHeadDim
/
64
],
int
warp_id
,
int
query_seqlen_stride
,
int
max_seq_q_offset
)
{
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
256
);
static_assert
(
WARP_M
==
32
);
vec4_uint
q_srsrc
;
q_srsrc
[
1
]
=
q_addr
[
1
];
q_srsrc
[
2
]
=
query_seqlen_stride
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kHeadDim
/
128
;
++
k_loop
)
{
if
(
warp_id
==
min_tile_m
)
{
const
int
q_row_base
=
min_tile_m
*
16
;
const
int
valid_rows
=
max_seq_q_offset
-
q_row_base
;
const
int
safe_q_row_base
=
valid_rows
<=
0
?
0
:
q_row_base
;
const
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
16
-
valid_rows
);
q_srsrc
[
3
]
=
valid_rows
>=
16
?
0
:
(
nm_filter
<<
8
);
const
int64_t
row_offset_bytes
=
int64_t
(
safe_q_row_base
)
*
int64_t
(
query_seqlen_stride
)
*
sizeof
(
Element
);
const
int64_t
dim_offset_bytes
=
int64_t
(
k_loop
)
*
128
*
sizeof
(
Element
);
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
row_offset_bytes
+
dim_offset_bytes
);
const
int
lds_offset_bytes
=
(
min_tile_m
*
(
kHeadDim
/
128
)
+
k_loop
)
*
16
*
128
*
sizeof
(
Element
);
inline_matrix_load_128x16_b8_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
lds_offset_bytes
,
0
);
}
}
}
flash
::
wait_buffer_data_arrived
<
true
/*sync*/
>
(
0
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
kHeadDim
/
128
;
++
k_loop
)
{
const
int
lds_offset_bytes
=
(
min_tile_m
*
(
kHeadDim
/
128
)
+
k_loop
)
*
16
*
128
*
sizeof
(
Element
);
const
int
q_lds_load_offset
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset_bytes
;
DS_READ_MATRIX_64x16_B8
(
q_lds_load_offset
,
q_reg
[
min_tile_m
][
k_loop
*
2
+
0
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
q_lds_load_offset
+
1024
,
q_reg
[
min_tile_m
][
k_loop
*
2
+
1
].
i32x4
,
true
/*transpose*/
)
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kBlockM
,
int
WARP_M
,
int
M_MMAC_COUNT
,
typename
Element
>
__forceinline__
__device__
void
fp8_mha_prefetch_q_to_vgpr_hdim192_gfx938
(
vec4_uint
q_addr
,
Element
*
q_lds
,
union_vec16_fp8
q_reg
[
M_MMAC_COUNT
][
3
],
int
warp_id
,
int
query_seqlen_stride
,
int
max_seq_q_offset
)
{
static_assert
(
WARP_M
==
32
);
constexpr
int
kLoadBytes
=
16
*
128
*
sizeof
(
Element
);
vec4_uint
q_srsrc
;
q_srsrc
[
1
]
=
q_addr
[
1
];
q_srsrc
[
2
]
=
query_seqlen_stride
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
if
(
warp_id
==
min_tile_m
)
{
const
int
q_row_base
=
min_tile_m
*
16
;
const
int
valid_rows
=
max_seq_q_offset
-
q_row_base
;
const
int
safe_q_row_base
=
valid_rows
<=
0
?
0
:
q_row_base
;
const
int
nm_filter
=
inline_min_max
<
0
,
16
>
(
16
-
valid_rows
);
q_srsrc
[
3
]
=
valid_rows
>=
16
?
0
:
(
nm_filter
<<
8
);
const
int64_t
row_offset_bytes
=
int64_t
(
safe_q_row_base
)
*
int64_t
(
query_seqlen_stride
)
*
sizeof
(
Element
);
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
row_offset_bytes
);
const
int
q_lds_first_offset
=
(
min_tile_m
*
2
+
0
)
*
kLoadBytes
;
inline_matrix_load_128x16_b8_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
q_lds_first_offset
,
0
);
*
(
uint64_t
*
)
&
q_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
q_addr
+
row_offset_bytes
+
64
*
sizeof
(
Element
));
const
int
q_lds_tail_offset
=
(
min_tile_m
*
2
+
1
)
*
kLoadBytes
;
inline_matrix_load_128x16_b8_lds_trans
<
0
,
1
>
(
q_lds
,
q_srsrc
,
q_lds_tail_offset
,
0
);
}
}
flash
::
wait_buffer_data_arrived
<
true
/*sync*/
>
(
0
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
q_lds_first_load
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
(
min_tile_m
*
2
+
0
)
*
kLoadBytes
;
DS_READ_MATRIX_64x16_B8
(
q_lds_first_load
,
q_reg
[
min_tile_m
][
0
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
q_lds_first_load
+
1024
,
q_reg
[
min_tile_m
][
1
].
i32x4
,
true
/*transpose*/
)
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
const
int
q_lds_tail_load
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
(
min_tile_m
*
2
+
1
)
*
kLoadBytes
;
DS_READ_MATRIX_64x16_B8
(
q_lds_tail_load
+
1024
,
q_reg
[
min_tile_m
][
2
].
i32x4
,
true
/*transpose*/
)
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_1rowblock_splitkv_fp8_gfx938
(
const
Params
&
params
,
const
int
bidb
,
const
int
bidh
,
const
int
warp_id
)
{
using
Element
=
fp8_e4m3
;
using
ElementAccum
=
typename
Kernel_traits
::
ElementAccum
;
using
SplitkvAccumType
=
typename
Kernel_traits
::
SplitkvAccumType
;
constexpr
int
kBlockM
=
Kernel_traits
::
kBlockM
;
constexpr
int
kBlockN
=
Kernel_traits
::
kBlockN
;
constexpr
int
kBlockK
=
Kernel_traits
::
kBlockK
;
constexpr
int
kHeadDim
=
Kernel_traits
::
kHeadDim
;
constexpr
int
kHeadDimV
=
Kernel_traits
::
kHeadDimV
;
constexpr
int
WARP_M
=
Kernel_traits
::
kWaveM
;
constexpr
int
WARP_N
=
Kernel_traits
::
kWaveN
;
constexpr
int
STAGES
=
Kernel_traits
::
STAGES
;
constexpr
int
WARP_NUM
=
kBlockN
/
WARP_N
;
constexpr
int
kHeadDimVSplit
=
kHeadDimV
/
HEADDIM_V_SPLIT
;
static_assert
(
kBlockK
==
64
);
static_assert
(
kHeadDim
==
128
||
kHeadDim
==
256
||
(
kHeadDim
==
192
&&
kHeadDimV
==
128
));
static_assert
(
kHeadDimVSplit
==
128
);
flash
::
SafeDecodeBlockInfo
binfo
;
binfo
.
set_params
<
Params
,
/*Is_Q_varlen=*/
Is_Varlen
,
/*Is_K_Cumulative=*/
false
>
(
params
,
bidb
);
int
split_id
=
0
;
int
original_actual_seqlen_k
=
binfo
.
actual_seqlen_k
;
int
partition_size
=
0
;
if
constexpr
(
Split
)
{
split_id
=
blockIdx
.
y
;
if
constexpr
(
Is_Varlen
)
{
partition_size
=
splitkv_get_partitionsize_of_fix_numsplits
(
binfo
.
actual_seqlen_k
,
params
.
num_splits
);
binfo
.
actual_seqlen_k
=
min
(
binfo
.
actual_seqlen_k
-
split_id
*
partition_size
,
partition_size
);
}
else
{
partition_size
=
params
.
partition_size
;
int
num_splits
=
max
(
1
,
floor_div
(
binfo
.
actual_seqlen_k
,
partition_size
));
binfo
.
actual_seqlen_k
=
(
split_id
==
num_splits
-
1
)
?
binfo
.
actual_seqlen_k
-
split_id
*
partition_size
:
partition_size
;
binfo
.
actual_seqlen_k
=
(
split_id
>=
num_splits
)
?
0
:
binfo
.
actual_seqlen_k
;
if
(
split_id
>=
num_splits
)
return
;
}
}
int
block_x
=
blockIdx
.
x
;
const
int
m_block
=
block_x
/
HEADDIM_V_SPLIT
;
const
int
headdim_split_id
=
block_x
&
(
HEADDIM_V_SPLIT
-
1
);
int
ngroups
=
1
;
int
actual_seqlen_q
=
binfo
.
actual_seqlen_q
;
if
constexpr
(
Is_Varlen
)
{
ngroups
=
params
.
ngroups
;
actual_seqlen_q
=
binfo
.
actual_seqlen_q
*
ngroups
;
}
if
(
m_block
*
kBlockM
>=
actual_seqlen_q
||
binfo
.
actual_seqlen_k
<=
0
)
return
;
extern
__shared__
Element
fp8_smem
[];
constexpr
int
q_smem_bytes
=
STAGES
*
kBlockM
*
kBlockK
*
sizeof
(
Element
);
constexpr
int
kv_smem_bytes
=
STAGES
*
kBlockK
*
WARP_N
*
sizeof
(
Element
)
*
WARP_NUM
;
constexpr
int
gemm_smem_bytes
=
q_smem_bytes
>
kv_smem_bytes
?
q_smem_bytes
:
kv_smem_bytes
;
Element
*
q_lds
=
reinterpret_cast
<
Element
*>
(
fp8_smem
);
Element
*
k_lds
=
reinterpret_cast
<
Element
*>
(
fp8_smem
);
Element
*
v_lds
=
k_lds
;
ElementAccum
*
acc_o_lds
=
reinterpret_cast
<
ElementAccum
*>
(
fp8_smem
);
ElementAccum
*
max_lds
=
reinterpret_cast
<
ElementAccum
*>
(
reinterpret_cast
<
char
*>
(
fp8_smem
)
+
gemm_smem_bytes
);
const
int
query_seqlen_stride
=
params
.
q_row_stride
;
const
int
kcache_seqlen_stride
=
params
.
k_row_stride
;
const
int
vcache_seqlen_stride
=
params
.
v_row_stride
;
int
n_block_min
=
0
;
int
n_block_max
=
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN
);
if
constexpr
(
Is_local
)
{
const
int
q_row_start
=
m_block
*
kBlockM
;
const
int
q_row_end
=
min
(
actual_seqlen_q
,
(
m_block
+
1
)
*
kBlockM
)
-
1
;
const
int
logical_q
=
Is_Varlen
?
actual_seqlen_q
/
ngroups
:
actual_seqlen_q
;
const
int
logical_row_start
=
Is_Varlen
?
q_row_start
/
ngroups
:
q_row_start
;
const
int
logical_row_end
=
Is_Varlen
?
q_row_end
/
ngroups
:
q_row_end
;
const
int
split_seqlen_start
=
Split
?
split_id
*
partition_size
:
0
;
const
int
local_left
=
max
(
0
,
logical_row_start
+
original_actual_seqlen_k
-
logical_q
-
params
.
window_size_left
);
const
int
local_right
=
min
(
original_actual_seqlen_k
,
logical_row_end
+
original_actual_seqlen_k
-
logical_q
+
params
.
window_size_right
+
1
);
const
int
split_local_left
=
local_left
-
split_seqlen_start
;
const
int
split_local_right
=
local_right
-
split_seqlen_start
;
const
int
n_block_count
=
n_block_max
;
const
int
raw_n_block_min
=
max
(
0
,
split_local_left
/
kBlockN
);
const
int
raw_n_block_max
=
ceil_div
(
max
(
0
,
split_local_right
),
kBlockN
);
n_block_min
=
min
(
max
(
raw_n_block_min
,
0
),
max
(
0
,
n_block_count
-
1
));
n_block_max
=
min
(
max
(
raw_n_block_max
,
n_block_min
+
1
),
n_block_count
);
}
const
int
page_block_size
=
params
.
page_block_size
;
int
*
block_table
=
params
.
block_table
+
bidb
*
params
.
block_table_batch_stride
;
const
int
this_split_seqlen_start
=
Split
?
split_id
*
partition_size
:
0
;
block_table
=
block_table
+
(
Split
?
ceil_div
(
this_split_seqlen_start
,
page_block_size
)
:
0
);
const
int
block_table_idx
=
n_block_min
*
kBlockN
/
page_block_size
;
const
int
block_table_offset
=
n_block_min
*
kBlockN
-
block_table_idx
*
page_block_size
;
const
int64_t
row_offset_k
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
k_batch_stride
)
+
block_table_offset
*
params
.
k_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
k_head_stride
;
const
int64_t
row_offset_v
=
int64_t
(
block_table
[
block_table_idx
])
*
int64_t
(
params
.
v_batch_stride
)
+
block_table_offset
*
params
.
v_row_stride
+
(
bidh
/
params
.
h_h_k_ratio
)
*
params
.
v_head_stride
;
const
int64_t
row_offset_q
=
Is_Varlen
?
binfo
.
sum_s_q
*
ngroups
*
int64_t
(
query_seqlen_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
int64_t
(
query_seqlen_stride
)
:
bidb
*
int64_t
(
params
.
q_batch_stride
)
+
bidh
*
params
.
q_head_stride
+
m_block
*
kBlockM
*
int64_t
(
query_seqlen_stride
);
auto
q_addr
=
prepare_for_buffer_load
<
kHeadDim
,
Element
,
false
>
(
reinterpret_cast
<
Element
*>
(
params
.
q_ptr
)
+
row_offset_q
);
auto
k_addr
=
prepare_for_buffer_load
<
kHeadDim
,
Element
,
false
>
(
reinterpret_cast
<
Element
*>
(
params
.
k_ptr
)
+
row_offset_k
);
auto
v_addr
=
prepare_for_buffer_load
<
kHeadDimV
,
Element
,
false
>
(
reinterpret_cast
<
Element
*>
(
params
.
v_ptr
)
+
row_offset_v
+
headdim_split_id
*
kHeadDimVSplit
);
const
ElementAccum
q_descale
=
params
.
q_descale_ptr
[
0
];
const
ElementAccum
k_descale
=
params
.
k_descale_ptr
[
0
];
const
ElementAccum
v_descale
=
params
.
v_descale_ptr
[
0
];
__float2
qk_descale
=
{
q_descale
*
k_descale
,
q_descale
*
k_descale
};
int
row_offset_lse
;
ElementAccum
*
scores_sum_ptr
=
nullptr
;
ElementAccum
*
scores_max_ptr
=
nullptr
;
ElementAccum
*
softmax_lse_ptr
=
nullptr
;
if
constexpr
(
Split
)
{
int
row_offset_scores_split
;
if
constexpr
(
Is_Varlen
)
{
row_offset_lse
=
bidh
*
ngroups
*
params
.
total_q
+
binfo
.
sum_s_q
+
m_block
*
kBlockM
;
row_offset_scores_split
=
split_id
*
(
params
.
h
*
ngroups
*
params
.
total_q
);
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lseaccum_ptr
)
+
row_offset_lse
+
row_offset_scores_split
;
}
else
{
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
row_offset_scores_split
=
split_id
*
(
params
.
b
*
params
.
h
*
params
.
seqlen_q
);
scores_sum_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
scores_sum_ptr
)
+
row_offset_lse
+
row_offset_scores_split
;
scores_max_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
scores_max_ptr
)
+
row_offset_lse
+
row_offset_scores_split
;
}
}
else
{
if
constexpr
(
Is_Varlen
)
{
row_offset_lse
=
bidh
*
ngroups
*
params
.
total_q
+
binfo
.
sum_s_q
+
m_block
*
kBlockM
;
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
;
}
else
{
row_offset_lse
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM
;
softmax_lse_ptr
=
reinterpret_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
)
+
row_offset_lse
;
}
}
constexpr
int
M_WARP_COUNT
=
WARP_M
/
32
;
constexpr
int
K_WARP_COUNT
=
kBlockK
/
32
;
constexpr
int
N_WARP_COUNT
=
WARP_N
/
32
;
constexpr
int
K_LOOP_COUNT
=
kHeadDimVSplit
/
kBlockK
;
vec2_Accum
<
ElementAccum
>
scores_max
[
M_WARP_COUNT
];
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
];
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
];
union_vec16_fp8
q_reg
[
M_MMAC_COUNT
][
kHeadDim
/
64
];
attention_initialize
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
acc_o
);
if
constexpr
(
kHeadDim
==
192
&&
kHeadDimV
==
128
)
{
fp8_mha_prefetch_q_to_vgpr_hdim192_gfx938
<
kBlockM
,
WARP_M
,
M_MMAC_COUNT
,
Element
>
(
q_addr
,
q_lds
,
q_reg
,
warp_id
,
query_seqlen_stride
,
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
fp8_mha_prefetch_q_to_vgpr_gfx938
<
kHeadDim
,
kBlockM
,
WARP_M
,
M_MMAC_COUNT
,
Element
>
(
q_addr
,
q_lds
,
q_reg
,
warp_id
,
query_seqlen_stride
,
actual_seqlen_q
-
m_block
*
kBlockM
);
}
int
n_block_loop
=
n_block_min
;
constexpr
bool
PrefetchK
=
true
;
if
constexpr
(
PrefetchK
)
{
int
warp_seqkv_limit
=
binfo
.
actual_seqlen_k
-
n_block_min
*
kBlockN
;
fp8_kvcache_prefetch_k_gfx938
<
WARP_NUM
,
Element
>
(
k_addr
,
k_lds
,
warp_id
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
}
for
(;
n_block_loop
<
n_block_max
;
++
n_block_loop
)
{
const
int
warp_offset_in_seqkv
=
n_block_loop
*
kBlockN
+
warp_id
*
WARP_N
;
const
int
warp_seqkv_limit
=
binfo
.
actual_seqlen_k
-
n_block_loop
*
kBlockN
;
constexpr
bool
PrefetchVInQK
=
(
kHeadDim
==
128
&&
K_LOOP_COUNT
==
2
);
if
constexpr
(
!
PrefetchK
)
{
fp8_kvcache_prefetch_k_gfx938
<
WARP_NUM
,
Element
>
(
k_addr
,
k_lds
,
warp_id
,
kcache_seqlen_stride
,
warp_seqkv_limit
);
}
vec4_Accum
<
ElementAccum
>
s_reg
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
];
fp8_kvcache_qk_gemm_gfx938
<
PrefetchVInQK
,
K_LOOP_COUNT
,
kHeadDim
,
kBlockK
,
WARP_M
,
WARP_N
,
WARP_NUM
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
k_addr
,
v_addr
,
k_lds
,
v_lds
,
q_reg
,
s_reg
,
warp_id
,
kcache_seqlen_stride
,
vcache_seqlen_stride
,
warp_seqkv_limit
);
if
constexpr
(
!
PrefetchVInQK
)
{
fp8_kvcache_prefetch_v_gfx938
<
K_LOOP_COUNT
,
kBlockK
,
WARP_NUM
,
Element
>
(
v_addr
,
v_lds
,
warp_id
,
vcache_seqlen_stride
,
warp_seqkv_limit
);
}
fp8_kvcache_apply_descale_gfx938
<
vec4_Accum
<
ElementAccum
>
,
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
>
(
s_reg
,
qk_descale
);
if
constexpr
(
Is_causal
)
{
if
constexpr
(
Is_Varlen
)
{
if
constexpr
(
Is_local
)
{
fp8_kvcache_apply_mask_local_causal_gfx938
<
vec4_Accum
<
ElementAccum
>
,
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
>
(
s_reg
,
warp_offset_in_seqkv
+
this_split_seqlen_start
,
original_actual_seqlen_k
,
m_block
*
kBlockM
,
actual_seqlen_q
,
ngroups
,
params
.
window_size_left
,
params
.
window_size_right
);
}
else
{
kvcache_apply_mask_causal_gfx938
<
vec4_Accum
<
ElementAccum
>
,
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
>
(
s_reg
,
warp_offset_in_seqkv
+
this_split_seqlen_start
,
original_actual_seqlen_k
,
m_block
*
kBlockM
,
actual_seqlen_q
,
ngroups
);
}
}
else
{
kvcache_apply_mask_causal_gfx938_mtp
<
vec4_Accum
<
ElementAccum
>
,
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
>
(
s_reg
,
warp_offset_in_seqkv
+
this_split_seqlen_start
,
original_actual_seqlen_k
,
m_block
*
kBlockM
,
actual_seqlen_q
,
params
.
mtp
,
params
.
layout
);
}
}
else
{
kvcache_apply_mask_gfx938
<
vec4_Accum
<
ElementAccum
>
,
M_WARP_COUNT
,
N_WARP_COUNT
,
M_MMAC_COUNT
>
(
s_reg
,
warp_seqkv_limit
,
warp_id
*
WARP_N
);
}
mla_softmax_rescale_o
<
Is_causal
||
Is_local
,
ElementAccum
,
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
N_WARP_COUNT
,
WARP_NUM
,
M_MMAC_COUNT
>
(
s_reg
,
scores_max
,
scores_sum
,
acc_o
,
max_lds
,
warp_id
,
params
.
scale_softmax_log2
);
union_vec32_fp8
p_reg
[
M_MMAC_COUNT
];
fp8_kvcache_cvt_f32_to_fp8_gfx938
<
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
p_reg
,
s_reg
);
const
int
block_table_idx_cur
=
n_block_loop
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_cur
=
n_block_loop
*
kBlockN
-
block_table_idx_cur
*
params
.
page_block_size
;
const
int
block_table_idx_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
/
params
.
page_block_size
;
const
int
block_table_offset_next
=
min
(
n_block_max
-
1
,
n_block_loop
+
1
)
*
kBlockN
-
block_table_idx_next
*
params
.
page_block_size
;
const
int
table_diff
=
block_table
[
block_table_idx_next
]
-
block_table
[
block_table_idx_cur
];
const
int
offset_diff
=
block_table_offset_next
-
block_table_offset_cur
;
const
int64_t
k_addr_offset
=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
k_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
k_row_stride
))
*
sizeof
(
Element
);
fp8_kvcache_pv_gemm_fp8_prefetch_k_gfx938
<
PrefetchK
,
K_LOOP_COUNT
,
kBlockK
,
kBlockN
,
M_WARP_COUNT
,
K_WARP_COUNT
,
WARP_NUM
,
M_MMAC_COUNT
,
Element
,
ElementAccum
>
(
v_addr
,
k_addr
,
v_lds
,
k_lds
,
p_reg
,
acc_o
,
warp_id
,
kcache_seqlen_stride
,
vcache_seqlen_stride
,
warp_seqkv_limit
,
k_addr_offset
);
*
(
int64_t
*
)
&
v_addr
+=
(
int64_t
(
table_diff
)
*
int64_t
(
params
.
v_batch_stride
)
+
offset_diff
*
int64_t
(
params
.
v_row_stride
))
*
sizeof
(
Element
);
}
if
constexpr
(
PrefetchK
)
{
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
}
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
const
int
thread_id
=
threadIdx
.
x
;
const
int
lane_id
=
thread_id
&
63
;
if
constexpr
(
WARP_NUM
>
1
)
{
fp8_kvcache_acco_reduce_compact_gfx938
<
REUSE_KV_TIMES
,
K_LOOP_COUNT
,
K_WARP_COUNT
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
WARP_NUM
,
ElementAccum
>
(
acc_o
,
acc_o_lds
,
params
.
seqlen_q
,
warp_id
,
lane_id
);
}
if
(
params
.
s_aux_ptr
!=
nullptr
&&
split_id
==
0
)
{
fp8_kvcache_apply_attention_sink_gfx938
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
acc_o
,
scores_max
,
scores_sum
,
params
.
s_aux_ptr
,
params
.
s_aux_type
,
bidh
,
params
.
h
,
ngroups
,
m_block
,
kBlockM
,
lane_id
,
params
.
scale_softmax
);
}
fp8_kvcache_epilogue_rescale_acco_gfx938
<
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
acc_o
,
scores_sum
,
v_descale
);
if
constexpr
(
Is_Varlen
)
{
kvcache_epilogue_store_softmax_lse
<
Is_Varlen
,
true
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
softmax_lse_ptr
,
params
.
scale_softmax
,
warp_id
,
thread_id
,
lane_id
,
headdim_split_id
,
actual_seqlen_q
-
m_block
*
kBlockM
,
params
.
total_q
,
params
.
ngroups
);
const
int64_t
row_offset_o
=
binfo
.
sum_s_q
*
ngroups
*
int64_t
(
params
.
o_row_stride
)
+
bidh
*
ngroups
*
params
.
o_head_stride
+
headdim_split_id
*
kHeadDimVSplit
+
m_block
*
kBlockM
*
int64_t
(
params
.
o_row_stride
);
kvcache_varlen_epilogue_store_output_gfx938
<
Params
,
kHeadDimV
,
kHeadDimVSplit
,
Split
,
SplitkvAccumType
,
ElementAccum
,
kBlockM
,
kBlockK
,
WARP_NUM
,
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
>
(
acc_o
,
params
,
row_offset_o
,
actual_seqlen_q
-
m_block
*
kBlockM
,
bidb
,
bidh
,
m_block
,
split_id
,
headdim_split_id
,
warp_id
,
lane_id
);
}
else
{
kvcache_epilogue_store_max_sum
<
Split
,
true
/*Is_16x32*/
,
M_WARP_COUNT
,
M_MMAC_COUNT
,
ElementAccum
>
(
scores_max
,
scores_sum
,
scores_max_ptr
,
scores_sum_ptr
,
params
.
scale_softmax
,
warp_id
,
thread_id
,
lane_id
,
headdim_split_id
,
actual_seqlen_q
-
m_block
*
kBlockM
);
kvcache_epilogue_store_output_gfx938
<
Params
,
kHeadDimV
,
kHeadDimVSplit
,
true
/*alt*/
,
Split
,
SplitkvAccumType
,
ElementAccum
,
kBlockM
,
kBlockK
,
WARP_NUM
,
K_LOOP_COUNT
,
M_WARP_COUNT
,
K_WARP_COUNT
,
M_MMAC_COUNT
>
(
acc_o
,
params
,
bidb
,
bidh
,
m_block
,
split_id
,
headdim_split_id
,
warp_id
,
lane_id
);
}
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_fp8_gfx938
(
const
Params
&
params
)
{
#if defined(__gfx938__)
// The block index for the head.
const
int
bidh
=
Split
?
blockIdx
.
z
%
params
.
h
:
blockIdx
.
y
;
// batch x num_head, num_head first
// The block index for the batch.
const
int
bidb
=
Split
?
blockIdx
.
z
/
params
.
h
:
blockIdx
.
z
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
// warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
flash
::
compute_attn_1rowblock_splitkv_fp8_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
*
128
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
#endif
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
inline
__device__
void
compute_attn_splitkv_fp8_gfx938_hdim192_v128
(
const
Params
&
params
)
{
#if defined(__gfx938__)
static_assert
(
Kernel_traits
::
kHeadDim
==
192
&&
Kernel_traits
::
kHeadDimV
==
128
);
static_assert
(
HEADDIM_V_SPLIT
==
1
);
const
int
bidh
=
Split
?
blockIdx
.
z
%
params
.
h
:
blockIdx
.
y
;
const
int
bidb
=
Split
?
blockIdx
.
z
/
params
.
h
:
blockIdx
.
z
;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
flash
::
compute_attn_1rowblock_splitkv_fp8_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
*
128
,
Params
>
(
params
,
bidb
,
bidh
,
warp_id
);
#endif
}
}
// namespace flash
csrc/flash_attn_hg/src/flash_fwd_launch_template.h
View file @
518a5f4d
...
...
@@ -32,6 +32,19 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_kernel_gfx938(Flash_fwd_para
flash
::
compute_attn_gfx938
<
Kernel_traits
,
true
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_Varlen
,
Return_softmax
,
Has_alibi
,
Layout
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
bool
Is_GQA
,
int
Layout
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fp8_fwd_kernel_gfx938
(
Flash_fwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
// If Is_local is true, Is_causal should be false
flash
::
compute_fp8_attn_gfx938
<
Kernel_traits
,
true
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_Varlen
,
Return_softmax
,
Has_alibi
,
Is_GQA
,
Layout
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_Varlen
,
bool
Return_softmax
,
bool
Has_alibi
,
int
Layout
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_kernel_gfx92a
(
Flash_fwd_params
params
)
{
static_assert
(
!
(
Is_causal
&&
Is_local
));
flash
::
compute_attn_gfx92a
<
Kernel_traits
,
true
,
Is_dropout
,
Is_causal
,
Is_local
,
Is_even_MN
,
Is_even_K
,
Is_Varlen
,
Return_softmax
,
Has_alibi
,
Layout
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
...
...
@@ -130,6 +143,101 @@ void run_flash_fwd_gfx938(Flash_fwd_params ¶ms, hipStream_t stream) {
});
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fp8_fwd_gfx938
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
auto
&
instance
=
DeviceProperties
<
Kernel_traits
,
FAFUNC
::
FORWARD
,
true
/*MLS_Enabled*/
>::
GetInstance
();
params
.
cu_count
=
instance
.
cu_count
;
const
bool
is_gqa
=
params
.
h
!=
params
.
h_k
;
const
bool
is_swa
=
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
);
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
h
*
Kernel_traits
::
SplitD
,
params
.
b
);
if
(
Is_causal
)
{
if
(
is_gqa
)
{
grid
=
dim3
(
params
.
h
*
Kernel_traits
::
SplitD
,
params
.
b
,
num_m_block
);
}
else
{
grid
.
x
=
(
params
.
seqlen_q
+
2
*
Kernel_traits
::
kBlockM
-
1
)
/
(
2
*
Kernel_traits
::
kBlockM
);
}
}
const
bool
is_varlen
=
params
.
cu_seqlens_q
!=
nullptr
&&
params
.
cu_seqlens_k
!=
nullptr
;
const
bool
is_even_MN
=
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
&&
(
!
is_varlen
||
params
.
b
==
1
);
const
bool
has_alibi
=
(
params
.
alibi_slopes_ptr
not_eq
nullptr
);
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
BOOL_SWITCH
(
is_varlen
,
Is_Varlen
,
[
&
]
{
BOOL_SWITCH
(
is_gqa
,
Is_GQA
,
[
&
]
{
constexpr
int
IsEvenKConst
=
true
;
BOOL_SWITCH
(
is_swa
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
has_alibi
,
Has_Alibi
,
[
&
]{
constexpr
bool
ReturnSoftmaxConst
=
false
;
LAYOUT_SWITCH
(
params
.
layout
,
[
&
]{
auto
kernel
=
&
flash_fp8_fwd_kernel_gfx938
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&!
Is_causal
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
256
,
true
/*Is_even_K*/
,
Is_Varlen
,
ReturnSoftmaxConst
&&
Is_dropout
,
Has_Alibi
,
Is_GQA
,
Layout
>
;
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
16
*
1024
,
stream
>>>
(
params
);
});
});
});
});
});
});
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
>
void
run_flash_fwd_gfx92a
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
auto
&
instance
=
DeviceProperties
<
Kernel_traits
,
FAFUNC
::
FORWARD
,
true
/*MLS_Enabled*/
>::
GetInstance
();
params
.
cu_count
=
instance
.
cu_count
;
size_t
smem_size
=
instance
.
lds_size
;
const
char
*
fa_debug
=
std
::
getenv
(
"FA_DEBUG"
);
const
bool
do_fa_debug
=
fa_debug
!=
nullptr
;
if
(
do_fa_debug
)
{
printf
(
"[gfx92a launch] gcn_arch=%d cu_count=%d smem_size=%zu q_smem=%zu k_smem=%zu v_smem=%zu seqlen_q=%d seqlen_k=%d h=%d b=%d causal=%d layout=%d
\n
"
,
instance
.
gcn_arch
,
params
.
cu_count
,
smem_size
,
Kernel_traits
::
q_smem_size
,
Kernel_traits
::
k_smem_size
,
Kernel_traits
::
v_smem_size
,
params
.
seqlen_q
,
params
.
seqlen_k
,
params
.
h
,
params
.
b
,
static_cast
<
int
>
(
Is_causal
),
params
.
layout
);
}
const
bool
is_swa
=
((
params
.
window_size_left
>=
0
||
params
.
window_size_right
>=
0
)
&&
!
Is_causal
);
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
num_m_block
,
params
.
h
*
Kernel_traits
::
SplitD
,
params
.
b
);
if
constexpr
(
Is_causal
)
{
grid
=
dim3
(
params
.
h
*
Kernel_traits
::
SplitD
,
params
.
b
,
num_m_block
);
}
const
bool
is_varlen
=
params
.
cu_seqlens_q
!=
nullptr
&&
params
.
cu_seqlens_k
!=
nullptr
;
const
bool
is_even_MN
=
!
is_varlen
&&
params
.
seqlen_k
%
Kernel_traits
::
kBlockN
==
0
&&
params
.
seqlen_q
%
Kernel_traits
::
kBlockM
==
0
;
const
bool
is_even_K
=
params
.
d
==
Kernel_traits
::
kHeadDim
;
const
bool
has_alibi
=
(
params
.
alibi_slopes_ptr
not_eq
nullptr
);
BOOL_SWITCH
(
is_even_MN
,
IsEvenMNConst
,
[
&
]
{
constexpr
int
IsEvenKConst
=
true
;
BOOL_SWITCH
(
is_varlen
,
Is_Varlen
,
[
&
]
{
BOOL_SWITCH
(
is_swa
,
Is_local
,
[
&
]
{
BOOL_SWITCH
(
has_alibi
,
Has_Alibi
,
[
&
]{
constexpr
bool
ReturnSoftmaxConst
=
false
;
LAYOUT_SWITCH
(
params
.
layout
,
[
&
]{
auto
kernel
=
&
flash_fwd_kernel_gfx92a
<
Kernel_traits
,
Is_dropout
,
Is_causal
,
Is_local
&&!
Is_causal
,
IsEvenMNConst
&&
IsEvenKConst
&&
!
Is_local
&&
!
ReturnSoftmaxConst
&&
Kernel_traits
::
kHeadDim
<=
256
,
true
/*Is_even_K*/
,
Is_Varlen
,
ReturnSoftmaxConst
&&
Is_dropout
,
Has_Alibi
,
Layout
>
;
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
});
});
});
});
});
}
template
<
typename
T
>
...
...
@@ -210,7 +318,9 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, hipStream_t stream) {
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
// if arch >= 938, new MLS is allowed
int
gcn_arch
=
getArch
();
if
(
gcn_arch
>=
938
and
std
::
getenv
(
"FA_FWD_NO_MLS"
)
==
nullptr
)
{
if
(
gcn_arch
==
930
and
std
::
getenv
(
"FA_FWD_NO_MLS"
)
==
nullptr
)
{
run_flash_fwd_gfx92a
<
Flash_fwd_kernel_traits
<
Headdim
,
Headdim
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
if
(
gcn_arch
>=
938
and
std
::
getenv
(
"FA_FWD_NO_MLS"
)
==
nullptr
)
{
if
(
params
.
qkvheaddim_compute
==
96
)
{
if
(
params
.
qkvheaddim_tail_tile16
==
1
)
run_flash_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
Headdim
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
,
T
,
T
,
64
,
96
,
96
,
1
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
...
...
@@ -234,6 +344,20 @@ void run_mha_fwd_hdim128(Flash_fwd_params ¶ms, hipStream_t stream) {
}
}
template
<
typename
T
>
void
run_fp8_mha_fwd_hdim128
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
constexpr
int
Headdim
=
128
;
constexpr
bool
Is_dropout
=
false
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
int
gcn_arch
=
getArch
();
if
(
gcn_arch
>=
938
)
{
run_flash_fp8_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
Headdim
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
,
Float16
,
fp8_e4m3
>
,
Is_dropout
,
Is_causal
>
(
params
,
stream
);
}
else
{
printf
(
"
\x1b
[31mfp8 is not supported in this arch!
\033
[0m
\n
"
);
}
});
}
template
<
typename
T
>
void
run_mha_fwd_hdim160
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
...
...
@@ -412,7 +536,7 @@ void run_flash_fwd_prefix_prefill_launcher(Flash_fwd_params ¶ms, hipStream_t
void
(
*
kernel
)(
Flash_fwd_params
);
constexpr
bool
IsEvenMNConst
=
false
;
BOOL_SWITCH
(
params
.
window_size_left
>
0
and
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]{
BOOL_SWITCH
(
!
Is_causal
&&
params
.
window_size_left
>
0
and
params
.
window_size_right
>=
0
,
Is_local
,
[
&
]{
kernel
=
&
flash_fwd_prefix_prefill_kernel
<
Kernel_traits
,
false
/*dropout*/
,
Is_causal
,
Is_local
,
IsEvenMNConst
,
false
/*return softmax*/
,
false
/*Has_Alibi*/
,
false
/*Is_GQA*/
,
1
/*layout*/
,
Flash_fwd_params
>
;
});
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
...
...
@@ -428,7 +552,24 @@ void run_flash_fwd_prefix_prefill_gfx938_launcher(Flash_fwd_params ¶ms, hipS
dim3
grid
(
params
.
h
,
params
.
b
,
num_m_block
);
constexpr
bool
IsEvenMNConst
=
false
;
auto
kernel
=
&
flash_fwd_prefix_prefill_gfx938_kernel
<
Kernel_traits
,
false
,
false
/*dropout*/
,
Is_causal
,
false
/*Is_local*/
,
IsEvenMNConst
,
true
,
false
/*return softmax*/
,
false
/*Has_Alibi*/
,
1
/*layout*/
,
Flash_fwd_params
>
;
const
bool
is_local
=
!
Is_causal
&&
params
.
window_size_left
>
0
&&
params
.
window_size_right
>=
0
;
BOOL_SWITCH
(
is_local
,
Is_local
,
[
&
]
{
auto
kernel
=
&
flash_fwd_prefix_prefill_gfx938_kernel
<
Kernel_traits
,
false
,
false
/*dropout*/
,
Is_causal
,
Is_local
,
IsEvenMNConst
,
true
,
false
/*return softmax*/
,
false
/*Has_Alibi*/
,
1
/*layout*/
,
Flash_fwd_params
>
;
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
});
}
template
<
typename
Kernel_traits
,
bool
Is_causal
>
void
run_flash_fwd_prefix_prefill_gfx92a_launcher
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
auto
&
instance
=
DeviceProperties
<
Kernel_traits
,
FAFUNC
::
FORWARD
,
true
/*MLS_enabled*/
>::
GetInstance
();
size_t
smem_size
=
instance
.
lds_size
;
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
(
params
.
h
,
params
.
b
,
num_m_block
);
constexpr
bool
IsEvenMNConst
=
false
;
auto
kernel
=
&
flash_fwd_prefix_prefill_gfx92a_kernel
<
Kernel_traits
,
false
,
false
/*dropout*/
,
Is_causal
,
false
/*Is_local*/
,
IsEvenMNConst
,
true
,
false
/*return softmax*/
,
false
/*Has_Alibi*/
,
1
/*layout*/
,
Flash_fwd_params
>
;
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
}
...
...
@@ -437,9 +578,18 @@ template<typename T, int Headdim, int HeaddimV>
void
run_flash_fwd_prefix_prefill
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
// is_causal = false, used in cascade attention
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
getArch
()
>=
938
and
std
::
getenv
(
"FA_FWD_NO_MLS"
)
==
nullptr
and
((
Headdim
==
128
and
HeaddimV
==
128
)
or
(
Headdim
==
192
and
HeaddimV
==
128
)))
{
const
bool
use_mls_prefix
=
std
::
getenv
(
"FA_FWD_NO_MLS"
)
==
nullptr
&&
((
Headdim
==
128
and
HeaddimV
==
128
)
or
(
Headdim
==
192
and
HeaddimV
==
128
));
const
int
gcn_arch
=
getArch
();
if
(
gcn_arch
==
930
and
use_mls_prefix
)
{
if
constexpr
(
Headdim
==
192
)
run_flash_fwd_prefix_prefill_gfx92a_launcher
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
64
,
32
,
32
,
32
,
2
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
else
run_flash_fwd_prefix_prefill_gfx92a_launcher
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
}
else
if
(
gcn_arch
>=
938
and
use_mls_prefix
)
{
if
constexpr
(
Headdim
==
192
)
run_flash_fwd_prefix_prefill_gfx938_launcher
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
64
,
32
,
32
,
32
,
2
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
else
if
(
params
.
page_block_size
==
64
)
run_flash_fwd_prefix_prefill_gfx938_launcher
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
64
,
32
,
32
,
32
,
2
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
else
run_flash_fwd_prefix_prefill_gfx938_launcher
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
>
,
Is_causal
>
(
params
,
stream
);
}
else
{
...
...
@@ -486,3 +636,49 @@ void run_int8_flash_fwd_prefix_prefill(Flash_fwd_params ¶ms, hipStream_t str
run_int8_flash_fwd_prefix_prefill_launcher
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
,
Float16
,
int8_t
>
,
Is_causal
>
(
params
,
stream
);
});
}
template
<
typename
Kernel_traits
,
bool
Is_causal
>
void
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
constexpr
bool
NeedsWideFp8MlsLds
=
Kernel_traits
::
kHeadDim
>
128
||
Kernel_traits
::
kHeadDimV
>
128
;
size_t
smem_size
=
NeedsWideFp8MlsLds
?
32
*
1024
:
16
*
1024
;
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
dim3
grid
=
Is_causal
?
dim3
(
params
.
h
,
params
.
b
,
num_m_block
)
:
dim3
(
num_m_block
,
params
.
h
,
params
.
b
);
constexpr
bool
Has_Alibi
=
false
;
const
bool
is_local
=
!
Is_causal
&&
params
.
window_size_left
>
0
&&
params
.
window_size_right
>=
0
;
BOOL_SWITCH
(
params
.
softmax_lse_ptr
!=
nullptr
,
ReturnSoftmaxConst
,
[
&
]
{
BOOL_SWITCH
(
is_local
,
IsLocalConst
,
[
&
]
{
LAYOUT_SWITCH
(
params
.
layout
,
[
&
]{
auto
kernel
=
&
flash_fp8_fwd_prefix_prefill_kernel_gfx938
<
Kernel_traits
,
true
/*Is_training*/
,
false
/*Is_dropout*/
,
Is_causal
,
IsLocalConst
,
true
/*Is_even_K*/
,
ReturnSoftmaxConst
,
Has_Alibi
,
Layout
>
;
kernel
<<<
grid
,
Kernel_traits
::
kNThreads
,
smem_size
,
stream
>>>
(
params
);
});
});
});
}
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_fp8_flash_fwd_prefix_prefill
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
int
gcn_arch
=
getArch
();
if
(
gcn_arch
>=
938
)
{
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
if
(
params
.
page_block_size
==
64
)
{
if
(
params
.
seqlen_q
<=
128
)
{
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
64
,
64
,
32
,
32
,
32
,
2
,
false
,
false
,
T
,
Float16
,
fp8_e4m3
>
,
Is_causal
>
(
params
,
stream
);
}
else
{
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
64
,
32
,
32
,
32
,
2
,
false
,
false
,
T
,
Float16
,
fp8_e4m3
>
,
Is_causal
>
(
params
,
stream
);
}
}
else
{
run_flash_fp8_fwd_prefix_prefill_launcher_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
128
,
128
,
32
,
32
,
32
,
2
,
false
,
false
,
T
,
Float16
,
fp8_e4m3
>
,
Is_causal
>
(
params
,
stream
);
}
});
}
else
{
printf
(
"
\x1b
[31mfp8 prefix_prefill is not supported in this arch!
\033
[0m
\n
"
);
}
}
csrc/flash_attn_hg/src/flash_fwd_launch_template_pa.h
View file @
518a5f4d
...
...
@@ -9,6 +9,15 @@
#include "flash_fwd_kernel.h"
#include "flash_singleton.h"
#include "assert.h"
#include <string>
static
inline
bool
hg_pa_is_gfx92a
(
const
std
::
string
&
gcn_arch_name
)
{
return
gcn_arch_name
.
rfind
(
"gfx92a"
,
0
)
==
0
;
}
static
inline
int
hg_pa_runtime_gfx_arch_id
(
const
std
::
string
&
gcn_arch_name
)
{
return
hg_pa_is_gfx92a
(
gcn_arch_name
)
?
930
:
std
::
stoi
(
gcn_arch_name
.
substr
(
3
,
3
));
}
template
<
typename
Kernel_traits
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Return_softmax
,
bool
Has_alibi
,
bool
Is_GQA
,
bool
Is_softcap
,
bool
Split
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
bool
Append_KV
>
...
...
@@ -23,18 +32,32 @@ __global__ void __launch_bounds__(256,1) flash_fwd_splitkv_int8_kernel(Flash_fwd
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_splitkv_tile16x32_kernel
(
Params
params
)
{
flash
::
compute_attn_splitkv_tile16x32
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
(
params
);
flash
::
compute_attn_splitkv_tile16x32
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_splitkv_gfx938_kernel
(
Params
params
)
{
flash
::
compute_attn_splitkv_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
(
params
);
flash
::
compute_attn_splitkv_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_splitkv_fp8_gfx938_kernel
(
Params
params
)
{
flash
::
compute_attn_splitkv_fp8_gfx938
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Split
,
bool
Is_local
,
int
M_MMAC_COUNT
,
int
REUSE_KV_TIMES
,
int
HEADDIM_V_SPLIT
,
int
Partition_Size
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel
(
Params
params
)
{
flash
::
compute_attn_splitkv_fp8_gfx938_hdim192_v128
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
(
params
);
}
template
<
typename
Kernel_traits
,
bool
Is_causal
,
bool
Is_Varlen
,
bool
Is_monopolize
,
bool
Split
,
int
M_MMAC_COUNT
,
int
HEADDIM_V_SPLIT
,
typename
Params
>
__global__
void
__launch_bounds__
(
256
,
1
)
flash_fwd_splitkv_gfx92a_kernel
(
Params
params
)
{
flash
::
compute_attn_splitkv_gfx92a
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Is_monopolize
,
Split
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
>
(
params
);
}
template
<
typename
Kernel_traits
,
const
bool
Tail
,
typename
Params
>
void
run_splitkv_reduce
(
Params
&
params
,
hipStream_t
stream
)
{
...
...
@@ -157,7 +180,7 @@ void run_flash_splitkv_fwd(Flash_fwd_params ¶ms, hipStream_t stream) {
int
gcn_arch
=
props
.
gcnArch
;
#else
std
::
string
gcn_arch_name
(
props
.
gcnArchName
);
int
gcn_arch
=
std
::
stoi
(
gcn_arch_name
.
substr
(
3
,
3
)
);
int
gcn_arch
=
hg_pa_runtime_gfx_arch_id
(
gcn_arch_name
);
#endif
const
size_t
smem_size
=
gcn_arch
>
928
?
required_smem_size
:
size_t
(
64
*
1024
);
if
(
std
::
getenv
(
"FA_DEBUG"
)
!=
nullptr
)
{
...
...
@@ -224,7 +247,7 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params ¶ms, hipStream_t strea
int
gcn_arch
=
props
.
gcnArch
;
#else
std
::
string
gcn_arch_name
(
props
.
gcnArchName
);
int
gcn_arch
=
std
::
stoi
(
gcn_arch_name
.
substr
(
3
,
3
)
);
int
gcn_arch
=
hg_pa_runtime_gfx_arch_id
(
gcn_arch_name
);
#endif
const
size_t
smem_size
=
gcn_arch
>
928
?
size_t
(
std
::
max
<
size_t
>
(
32
*
1024
,
required_smem_size
))
:
size_t
(
64
*
1024
);
if
(
std
::
getenv
(
"FA_DEBUG"
)
!=
nullptr
)
{
...
...
@@ -245,14 +268,16 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params ¶ms, hipStream_t strea
BOOL_SWITCH
(
params
.
q_batch_stride
==
0
,
Is_Varlen
,
[
&
]
{
if
(
params
.
window_size_left
>
0
and
params
.
window_size_right
>=
0
)
{
M_MMAC_COUNT_SWITCH
(
params
.
seqlen_q
>
16
,
M_MMAC_COUNT
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_tile16x32_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
false
,
true
/*Is_local*/
,
M_MMAC_COUNT
,
0
,
HEADDIM_V_SPLIT
,
0
>
;
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_tile16x32_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
true
/*Is_local*/
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
0
>
;
});
});
}
else
if
(
params
.
mtp
==
1
)
{
M_MMAC_COUNT_SWITCH
(
params
.
seqlen_q
>
16
,
M_MMAC_COUNT
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
constexpr
int
Partition_Size
=
0
;
// pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH
(
params
.
seqlen_q
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_tile16x32_kernel
<
Kernel_traits
,
false
/*Is_causal*/
,
Is_Varlen
,
Split
,
false
/*Is_local*/
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
kernel
=
&
flash_fwd_splitkv_tile16x32_kernel
<
Kernel_traits
,
false
/*Is_causal*/
,
Is_Varlen
,
Split
,
false
/*Is_local*/
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
});
});
});
...
...
@@ -261,7 +286,9 @@ void run_flash_splitkv_fwd_tile16x32(Flash_fwd_params ¶ms, hipStream_t strea
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
constexpr
int
Partition_Size
=
0
;
// pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH
(
params
.
seqlen_q
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_tile16x32_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
false
/*Is_local*/
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
BOOL_SWITCH
(
params
.
is_causal
,
Is_causal
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_tile16x32_kernel
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
false
/*Is_local*/
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
});
});
});
});
...
...
@@ -311,29 +338,157 @@ void run_flash_splitkv_fwd_gfx938(Flash_fwd_params ¶ms, hipStream_t stream)
const
size_t
required_smem_size
=
std
::
max
(
smem_for_acc
,
std
::
max
(
smem_for_gemm
,
smem_for_max
));
const
size_t
smem_size
=
size_t
(
std
::
max
<
size_t
>
(
32
*
1024
,
required_smem_size
));
// compute block partition along seqlen_q direction
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
// decide task dispatch logic
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
h
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
b
);
if
(
std
::
getenv
(
"FA_DEBUG"
)
!=
nullptr
)
{
printf
(
"smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld
\n
"
,
smem_for_max
,
smem_for_acc
,
q_smem_size
,
k_smem_size
,
v_smem_size
,
smem_for_gemm
,
required_smem_size
,
smem_size
);
printf
(
"grid: (%d, %d, %d)
\n
"
,
grid
.
x
,
grid
.
y
,
grid
.
z
);
}
// acquire kernel fuction
void
(
*
kernel
)(
Flash_fwd_params
);
constexpr
int
HEADDIM_V_SPLIT
=
Kernel_traits
::
kHeadDimV
==
256
?
2
:
1
;
grid
.
x
=
num_m_block
*
HEADDIM_V_SPLIT
;
BOOL_SWITCH
(
params
.
q_batch_stride
==
0
,
Is_Varlen
,
[
&
]
{
BOOL_SWITCH
(
params
.
mtp
!=
1
,
Is_causal
,
[
&
]{
M_MMAC_COUNT_SWITCH
(
params
.
seqlen_q
>
16
,
M_MMAC_COUNT
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
constexpr
int
Partition_Size
=
0
;
// pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
kernel
=
&
flash_fwd_splitkv_gfx938_kernel
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
});
});
});
});
// Kernel execution
const
int
nthread
=
Kernel_traits
::
kBlockN
/
Kernel_traits
::
kWaveN
*
64
;
kernel
<<<
grid
,
nthread
,
smem_size
,
stream
>>>
(
params
);
// reduce PA v2
if
(
params
.
q_batch_stride
==
0
)
{
run_splitkv_reduce_varlen
<
Kernel_traits
,
false
/*Tail*/
>
(
params
,
stream
);
}
else
{
run_splitkv_reduce
<
Kernel_traits
,
true
/*Tail*/
>
(
params
,
stream
);
}
}
template
<
typename
Kernel_traits
>
void
run_flash_splitkv_fwd_gfx92a
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
// compute block partition along seqlen_q direction
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
// decide task dispatch logic
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
h
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
b
);
// decide shared memory
size_t
smem_size
=
32768
;
if
(
grid
.
x
*
grid
.
y
*
grid
.
z
<=
params
.
cu_count
)
{
smem_size
=
65536
;
}
if
(
std
::
getenv
(
"PA_NO_ALL_LDS"
)
!=
nullptr
)
{
smem_size
=
32768
;
}
if
(
std
::
getenv
(
"PA_USE_ALL_LDS"
)
!=
nullptr
)
{
smem_size
=
65536
;
}
// output some details
if
(
std
::
getenv
(
"FA_DEBUG"
)
!=
nullptr
)
{
printf
(
"smem_size: %ld
\n
"
,
smem_size
);
printf
(
"grid: (%d, %d, %d)
\n
"
,
grid
.
x
,
grid
.
y
,
grid
.
z
);
}
// acquire kernel fuction
void
(
*
kernel
)(
Flash_fwd_params
);
constexpr
int
HEADDIM_V_SPLIT
=
1
;
// no need to split-D
constexpr
int
HEADDIM_V_SPLIT
=
Kernel_traits
::
kHeadDimV
==
256
?
2
:
1
;
grid
.
x
=
num_m_block
*
HEADDIM_V_SPLIT
;
BOOL_SWITCH
(
params
.
q_batch_stride
==
0
,
Is_Varlen
,
[
&
]
{
if
(
params
.
mtp
==
1
)
{
BOOL_SWITCH
(
params
.
mtp
!=
1
,
Is_causal
,
[
&
]
{
M_MMAC_COUNT_SWITCH
(
params
.
seqlen_q
>
16
,
M_MMAC_COUNT
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
BOOL_SWITCH
(
smem_size
==
65536
,
Is_monopolize
,
[
&
]{
kernel
=
&
flash_fwd_splitkv_gfx92a_kernel
<
Kernel_traits
,
Is_causal
,
Is_Varlen
,
Is_monopolize
,
Split
,
M_MMAC_COUNT
,
HEADDIM_V_SPLIT
>
;
});
});
});
});
});
// Kernel execution
const
int
nthread
=
Kernel_traits
::
kBlockN
/
Kernel_traits
::
kWaveN
*
64
;
kernel
<<<
grid
,
nthread
,
smem_size
,
stream
>>>
(
params
);
// reduce PA v2
if
(
params
.
q_batch_stride
==
0
)
{
run_splitkv_reduce_varlen
<
Kernel_traits
,
false
/*Tail*/
>
(
params
,
stream
);
}
else
{
run_splitkv_reduce
<
Kernel_traits
,
true
/*Tail*/
>
(
params
,
stream
);
}
}
template
<
typename
Kernel_traits
>
void
run_fp8_flash_splitkv_fwd_gfx938
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
constexpr
int
WARP_NUM
=
Kernel_traits
::
kBlockN
/
Kernel_traits
::
kWaveN
;
constexpr
int
kReduceBlockK
=
32
;
const
size_t
smem_for_max
=
std
::
max
(
WARP_NUM
*
Kernel_traits
::
kWaveM
*
sizeof
(
float
),
size_t
(
1024
));
const
size_t
smem_for_acc
=
Kernel_traits
::
kBlockM
*
WARP_NUM
*
kReduceBlockK
*
sizeof
(
float
);
const
size_t
q_smem_size
=
Kernel_traits
::
STAGES
*
Kernel_traits
::
kBlockM
*
Kernel_traits
::
kBlockK
*
sizeof
(
Float8_e4m3_t
);
const
size_t
k_smem_size
=
Kernel_traits
::
STAGES
*
Kernel_traits
::
kBlockK
*
Kernel_traits
::
kWaveN
*
sizeof
(
Float8_e4m3_t
)
*
WARP_NUM
;
const
size_t
v_smem_size
=
k_smem_size
;
const
size_t
smem_for_gemm
=
std
::
max
(
q_smem_size
,
std
::
max
(
k_smem_size
,
v_smem_size
));
constexpr
bool
IsFp8PA192x128
=
Kernel_traits
::
kHeadDim
==
192
&&
Kernel_traits
::
kHeadDimV
==
128
;
const
size_t
required_smem_size
=
IsFp8PA192x128
?
std
::
max
(
smem_for_acc
,
std
::
max
(
smem_for_gemm
,
smem_for_max
))
:
std
::
max
(
smem_for_acc
,
smem_for_gemm
+
smem_for_max
);
const
size_t
smem_size_floor
=
IsFp8PA192x128
?
size_t
(
32
*
1024
)
:
size_t
(
17
*
1024
);
const
size_t
smem_size
=
size_t
(
std
::
max
(
smem_size_floor
,
required_smem_size
));
if
(
std
::
getenv
(
"FA_DEBUG"
)
!=
nullptr
)
{
printf
(
"smem_for_max: %ld | smem_for_acc: %ld | q_smem: %ld k_smem: %ld v_smem: %ld | smem_for_gemm: %ld | needed required_smem_size: %ld | smem_size: %ld
\n
"
,
smem_for_max
,
smem_for_acc
,
q_smem_size
,
k_smem_size
,
v_smem_size
,
smem_for_gemm
,
required_smem_size
,
smem_size
);
}
// compute block partition along seqlen_q direction
const
int
num_m_block
=
(
params
.
seqlen_q
+
Kernel_traits
::
kBlockM
-
1
)
/
Kernel_traits
::
kBlockM
;
// decide task dispatch logic
dim3
grid
(
num_m_block
,
params
.
num_splits
>
1
?
params
.
num_splits
:
params
.
h
,
params
.
num_splits
>
1
?
params
.
b
*
params
.
h
:
params
.
b
);
// acquire kernel fuction
void
(
*
kernel
)(
Flash_fwd_params
);
constexpr
int
HEADDIM_V_SPLIT
=
Kernel_traits
::
kHeadDimV
==
256
?
2
:
1
;
grid
.
x
=
num_m_block
*
HEADDIM_V_SPLIT
;
BOOL_SWITCH
(
params
.
q_batch_stride
==
0
,
Is_Varlen
,
[
&
]
{
if
(
params
.
window_size_left
>
0
&&
params
.
window_size_right
>=
0
)
{
M_MMAC_COUNT_SWITCH
(
params
.
seqlen_q
>
16
,
M_MMAC_COUNT
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
constexpr
int
Partition_Size
=
0
;
// pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH
(
params
.
seqlen_q
,
[
&
]
{
if
constexpr
(
IsFp8PA192x128
)
{
kernel
=
&
flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
true
/*Is_local*/
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
}
else
{
kernel
=
&
flash_fwd_splitkv_fp8_gfx938_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
true
/*Is_local*/
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
}
});
});
});
}
else
if
(
params
.
mtp
==
1
)
{
M_MMAC_COUNT_SWITCH
(
params
.
seqlen_q
>
16
,
M_MMAC_COUNT
,
[
&
]
{
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
constexpr
int
Partition_Size
=
0
;
// pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
REUSEKV_SWITCH
(
params
.
seqlen_q
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_gfx938_kernel
<
Kernel_traits
,
false
/*Is_causal*/
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
constexpr
bool
Is_local
=
false
;
if
constexpr
(
IsFp8PA192x128
)
{
kernel
=
&
flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel
<
Kernel_traits
,
false
/*Is_causal*/
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
}
else
{
kernel
=
&
flash_fwd_splitkv_fp8_gfx938_kernel
<
Kernel_traits
,
false
/*Is_causal*/
,
Is_Varlen
,
Split
,
Is_local
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
}
});
});
});
...
...
@@ -342,7 +497,11 @@ void run_flash_splitkv_fwd_gfx938(Flash_fwd_params ¶ms, hipStream_t stream)
BOOL_SWITCH
(
params
.
num_splits
>
1
,
Split
,
[
&
]
{
constexpr
int
Partition_Size
=
0
;
// pa adopt floor in splitkv, need to process tails, and thus cannot use partition_size unroll
PA_MTP_REUSEKV_SWITCH
(
params
.
seqlen_q
,
[
&
]
{
kernel
=
&
flash_fwd_splitkv_gfx938_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
if
constexpr
(
IsFp8PA192x128
)
{
kernel
=
&
flash_fwd_splitkv_fp8_gfx938_hdim192_v128_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
false
/*Is_local*/
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
}
else
{
kernel
=
&
flash_fwd_splitkv_fp8_gfx938_kernel
<
Kernel_traits
,
true
/*Is_causal*/
,
Is_Varlen
,
Split
,
false
/*Is_local*/
,
M_MMAC_COUNT
,
REUSE_KV_TIMES
,
HEADDIM_V_SPLIT
,
Partition_Size
>
;
}
});
});
});
...
...
@@ -366,8 +525,34 @@ template<typename T, int Headdim, int HeaddimV>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
// decide whether commonly used headdims
const
bool
is_commonly_used
=
params
.
d
%
64
==
0
and
params
.
d_value
%
64
==
0
/*prefetch 2 32x32 blocks along headdim*/
;
// For latest archs, mls can be applied for headdim 128
if
((
getArch
()
>=
938
)
and
std
::
getenv
(
"PA_NO_MLS"
)
==
nullptr
and
is_commonly_used
)
{
// For latest archs, MLS can be applied for the common decode head dims.
int
arch_id
=
getArch
();
constexpr
bool
use_gfx938_mls
=
(
Headdim
==
128
and
HeaddimV
==
128
)
or
(
Headdim
==
192
and
HeaddimV
==
128
)
or
(
Headdim
==
256
and
HeaddimV
==
256
);
if
constexpr
(
use_gfx938_mls
)
{
const
bool
is_local
=
params
.
window_size_left
>
0
&&
params
.
window_size_right
>=
0
;
const
bool
use_mls_mask
=
params
.
is_e4m3
?
true
:
params
.
is_causal
;
if
((
arch_id
>=
938
)
and
std
::
getenv
(
"PA_NO_MLS"
)
==
nullptr
and
is_commonly_used
and
use_mls_mask
)
{
if
(
params
.
is_e4m3
)
{
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if
(
params
.
page_block_size
%
32
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
PA_PAGEBLOCKSIZE_SWITCH
(
params
.
page_block_size
,
[
&
]{
run_fp8_flash_splitkv_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
64
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
});
#else
if
(
params
.
page_block_size
==
64
)
{
constexpr
int
kBlockN
=
64
;
run_fp8_flash_splitkv_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
64
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
else
{
constexpr
int
kBlockN
=
128
;
if
(
params
.
page_block_size
%
kBlockN
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
run_fp8_flash_splitkv_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
64
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
#endif
}
else
{
// Decide whether compile all page block sizes
#ifdef PA_PAGE_BLOCK_SIZE
if
(
params
.
page_block_size
%
32
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
...
...
@@ -375,13 +560,26 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, hipStream_t stream)
run_flash_splitkv_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
});
#else
if
(
params
.
page_block_size
==
64
)
{
constexpr
int
kBlockN
=
64
;
run_flash_splitkv_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
else
{
constexpr
int
kBlockN
=
128
;
if
(
params
.
page_block_size
%
kBlockN
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
run_flash_splitkv_fwd_gfx938
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
#endif
}
return
;
}
else
if
(
arch_id
==
930
and
std
::
getenv
(
"PA_NO_MLS"
)
==
nullptr
and
is_commonly_used
)
{
constexpr
int
kBlockN
=
128
;
if
(
params
.
page_block_size
%
kBlockN
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
run_flash_splitkv_fwd_gfx92a
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
return
;
}
}
// For MHA-fma, headdim = 128
else
if
(
params
.
seqlen_q
==
1
and
!
params
.
seqlenq_ngroups_swapped
and
Headdim
==
128
and
HeaddimV
==
128
and
std
::
getenv
(
"PA_USE_FMA"
)
!=
nullptr
)
{
if
(
params
.
seqlen_q
==
1
and
!
params
.
seqlenq_ngroups_swapped
and
Headdim
==
128
and
HeaddimV
==
128
and
std
::
getenv
(
"PA_USE_FMA"
)
!=
nullptr
)
{
constexpr
int
kBlockN
=
128
;
run_flash_splitkv_fwd_mha
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
/*kBlockM*/
,
kBlockN
,
32
/*kBlockK*/
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
float
>
>
(
params
,
stream
);
}
...
...
@@ -393,9 +591,14 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, hipStream_t stream)
run_flash_splitkv_fwd_tile16x32
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
});
#else
if
(
params
.
page_block_size
==
64
)
{
constexpr
int
kBlockN
=
64
;
run_flash_splitkv_fwd_tile16x32
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
else
{
constexpr
int
kBlockN
=
128
;
if
(
params
.
page_block_size
%
kBlockN
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
run_flash_splitkv_fwd_tile16x32
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
2
/*STAGES*/
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
#endif
}
else
{
// Decide whether compile all page block sizes
...
...
@@ -407,6 +610,15 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, hipStream_t stream)
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
STAGES
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
});
#else
if
(
params
.
page_block_size
==
64
)
{
constexpr
int
kBlockN
=
64
;
constexpr
int
STAGES
=
(
Headdim
==
128
)
?
3
:
(
Headdim
==
32
?
1
:
2
);
if
(
params
.
splitkv_use_fp32_as_accum
)
{
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
STAGES
,
false
,
false
,
T
,
float
>
>
(
params
,
stream
);
}
else
{
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
STAGES
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
}
else
{
constexpr
int
kBlockN
=
128
;
if
(
params
.
page_block_size
%
kBlockN
!=
0
)
{
printf
(
"
\x1b
[31mPage block size %d is not supported yet!
\033
[0m
\n
"
,
params
.
page_block_size
);
return
;
}
constexpr
int
STAGES
=
(
Headdim
==
128
)
?
3
:
(
Headdim
==
32
?
1
:
2
);
...
...
@@ -415,6 +627,7 @@ void run_mha_fwd_splitkv_dispatch(Flash_fwd_params ¶ms, hipStream_t stream)
}
else
{
run_flash_splitkv_fwd
<
Flash_fwd_kernel_traits
<
Headdim
,
HeaddimV
,
32
,
kBlockN
,
32
,
32
,
32
,
STAGES
,
false
,
false
,
T
,
T
>
>
(
params
,
stream
);
}
}
#endif
}
}
...
...
@@ -446,7 +659,7 @@ void run_int8_flash_splitkv_fwd(Flash_fwd_params ¶ms, hipStream_t stream) {
int
gcn_arch
=
props
.
gcnArch
;
#else
std
::
string
gcn_arch_name
(
props
.
gcnArchName
);
int
gcn_arch
=
std
::
stoi
(
gcn_arch_name
.
substr
(
3
,
3
)
);
int
gcn_arch
=
hg_pa_runtime_gfx_arch_id
(
gcn_arch_name
);
#endif
const
size_t
smem_size
=
gcn_arch
>
928
?
required_smem_size
:
size_t
(
48
*
1024
);
// printf("smem_for_max: %ld | smem_for_acc: %ld | smem_for_gemm: %ld | needed smem_size: %ld | smem_size: %ld\n", smem_for_max, smem_for_acc, smem_for_gemm, required_smem_size, smem_size);
...
...
csrc/flash_attn_hg/src/flash_fwd_reduce.h
View file @
518a5f4d
...
...
@@ -102,7 +102,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel(
for
(
int
t
=
0
;
t
<
tx_float_count
;
t
+=
2
)
{
if
constexpr
(
kHeadDim
%
128
==
0
)
{
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
...
...
@@ -289,7 +289,7 @@ __global__ void flash_fwd_splitkv_reduce_kernel_split128(Params params) {
tx_accum
[
t
]
=
lds
[
tx
*
tx_float_count
+
t
];
tx_accum
[
t
+
1
]
=
lds
[
tx
*
tx_float_count
+
t
+
1
];
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
...
...
@@ -445,7 +445,7 @@ __global__ void __launch_bounds__(256, 1) flash_fwd_splitkv_reduce_varlen_kernel
for
(
int
t
=
0
;
t
<
tx_float_count
;
t
+=
2
)
{
if
constexpr
(
kHeadDim
%
128
==
0
)
{
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
...
...
@@ -588,7 +588,7 @@ __global__ void __launch_bounds__(256, 1) flash_mla_splitkv_reduce_kernel(
}
// cvt
vec2_Element
<
reduceType
>
accum_result
;
#if defined(__gfx938__) || defined(__gfx__)
#if defined(__gfx938__) || defined(__gfx
946__) || defined(__gfx92a
__)
accum_result
=
DownCastPairNoPack
<
float
,
reduceType
>
(
tx_accum
[
t
],
tx_accum
[
t
+
1
]);
#else
accum_result
[
0
]
=
DownCast
<
float
,
reduceType
,
true
>
(
tx_accum
[
t
]);
...
...
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_bf16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_
<
BFloat16
,
128
,
128
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_mha_fwd_hdim128
<
BFloat16
>
(
params
,
stream
);
#endif
}
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_fp16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2023, Tri Dao.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_
<
Float16
,
128
,
128
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_mha_fwd_hdim128
<
Float16
>
(
params
,
stream
);
#endif
}
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_prefix_prefill_bf16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_prefix_prefill_
<
BFloat16
,
128
,
128
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill
<
BFloat16
,
128
,
128
>
(
params
,
stream
);
#endif
}
\ No newline at end of file
csrc/flash_attn_hg/src/target/flash_fp8_fwd_hdim128_prefix_prefill_fp16.cpp
0 → 100644
View file @
518a5f4d
// Copyright (c) 2025, Wenjian Zhang.
// Splitting the different head dimensions to different files to speed up compilation.
// This file is auto-generated. See "generate_kernels.py"
#include "../flash_fwd_launch_template.h"
template
<
>
void
run_fp8_mha_fwd_prefix_prefill_
<
Float16
,
128
,
128
>
(
Flash_fwd_params
&
params
,
hipStream_t
stream
)
{
#ifdef BUILD_FA_FWD
run_fp8_flash_fwd_prefix_prefill
<
Float16
,
128
,
128
>
(
params
,
stream
);
#endif
}
\ No newline at end of file
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment