Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
518a5f4d
Commit
518a5f4d
authored
Jun 09, 2026
by
hly
Browse files
import aicc-master-dev
parent
c2a1b310
Changes
131
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
6299 additions
and
4928 deletions
+6299
-4928
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu
...sh_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu
+0
-1
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu
+0
-1
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu
...sh_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu
+0
-1
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu
+0
-1
csrc/flash_attn/src/paged_attention.cu
csrc/flash_attn/src/paged_attention.cu
+32
-36
csrc/flash_attn/src/paged_attention_938.cu
csrc/flash_attn/src/paged_attention_938.cu
+30
-27
csrc/flash_attn/src/softmax.h
csrc/flash_attn/src/softmax.h
+58
-0
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+219
-26
csrc/flash_attn_hg/flash_api.cpp
csrc/flash_attn_hg/flash_api.cpp
+4458
-4764
csrc/flash_attn_hg/flash_c_api.h
csrc/flash_attn_hg/flash_c_api.h
+32
-11
csrc/flash_attn_hg/include/bwd/dot_do_o_gfx946.h
csrc/flash_attn_hg/include/bwd/dot_do_o_gfx946.h
+116
-0
csrc/flash_attn_hg/include/bwd/flash_attention_bwd.h
csrc/flash_attn_hg/include/bwd/flash_attention_bwd.h
+3
-0
csrc/flash_attn_hg/include/bwd/flash_attention_dq_bwd_gfx946.h
...flash_attn_hg/include/bwd/flash_attention_dq_bwd_gfx946.h
+336
-0
csrc/flash_attn_hg/include/bwd/flash_attention_dv_dk_bwd_gfx938.h
...sh_attn_hg/include/bwd/flash_attention_dv_dk_bwd_gfx938.h
+7
-4
csrc/flash_attn_hg/include/bwd/flash_attention_dv_dk_bwd_gfx946.h
...sh_attn_hg/include/bwd/flash_attention_dv_dk_bwd_gfx946.h
+781
-0
csrc/flash_attn_hg/include/bwd/gpu_gemm_nn.h
csrc/flash_attn_hg/include/bwd/gpu_gemm_nn.h
+176
-6
csrc/flash_attn_hg/include/bwd/gpu_gemm_tt.h
csrc/flash_attn_hg/include/bwd/gpu_gemm_tt.h
+4
-6
csrc/flash_attn_hg/include/bwd/prefetch.h
csrc/flash_attn_hg/include/bwd/prefetch.h
+16
-40
csrc/flash_attn_hg/include/bwd/softmax_tiling.h
csrc/flash_attn_hg/include/bwd/softmax_tiling.h
+21
-3
csrc/flash_attn_hg/include/flash.h
csrc/flash_attn_hg/include/flash.h
+10
-1
No files found.
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_causal_sm80.cu
View file @
518a5f4d
...
...
@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
bfloat16_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_unified_hdim128_bf16_sm80.cu
View file @
518a5f4d
...
...
@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
bfloat16_t
,
128
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_causal_sm80.cu
View file @
518a5f4d
...
...
@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
half_t
,
128
,
true
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/flash_fwd_unified_hdim128_fp16_sm80.cu
View file @
518a5f4d
...
...
@@ -5,4 +5,3 @@
#include "flash_fwd_launch_template.h"
template
void
run_mha_fwd_unified_dispatch
<
cutlass
::
half_t
,
128
,
false
>(
Flash_fwd_params
&
params
,
cudaStream_t
stream
);
csrc/flash_attn/src/paged_attention.cu
View file @
518a5f4d
...
...
@@ -48,7 +48,7 @@ static __device__ inline float to_float(scalar_t in){
inline
__device__
float
uint82float
(
const
uint8_t
&
input
)
{
#if (defined(__gfx938__) )
#if (defined(__gfx938__)
||defined(__gfx92a__)
)
return
__builtin_hcu_cvt_f32_fp8
(
input
,
false
,
0
,
0
);
#else
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
...
...
@@ -137,11 +137,11 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==4
8
){ \
constexpr static int REUSE_KV_TIMES = 4
8
; \
if (reusekv==
6
4){ \
constexpr static int REUSE_KV_TIMES =
6
4; \
return __VA_ARGS__(); \
}else if (reusekv==
36
){ \
constexpr static int REUSE_KV_TIMES =
36
; \
}else if (reusekv==
48
){ \
constexpr static int REUSE_KV_TIMES =
48
; \
return __VA_ARGS__(); \
}else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \
...
...
@@ -257,18 +257,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template
<
bool
is_half
>
inline
__device__
void
builtin_amdgcn_mmac
(
const
half4_t
&
reg_a
,
const
half4_t
&
reg_b
,
float4_t
&
reg_c
)
{
#if (defined(__gfx938__) )
if
constexpr
(
is_half
){
reg_c
=
__builtin_hcu_mmac_f32_16x16x16_f16_lit_lts
(
reg_a
,
reg_b
,
reg_c
,
false
,
false
);}
else
{
reg_c
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
*
(
v4bh
*
)
&
reg_a
,
*
(
v4bh
*
)
&
reg_b
,
reg_c
,
false
,
false
);
}
#else
if
constexpr
(
is_half
){
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16f16
(
reg_a
,
reg_b
,
reg_c
);}
else
{
reg_c
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
*
(
v4bh
*
)
&
reg_a
,
*
(
v4bh
*
)
&
reg_b
,
reg_c
);
if
constexpr
(
is_half
){
reg_c
=
__builtin_hcu_mmac_f32_16x16x16_f16
(
reg_a
,
reg_b
,
reg_c
);
}
else
{
reg_c
=
__builtin_hcu_mmac_f32_16x16x16_bf16
(
*
(
v4bh
*
)
&
reg_a
,
*
(
v4bh
*
)
&
reg_b
,
reg_c
);
}
#endif
}
...
...
@@ -390,23 +383,28 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
qk_vec
[
m
]
=
{
0
,
0
,
0
,
0
};
}
half4x2
k_vec
[
HEAD_SIZE
/
32
];
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
HEAD_SIZE
/
32
;
i
++
){
half4x2
k_vec
;
if
constexpr
(
is_fp8
){
uint8x4x2
k_vec_u8
=*
reinterpret_cast
<
const
uint8x4x2
*>
(
k_ptr
+
i
*
32
+
rowid
*
HEAD_SIZE
+
rows
*
8
);
scalar_t
*
p1
=
(
scalar_t
*
)
&
k_vec
;
scalar_t
*
p1
=
(
scalar_t
*
)
(
k_vec
+
i
)
;
uint8_t
*
p2
=
(
uint8_t
*
)
&
k_vec_u8
;
for
(
int
ii
=
0
;
ii
<
8
;
ii
++
){
p1
[
ii
]
=
uint82half
<
scalar_t
,
is_e4m3
>
(
p2
[
ii
]);
}
}
else
{
k_vec
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
i
*
32
+
rowid
*
HEAD_SIZE
+
rows
*
8
);
k_vec
[
i
]
=*
reinterpret_cast
<
const
half4x2
*>
(
k_ptr
+
i
*
32
+
rowid
*
HEAD_SIZE
+
rows
*
8
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
HEAD_SIZE
/
32
;
i
++
){
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
.
data
[
0
],
q_vec
[
m
][
i
].
data
[
0
],
qk_vec
[
m
]);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
.
data
[
1
],
q_vec
[
m
][
i
].
data
[
1
],
qk_vec
[
m
]);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
]
.
data
[
0
],
q_vec
[
m
][
i
].
data
[
0
],
qk_vec
[
m
]);
builtin_amdgcn_mmac
<
is_half
>
(
k_vec
[
i
]
.
data
[
1
],
q_vec
[
m
][
i
].
data
[
1
],
qk_vec
[
m
]);
}
}
#pragma unroll
...
...
@@ -597,7 +595,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
const
cache_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
if
(
partition_idx
<
num_partitions
-
1
){
if
(
partition_idx
<
num_partitions
-
1
||
block_idx
<
num_seq_blocks
-
1
){
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
int
offset
=
i
*
BLOCK_SIZE
*
HEAD_SIZE
/
NUM_ROWS_PER_THREAD
+
warp_idx
*
BLOCK_SIZE
*
HEAD_SIZE
/
NUM_ROWS_PER_THREAD
/
NUM_WARPS
+
rows
*
vecsize
*
4
+
rowid
*
BLOCK_SIZE
;
...
...
@@ -635,13 +633,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
v_vec
=*
reinterpret_cast
<
const
half4_vec
*>
(
v_ptr
+
offset
);
}
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if
(
block_idx
==
num_seq_blocks
-
1
)
{
scalar_t
*
v_vec_ptr
=
reinterpret_cast
<
scalar_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
4
*
vecsize
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
0
;
}
}
for
(
int
ii
=
0
;
ii
<
vecsize
;
ii
++
){
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
builtin_amdgcn_mmac
<
is_half
>
(
v_vec
.
data
[
ii
],
logits_vec
[
m
].
data
[
ii
],
accs
[
m
][
i
]);
...
...
@@ -756,8 +752,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
}
static
int
get_reusekv
(
int
qhead
,
int
kv_head
){
if
(
qhead
>
kv_head
*
36
)
return
4
8
;
if
(
qhead
>
kv_head
*
32
)
return
36
;
//glm4.7 mtp 3
if
(
qhead
>
kv_head
*
48
)
return
6
4
;
if
(
qhead
>
kv_head
*
32
)
return
48
;
if
(
qhead
>
kv_head
*
24
)
return
32
;
if
(
qhead
>
kv_head
*
16
)
return
24
;
if
(
qhead
>
kv_head
*
8
)
return
16
;
...
...
@@ -831,7 +827,7 @@ void paged_attention(
hipMalloc
(
&
tmp_out_ptr
,
temp_out_size
);
// 100m
hipMemset
(
tmp_out_ptr
,
0
,
temp_out_size
);
}
if
(
device_name
==
"gfx938"
&&
(
key_cache
.
dtype
()
==
torch
::
kFloat8_e5m2
||
key_cache
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)){
if
(
(
device_name
==
"gfx938"
||
device_name
==
"gfx92a"
)
&&
(
key_cache
.
dtype
()
==
torch
::
kFloat8_e5m2
||
key_cache
.
dtype
()
==
torch
::
kFloat8_e4m3fn
)){
paged_attention_938
(
out
,
query
,
key_cache
,
value_cache
,
block_tables
,
seq_lens
,
alibi_slopes
,
q_scale
,
k_scale
,
v_scale
,
max_seq_len
,
s_aux_
,
tmp_out_ptr
,
PARTITION_SIZE
);
return
;
}
...
...
@@ -876,16 +872,16 @@ void paged_attention(
grid
.
x
=
num_kv_heads
;
grid
.
y
=
num_seqs
;
AT_ASSERTM
(
headsize
%
64
==
0
&&
headsize
<=
256
,
"Page Attention head size must be 64, 128, 192 or 256"
);
AT_ASSERTM
(
num_heads
<=
num_kv_heads
*
4
8
,
"Page Attention qheads*mtp/kvheads must be smaller than 48"
);
AT_ASSERTM
(
num_heads
<=
num_kv_heads
*
6
4
,
"Page Attention qheads*mtp/kvheads must be smaller than 48"
);
HEADSIZE_SWITCH
(
headsize
,[
&
]{
Input_Type_SWITCH
(
query
.
dtype
(),[
&
]{
Cache_Type_SWITCH
(
scalar_t
,
key_cache
.
dtype
(),[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
BOOL_SWITCH
(
block_size
==
64
,
is_block64
,[
&
]{
constexpr
int
BLOCK_SIZE
=
(
is_block64
?
64
:
128
);
//
constexpr int BLOCK_SIZE=
128
;
//
constexpr int BLOCK_SIZE = (is_block64?64:128);
constexpr
int
BLOCK_SIZE
=
64
;
// constexpr int HEAD_SIZE=128;
// using scalar_t=
_Floa
t16;
// using scalar_t=
uin
t16
_t
;
// using cache_t = scalar_t;
constexpr
bool
is_e4m3
=
false
;
// constexpr static int REUSE_KV_TIMES = 4;
...
...
csrc/flash_attn/src/paged_attention_938.cu
View file @
518a5f4d
...
...
@@ -58,7 +58,7 @@ static __device__ inline float to_float(scalar_t in){
inline
__device__
float
uint82float
(
const
uint8_t
&
input
)
{
#if (defined(__gfx938__) )
#if (defined(__gfx938__)
||defined(__gfx92a__)
)
return
__builtin_hcu_cvt_f32_fp8
(
input
,
false
,
0
,
0
);
#else
const
uint32_t
w
=
(
uint32_t
)
input
<<
24
;
...
...
@@ -106,7 +106,7 @@ __forceinline__ __device__ scalar_t uint82half(const uint8_t& input) {
template
<
bool
is_e4m3
>
static
__device__
int
to_f8_from_f32
(
float
v1
,
float
v2
,
float
v3
,
float
v4
)
{
int
val
=
0
;
#if (defined(__gfx938__) )
#if (defined(__gfx938__)
|| defined(__gfx92a__)
)
if
constexpr
(
is_e4m3
){
val
=
__builtin_hcu_cvt_pk_fp8_f32
(
v1
,
v2
,
val
,
false
);
val
=
__builtin_hcu_cvt_pk_fp8_f32
(
v3
,
v4
,
val
,
true
);
...
...
@@ -122,7 +122,7 @@ static __device__ int to_f8_from_f32(float v1,float v2,float v3,float v4) {
template
<
bool
is_e4m3
>
static
__device__
float4_t
to_fp32_from_fp8
(
int
val
)
{
float4_t
ret
;
#if (defined(__gfx938__) )
#if (defined(__gfx938__)
|| defined(__gfx92a__)
)
if
constexpr
(
is_e4m3
){
ret
[
0
]
=
__builtin_hcu_cvt_f32_fp8
(
val
,
false
,
0
,
0
);
ret
[
1
]
=
__builtin_hcu_cvt_f32_fp8
(
val
,
false
,
0
,
1
);
...
...
@@ -184,11 +184,11 @@ static __device__ float4_t to_fp32_from_fp8(int val) {
#define REUSEKV_SWITCH(reusekv,...) \
[&] { \
if (reusekv==4
8
){ \
constexpr static int REUSE_KV_TIMES = 4
8
; \
if (reusekv==
6
4){ \
constexpr static int REUSE_KV_TIMES =
6
4; \
return __VA_ARGS__(); \
}else if (reusekv==
36
){ \
constexpr static int REUSE_KV_TIMES =
36
; \
}else if (reusekv==
48
){ \
constexpr static int REUSE_KV_TIMES =
48
; \
return __VA_ARGS__(); \
}else if (reusekv==32){ \
constexpr static int REUSE_KV_TIMES = 32; \
...
...
@@ -303,13 +303,11 @@ inline __device__ float block_sum(float* red_smem, float sum) {
template
<
bool
is_e4m3
>
inline
__device__
void
builtin_amdgcn_mmac
(
const
intx2
&
reg_a
,
const
intx2
&
reg_b
,
float4_t
&
reg_c
)
{
#if (defined(__gfx938__) )
if
constexpr
(
is_e4m3
){
reg_c
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8
_lit_lts
(
reg_a
,
reg_b
,
reg_c
,
false
,
false
);
reg_c
=
__builtin_hcu_mmac_f32_16x16x32_fp8_fp8
(
reg_a
,
reg_b
,
reg_c
);
}
else
{
reg_c
=
__builtin_hcu_mmac_f32_16x16x32_bf8_bf8
_lit_lts
(
reg_a
,
reg_b
,
reg_c
,
false
,
false
);
reg_c
=
__builtin_hcu_mmac_f32_16x16x32_bf8_bf8
(
reg_a
,
reg_b
,
reg_c
);
}
#endif
}
template
<
typename
scalar_t
,
typename
q_type
,
bool
is_e4m3
,
int
HEAD_SIZE
,
int
BLOCK_SIZE
,
...
...
@@ -332,7 +330,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
const
float
*
q_scale_ptr
,
const
float
*
k_scale_ptr
,
const
float
*
v_scale_ptr
,
int
max_num_partitions
,
int
PARTITION_SIZE
,
const
scalar_t
*
__restrict__
s_aux_ptr
,
int
mtp
,
bool
has_abili
)
{
// ★ Attention Sinks: [num_heads] scalar_t ★
#if (defined(__gfx938__) )
#if (defined(__gfx938__)
||defined(__gfx92a__)
)
const
int
seq_idx
=
blockIdx
.
y
;
const
int
partition_idx
=
blockIdx
.
z
;
constexpr
int
kv_head_stride
=
BLOCK_SIZE
*
HEAD_SIZE
;
...
...
@@ -453,10 +451,16 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
qk_vec
[
m
]
=
{
0
,
0
,
0
,
0
};
}
intx4
k_vec
[
HEAD_SIZE
/
64
];
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
HEAD_SIZE
/
64
;
i
++
){
intx4
k_vec
=*
reinterpret_cast
<
const
intx4
*>
(
k_ptr
+
i
*
64
+
rowid
*
HEAD_SIZE
+
rows
*
16
);
intx2
*
k_vec_2
=
(
intx2
*
)
&
k_vec
;
k_vec
[
i
]
=*
reinterpret_cast
<
const
intx4
*>
(
k_ptr
+
i
*
64
+
rowid
*
HEAD_SIZE
+
rows
*
16
);
}
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
i
=
0
;
i
<
HEAD_SIZE
/
64
;
i
++
){
intx2
*
k_vec_2
=
(
intx2
*
)(
k_vec
+
i
);
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
intx2
*
q_vec_2
=
(
intx2
*
)(
&
q_vec
[
m
][
i
]);
builtin_amdgcn_mmac
<
is_e4m3
>
(
k_vec_2
[
0
],
q_vec_2
[
0
],
qk_vec
[
m
]);
...
...
@@ -655,7 +659,7 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
}
const
uint8_t
*
v_ptr
=
v_cache
+
physical_block_number
*
kv_block_stride
+
kv_head_idx
*
kv_head_stride
;
if
(
partition_idx
<
num_partitions
-
1
){
if
(
partition_idx
<
num_partitions
-
1
||
block_idx
<
num_seq_blocks
-
1
){
#pragma unroll
for
(
int
i
=
0
;
i
<
NUM_ROWS_PER_THREAD
;
i
++
)
{
int
offset
=
i
*
BLOCK_SIZE
*
HEAD_SIZE
/
NUM_ROWS_PER_THREAD
+
warp_idx
*
BLOCK_SIZE
*
HEAD_SIZE
/
NUM_ROWS_PER_THREAD
/
NUM_WARPS
+
rows
*
16
+
rowid
*
BLOCK_SIZE
;
...
...
@@ -673,13 +677,11 @@ __launch_bounds__(NUM_THREADS) __global__ void paged_attention_kernel(
int
offset
=
i
*
BLOCK_SIZE
*
HEAD_SIZE
/
NUM_ROWS_PER_THREAD
+
warp_idx
*
BLOCK_SIZE
*
HEAD_SIZE
/
NUM_ROWS_PER_THREAD
/
NUM_WARPS
+
rows
*
16
+
rowid
*
BLOCK_SIZE
;
int_vec
v_vec
=
*
reinterpret_cast
<
const
int_vec
*>
(
v_ptr
+
offset
);
//这里的if判断会影响一定的性能,因此只有最后一个patition才判断
if
(
block_idx
==
num_seq_blocks
-
1
)
{
uint8_t
*
v_vec_ptr
=
reinterpret_cast
<
uint8_t
*>
(
&
v_vec
);
#pragma unroll
for
(
int
j
=
0
;
j
<
16
;
j
++
)
{
v_vec_ptr
[
j
]
=
token_idx
+
j
<
seq_len
?
v_vec_ptr
[
j
]
:
0
;
}
}
for
(
int
ii
=
0
;
ii
<
vecsize
;
ii
++
){
for
(
int
m
=
0
;
m
<
Mloop
;
m
++
){
builtin_amdgcn_mmac
<
is_e4m3
>
(
v_vec
.
data
[
ii
],
logits_vec
[
m
][
ii
],
accs
[
m
][
i
]);
...
...
@@ -795,8 +797,8 @@ __global__ __launch_bounds__(NUM_THREADS, 1) void paged_attention_combine(
}
static
int
get_reusekv
(
int
qhead
,
int
kv_head
){
if
(
qhead
>
kv_head
*
36
)
return
4
8
;
if
(
qhead
>
kv_head
*
32
)
return
36
;
//glm4.7 mtp 3
if
(
qhead
>
kv_head
*
48
)
return
6
4
;
if
(
qhead
>
kv_head
*
32
)
return
48
;
if
(
qhead
>
kv_head
*
24
)
return
32
;
if
(
qhead
>
kv_head
*
16
)
return
24
;
if
(
qhead
>
kv_head
*
8
)
return
16
;
...
...
@@ -870,17 +872,18 @@ void paged_attention_938(
int
reusekv
=
get_reusekv
(
num_heads
,
num_kv_heads
);
int
headsize
=
query
.
size
(
3
);
AT_ASSERTM
(
headsize
%
64
==
0
&&
headsize
<=
256
,
"Page Attention head size must be 64, 128, 192 or 256"
);
AT_ASSERTM
(
num_heads
<=
num_kv_heads
*
4
8
,
"Page Attention qheads*mtp/kvheads must be smaller than 48"
);
AT_ASSERTM
(
num_heads
<=
num_kv_heads
*
6
4
,
"Page Attention qheads*mtp/kvheads must be smaller than 48"
);
HEADSIZE_SWITCH
(
headsize
,[
&
]{
Output_Type_SWITCH
(
out
.
dtype
(),[
&
]{
Input_Type_SWITCH
(
scalar_t
,
query
.
dtype
(),
key_cache
.
dtype
(),[
&
]
{
REUSEKV_SWITCH
(
reusekv
,[
&
]
{
BOOL_SWITCH
(
block_size
==
64
,
is_block64
,[
&
]{
constexpr
int
BLOCK_SIZE
=
(
is_block64
?
64
:
128
);
// constexpr int HEAD_SIZE=128;
// constexpr int BLOCK_SIZE = (is_block64?64:128);
constexpr
int
BLOCK_SIZE
=
64
;
// constexpr int HEAD_SIZE=256;
// using scalar_t=uint16_t;
// constexpr bool is_e4m3=true;
// constexpr static int REUSE_KV_TIMES = 4;
// constexpr static int REUSE_KV_TIMES =
6
4;
// constexpr bool has_abili=false;
// constexpr bool use_mtp=false;
constexpr
static
int
NUM_THREADS
=
256
;
...
...
csrc/flash_attn/src/softmax.h
View file @
518a5f4d
...
...
@@ -26,6 +26,7 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
CUTE_STATIC_ASSERT_V
(
size
<
0
>
(
summary
)
==
size
<
0
>
(
tensor
));
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
<
0
>
(
tensor
);
mi
++
)
{
#if defined(__gfx928__)
summary
(
mi
)
=
zero_init
?
tensor
(
mi
,
0
)
:
op
(
summary
(
mi
),
tensor
(
mi
,
0
));
#pragma unroll
for
(
int
ni
=
1
;
ni
<
size
<
1
>
(
tensor
);
ni
++
)
{
...
...
@@ -36,6 +37,29 @@ __device__ __forceinline__ void thread_reduce_(Tensor<Engine0, Layout0> const &t
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
#else
if
constexpr
(
std
::
is_same_v
<
Operator
,
SumOp
<
float
>>
)
{
using
__float2
=
__attribute__
((
ext_vector_type
(
2
)))
float
;
__float2
sum_v
=
{
zero_init
?
0.0
f
:
summary
(
mi
),
0.0
f
};
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
tensor
);
ni
+=
2
)
{
__float2
vx2
=
{
tensor
(
mi
,
ni
),
tensor
(
mi
,
ni
+
1
)};
sum_v
=
__builtin_hcu_pk_add_f32
(
sum_v
,
vx2
);
}
summary
(
mi
)
=
sum_v
.
x
+
sum_v
.
y
;
}
else
{
summary
(
mi
)
=
zero_init
?
tensor
(
mi
,
0
)
:
op
(
summary
(
mi
),
tensor
(
mi
,
0
));
#pragma unroll
for
(
int
ni
=
1
;
ni
<
size
<
1
>
(
tensor
);
ni
++
)
{
// float ori = summary(mi);
summary
(
mi
)
=
op
(
summary
(
mi
),
tensor
(
mi
,
ni
));
// wangaq debug
// if (thread0()) {
// printf("thread_reduce_ mi:%d ni:%d %7.4f %7.4f %7.4f\n", mi, ni, ori, tensor(mi, ni), summary(mi));
// }
}
}
#endif
}
}
...
...
@@ -131,6 +155,7 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// We don't want (-inf - (-inf)) since that would give NaN.
// If we don't have float around M_LOG2E the multiplication is done in fp64.
const
float
max_scaled
=
max
(
mi
)
==
-
INFINITY
?
0.
f
:
max
(
mi
)
*
(
Scale_max
?
scale
:
float
(
M_LOG2E
));
#if defined(__gfx928__)
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
tensor
);
++
ni
)
{
// Instead of computing exp(x - max), we compute exp2(x * log_2(e) -
...
...
@@ -141,6 +166,17 @@ __forceinline__ __device__ void scale_apply_exp2(Tensor<Engine0, Layout0> &tenso
// This macro is set in PyTorch and not FlashAttention
tensor
(
mi
,
ni
)
=
custom_exp2f
(
tensor
(
mi
,
ni
)
*
scale
-
max_scaled
);
}
#else
using
__float2
=
__attribute__
((
ext_vector_type
(
2
)))
float
;
__float2
scalex2
=
{
scale
,
scale
};
__float2
max_scaledx2
=
{
-
max_scaled
,
-
max_scaled
};
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
tensor
);
ni
+=
2
)
{
__float2
vx2
=
{
tensor
(
mi
,
ni
),
tensor
(
mi
,
ni
+
1
)};
__float2
res
=
__builtin_hcu_pk_fma_f32
(
vx2
,
scalex2
,
max_scaledx2
);
tensor
(
mi
,
ni
)
=
custom_exp2f
(
res
.
x
);
tensor
(
mi
,
ni
+
1
)
=
custom_exp2f
(
res
.
y
);
}
#endif
}
}
...
...
@@ -229,8 +265,19 @@ struct Softmax {
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
float
scores_scale
=
custom_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
row_sum
(
mi
)
*=
scores_scale
;
#if defined(__gfx928__)
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scores_scale
;
}
#else
using
__float2
=
__attribute__
((
ext_vector_type
(
2
)))
float
;
__float2
scores_scalex2
=
{
scores_scale
,
scores_scale
};
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
ni
+=
2
)
{
__float2
vx2
=
{
acc_o_rowcol
(
mi
,
ni
),
acc_o_rowcol
(
mi
,
ni
+
1
)};
__float2
res
=
__builtin_hcu_pk_mul_f32
(
vx2
,
scores_scalex2
);
acc_o_rowcol
(
mi
,
ni
)
=
res
.
x
;
acc_o_rowcol
(
mi
,
ni
+
1
)
=
res
.
y
;
}
#endif
}
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
...
...
@@ -584,8 +631,19 @@ struct Softmax {
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
#if defined(__gfx928__)
#pragma unroll
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
++
ni
)
{
acc_o_rowcol
(
mi
,
ni
)
*=
scale
;
}
#else
using
__float2
=
__attribute__
((
ext_vector_type
(
2
)))
float
;
__float2
scores_scalex2
=
{
scale
,
scale
};
for
(
int
ni
=
0
;
ni
<
size
<
1
>
(
acc_o_rowcol
);
ni
+=
2
)
{
__float2
vx2
=
{
acc_o_rowcol
(
mi
,
ni
),
acc_o_rowcol
(
mi
,
ni
+
1
)};
__float2
res
=
__builtin_hcu_pk_mul_f32
(
vx2
,
scores_scalex2
);
acc_o_rowcol
(
mi
,
ni
)
=
res
.
x
;
acc_o_rowcol
(
mi
,
ni
+
1
)
=
res
.
y
;
}
#endif
}
return
lse
;
};
...
...
csrc/flash_attn/src/utils.h
View file @
518a5f4d
...
...
@@ -33,16 +33,20 @@ __forceinline__ __device__ void s_nop() {
}
__forceinline__
__device__
void
s_barrier
()
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
const
int
COUNT
>
__forceinline__
__device__
void
s_waitcnt
()
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(%0)
\n\t
"
"s_barrier
\n
"
::
"B"
(
COUNT
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
const
int
COUNT
>
...
...
@@ -1392,12 +1396,14 @@ lds_direct_copy(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_64x16
)
{
constexpr
int
elements_per_thread
=
4
;
...
...
@@ -1413,12 +1419,14 @@ lds_direct_copy(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x128
)
{
constexpr
int
elements_per_thread
=
8
;
...
...
@@ -1435,12 +1443,14 @@ lds_direct_copy(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x192
)
{
constexpr
int
elements_per_thread
=
8
;
...
...
@@ -1457,12 +1467,14 @@ lds_direct_copy(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
constexpr
int
elements_per_thread_tail
=
4
;
...
...
@@ -1481,12 +1493,14 @@ lds_direct_copy(
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x64_128
)
{
constexpr
int
elements_per_thread
=
4
;
...
...
@@ -1505,12 +1519,14 @@ lds_direct_copy(
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x64_64
)
{
constexpr
int
elements_per_thread
=
4
;
...
...
@@ -1529,12 +1545,14 @@ lds_direct_copy(
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x96
)
{
constexpr
int
elements_per_thread
=
8
;
...
...
@@ -1552,12 +1570,14 @@ lds_direct_copy(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
if
(
warp_id
<
3
)
{
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
else
if
constexpr
(
mma_layout
==
_16x96_multi_ins
)
{
...
...
@@ -1575,12 +1595,14 @@ lds_direct_copy(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
constexpr
int
elements_per_thread_tail
=
2
;
...
...
@@ -1599,17 +1621,188 @@ lds_direct_copy(
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
template
<
int
k_idx
,
int
K_BUFF_SIZE
=
0
,
bool
Is_even_MN
=
true
,
MMA_LAYOUT
mma_layout
=
_64x32
,
int
n_idx
=
0
,
bool
Use_cache_swizzle
=
true
,
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_even_k_dim256
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
const
int
row_stride
,
const
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
if
constexpr
(
Use_cache_swizzle
)
{
glob_ptr
.
latter
+=
0x41000000
;
// 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
if
constexpr
(
mma_layout
==
_64x32
)
{
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
int
mma_k
=
32
*
64
;
int
row
=
tidx
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
*
4
+
warp_id
;
int
col_offset
=
col
*
elements_per_thread
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
(
k_idx
%
K_BUFF_SIZE
)
*
mma_k
*
element_size
;
const
int
offset_s
=
k_idx
*
32
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x256
)
{
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
int
mma_k
=
16
*
128
;
int
row
=
lane
/
4
;
int
col
=
tidx
%
4
;
int
row_offset
=
row
+
k_idx
*
16
;
int
col_offset
=
col
*
elements_per_thread
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
const
int
offset_s
=
(
warp_id
*
32
+
n_idx
*
128
)
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
template
<
int
k_idx
,
bool
Is_even_MN
=
true
,
MMA_LAYOUT
mma_layout
=
_64x32
,
bool
Use_cache_swizzle
=
true
,
class
SrcEngine
,
class
SrcLayout
,
class
DstEngine
,
class
DstLayout
>
CUTE_HOST_DEVICE
void
lds_direct_copy_even_k
(
Tensor
<
SrcEngine
,
SrcLayout
>
const
&
src
,
Tensor
<
DstEngine
,
DstLayout
>
&
dst
,
const
int
row_stride
,
const
int
max_MN
=
0
)
{
constexpr
int
warp_size
=
64
;
int
tidx
=
threadIdx
.
x
;
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
tidx
/
warp_size
);
int
lane
=
tidx
%
warp_size
;
constexpr
int
element_size
=
2
;
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
src
.
data
().
get
());
if
constexpr
(
Use_cache_swizzle
)
{
glob_ptr
.
latter
+=
0x41000000
;
// 62 bit: cache swizzle; 48~61: Stride
}
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
former
);
global_addr
[
1
]
=
__builtin_amdgcn_readfirstlane
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
if
constexpr
(
mma_layout
==
_64x32
)
{
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
int
mma_k
=
32
*
64
;
int
row
=
tidx
%
16
;
int
col
=
lane
/
16
;
int
row_offset
=
row
*
4
+
warp_id
;
int
col_offset
=
col
*
elements_per_thread
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
const
int
offset_s
=
k_idx
*
32
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
else
if
constexpr
(
mma_layout
==
_16x64_64
)
{
constexpr
int
elements_per_thread
=
4
;
constexpr
int
bytes_per_warp
=
warp_size
*
elements_per_thread
*
element_size
;
int
mma_k
=
16
*
64
;
int
row
=
(
tidx
/
8
)
%
16
;
int
col
=
tidx
%
8
;
int
row_offset
=
row
+
k_idx
*
16
;
int
col_offset
=
col
*
elements_per_thread
;
int
offset_v
=
(
row_offset
*
row_stride
+
col_offset
)
*
element_size
;
// bytes
if
(
!
Is_even_MN
&&
row_offset
>=
max_MN
)
offset_v
=
-
1
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_idx
*
mma_k
*
element_size
;
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
const
int
offset_s
=
warp_id
/
2
*
32
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) ||defined(__gfx92a__))
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
}
#define fp8 unsigned char
__forceinline__
__device__
float
fp8e5m2_to_fp32
(
const
fp8
&
input
)
{
union
uf16
{
...
...
@@ -1769,7 +1962,7 @@ lds_direct_copy(int k_slide,
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1791,7 +1984,7 @@ lds_direct_copy(int k_slide,
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1813,7 +2006,7 @@ lds_direct_copy(int k_slide,
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1837,7 +2030,7 @@ lds_direct_copy(int k_slide,
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1861,7 +2054,7 @@ lds_direct_copy(int k_slide,
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1885,7 +2078,7 @@ lds_direct_copy(int k_slide,
// if (tidx < 64) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1908,7 +2101,7 @@ lds_direct_copy(int k_slide,
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
if
(
warp_id
<
3
)
{
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1931,7 +2124,7 @@ lds_direct_copy(int k_slide,
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -1955,7 +2148,7 @@ lds_direct_copy(int k_slide,
// if (thread0()) printf("tid:%d offset_v:%d ldsAddrPerWave:%d\n", tidx, offset_v, ldsAddrPerWave);
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2021,7 +2214,7 @@ lds_direct_copy(int n_idx, int k_slide,
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2101,7 +2294,7 @@ lds_direct_copy_fp8(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if defined(__gfx938__)
#if defined(__gfx938__)
||defined(__gfx92a__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2123,7 +2316,7 @@ lds_direct_copy_fp8(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if defined(__gfx938__)
#if defined(__gfx938__)
||defined(__gfx92a__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2144,7 +2337,7 @@ lds_direct_copy_fp8(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if defined(__gfx938__)
#if defined(__gfx938__)
||defined(__gfx92a__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2167,7 +2360,7 @@ lds_direct_copy_fp8(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if defined(__gfx938__)
#if defined(__gfx938__)
||defined(__gfx92a__)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2244,7 +2437,7 @@ lds_direct_copy_for_vertical_sparse(
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,idxen offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
@@ -2289,7 +2482,7 @@ lds_direct_copy_for_vertical_sparse(
// int index_v = offset_v;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
dst
.
data
().
get
())
+
warp_id
*
bytes_per_warp
+
k_slide
*
mma_k
*
element_size
;
#if (defined(__gfx936__) || defined(__gfx938__) )
#if (defined(__gfx936__) || defined(__gfx938__)
||defined(__gfx92a__)
)
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,idxen offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
...
...
csrc/flash_attn_hg/flash_api.cpp
View file @
518a5f4d
This source diff could not be displayed because it is too large. You can
view the blob
instead.
csrc/flash_attn_hg/flash_c_api.h
View file @
518a5f4d
...
...
@@ -20,7 +20,20 @@ void run_mha_fwd(Flash_fwd_params ¶ms, hipStream_t stream, bool force_split_
}
if
(
params
.
seqused_k
!=
nullptr
)
{
// Prefix prefill attention
if
(
!
params
.
is_int8
){
if
(
params
.
is_e4m3
)
{
// FP8 prefix prefill
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
if
(
params
.
d
==
128
and
params
.
d_value
==
128
)
{
run_fp8_mha_fwd_prefix_prefill_
<
elem_type
,
128
,
128
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
192
and
params
.
d_value
==
128
)
{
run_fp8_mha_fwd_prefix_prefill_
<
elem_type
,
192
,
128
>
(
params
,
stream
);
}
else
if
(
params
.
d
==
256
and
params
.
d_value
==
256
)
{
run_fp8_mha_fwd_prefix_prefill_
<
elem_type
,
256
,
256
>
(
params
,
stream
);
}
else
{
assert
(
false
&&
"FP8 prefix prefill only supports head_dim=128/128, 192/128, or 256/256"
);
}
});
}
else
if
(
!
params
.
is_int8
){
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
if
(
params
.
d
==
128
and
params
.
d_value
==
128
)
{
run_mha_fwd_prefix_prefill_
<
elem_type
,
128
,
128
>
(
params
,
stream
);
...
...
@@ -65,6 +78,13 @@ void run_mha_fwd(Flash_fwd_params ¶ms, hipStream_t stream, bool force_split_
else
{
// Decoder-only attention
FP16_SWITCH
(
!
params
.
is_bf16
,
[
&
]
{
if
(
params
.
is_e4m3
)
{
if
(
params
.
d
==
128
and
params
.
d_value
==
128
)
{
run_fp8_mha_fwd_
<
elem_type
,
128
,
128
>
(
params
,
stream
);
}
else
{
assert
(
false
&&
"FP8 forward only supports head_dim=128"
);
}
}
else
{
#if defined(HEADDIM_128_ONLY)
run_mha_fwd_
<
elem_type
,
128
,
128
>
(
params
,
stream
);
#elif defined(HEADDIM_192_128_ONLY)
...
...
@@ -74,6 +94,7 @@ void run_mha_fwd(Flash_fwd_params ¶ms, hipStream_t stream, bool force_split_
run_mha_fwd_
<
elem_type
,
kHeadDimQ
,
kHeadDimV
>
(
params
,
stream
);
});
#endif
}
});
}
#endif
...
...
csrc/flash_attn_hg/include/bwd/dot_do_o_gfx946.h
0 → 100644
View file @
518a5f4d
#pragma once
#include <block_info.h>
#include "utils.h"
#include "prefetch.h"
// Just compute dot(do, o) and write the result (softmax_d) to global memory as a separate kernel.
// This is used in the case where we want to parallelize the backward across seqlen_k.
template
<
bool
Clear_dQaccum
=
true
,
bool
Is_even_MN
,
class
Element
,
class
ElementAccum
,
int
kBlockM
,
int
kBlockN
,
int
WARP_M
,
int
WARP_N
,
int
K
,
int
STAGES
,
bool
USE_BSHD_LAYOUT
,
typename
Params
>
inline
__device__
void
compute_dot_do_o_gfx946
(
const
Params
&
params
)
{
Element
*
do_ptr
=
static_cast
<
Element
*>
(
params
.
do_ptr
);
Element
*
o_ptr
=
static_cast
<
Element
*>
(
params
.
o_ptr
);
ElementAccum
*
dsoftmax_sum
=
static_cast
<
ElementAccum
*>
(
params
.
dsoftmax_sum
);
const
int
m_block
=
blockIdx
.
x
;
// The block index for the batch.
const
int
bidb
=
blockIdx
.
z
;
// The block index for the head.
const
int
bidh
=
blockIdx
.
y
;
// The thread index.
const
int
tidx
=
threadIdx
.
x
;
//wave size should be defined in launch file. Here use 64 threads
int
lane_id
=
threadIdx
.
x
&
63
;
//lane id, 0-63
int
warp_id_vec
=
threadIdx
.
x
/
64
;
//warp id in a block
int
warp_id
=
0
;
__shared__
Element
dO_lds
[
kBlockM
*
kBlockN
];
__shared__
Element
O_lds
[
kBlockM
*
kBlockN
];
float
dP_sum_cur
[(
kBlockM
/
16
)]
=
{
0.0
f
};
const
int
WARP_NUM
=
(
kBlockM
)
/
(
WARP_M
);
const
flash
::
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
,
false
,
USE_BSHD_LAYOUT
>
binfo
(
params
,
bidb
);
if
(
m_block
*
kBlockM
>=
binfo
.
actual_seqlen_q
)
return
;
int
seqlen_do_stride
=
params
.
do_row_stride
;
int
seqlen_o_stride
=
params
.
o_row_stride
;
const
int
row_offset_do
=
binfo
.
q_offset1
(
params
.
do_batch_stride
,
params
.
do_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
do_head_stride
,
bidh
)
+
m_block
*
kBlockM
*
seqlen_do_stride
;
const
int
row_offset_o
=
binfo
.
q_offset1
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
o_head_stride
,
bidh
)
+
m_block
*
kBlockM
*
seqlen_o_stride
;
const
int
row_offset_dpsum
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
+
m_block
*
kBlockM
;
auto
gdO
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
do_ptr
)
+
row_offset_do
,
seqlen_do_stride
);
auto
gO
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
o_ptr
)
+
row_offset_o
,
seqlen_o_stride
);
ElementAccum
*
dP_sum
=
reinterpret_cast
<
ElementAccum
*>
(
dsoftmax_sum
)
+
row_offset_dpsum
;
asm
volatile
(
"v_readfirstlane_b32 %0,%1"
:
"=s"
(
warp_id
)
:
"v"
(
warp_id_vec
)
:
);
union_vec4_f16x2
<
Element
>
dO_reg
[((
WARP_M
*
kBlockN
)
/
(
32
*
32
))
*
2
];
union_vec4_f16x2
<
Element
>
O_reg
[((
WARP_M
*
kBlockN
)
/
(
32
*
32
))
*
2
];
for
(
int
k_loop
=
0
;
k_loop
<
K
/
kBlockN
;
k_loop
++
)
{
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
int
do_block_buffer_load_global_offset
=
k_loop
*
kBlockN
;
//read 32 * 128
prefetch_to_lds_gfx938
<
true
,
kBlockM
,
kBlockN
,
Element
,
ElementAccum
,
Is_even_MN
,
1
>
(
gdO
,
do_block_buffer_load_global_offset
,
dO_lds
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
,
warp_id
);
prefetch_to_lds_gfx938
<
true
,
kBlockM
,
kBlockN
,
Element
,
ElementAccum
,
Is_even_MN
,
1
>
(
gO
,
do_block_buffer_load_global_offset
,
O_lds
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
,
warp_id
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
for
(
int
i
=
0
;
i
<
kBlockN
/
32
;
++
i
)
{
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(dO_lds + i * 32 * 32), dO_reg[i * 2 + 0].f16, dO_reg[i * 2 + 1].f16, true);
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(O_lds + i * 32 * 32), O_reg[i * 2 + 0].f16, O_reg[i * 2 + 1].f16, true);
if
constexpr
(
std
::
is_same_v
<
Element
,
half_t
>
)
{
dO_reg
[
i
*
2
+
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
dO_lds
+
i
*
32
*
32
,
0
,
2
,
1
,
0
);
dO_reg
[
i
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
dO_lds
+
i
*
32
*
32
,
1024
,
2
,
1
,
0
);
O_reg
[
i
*
2
+
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
O_lds
+
i
*
32
*
32
,
0
,
2
,
1
,
0
);
O_reg
[
i
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
O_lds
+
i
*
32
*
32
,
1024
,
2
,
1
,
0
);
}
else
{
dO_reg
[
i
*
2
+
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
dO_lds
+
i
*
32
*
32
,
0
,
2
,
1
,
0
);
dO_reg
[
i
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
dO_lds
+
i
*
32
*
32
,
1024
,
2
,
1
,
0
);
O_reg
[
i
*
2
+
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
O_lds
+
i
*
32
*
32
,
0
,
2
,
1
,
0
);
O_reg
[
i
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
O_lds
+
i
*
32
*
32
,
1024
,
2
,
1
,
0
);
}
}
asm
volatile
(
"s_waitcnt lgkmcnt(0)"
);
__builtin_amdgcn_sched_barrier
(
0
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
(
kBlockN
/
32
);
++
head_dim_idx
)
{
#pragma unroll
for
(
int
vec_id
=
0
;
vec_id
<
4
;
vec_id
++
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
if
(
Is_even_MN
||
(
m_block
*
kBlockM
+
min_tile_m
*
16
+
(
threadIdx
.
x
&
15
))
<
binfo
.
actual_seqlen_q
)
{
dP_sum_cur
[
min_tile_m
]
+=
UpCast
<
Element
,
float
,
false
>
(
dO_reg
[
head_dim_idx
*
2
+
min_tile_m
].
f16
[
vec_id
*
2
+
min_tile_n
])
*
UpCast
<
Element
,
float
,
false
>
(
O_reg
[
head_dim_idx
*
2
+
min_tile_m
].
f16
[
vec_id
*
2
+
min_tile_n
]);
}
}
}
}
}
}
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
flash
::
SumOp
<
float
>
sum_op
;
dP_sum_cur
[
mi
*
2
+
min_tile_m
]
=
flash
::
Allreduce
<
64
>::
run
(
dP_sum_cur
[
mi
*
2
+
min_tile_m
],
sum_op
)
*
params
.
p_dropout
;
if
((
threadIdx
.
x
>>
4
)
==
0
)
{
dP_sum
[
mi
*
32
+
min_tile_m
*
16
+
(
threadIdx
.
x
&
15
)]
=
dP_sum_cur
[
mi
*
2
+
min_tile_m
];
}
}
}
}
\ No newline at end of file
csrc/flash_attn_hg/include/bwd/flash_attention_bwd.h
View file @
518a5f4d
...
...
@@ -24,12 +24,15 @@
#include "static_switch.h"
#include "dot_do_o.h"
#include "dot_do_o_gfx938.h"
#include "dot_do_o_gfx946.h"
#include "prefetch.h"
#include "flash_singleton.h"
#include "flash_attention_dv_dk_bwd.h"
#include "flash_attention_dv_dk_bwd_gfx938.h"
#include "flash_attention_dv_dk_bwd_gfx946.h"
#include "flash_attention_dq_bwd.h"
#include "flash_attention_dq_bwd_gfx938.h"
#include "flash_attention_dq_bwd_gfx946.h"
using
std
::
make_shared
;
using
std
::
shared_ptr
;
...
...
csrc/flash_attn_hg/include/bwd/flash_attention_dq_bwd_gfx946.h
0 → 100644
View file @
518a5f4d
#ifdef DEBUGING
#define print_qk(block_id_m, bidb, bidh) {\
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
kq_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
int qk_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = qk_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
s_ptr[offset] = s_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_dp(block_id_m, bidb, bidh) {\
int dp_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int dp_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = dp_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
dp_ptr[offset] = dp_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#define print_ds(block_id_m, bidb, bidh) {\
int ds_warp_n_offset = warp_id * WARP_M_ * params.seqlen_k; \
int ds_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_offset; \
for(int block_n_idx = 0; block_n_idx < kBlockN_/WARP_N_; ++block_n_idx){ \
for(int warp_n_idx = 0; warp_n_idx < WARP_N_/16; ++warp_n_idx){ \
for(int warp_m_idx = 0; warp_m_idx < WARP_M_/16; ++warp_m_idx){ \
for(int vec_idx = 0; vec_idx < 4 ; ++vec_idx) {\
int offset = ds_global_offset + block_n_idx * WARP_N_ + lane_id%16 * params.seqlen_k + warp_m_idx * params.seqlen_k * 16 + warp_n_idx*16 + lane_id/16*4 + vec_idx; \
ds_ptr[offset] = dS_reg[block_n_idx][warp_n_idx*(WARP_M_/16) + warp_m_idx].f32[vec_idx]; \
} \
} \
} \
}\
}
#endif
template
<
class
Element
,
class
ElementAccum
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
,
int
kBlockM_
,
int
kBlockN_
,
int
K
,
int
K_v
,
int
kBlockK_
,
int
WARP_M_
,
int
WARP_N_
,
int
STAGES
,
int
USE_BSHD_LAYOUT
,
typename
Params
>
__forceinline__
__device__
void
compute_dq_1colblock_gfx946
(
Params
&
params
,
int
bidb
,
int
bidh
,
int
m_block
)
{
#ifdef DEBUGING
ElementAccum
*
kq_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
kq_ptr
);
ElementAccum
*
s_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
s_ptr
);
ElementAccum
*
dp_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
dp_ptr
);
ElementAccum
*
ds_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
ds_ptr
);
#endif
Element
*
q_ptr
=
static_cast
<
Element
*>
(
params
.
q_ptr
);
Element
*
k_ptr
=
static_cast
<
Element
*>
(
params
.
k_ptr
);
Element
*
v_ptr
=
static_cast
<
Element
*>
(
params
.
v_ptr
);
Element
*
o_ptr
=
static_cast
<
Element
*>
(
params
.
o_ptr
);
Element
*
dq_ptr
=
static_cast
<
Element
*>
(
params
.
dq_ptr
);
Element
*
dk_ptr
=
static_cast
<
Element
*>
(
params
.
dk_ptr
);
Element
*
dv_ptr
=
static_cast
<
Element
*>
(
params
.
dv_ptr
);
Element
*
do_ptr
=
static_cast
<
Element
*>
(
params
.
do_ptr
);
ElementAccum
*
softmax_lse_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
);
ElementAccum
*
dsoftmax_sum
=
static_cast
<
ElementAccum
*>
(
params
.
dsoftmax_sum
);
//flash-attention QK, kBlockN_==WARP_N_;
const
int
M_BLOCK_NUM
=
params
.
seqlen_q
/
kBlockM_
;
const
int
N_BLOCK_NUM
=
params
.
seqlen_k
/
kBlockN_
;
extern
__shared__
Element
smem
[];
#if 1//defined(__gfx936__)
const
bool
Is_store_K
=
true
;
const
bool
Is_preload_K
=
true
;
const
bool
Is_preload_V
=
true
;
#else
const
bool
Is_store_K
=
false
;
const
bool
Is_preload_K
=
false
;
const
bool
Is_preload_V
=
false
;
#endif
const
int
K_prefetch_level
=
Is_preload_K
?
1
:
0
;
const
int
V_prefetch_level
=
Is_preload_V
?
1
:
0
;
const
int
Q_prefetch_level
=
3
;
Element
*
K_lds
=
(
Element
*
)
&
(
smem
);
Element
*
Q_lds
=
(
Element
*
)
&
(
smem
);
Element
*
dO_lds
=
(
Element
*
)
&
(
smem
);
Element
*
V_lds
=
(
Element
*
)
&
(
smem
)
+
kBlockN_
*
K
;
int
tidx
=
threadIdx
.
x
;
int
lane_id
=
threadIdx
.
x
&
63
;
//lane id, 0-63
const
flash
::
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
,
false
,
USE_BSHD_LAYOUT
>
binfo
(
params
,
bidb
);
if
(
m_block
<
0
||
m_block
*
kBlockM_
>=
binfo
.
actual_seqlen_q
||
binfo
.
actual_seqlen_k
==
0
)
return
;
const
int
n_block_min
=
!
Is_local
?
0
:
std
::
max
(
0
,
(
m_block
*
kBlockM_
-
params
.
window_size_left
)
/
kBlockN_
);
const
int
n_block_max
=
(
!
Is_causal
&&
!
Is_local
)
?
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN_
)
:
std
::
min
(
ceil_div
(
binfo
.
actual_seqlen_k
,
kBlockN_
),
flash
::
ceil_div
((
m_block
+
1
)
*
kBlockM_
+
params
.
window_size_right
,
kBlockN_
));
int
seqlen_q_stride
=
params
.
q_row_stride
;
int
seqlen_k_stride
=
params
.
k_row_stride
;
int
seqlen_v_stride
=
params
.
v_row_stride
;
int
seqlen_do_stride
=
params
.
do_row_stride
;
int
seqlen_o_stride
=
params
.
o_row_stride
;
int
seqlen_dq_stride
=
params
.
dq_row_stride
;
// We move K and V to the last block.
const
int
row_offset_q
=
binfo
.
q_offset1
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
q_head_stride
,
bidh
)
+
m_block
*
kBlockM_
*
seqlen_q_stride
;
const
int
row_offset_k
=
binfo
.
k_offset1
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
binfo
.
k_offset2
(
params
.
k_head_stride
,
bidh
/
params
.
h_h_k_ratio
)
+
(
n_block_max
-
1
)
*
kBlockN_
*
seqlen_k_stride
;
const
int
row_offset_v
=
binfo
.
k_offset1
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
binfo
.
k_offset2
(
params
.
v_head_stride
,
bidh
/
params
.
h_h_k_ratio
)
+
(
n_block_max
-
1
)
*
kBlockN_
*
seqlen_v_stride
;
const
int
row_offset_dO
=
binfo
.
q_offset1
(
params
.
do_batch_stride
,
params
.
do_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
do_head_stride
,
bidh
)
+
m_block
*
kBlockM_
*
seqlen_do_stride
;
const
int
row_offset_o
=
binfo
.
q_offset1
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
o_head_stride
,
bidh
)
+
m_block
*
kBlockM_
*
seqlen_o_stride
;
const
int
row_offset_dq
=
binfo
.
q_offset1
(
params
.
dq_batch_stride
,
params
.
dq_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
dq_head_stride
,
bidh
)
+
m_block
*
kBlockM_
*
seqlen_dq_stride
;
const
int
row_offset_lse
=
params
.
cu_seqlens_q
==
nullptr
?
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
m_block
*
kBlockM_
:
bidh
*
params
.
total_q
+
binfo
.
sum_s_q
+
m_block
*
kBlockM_
;
const
int
row_offset_dpsum
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
+
m_block
*
kBlockM_
;
auto
gQ
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
q_ptr
)
+
row_offset_q
,
seqlen_q_stride
);
auto
gK
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
k_ptr
)
+
row_offset_k
,
seqlen_k_stride
);
auto
gV
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
v_ptr
)
+
row_offset_v
,
seqlen_v_stride
);
auto
gdO
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
do_ptr
)
+
row_offset_dO
,
seqlen_do_stride
);
Element
*
gO
=
reinterpret_cast
<
Element
*>
(
o_ptr
)
+
row_offset_o
;
dq_ptr
=
reinterpret_cast
<
Element
*>
(
dq_ptr
)
+
row_offset_dq
;
ElementAccum
*
gLSE
=
reinterpret_cast
<
ElementAccum
*>
(
softmax_lse_ptr
)
+
row_offset_lse
;
ElementAccum
*
gdPsum
=
reinterpret_cast
<
ElementAccum
*>
(
dsoftmax_sum
)
+
row_offset_dpsum
;
constexpr
int
n_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
1
:
((
Is_even_MN
&&
Is_causal
)
?
flash
::
ceil_div
(
kBlockM_
,
kBlockN_
)
:
flash
::
ceil_div
(
kBlockM_
,
kBlockN_
)
+
1
);
// int warp_id =0;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
//warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
union_vec4_f16x2
<
Element
>
q_reg
[(
K
/
kBlockK_
)
*
((
WARP_M_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
union_vec4_f16x2
<
Element
>
dO_reg
[(
K_v
/
kBlockK_
)
*
((
WARP_M_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
union_vec4_fp32
acc_dq
[(
K
/
kBlockK_
)
*
((
WARP_M_
/
32
)
*
(
kBlockK_
/
32
))][
4
]
=
{
0
};
float
lse
[
WARP_M_
/
16
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M_
/
32
);
++
mi
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
int
lse_idx
=
warp_id
*
WARP_M_
+
mi
*
32
+
(
lane_id
&
15
)
+
min_tile_m
*
16
;
lse
[
mi
*
2
+
min_tile_m
]
=
(
Is_even_MN
||
lse_idx
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
)
?
gLSE
[
lse_idx
]
:
INFINITY
;
}
}
float
dP_sum_reg
[
WARP_M_
/
16
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M_
/
32
);
++
mi
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
int
dP_sum_idx
=
warp_id
*
WARP_M_
+
mi
*
32
+
(
lane_id
&
15
)
+
min_tile_m
*
16
;
dP_sum_reg
[
mi
*
2
+
min_tile_m
]
=
gdPsum
[
dP_sum_idx
];
}
}
prefetch_to_vgpr_gfx938
<
true
,
kBlockM_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gQ
,
Q_lds
,
q_reg
,
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
warp_id
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
prefetch_to_vgpr_gfx938
<
true
,
kBlockM_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gdO
,
dO_lds
,
dO_reg
,
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
warp_id
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
if
constexpr
(
Is_preload_V
){
prefetch_to_lds_gfx938
<
true
,
kBlockN_
,
K_v
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gV
,
0
,
V_lds
,
(
binfo
.
actual_seqlen_k
-
(
n_block_max
-
1
)
*
kBlockN_
),
warp_id
);
}
if
constexpr
(
Is_preload_K
){
prefetch_to_lds_gfx938
<
true
,
kBlockN_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gK
,
0
,
K_lds
,
(
binfo
.
actual_seqlen_k
-
(
n_block_max
-
1
)
*
kBlockN_
),
warp_id
);
}
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
for
(
int
n_block
=
n_block_max
-
1
;
n_block
>=
n_block_min
;
--
n_block
)
{
union_vec4_f16x2
<
Element
>
v_reg
[((
WARP_N_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
union_vec4_fp32
dp_reg
[(
WARP_M_
/
32
)
*
(
kBlockN_
/
32
)][
4
]
=
{
0
};
//dP gemm
gemm_tt_kq_gfx938
<
false
,
Is_preload_K
,
Is_even_MN
,
3
,
V_prefetch_level
,
K_v
,
kBlockM_
,
kBlockN_
,
kBlockK_
,
WARP_N_
,
WARP_N_
,
STAGES
,
Element
>
(
gdO
,
gV
,
dO_lds
,
V_lds
,
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
dO_reg
,
v_reg
,
dp_reg
,
warp_id
,
seqlen_do_stride
,
seqlen_v_stride
);
#ifdef DEBUGING
print_dp
(
m_block
,
bidb
,
bidh
);
#endif
union_vec4_f16x2
<
Element
>
k_reg
[((
WARP_M_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
//c mini tile is 32*32
union_vec4_fp32
s_reg
[(
WARP_N_
/
32
)
*
(
kBlockM_
/
32
)][
4
]
=
{
0
};
//qk gemm
gemm_tt_kq_gfx938
<
Is_store_K
,
false
,
Is_even_MN
,
Q_prefetch_level
,
K_prefetch_level
,
K
,
kBlockM_
,
kBlockN_
,
kBlockK_
,
WARP_N_
,
WARP_N_
,
STAGES
,
Element
>
(
gQ
,
gK
,
Q_lds
,
K_lds
,
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
q_reg
,
k_reg
,
s_reg
,
warp_id
,
seqlen_q_stride
,
seqlen_k_stride
);
*
(
uint64_t
*
)
&
gV
-=
((
kBlockN_
*
seqlen_v_stride
)
*
sizeof
(
Element
));
if
(
Is_preload_V
&&
n_block
>
n_block_min
){
prefetch_to_lds_gfx938
<
true
,
kBlockN_
,
K_v
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gV
,
0
,
V_lds
,
(
binfo
.
actual_seqlen_k
-
(
n_block
-
1
)
*
kBlockN_
),
warp_id
);
}
apply_mask_bwd_gfx938
<
Is_even_MN
,
Is_local
?
3
:
(
Is_causal
?
1
:
0
)
>
(
s_reg
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
-
warp_id
*
32
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
,
(
m_block
*
kBlockM_
+
warp_id
*
32
)
-
(
n_block
*
kBlockN_
),
params
.
window_size_left
,
params
.
window_size_right
);
#ifdef DEBUGING
print_qk
(
m_block
,
bidb
,
bidh
);
#endif
scale_apply_exp2_bwd_seq_q_major
<
/*scale_max=*/
false
,
WARP_M_
,
kBlockN_
,
union_vec4_fp32
,
ElementAccum
>
(
s_reg
,
lse
,
params
.
scale_softmax_log2
);
#ifdef DEBUGING
print_softmax_rescale_o
(
m_block
,
bidb
,
bidh
)
#endif
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
return
p
*
(
!
Is_dropout
||
p
>=
0
?
dp
-
d
:
d
);
// return p * (dp - d);
};
union_vec4_fp32
dS_reg
[(
WARP_M_
/
32
)
*
(
kBlockN_
/
32
)][
4
];
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
kBlockN_
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
WARP_M_
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
// result register ds_reg reuse dp_reg
dS_reg
[
ni
*
(
WARP_M_
/
32
)
+
mi
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
pointwise_mult
(
s_reg
[
ni
*
(
WARP_M_
/
32
)
+
mi
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
],
dp_reg
[
ni
*
(
WARP_M_
/
32
)
+
mi
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
],
dP_sum_reg
[
min_tile_m
+
mi
*
2
]);
// dS_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx] = s_reg[ni*(WARP_M_/32) + mi][min_tile_n*2 + min_tile_m].f32[vec_idx];
}
}
}
}
}
#ifdef DEBUGING
print_ds
(
m_block
,
bidb
,
bidh
);
#endif
union_vec4_f16x2
<
Element
>
dS_reg_fp16
[(
WARP_M_
/
32
)
*
(
kBlockN_
/
32
)
*
2
];
convert_pk_type_gfx938
<
WARP_M_
,
kBlockN_
,
Element
>
(
dS_reg_fp16
,
dS_reg
);
{
//dq gemm, K*dS
gpu_gemm_B_in_reg_gfx946
<
Is_store_K
,
false
,
Is_even_MN
,
K
,
kBlockK_
,
kBlockM_
,
kBlockN_
,
kBlockK_
,
WARP_M_
,
2
,
Element
>
(
gK
,
gK
,
K_lds
,
dS_reg_fp16
,
acc_dq
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
warp_id
,
seqlen_k_stride
);
}
*
(
uint64_t
*
)
&
gK
-=
((
kBlockN_
*
seqlen_k_stride
)
*
sizeof
(
Element
));
// if(blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0){
// printf("(binfo.actual_seqlen_k - n_block * kBlockN_) = %d\n", (binfo.actual_seqlen_k - n_block * kBlockN_));
// }
#if 1//defined(__gfx936__)
{
__syncthreads
();
if
(
Is_preload_K
&&
n_block
>
n_block_min
){
prefetch_to_lds_gfx938
<
true
,
kBlockN_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gK
,
0
,
K_lds
,
(
binfo
.
actual_seqlen_k
-
(
n_block
-
1
)
*
kBlockN_
),
warp_id
);
}
}
#else
__builtin_amdgcn_sched_barrier
(
0
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
#endif
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dq_ptr
=
dq_ptr
+
binfo
.
q_offset1
(
params
.
dq_batch_stride
,
params
.
dq_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
dq_head_stride
,
bidh
);
auto
gdQ
=
tcp_cache_swizzle_func
<
K_v
,
Element
>
(
dq_ptr
);
int
dq_lane_seq_idx
=
(
lane_id
>>
4
);
int
dq_lane_head_dim_idx
=
(
lane_id
&
15
);
int
dq_global_addr_offset
=
0
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
K_v
/
kBlockK_
);
k_loop
++
)
{
#pragma unroll
for
(
int
warp_m_idx
=
0
;
warp_m_idx
<
(
WARP_M_
/
32
);
warp_m_idx
++
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK_
/
32
);
k_tile_idx
++
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
vec_index
++
)
{
int
v_offset
=
dq_lane_head_dim_idx
*
seqlen_dq_stride
+
dq_lane_seq_idx
*
4
;
int
s_offset
=
(
min_tile_m
*
seqlen_dq_stride
*
16
+
vec_index
%
2
*
2
+
vec_index
/
2
*
16
)
+
(
k_tile_idx
*
32
)
+
((
warp_id
*
WARP_M_
+
warp_m_idx
*
32
)
*
seqlen_dq_stride
)
+
(
k_loop
*
kBlockK_
+
m_block
*
kBlockM_
*
seqlen_dq_stride
);
int
known_offset
=
0
;
vec2_Element
<
Element
>
v_data
;
v_data
[
0
]
=
DownCast
<
float
,
Element
,
true
>
(
acc_dq
[
k_loop
*
((
WARP_M_
/
32
)
*
(
kBlockK_
/
32
))
+
(
warp_m_idx
*
(
kBlockK_
/
32
)
+
k_tile_idx
)][
min_tile_m
*
2
+
vec_index
/
2
].
f32
[
vec_index
%
2
*
2
]
*
params
.
scale_softmax_rp_dropout
);
v_data
[
1
]
=
DownCast
<
float
,
Element
,
true
>
(
acc_dq
[
k_loop
*
((
WARP_M_
/
32
)
*
(
kBlockK_
/
32
))
+
(
warp_m_idx
*
(
kBlockK_
/
32
)
+
k_tile_idx
)][
min_tile_m
*
2
+
vec_index
/
2
].
f32
[
vec_index
%
2
*
2
+
1
]
*
params
.
scale_softmax_rp_dropout
);
if
(
Is_even_MN
||
min_tile_m
*
16
+
(
warp_id
*
WARP_M_
+
warp_m_idx
*
32
)
+
m_block
*
kBlockM_
+
dq_lane_head_dim_idx
<
binfo
.
actual_seqlen_q
){
inline_buffer_store_dword_glc_slc
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gdQ
,
s_offset
,
/* immediate integer */
known_offset
);
}
}
}
}
}
}
}
#endif
}
#undef print_qk
#undef print_softmax_rescale_o
#undef print_dp
#undef print_ds
csrc/flash_attn_hg/include/bwd/flash_attention_dv_dk_bwd_gfx938.h
View file @
518a5f4d
...
...
@@ -234,7 +234,7 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params ¶ms, i
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return
__builtin_
hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
...
...
@@ -295,9 +295,12 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params ¶ms, i
//提前读取V到vgpr
prefetch_to_vgpr_gfx938
<
true
,
kBlockN_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gV
,
V_lds
,
v_reg
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
warp_id
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
//提前读取K到vgpr
prefetch_to_vgpr_gfx938
<
true
,
kBlockN_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gK
,
K_lds
,
k_reg
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
warp_id
);
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
//提前读取Q到lds
if
constexpr
(
Is_preload_Q
){
...
...
@@ -307,8 +310,8 @@ __forceinline__ __device__ void compute_dk_dv_1colblock_gfx938(Params ¶ms, i
if
constexpr
(
Is_preload_dO
){
prefetch_to_lds_gfx938
<
true
,
kBlockM_
,
K_v
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gdO
,
0
,
dO_lds
,
(
binfo
.
actual_seqlen_q
-
(
m_block_max
-
1
)
*
kBlockM_
),
warp_id
);
}
//
__builtin_amdgcn_s_waitcnt(0);
//
__syncthreads();
__builtin_amdgcn_s_waitcnt
(
0
);
__syncthreads
();
union_vec4_fp32
acc_dv
[(
K_v
/
kBlockK_
)
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))][
4
]
=
{
0
};
...
...
csrc/flash_attn_hg/include/bwd/flash_attention_dv_dk_bwd_gfx946.h
0 → 100644
View file @
518a5f4d
#define print_kq(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int qk_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int qk_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int qk_global_offset = bidb*(params.h* params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ \
+ block_id_m*kBlockM_*params.seqlen_k + qk_warp_n_id*WARP_N_ + qk_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + qk_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + qk_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = qk_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
kq_ptr[offset + reg_id *params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]); \
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_softmax_rescale_o(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int s_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int s_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int s_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + s_warp_n_id*WARP_N_ + s_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + s_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + s_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = s_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
s_ptr[offset + reg_id * params.seqlen_k] = (s_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = (dS_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f32[reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
#define print_ds_fp16(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
ds_ptr[offset + reg_id * params.seqlen_k] = UpCast<Element,float,true>(dS_reg_fp16[(m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx)*2 + min_tile_m].f16[min_tile_n*4 + reg_id]);\
} \
} \
} \
} \
} \
} \
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
// #define print_ds_fp16(block_id_m, bidb, bidh) { \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// int ds_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
// int ds_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
// int ds_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + ds_warp_n_id*WARP_N_ + ds_warp_m_id*WARP_M_*params.seqlen_k; \
// for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) { \
// for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
// for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
// for(int reg_id=0; reg_id<4; reg_id++) { \
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
// for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
// if(((n_block*kBlockN_ + ds_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
// ((block_id_m*kBlockM_ + ds_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
// int offset = ds_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
// ds_ptr[offset + reg_id * 8 * params.seqlen_k] = UpCast<Element,float>(dS_reg_fp16[m_block_idx * (WARP_M_/32) *(WARP_N_/32)+ m_idx*(WARP_N_/32) + n_idx][min_tile_n + min_tile_m*2].f16[reg_id]);\
// } \
// } \
// } \
// } \
// } \
// } \
// } \
// __builtin_amdgcn_sched_barrier(0);\
// __builtin_amdgcn_s_waitcnt(0);\
// __syncthreads();\
// __builtin_amdgcn_sched_barrier(0);\
// }
#define print_dp(block_id_m, bidb, bidh) { \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
int dp_warp_m_id = (warp_id / (kBlockN_/WARP_N_)); \
int dp_warp_n_id = (warp_id & (kBlockN_/WARP_N_ - 1)); \
int dp_global_offset = bidb*(params.h * params.seqlen_q * params.seqlen_k) + bidh * (params.seqlen_q * params.seqlen_k) + n_block*kBlockN_ + block_id_m*kBlockM_*params.seqlen_k + dp_warp_n_id*WARP_N_ + dp_warp_m_id*WARP_M_*params.seqlen_k; \
for(int n_idx=0; n_idx<(WARP_N_/32); n_idx++) {\
for(int m_block_idx=0; m_block_idx<kBlockM_/WARP_M_; m_block_idx++) { \
for(int m_idx=0; m_idx<(WARP_M_/32); m_idx++) { \
for(int reg_id=0; reg_id<4; reg_id++) { \
for(int min_tile_n=0; min_tile_n<2; min_tile_n++) { \
for(int min_tile_m=0; min_tile_m<2; min_tile_m++) { \
if(((n_block*kBlockN_ + dp_warp_n_id*WARP_N_ + n_idx * 32 + (lane_id & 15) + min_tile_n*16) < params.seqlen_k) && \
((block_id_m*kBlockM_ + dp_warp_m_id*WARP_M_ + m_idx*32 + reg_id + min_tile_m*16 + ((lane_id / 16) * 4)) < params.seqlen_q)) { \
int offset = dp_global_offset + n_idx * 32 + m_block_idx * WARP_M_ * params.seqlen_k + m_idx*32*params.seqlen_k+ (lane_id & 15) + min_tile_m*params.seqlen_k*16 + ((lane_id / 16) * 4) *params.seqlen_k + min_tile_n*16 ; \
dp_ptr[offset + reg_id * params.seqlen_k] = (dp_reg[m_block_idx * (WARP_M_/32) *(WARP_N_/32) + m_idx*(WARP_N_/32) + n_idx][min_tile_m*2 + min_tile_n].f32[reg_id]);\
} \
} \
} \
} \
} \
}\
} \
__builtin_amdgcn_sched_barrier(0);\
__builtin_amdgcn_s_waitcnt(0);\
__syncthreads();\
__builtin_amdgcn_sched_barrier(0);\
}
/*
load q/k:累加方向为主序方向
ps: 在offset传0的情况下,T和R的取值似乎没有影响!?
调用matrix_load_32x32_b16:
R=0: offset in column direction
load Q: T=1: row major
load K: T=0: column major
m_ab=1: 线程数据在主序方向拼接
调用ds_read_matrix_trans_format(和m_ab保持一致):
element:0x2 row:0x2 col:0x1 alt:0x0
load v:累加方向为非主序方向
调用matrix_load_32x32_b16:
R=0: offset in column direction
T=1: row major
m_ab=0: 线程数据在非主序方向拼接
调用ds_read_matrix_format(和m_ab保持一致)
*/
template
<
class
Element
,
class
ElementAccum
,
bool
Is_dropout
,
bool
Is_causal
,
bool
Is_local
,
bool
Is_even_MN
,
bool
Is_even_K
,
bool
Is_first
,
bool
Is_last
,
bool
Seq_parallel
=
false
,
int
kBlockM_
,
int
kBlockN_
,
int
K
,
int
K_v
,
int
kBlockK_
,
int
WARP_M_
,
int
WARP_N_
,
bool
USE_BSHD_LAYOUT
,
typename
Params
>
__forceinline__
__device__
void
compute_dk_dv_1colblock_gfx946
(
Params
&
params
,
int
bidb
,
int
bidh
,
int
n_block
)
{
#ifdef DEBUGING
ElementAccum
*
kq_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
kq_ptr
);
ElementAccum
*
s_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
s_ptr
);
ElementAccum
*
dp_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
dp_ptr
);
ElementAccum
*
ds_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
ds_ptr
);
#endif
Element
*
q_ptr
=
static_cast
<
Element
*>
(
params
.
q_ptr
);
Element
*
k_ptr
=
static_cast
<
Element
*>
(
params
.
k_ptr
);
Element
*
v_ptr
=
static_cast
<
Element
*>
(
params
.
v_ptr
);
Element
*
o_ptr
=
static_cast
<
Element
*>
(
params
.
o_ptr
);
Element
*
p_ptr
=
static_cast
<
Element
*>
(
params
.
p_ptr
);
// Element* dq_ptr = static_cast<Element*>(params.dq_ptr);
Element
*
dk_ptr
=
static_cast
<
Element
*>
(
params
.
dk_ptr
);
Element
*
dv_ptr
=
static_cast
<
Element
*>
(
params
.
dv_ptr
);
Element
*
do_ptr
=
static_cast
<
Element
*>
(
params
.
do_ptr
);
ElementAccum
*
softmax_lse_ptr
=
static_cast
<
ElementAccum
*>
(
params
.
softmax_lse_ptr
);
ElementAccum
*
dsoftmax_sum
=
static_cast
<
ElementAccum
*>
(
params
.
dsoftmax_sum
);
//flash-attention QK, kBlockN_==WARP_N_;
// static_assert(kBlockM_=WARP_M_,"Error: kBlockM_ not equal WARP_M_!");
const
int
WARP_NUM
=
(
kBlockM_
*
kBlockN_
)
/
(
WARP_M_
*
WARP_N_
);
const
int
M_BLOCK_NUM
=
params
.
seqlen_q
/
kBlockM_
;
const
int
N_BLOCK_NUM
=
params
.
seqlen_k
/
kBlockN_
;
extern
__shared__
Element
smem
[];
int
K_lds_ratio
;
// 0表示k不预取;1表示k预取一半到寄存器;2表示一半到寄存器,一半到LDS;3表示全部预取到寄存器
const
int
K_prefetch_level
=
3
;
const
int
STAGES
=
2
;
const
bool
Is_store_Q
=
true
;
const
bool
Is_store_dO
=
true
;
const
bool
Is_preload_Q
=
true
;
const
bool
Is_preload_dO
=
true
;
const
int
dP_dO_prefetch_level
=
Is_store_dO
?
1
:
0
;
const
int
Q_prefetech_level
=
Is_preload_Q
?
1
:
0
;
if
constexpr
(
K_prefetch_level
==
2
){
K_lds_ratio
=
(
K
/
kBlockK_
)
/
2
;
}
else
{
K_lds_ratio
=
(
K_prefetch_level
==
3
)
?
0
:
STAGES
;
}
Element
*
K_lds
=
(
Element
*
)
&
(
smem
);
Element
*
dO_lds
=
K_lds
+
kBlockN_
*
kBlockK_
*
K_lds_ratio
;
Element
*
V_lds
=
K_prefetch_level
==
2
?
dO_lds
:
K_lds
;
Element
*
Q_lds
=
Is_store_Q
?
dO_lds
+
kBlockM_
*
K_v
:
dO_lds
;
#if 0//defined(__gfx936__)
auto pointwise_mult = [](vec2_fp32 p, vec2_fp32 dp, vec2_fp32 d) {
auto d0 = (!Is_dropout || p[0] >= 0 ? dp[0] - d[0] : d[0]);
auto d1 = (!Is_dropout || p[1] >= 0 ? dp[1] - d[1] : d[1]);
// return vec2_fp32{p[0]*d0,p[1]*d1};
// return __builtin_hcu_pk_mul_f32(p, vec2_fp32{d0, d1});
return __builtin_hcu_v_pk_mul_f32(p, vec2_fp32{d0, d1});
};
#else
auto
pointwise_mult
=
[](
float
p
,
float
dp
,
float
d
)
{
return
p
*
(
!
Is_dropout
||
p
>=
0
?
dp
-
d
:
d
);
};
#endif
int
tidx
=
threadIdx
.
x
;
int
lane_id
=
threadIdx
.
x
&
63
;
//lane id, 0-63
const
flash
::
BlockInfo
<
/*Varlen=*/
!
Is_even_MN
,
false
,
USE_BSHD_LAYOUT
>
binfo
(
params
,
bidb
);
if
(
n_block
*
kBlockN_
>=
binfo
.
actual_seqlen_k
||
binfo
.
actual_seqlen_q
==
0
)
return
;
const
int
m_block_min
=
(
!
Is_causal
&&
!
Is_local
)
?
0
:
std
::
max
(
0
,
(
n_block
*
kBlockN_
-
params
.
window_size_right
)
/
kBlockM_
);
const
int
m_block_max
=
!
Is_local
?
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM_
)
:
std
::
min
(
ceil_div
(
binfo
.
actual_seqlen_q
,
kBlockM_
),
ceil_div
((
n_block
+
1
)
*
kBlockN_
+
params
.
window_size_left
,
kBlockM_
));
int
seqlen_q_stride
=
params
.
q_row_stride
;
int
seqlen_k_stride
=
params
.
k_row_stride
;
int
seqlen_v_stride
=
params
.
v_row_stride
;
int
seqlen_do_stride
=
params
.
do_row_stride
;
int
seqlen_o_stride
=
params
.
o_row_stride
;
int
seqlen_dk_stride
=
params
.
dk_row_stride
;
int
seqlen_dv_stride
=
params
.
dv_row_stride
;
// We move K and V to the last block.
const
int
row_offset_q
=
binfo
.
q_offset1
(
params
.
q_batch_stride
,
params
.
q_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
q_head_stride
,
bidh
)
+
(
m_block_max
-
1
)
*
kBlockM_
*
seqlen_q_stride
;
const
int
row_offset_k
=
binfo
.
k_offset1
(
params
.
k_batch_stride
,
params
.
k_row_stride
,
bidb
)
+
binfo
.
k_offset2
(
params
.
k_head_stride
,
bidh
/
params
.
h_h_k_ratio
)
+
n_block
*
kBlockN_
*
seqlen_k_stride
;
const
int
row_offset_v
=
binfo
.
k_offset1
(
params
.
v_batch_stride
,
params
.
v_row_stride
,
bidb
)
+
binfo
.
k_offset2
(
params
.
v_head_stride
,
bidh
/
params
.
h_h_k_ratio
)
+
n_block
*
kBlockN_
*
seqlen_v_stride
;
const
int
row_offset_dO
=
binfo
.
q_offset1
(
params
.
do_batch_stride
,
params
.
do_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
do_head_stride
,
bidh
)
+
(
m_block_max
-
1
)
*
kBlockM_
*
seqlen_do_stride
;
const
int
row_offset_o
=
binfo
.
q_offset1
(
params
.
o_batch_stride
,
params
.
o_row_stride
,
bidb
)
+
binfo
.
q_offset2
(
params
.
o_head_stride
,
bidh
)
+
(
m_block_max
-
1
)
*
kBlockM_
*
seqlen_o_stride
;
// const int row_offset_lse = (bidb * params.h + bidh) * params.seqlen_q + (m_block_max - 1) * kBlockM_;
const
int
row_offset_lse
=
params
.
cu_seqlens_q
==
nullptr
?
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q
+
(
m_block_max
-
1
)
*
kBlockM_
:
bidh
*
params
.
total_q
+
binfo
.
sum_s_q
+
(
m_block_max
-
1
)
*
kBlockM_
;
const
int
row_offset_dpsum
=
(
bidb
*
params
.
h
+
bidh
)
*
params
.
seqlen_q_rounded
+
(
m_block_max
-
1
)
*
kBlockM_
;
auto
gQ
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
q_ptr
)
+
row_offset_q
,
seqlen_q_stride
);
auto
gK
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
k_ptr
)
+
row_offset_k
,
seqlen_k_stride
);
auto
gV
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
v_ptr
)
+
row_offset_v
,
seqlen_v_stride
);
auto
gdO
=
prepare_for_matrix_load_gfx938
<
Element
>
(
reinterpret_cast
<
Element
*>
(
do_ptr
)
+
row_offset_dO
,
seqlen_do_stride
);
Element
*
gO
=
reinterpret_cast
<
Element
*>
(
o_ptr
)
+
row_offset_o
;
ElementAccum
*
gLSE
=
reinterpret_cast
<
ElementAccum
*>
(
softmax_lse_ptr
)
+
row_offset_lse
;
ElementAccum
*
gdPsum
=
reinterpret_cast
<
ElementAccum
*>
(
dsoftmax_sum
)
+
row_offset_dpsum
;
constexpr
int
m_masking_steps
=
(
!
Is_causal
&&
!
Is_local
)
?
0
:
flash
::
ceil_div
(
kBlockN_
,
kBlockM_
);
/***************************************************************************************************************************/
// int warp_id =0;
int
warp_id_vec
=
threadIdx
.
x
/
64
;
//warp id in a block
int
warp_id
=
__builtin_amdgcn_readfirstlane
(
warp_id_vec
);
union_vec4_f16x2
<
Element
>
k_reg
[(
K
/
kBlockK_
)
*
((
WARP_N_
*
kBlockK_
)
/
(
32
*
32
))
*
2
/
((
K_prefetch_level
==
3
)
?
1
:
2
)];
//ds_read mini size is 32*32,2 is seq, 4 is head dim
union_vec4_f16x2
<
Element
>
v_reg
[(
K_v
/
kBlockK_
)
*
((
WARP_N_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
//提前读取V到vgpr
prefetch_to_vgpr_gfx938
<
true
,
kBlockN_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gV
,
V_lds
,
v_reg
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
warp_id
);
//提前读取K到vgpr
prefetch_to_vgpr_gfx938
<
true
,
kBlockN_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gK
,
K_lds
,
k_reg
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
warp_id
);
//提前读取Q到lds
if
constexpr
(
Is_preload_Q
){
prefetch_to_lds_gfx938
<
true
,
kBlockM_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gQ
,
0
,
Q_lds
,
(
binfo
.
actual_seqlen_q
-
(
m_block_max
-
1
)
*
kBlockM_
),
warp_id
);
}
//提前读取dO到lds
if
constexpr
(
Is_preload_dO
){
prefetch_to_lds_gfx938
<
true
,
kBlockM_
,
K_v
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gdO
,
0
,
dO_lds
,
(
binfo
.
actual_seqlen_q
-
(
m_block_max
-
1
)
*
kBlockM_
),
warp_id
);
}
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
union_vec4_fp32
acc_dv
[(
K_v
/
kBlockK_
)
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))][
4
]
=
{
0
};
union_vec4_fp32
acc_dk
[(
K
/
kBlockK_
)
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))][
4
]
=
{
0
};
for
(
int
m_block
=
m_block_max
-
1
;
m_block
>=
m_block_min
;
--
m_block
)
{
union_vec4_f16x2
<
Element
>
q_reg
[((
WARP_M_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
//c mini tile is 32*32
union_vec4_fp32
s_reg
[(
WARP_N_
/
32
)
*
(
kBlockM_
/
32
)][
4
]
=
{
0
};
/*
qk gemm
结果矩阵layout:
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
...
0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48 0 0 0 0 16 16 16 16 32 32 32 32 48 48 48 48
1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49 1 1 1 1 17 17 17 17 33 33 33 33 49 49 49 49
*/
gemm_tt_kq_gfx938
<
Is_store_Q
,
Is_preload_dO
,
Is_even_MN
,
K_prefetch_level
,
Q_prefetech_level
,
K
,
kBlockN_
,
kBlockM_
,
kBlockK_
,
WARP_N_
,
WARP_M_
,
STAGES
,
Element
>
(
gK
,
gQ
,
K_lds
,
Q_lds
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
k_reg
,
q_reg
,
s_reg
,
warp_id
,
seqlen_k_stride
,
seqlen_q_stride
);
/*
lse layout:
4 warp:
32
32
32
32
因为warp在seqlen_k维度,所以不区分warp
每16个thread持有相同的lse,所以需要/4
*/
float
lse
[
kBlockM_
/
4
];
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
kBlockM_
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
const
int
lse_idx
=
mi
*
32
+
min_tile_m
*
16
+
(
lane_id
>>
4
)
*
4
+
vec_idx
;
lse
[(
mi
*
2
+
min_tile_m
)
*
4
+
vec_idx
]
=
Is_even_MN
||
lse_idx
<
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
?
gLSE
[
lse_idx
]
:
INFINITY
;
}
}
}
apply_mask_bwd_gfx938
<
Is_even_MN
,
Is_local
?
3
:
(
Is_causal
?
2
:
0
)
>
(
s_reg
,
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
-
warp_id
*
32
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
,
(
n_block
*
kBlockN_
+
warp_id
*
32
)
-
m_block
*
kBlockM_
,
params
.
window_size_right
,
params
.
window_size_left
);
#ifdef DEBUGING
print_kq
(
m_block
,
bidb
,
bidh
);
#endif
//do . o后在headdim维度reduce求和,读取方式和lse一样,因为pad了,所以无需边界判断
float
dP_sum_reg
[
kBlockM_
/
4
];
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
kBlockM_
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
const
int
dPsum_idx
=
mi
*
32
+
min_tile_m
*
16
+
(
lane_id
>>
4
)
*
4
+
vec_idx
;
dP_sum_reg
[(
mi
*
2
+
min_tile_m
)
*
4
+
vec_idx
]
=
gdPsum
[
dPsum_idx
];
}
}
}
{
scale_apply_exp2_bwd
<
/*scale_max=*/
false
,
kBlockM_
,
WARP_N_
>
(
s_reg
,
lse
,
params
.
scale_softmax_log2
);
}
#ifdef DEBUGING
print_softmax_rescale_o
(
m_block
,
bidb
,
bidh
);
#endif
// //TODO:drop
union_vec4_f16x2
<
Element
>
p_reg
[(
kBlockM_
/
32
)
*
(
WARP_N_
/
32
)
*
2
];
// convert_pk_type<kBlockM_, WARP_N_, Element>(p_reg, s_reg);
convert_pk_type_gfx938
<
kBlockM_
,
WARP_N_
,
Element
>
(
p_reg
,
s_reg
);
//QK(seq_q, seq_kv), seq_q is continuous, seq_kv is not continuous
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dv gemm,dO*P
gpu_gemm_B_in_reg_gfx946
<
Is_preload_dO
,
Is_store_dO
,
Is_even_MN
,
K_v
,
kBlockK_
,
kBlockN_
,
kBlockM_
,
kBlockK_
,
WARP_N_
,
2
,
Element
>
(
gdO
,
gQ
,
dO_lds
,
p_reg
,
acc_dv
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
warp_id
,
seqlen_do_stride
);
}
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
union_vec4_f16x2
<
Element
>
dO_reg
[((
WARP_M_
*
kBlockK_
)
/
(
32
*
32
))
*
2
];
union_vec4_fp32
dp_reg
[(
WARP_N_
/
32
)
*
(
kBlockM_
/
32
)][
4
]
=
{
0
};
{
// dP gemm dO * V
gemm_tt_kq_gfx938
<
Is_store_dO
,
false
,
Is_even_MN
,
3
,
dP_dO_prefetch_level
,
K_v
,
kBlockN_
,
kBlockM_
,
kBlockK_
,
WARP_N_
,
WARP_M_
,
STAGES
,
Element
>
(
gV
,
gdO
,
V_lds
,
dO_lds
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
v_reg
,
dO_reg
,
dp_reg
,
warp_id
,
seqlen_v_stride
,
seqlen_do_stride
);
}
#ifdef DEBUGING
print_dp
(
m_block
,
bidb
,
bidh
);
#endif
union_vec4_fp32
dS_reg
[(
WARP_N_
/
32
)
*
(
kBlockM_
/
32
)][
4
];
#pragma unroll
for
(
int
mi
=
0
;
mi
<
(
kBlockM_
/
32
);
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
(
WARP_N_
/
32
);
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
#if 0//defined(__gfx936__)
#pragma unroll
for(int vec_idx=0; vec_idx<2; vec_idx++) {
// result register ds_reg reuse dp_reg
dS_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx] = pointwise_mult(
s_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
dp_reg[ni + mi*(WARP_N_/32)][min_tile_m*2 + min_tile_n].u64[vec_idx],
vec2_fp32{gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m], gdPsum[vec_idx*16 + mi*8*4 + ((lane_id >> 4)*2) + min_tile_m + 8]});
}
#else
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
// result register ds_reg reuse dp_reg
dS_reg
[
ni
+
mi
*
(
WARP_N_
/
32
)][
min_tile_m
*
2
+
min_tile_n
].
f32
[
vec_idx
]
=
pointwise_mult
(
s_reg
[
ni
+
mi
*
(
WARP_N_
/
32
)][
min_tile_m
*
2
+
min_tile_n
].
f32
[
vec_idx
],
dp_reg
[
ni
+
mi
*
(
WARP_N_
/
32
)][
min_tile_m
*
2
+
min_tile_n
].
f32
[
vec_idx
],
dP_sum_reg
[
min_tile_m
*
4
+
vec_idx
]);
}
#endif
}
}
}
}
// #ifdef DEBUGING
// print_ds(m_block, bidb, bidh);
// #endif
union_vec4_f16x2
<
Element
>
dS_reg_fp16
[(
WARP_N_
/
32
)
*
(
kBlockM_
/
32
)
*
2
];
convert_pk_type_gfx938
<
kBlockM_
,
WARP_N_
,
Element
>
(
dS_reg_fp16
,
dS_reg
);
// #ifdef DEBUGING
// print_ds_fp16(m_block, bidb, bidh);
// #endif
// __builtin_amdgcn_sched_barrier(0);
// __builtin_amdgcn_s_waitcnt(0);
// __syncthreads();
// __builtin_amdgcn_sched_barrier(0);
{
//dk gemm, Q*dS
gpu_gemm_B_in_reg_gfx946
<
Is_store_Q
,
false
,
Is_even_MN
,
K
,
kBlockK_
,
kBlockN_
,
kBlockM_
,
kBlockK_
,
WARP_N_
,
2
,
Element
>
(
gQ
,
gdO
,
Q_lds
,
dS_reg_fp16
,
acc_dk
,
(
binfo
.
actual_seqlen_k
-
n_block
*
kBlockN_
),
(
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM_
),
warp_id
,
seqlen_q_stride
);
}
gLSE
=
gLSE
+
(
-
int
(
kBlockM_
));
gdPsum
=
gdPsum
-
kBlockM_
;
*
(
uint64_t
*
)
&
gQ
-=
((
kBlockM_
*
seqlen_q_stride
)
*
sizeof
(
Element
));
*
(
uint64_t
*
)
&
gdO
-=
((
kBlockM_
*
seqlen_do_stride
)
*
sizeof
(
Element
));
{
__syncthreads
();
if
(
Is_preload_Q
&&
m_block
>
m_block_min
){
prefetch_to_lds_gfx938
<
true
,
kBlockM_
,
K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gQ
,
0
,
Q_lds
,
(
binfo
.
actual_seqlen_q
-
(
m_block
-
1
)
*
kBlockM_
),
warp_id
);
}
// __syncthreads();
if
(
Is_preload_dO
&&
m_block
>
m_block_min
){
prefetch_to_lds_gfx938
<
true
,
kBlockM_
,
K_v
,
Element
,
ElementAccum
,
Is_even_MN
>
(
gdO
,
0
,
dO_lds
,
(
binfo
.
actual_seqlen_q
-
(
m_block
-
1
)
*
kBlockM_
),
warp_id
);
}
}
}
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
// dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
dv_ptr
=
dv_ptr
+
binfo
.
k_offset1_write
(
params
.
dv_batch_stride
,
params
.
dv_row_stride
,
bidb
)
+
binfo
.
k_offset2
(
params
.
dv_head_stride
,
bidh
);
auto
gdV
=
tcp_cache_swizzle_func
<
K_v
,
Element
>
(
dv_ptr
);
int
dv_lane_seq_idx
=
(
lane_id
>>
4
);
int
dv_lane_head_dim_idx
=
(
lane_id
&
15
);
int
dv_global_addr_offset
=
0
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
K_v
/
kBlockK_
);
k_loop
++
)
{
#pragma unroll
for
(
int
warp_n_idx
=
0
;
warp_n_idx
<
(
WARP_N_
/
32
);
warp_n_idx
++
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK_
/
32
);
k_tile_idx
++
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
vec_index
++
)
{
int
v_offset
=
dv_lane_head_dim_idx
*
seqlen_dv_stride
+
dv_lane_seq_idx
*
4
;
int
s_offset
=
(
min_tile_n
*
seqlen_dv_stride
*
16
+
vec_index
%
2
*
2
+
vec_index
/
2
*
16
)
+
(
k_tile_idx
*
32
)
+
((
warp_id
*
WARP_N_
+
warp_n_idx
*
32
)
*
seqlen_dv_stride
)
+
(
k_loop
*
kBlockK_
+
n_block
*
kBlockN_
*
seqlen_dv_stride
);
int
known_offset
=
0
;
vec2_Element
<
Element
>
v_data
;
v_data
[
0
]
=
DownCast
<
float
,
Element
,
true
>
(
acc_dv
[
k_loop
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))
+
(
warp_n_idx
*
(
kBlockK_
/
32
)
+
k_tile_idx
)][
min_tile_n
*
2
+
vec_index
/
2
].
f32
[
vec_index
%
2
*
2
]);
v_data
[
1
]
=
DownCast
<
float
,
Element
,
true
>
(
acc_dv
[
k_loop
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))
+
(
warp_n_idx
*
(
kBlockK_
/
32
)
+
k_tile_idx
)][
min_tile_n
*
2
+
vec_index
/
2
].
f32
[
vec_index
%
2
*
2
+
1
]);
if
(
Is_even_MN
||
min_tile_n
*
16
+
(
warp_id
*
WARP_N_
+
warp_n_idx
*
32
)
+
n_block
*
kBlockN_
+
dv_lane_head_dim_idx
<
binfo
.
actual_seqlen_k
){
inline_buffer_store_dword_glc_slc
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gdV
,
s_offset
,
/* immediate integer */
known_offset
);
}
}
}
}
}
}
}
#endif
#if 1
//这是正常的MLS+ds_read_matrix的layout
{
dk_ptr
=
dk_ptr
+
binfo
.
k_offset1_write
(
params
.
dk_batch_stride
,
params
.
dk_row_stride
,
bidb
)
+
binfo
.
k_offset2
(
params
.
dk_head_stride
,
bidh
);
auto
gdK
=
tcp_cache_swizzle_func
<
K_v
,
Element
>
(
dk_ptr
);
int
dk_lane_seq_idx
=
(
lane_id
>>
4
);
int
dk_lane_head_dim_idx
=
(
lane_id
&
15
);
int
dk_global_addr_offset
=
0
;
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
(
K_v
/
kBlockK_
);
k_loop
++
)
{
#pragma unroll
for
(
int
warp_n_idx
=
0
;
warp_n_idx
<
(
WARP_N_
/
32
);
warp_n_idx
++
)
{
#pragma unroll
for
(
int
k_tile_idx
=
0
;
k_tile_idx
<
(
kBlockK_
/
32
);
k_tile_idx
++
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
#pragma unroll
for
(
int
vec_index
=
0
;
vec_index
<
4
;
vec_index
++
)
{
int
v_offset
=
dk_lane_head_dim_idx
*
seqlen_dk_stride
+
dk_lane_seq_idx
*
4
;
int
s_offset
=
(
min_tile_n
*
seqlen_dk_stride
*
16
+
vec_index
%
2
*
2
+
vec_index
/
2
*
16
)
+
(
k_tile_idx
*
32
)
+
((
warp_id
*
WARP_N_
+
warp_n_idx
*
32
)
*
seqlen_dk_stride
)
+
(
k_loop
*
kBlockK_
+
n_block
*
kBlockN_
*
seqlen_dk_stride
);
int
known_offset
=
0
;
vec2_Element
<
Element
>
v_data
;
v_data
[
0
]
=
DownCast
<
float
,
Element
,
true
>
(
acc_dk
[
k_loop
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))
+
(
warp_n_idx
*
(
kBlockK_
/
32
)
+
k_tile_idx
)][
min_tile_n
*
2
+
vec_index
/
2
].
f32
[
vec_index
%
2
*
2
]
*
params
.
scale_softmax_rp_dropout
);
v_data
[
1
]
=
DownCast
<
float
,
Element
,
true
>
(
acc_dk
[
k_loop
*
((
WARP_N_
/
32
)
*
(
kBlockK_
/
32
))
+
(
warp_n_idx
*
(
kBlockK_
/
32
)
+
k_tile_idx
)][
min_tile_n
*
2
+
vec_index
/
2
].
f32
[
vec_index
%
2
*
2
+
1
]
*
params
.
scale_softmax_rp_dropout
);
if
(
Is_even_MN
||
min_tile_n
*
16
+
(
warp_id
*
WARP_N_
+
warp_n_idx
*
32
)
+
n_block
*
kBlockN_
+
dk_lane_head_dim_idx
<
binfo
.
actual_seqlen_k
){
inline_buffer_store_dword_glc_slc
<
vec2_Element
<
Element
>
,
1
>
(
v_data
,
v_offset
,
gdK
,
s_offset
,
/* immediate integer */
known_offset
);
}
}
}
}
}
}
}
#endif
// #if 1
// {
// // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// int dv_lane_seq_idx = (lane_id >> 4);
// int dv_lane_head_dim_idx = (lane_id & 15);
// int dv_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 8;
// int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// int known_offset = 0;
// vec2_Element<Element> v_data;
// v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 0].f32[vec_index]);
// v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
// if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dv_lane_head_dim_idx < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// #endif
// // //test only
// // {
// // // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// // dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// // auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// // int dv_lane_seq_idx = (lane_id >> 4);
// // int dv_lane_head_dim_idx = (lane_id & 15);
// // int dv_global_addr_offset=0;
// // #pragma unroll
// // for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// // #pragma unroll
// // for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// // #pragma unroll
// // for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// // #pragma unroll
// // for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// // #pragma unroll
// // for(int vec_index=0; vec_index<4; vec_index++) {
// // // int v_offset = dv_lane_head_dim_idx * seqlen_dv_stride + dv_lane_seq_idx * 8;
// // // int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// // int v_offset = dv_lane_head_dim_idx * 2 + dv_lane_seq_idx * 4 * seqlen_dv_stride;
// // int s_offset = (min_tile_n * seqlen_dv_stride * 16 + vec_index * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// // int known_offset = 0;
// // vec2_Element<Element> v_data;
// // v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2]);
// // v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + vec_index / 2].f32[vec_index % 2 * 2 + 1]);
// // inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// // }
// // }
// // }
// // }
// // }
// // }
// {
// // dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
// dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
// auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
// int dk_lane_seq_idx = (lane_id >> 4);
// int dk_lane_head_dim_idx = (lane_id & 15);
// int dk_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// vec2_Element<Element> v_data;
// int v_offset = dk_lane_head_dim_idx * seqlen_dk_stride + dk_lane_seq_idx * 8;
// int s_offset = (min_tile_n * seqlen_dk_stride * 16 + vec_index * 2) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dk_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dk_stride);
// int known_offset = 0;
// v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 0].f32[vec_index] * params.scale_softmax_rp_dropout);
// v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
// if (Is_even_MN || min_tile_n*16 + (warp_id*WARP_N_ + warp_n_idx*32) + n_block * kBlockN_ + dk_lane_head_dim_idx < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// {
// // dv_ptr = dv_ptr + binfo.k_offset1(params.dv_batch_stride, params.dv_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dv_head_stride,bidh);
// dv_ptr = dv_ptr + binfo.k_offset1_write(params.dv_batch_stride, params.dv_row_stride, bidb) + binfo.k_offset2(params.dv_head_stride,bidh);
// auto gdV = tcp_cache_swizzle_func<K_v, Element>(dv_ptr);
// int dv_lane_seq_idx = (lane_id >> 4);
// int dv_lane_head_dim_idx = (lane_id & 15);
// int dv_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K_v/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// int v_offset = dv_lane_head_dim_idx*2 + dv_lane_seq_idx * seqlen_dv_stride;
// int s_offset = (min_tile_n*seqlen_dv_stride*16 + vec_index * 4 * seqlen_dv_stride) + (k_tile_idx*32) + ((warp_id*WARP_N_ + warp_n_idx*32) * seqlen_dv_stride) + (k_loop * kBlockK_ + n_block * kBlockN_ * seqlen_dv_stride);
// int known_offset = 0;
// vec2_Element<Element> v_data;
// v_data[0] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index]);
// v_data[1] = DownCast<float,Element,true>(acc_dv[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index]);
// if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + warp_n_idx*32 + dv_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdV, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
// {
// // dk_ptr = dk_ptr + binfo.k_offset1(params.dk_batch_stride, params.dk_row_stride, bidb)*params.h_h_k_ratio + binfo.k_offset2(params.dk_head_stride,bidh);
// dk_ptr = dk_ptr + binfo.k_offset1_write(params.dk_batch_stride, params.dk_row_stride, bidb) + binfo.k_offset2(params.dk_head_stride,bidh);
// auto gdK = tcp_cache_swizzle_func<K, Element>(dk_ptr);
// int dk_lane_seq_idx = (lane_id >> 4);
// int dk_lane_head_dim_idx = (lane_id & 15);
// int dk_global_addr_offset=0;
// #pragma unroll
// for(int k_loop = 0; k_loop<(K/kBlockK_); k_loop++) {
// #pragma unroll
// for(int warp_n_idx=0; warp_n_idx<(WARP_N_/32); warp_n_idx++) {
// #pragma unroll
// for(int k_tile_idx=0; k_tile_idx<(kBlockK_/32); k_tile_idx++) {
// #pragma unroll
// for(int min_tile_n=0; min_tile_n<2; min_tile_n++) {
// #pragma unroll
// for(int vec_index=0; vec_index<4; vec_index++) {
// vec2_Element<Element> v_data;
// int v_offset = dk_lane_head_dim_idx*2 + dk_lane_seq_idx * seqlen_dk_stride;
// int s_offset = n_block * kBlockN_ * seqlen_dk_stride + (warp_id*WARP_N_) * seqlen_dk_stride + (min_tile_n*seqlen_dk_stride*16 + vec_index * 4 * seqlen_dk_stride + k_tile_idx*32 + k_loop * kBlockK_ + warp_n_idx*32);
// int known_offset = 0;
// v_data[0] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2].f32[vec_index] * params.scale_softmax_rp_dropout);
// v_data[1] = DownCast<float,Element,true>(acc_dk[k_loop * ((WARP_N_/32)*(kBlockK_/32)) + (warp_n_idx*(kBlockK_/32) + k_tile_idx)][min_tile_n*2 + 1].f32[vec_index] * params.scale_softmax_rp_dropout);
// if (Is_even_MN || n_block * kBlockN_ + warp_id*WARP_N_ + dk_lane_seq_idx + min_tile_n*16 + vec_index * 4 < binfo.actual_seqlen_k){
// inline_buffer_store_dword_glc_slc<vec2_Element<Element>, 1>(v_data, v_offset, gdK, s_offset, /* immediate integer */known_offset);
// }
// }
// }
// }
// }
// }
// }
}
#undef print_dq
#undef print_softmax_rescale_o
#undef print_ds
#undef print_ds_fp16
#undef print_dp
csrc/flash_attn_hg/include/bwd/gpu_gemm_nn.h
View file @
518a5f4d
...
...
@@ -401,13 +401,11 @@ __forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
int
A_lds_stage_offset
=
stage_id
*
BLOCK_K
*
BLOCK_M
;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if
constexpr
(
std
::
is_same_v
<
Element
,
half_t
>
)
{
auto
*
const
f16_lds
=
hcu_ds_read_matrix_f16_lds_base
(
A_lds
+
A_lds_stage_offset
);
A_reg
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
f16_lds
,
0
,
2
,
1
,
0
);
A_reg
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
f16_lds
,
1024
,
2
,
1
,
0
);
A_reg
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
A_lds
+
A_lds_stage_offset
,
0
,
2
,
1
,
0
);
A_reg
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
A_lds
+
A_lds_stage_offset
,
1024
,
2
,
1
,
0
);
}
else
{
auto
*
const
bf16_lds
=
hcu_ds_read_matrix_bf16_lds_base
(
A_lds
+
A_lds_stage_offset
);
A_reg
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
bf16_lds
,
0
,
2
,
1
,
0
);
A_reg
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
bf16_lds
,
1024
,
2
,
1
,
0
);
A_reg
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
A_lds
+
A_lds_stage_offset
,
0
,
2
,
1
,
0
);
A_reg
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
A_lds
+
A_lds_stage_offset
,
1024
,
2
,
1
,
0
);
}
}
else
{
// gfx938 m_ab = 0的gemm想要复用m_ab = 1的LDS数据
...
...
@@ -485,3 +483,175 @@ __forceinline__ __device__ void gpu_gemm_B_in_reg_gfx938(
#endif
#endif
}
// K BLOCK_K BLOCK_N BLOCK_M BLOCK_K WARP_N
template
<
bool
Is_preload_A
,
bool
Is_store_A
,
bool
Is_even_MN
,
int
M
/*head_dim*/
,
int
BLOCK_M
,
int
BLOCK_N
,
int
BLOCK_K
,
int
WARP_M
,
int
WARP_N
,
int
STAGES
,
typename
Element
,
typename
ElementAccum
=
float
>
__forceinline__
__device__
void
gpu_gemm_B_in_reg_gfx946
(
vec4_uint
A_ptr
,
vec4_uint
C_ptr
,
Element
*
A_lds
,
union_vec4_f16x2
<
Element
>
B_reg
[(
WARP_M
/
32
)
*
(
BLOCK_K
/
32
)
*
2
],
union_vec4_fp32
C_reg
[(
M
/
BLOCK_M
)
*
(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
int
N
/*seq_kv*/
,
int
K
/*seq_q*/
,
int
warp_id
,
int
seqlen_A_stride
)
{
#if 1
const
int
WARP_NUM
=
(
BLOCK_M
*
BLOCK_N
)
/
(
WARP_M
*
WARP_N
);
const
int
A_lds_load_num
=
(
BLOCK_M
*
BLOCK_K
)
/
(
4
*
32
);
static_assert
(
BLOCK_K
>=
32
,
"Error: gpu_gemm_B_in_reg gemm BLOCK_K must be equal or greater than 32"
);
static_assert
(
BLOCK_N
>=
WARP_N
,
"Error: gpu_gemm_B_in_reg gemm BLOCK_N must be equal or greater than WARP_N"
);
static_assert
(
BLOCK_M
==
WARP_M
,
"Error: gpu_gemm_B_in_reg gemm BLOCK_M must be equal to WARP_M"
);
union_vec4_f16x2
<
Element
>
A_reg
[((
WARP_M
*
BLOCK_K
)
/
(
32
*
32
))
*
2
];
//c mini tile is 32*32
// vec4_fp32 o[(WARP_M/32)*(WARP_N/32)][4]={0};
// __shared__ Element A_lds[STAGES*BLOCK_N * BLOCK_K];
//wave size should be defined in launch file. Here use 64 threads
int
lane_id
=
threadIdx
.
x
&
63
;
//lane id, 0-63
int
row
=
lane_id
%
4
;
int
col
=
lane_id
/
4
;
int
stage_id
=
0
;
if
(
STAGES
>
1
&&
(
!
Is_preload_A
))
{
int
m_loop
=
0
;
int
A_block_buffer_load_global_offset
=
m_loop
*
BLOCK_M
;
int
A_lds_stage_offset
=
stage_id
*
BLOCK_M
*
BLOCK_K
;
prefetch_to_lds_gfx938
<
false
,
BLOCK_M
,
BLOCK_K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
A_ptr
,
A_block_buffer_load_global_offset
,
A_lds
+
A_lds_stage_offset
,
seqlen_A_stride
,
warp_id
);
}
#if 1
// int lds_offset = row * 8 + col * 32;
for
(
int
m_loop
=
1
;
m_loop
<
(
M
/
BLOCK_M
)
+
1
;
m_loop
++
)
{
if
(
STAGES
>
1
)
{
if
constexpr
(
Is_preload_A
||
Is_store_A
){
stage_id
++
;
}
else
{
stage_id
=
stage_id
^
1
;
}
}
if
(
STAGES
==
1
)
{
m_loop
--
;
}
if
((
!
Is_preload_A
)
&&
m_loop
<
(
M
/
BLOCK_M
))
{
int
A_block_buffer_load_global_offset
=
m_loop
*
BLOCK_M
;
int
A_lds_stage_offset
=
(
stage_id
)
*
BLOCK_K
*
BLOCK_M
;
prefetch_to_lds_gfx938
<
false
,
BLOCK_M
,
BLOCK_K
,
Element
,
ElementAccum
,
Is_even_MN
>
(
A_ptr
,
A_block_buffer_load_global_offset
,
A_lds
+
A_lds_stage_offset
,
seqlen_A_stride
,
warp_id
);
}
//BM = 32, BK = 32
if
(
warp_id
==
0
)
{
if
(
!
Is_preload_A
){
if
(
STAGES
>
1
&&
m_loop
<
(
M
/
BLOCK_M
))
{
vmcnt_wait
(
1
);
}
else
{
vmcnt_wait
(
0
);
}
}
}
if
constexpr
(
STAGES
>
1
)
{
if
constexpr
(
Is_preload_A
||
Is_store_A
){
stage_id
--
;
}
else
{
stage_id
=
stage_id
^
1
;
}
}
//lds -> vgpr use ds_read_m; left matrix
//由于ds_read方式发生了改变,mmac结果矩阵layout变化,存储的时候,offset要进行修改
{
int
A_lds_stage_offset
=
stage_id
*
BLOCK_K
*
BLOCK_M
;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg[0].f16, A_reg[1].f16, false);
if
constexpr
(
std
::
is_same_v
<
Element
,
half_t
>
)
{
A_reg
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
A_lds
+
A_lds_stage_offset
,
0
,
2
,
1
,
0
);
A_reg
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
A_lds
+
A_lds_stage_offset
,
1024
,
2
,
1
,
0
);
}
else
{
A_reg
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
A_lds
+
A_lds_stage_offset
,
0
,
2
,
1
,
0
);
A_reg
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
A_lds
+
A_lds_stage_offset
,
1024
,
2
,
1
,
0
);
}
}
asm
volatile
(
"s_waitcnt lgkmcnt(0)"
);
__builtin_amdgcn_sched_barrier
(
0
);
if
constexpr
(
STAGES
==
1
){
m_loop
++
;
}
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
n_idx
++
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
m_idx
++
)
{
#pragma unroll
for
(
int
k_idx
=
0
;
k_idx
<
(
BLOCK_K
/
32
);
k_idx
++
)
{
//BLOCK_K mini size is 32
//min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
min_tile_k
++
)
{
if
constexpr
(
std
::
is_same
<
Element
,
Float8_e4m3_t
>::
value
){
}
else
{
//A采用ds_read后对应的mmac
C_reg
[
m_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
//BN = 32, BK = 32
// vec4_Element<Element>{A_reg[min_tile_k].f16[0*2 + min_tile_m], A_reg[min_tile_k].f16[1*2 + min_tile_m], A_reg[min_tile_k].f16[2*2 + min_tile_m], A_reg[min_tile_k].f16[3*2 + min_tile_m]},
B_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
A_reg
[
min_tile_k
].
f16x4
[
min_tile_m
],
C_reg
[
m_loop
-
1
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
}
}
// //test only
// for(int min_tile_n = 0; min_tile_n < 2; ++ min_tile_n) {
// for(int min_tile_m = 0; min_tile_m < 2; ++ min_tile_m) {
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[0] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][0]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[1] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][1]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[2] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][2]);
// C_reg[m_loop-1][min_tile_n*2 + min_tile_m].f32[3] = UpCast<Element,float, true>(B_reg[min_tile_m].f16x4[min_tile_n][3]);
// }
// }
asm
volatile
(
"s_setprio 0"
);
if
(
STAGES
>
1
)
{
if
constexpr
(
Is_preload_A
||
Is_store_A
){
stage_id
++
;
}
else
{
stage_id
^=
1
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)"
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
else
{
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)"
);
__syncthreads
();
__builtin_amdgcn_sched_barrier
(
0
);
}
}
#endif
#endif
}
\ No newline at end of file
csrc/flash_attn_hg/include/bwd/gpu_gemm_tt.h
View file @
518a5f4d
...
...
@@ -410,13 +410,11 @@ __forceinline__ __device__ void gemm_tt_kq_gfx938(
int
A_lds_stage_offset
=
stage_id
*
BLOCK_M
*
BLOCK_K
;
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(A_lds + A_lds_stage_offset), A_reg_tmp[0].f16, A_reg_tmp[1].f16, true);
if
constexpr
(
std
::
is_same_v
<
Element
,
half_t
>
)
{
auto
*
const
f16_lds
=
hcu_ds_read_matrix_f16_lds_base
(
A_lds
+
A_lds_stage_offset
);
A_reg_tmp
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
f16_lds
,
0
,
2
,
1
,
0
);
A_reg_tmp
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
f16_lds
,
1024
,
2
,
1
,
0
);
A_reg_tmp
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
A_lds
+
A_lds_stage_offset
,
0
,
2
,
1
,
0
);
A_reg_tmp
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
A_lds
+
A_lds_stage_offset
,
1024
,
2
,
1
,
0
);
}
else
{
auto
*
const
bf16_lds
=
hcu_ds_read_matrix_bf16_lds_base
(
A_lds
+
A_lds_stage_offset
);
A_reg_tmp
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
bf16_lds
,
0
,
2
,
1
,
0
);
A_reg_tmp
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
bf16_lds
,
1024
,
2
,
1
,
0
);
A_reg_tmp
[
0
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
A_lds
+
A_lds_stage_offset
,
0
,
2
,
1
,
0
);
A_reg_tmp
[
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
A_lds
+
A_lds_stage_offset
,
1024
,
2
,
1
,
0
);
}
}
int
B_lds_stage_offset
=
stage_id
*
WARP_N
*
BLOCK_K
;
...
...
csrc/flash_attn_hg/include/bwd/prefetch.h
View file @
518a5f4d
...
...
@@ -117,13 +117,10 @@ inline __device__ void prefetch_to_vgpr_gfx938(
srsrc
[
3
]
=
nm_filter
<<
8
;
// set only once
}
*
(
uint64_t
*
)
&
srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
ptr
+
global_offset
*
ELEMENT_BYTES
);
union
union_vec4_uint
rsrc_bits
;
rsrc_bits
.
v32
=
srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
lds
)
+
lds_offset_stage
;
if
(
trans
)
{
matrix_load_b16_lds_trans
_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
rsrc
_bits
.
i32
,
0
);
inline_
matrix_load_
32x32_
b16_lds_trans
<
0
,
0
>
(
lds
,
s
rsrc
,
lds_offset_stage
,
0
);
}
else
{
matrix_load_b16_lds
_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
rsrc
_bits
.
i32
,
0
);
inline_
matrix_load_
32x32_
b16_lds
<
0
,
0
>
(
lds
,
s
rsrc
,
lds_offset_stage
,
0
);
}
}
for
(
int
m_loop
=
0
;
m_loop
<
M
/
128
;
++
m_loop
)
{
...
...
@@ -147,13 +144,10 @@ inline __device__ void prefetch_to_vgpr_gfx938(
}
*
(
uint64_t
*
)
&
srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
ptr
+
global_offset
*
ELEMENT_BYTES
);
if
(
n_loop
<
N
/
32
)
{
union
union_vec4_uint
rsrc_bits
;
rsrc_bits
.
v32
=
srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
lds
)
+
lds_offset_stage
;
if
(
trans
)
{
matrix_load_b16_lds_trans
_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
rsrc
_bits
.
i32
,
0
);
inline_
matrix_load_
32x32_
b16_lds_trans
<
0
,
0
>
(
lds
,
s
rsrc
,
lds_offset_stage
,
0
);
}
else
{
matrix_load_b16_lds
_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
rsrc
_bits
.
i32
,
0
);
inline_
matrix_load_
32x32_
b16_lds
<
0
,
0
>
(
lds
,
s
rsrc
,
lds_offset_stage
,
0
);
}
}
...
...
@@ -167,36 +161,20 @@ inline __device__ void prefetch_to_vgpr_gfx938(
if
(
trans
){
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, true);
if
constexpr
(
std
::
is_same_v
<
Element
,
half_t
>
)
{
auto
*
const
f16_lds
=
hcu_ds_read_matrix_f16_lds_base
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
f16_lds
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
f16_lds
,
1024
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_f16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
1024
,
2
,
1
,
0
);
}
else
{
auto
*
const
bf16_lds
=
hcu_ds_read_matrix_bf16_lds_base
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
bf16_lds
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
bf16_lds
,
1024
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_trans_format_bf16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
1024
,
2
,
1
,
0
);
}
}
else
{
// DS_READ_MATRIX_32X32_B16(ds_offset_cast(lds_load_offset_stage), reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2].f16, reg[(stages == 2 ? (n_loop - 1) : n_loop) * 2 + 1].f16, false);
if
constexpr
(
std
::
is_same_v
<
Element
,
half_t
>
)
{
auto
*
const
f16_lds
=
hcu_ds_read_matrix_f16_lds_base
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
f16_lds
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
f16_lds
,
1024
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_f16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
1024
,
2
,
1
,
0
);
}
else
{
auto
*
const
bf16_lds
=
hcu_ds_read_matrix_bf16_lds_base
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
bf16_lds
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
bf16_lds
,
1024
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
0
,
2
,
1
,
0
);
reg
[(
stages
==
2
?
(
n_loop
-
1
)
:
n_loop
)
*
2
+
1
].
f16x8
=
__builtin_hcu_ds_read_matrix_format_bf16
(
lds
+
(
stages
==
2
?
(
stages_id
^
1
)
:
stages_id
)
*
(
WARP_NUM
*
32
*
32
)
+
lds_offset
,
1024
,
2
,
1
,
0
);
}
}
lgkmcnt_wait
(
0
);
...
...
@@ -246,13 +224,11 @@ inline __device__ void prefetch_to_lds_gfx938(
*
(
uint64_t
*
)
&
srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
ptr
+
global_offset
);
//计算LDS地址,每个warp使用一个32*32;下一个loop重复利用
int
lds_offset
=
(
loop_warp
*
32
*
32
)
*
ELEMENT_BYTES
;
union
union_vec4_uint
rsrc_bits
;
rsrc_bits
.
v32
=
srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
lds
)
+
lds_offset
;
int
lds_load_offset
=
reinterpret_cast
<
size_t
>
(
lds
)
+
lds_offset
;
if
(
trans
)
{
matrix_load_b16_lds_trans
_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
rsrc
_bits
.
i32
,
0
);
inline_
matrix_load_
32x32_
b16_lds_trans
<
0
,
0
>
(
lds
,
s
rsrc
,
lds_offset
,
0
);
}
else
{
matrix_load_b16_lds
_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
rsrc
_bits
.
i32
,
0
);
inline_
matrix_load_
32x32_
b16_lds
<
0
,
0
>
(
lds
,
s
rsrc
,
lds_offset
,
0
);
}
}
}
...
...
csrc/flash_attn_hg/include/bwd/softmax_tiling.h
View file @
518a5f4d
...
...
@@ -57,6 +57,7 @@ inline __device__ void apply_mask_bwd(union_vec4_fp32 tensor[1][4], int M, int N
}
}
}
//local mask
if
(
mask_type
==
3
)
{
// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
...
...
@@ -112,21 +113,38 @@ inline __device__ void apply_mask_bwd_gfx938(union_vec4_fp32 tensor[1][4], int M
}
}
}
// //mask左下角
// if (mask_type == 2 && (!Is_even_MN || Is_even_MN && std::abs(M_minus_N) < 128)) {
// for(int min_tile_m = 0; min_tile_m < 2; min_tile_m ++) {
// int M_offset = min_tile_m * 16 + lane_m_idx;
// for(int min_tile_n = 0; min_tile_n < 2; min_tile_n ++) {
// for(int vec_idx = 0; vec_idx < 4; vec_idx ++) {
// int N_offset = min_tile_n * 16 + lane_n_idx * 4 + vec_idx;
// int N_limit = (M_offset + M_minus_N);
// if((!Is_even_MN && N_offset > N - 1) || N_offset < N_limit){
// tensor[0][min_tile_n * 2 + min_tile_m].f32[vec_idx] = -INFINITY;
// }
// }
// }
// }
// }
//mask左下角
if
(
mask_type
==
2
&&
(
!
Is_even_MN
||
Is_even_MN
&&
std
::
abs
(
M_minus_N
)
<
128
)
)
{
if
(
mask_type
==
2
)
{
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
int
M_offset
=
min_tile_m
*
16
+
lane_m_idx
;
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
int
N_offset
=
min_tile_n
*
16
+
lane_n_idx
*
4
+
vec_idx
;
int
N_limit
=
(
M_offset
+
M_minus_N
);
if
(
(
!
Is_even_MN
&&
N_offset
>
N
-
1
)
||
N_offset
<
N_limit
){
if
(
N_offset
<
N_limit
){
tensor
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
}
}
}
}
//local mask
if
(
mask_type
==
3
)
{
// && (!Is_even_MN || Is_even_MN && (std::abs(M_minus_N - window_size_left) < 128 || std::abs(M_minus_N + window_size_right) < 128))
for
(
int
min_tile_m
=
0
;
min_tile_m
<
2
;
min_tile_m
++
)
{
...
...
@@ -327,7 +345,7 @@ inline __device__ void scale_apply_exp2_bwd(DataType0 tensor[(BLOCK_M/32)*(WARP_
auto vec2_scale = vec2_fp32{scale, scale};
auto vec2_max_scaled = vec2_fp32{-max_scaled, -max_scaled};
auto tensor_tmp =
hcu_pk_fma_f32(
__builtin_
hcu_pk_fma_f32(
vec2_tensor,
vec2_scale,
vec2_max_scaled);
...
...
csrc/flash_attn_hg/include/flash.h
View file @
518a5f4d
...
...
@@ -75,6 +75,11 @@ struct Flash_fwd_params : public Qkv_params {
void
*
__restrict__
softmax_lse_ptr
;
void
*
__restrict__
softmax_lseaccum_ptr
;
// Attention sink values, one scalar per original query head.
// s_aux_type: 0 none, 1 fp32, 2 fp16, 3 bf16.
void
*
__restrict__
s_aux_ptr
;
int
s_aux_type
;
// For FP8 scaling
float
*
__restrict__
q_descale_ptr
;
float
*
__restrict__
k_descale_ptr
;
...
...
@@ -366,6 +371,8 @@ struct Flash_fwd_mla_reduce_params {
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_mha_fwd_
(
Flash_fwd_params
&
params
,
hipStream_t
stream
);
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_fp8_mha_fwd_
(
Flash_fwd_params
&
params
,
hipStream_t
stream
);
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_mha_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
hipStream_t
stream
);
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_int8_fwd_splitkv_dispatch
(
Flash_fwd_params
&
params
,
hipStream_t
stream
);
...
...
@@ -386,6 +393,8 @@ template<typename T, int Headdim, int HeaddimV> void run_mha_fwd_prefix_prefill_
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_int8_mha_fwd_prefix_prefill_
(
Flash_fwd_params
&
params
,
hipStream_t
stream
);
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_fp8_mha_fwd_prefix_prefill_
(
Flash_fwd_params
&
params
,
hipStream_t
stream
);
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_mla_fwd_prefix_prefill_dispatch_
(
Flash_fwd_mla_params
&
params
,
hipStream_t
stream
);
template
<
typename
T
,
int
Headdim
,
int
HeaddimV
>
void
run_mla_fwd_dispatch
(
Flash_fwd_mla_params
&
params
,
hipStream_t
stream
);
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment