Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
flash-attention
Commits
518a5f4d
Commit
518a5f4d
authored
Jun 09, 2026
by
hly
Browse files
import aicc-master-dev
parent
c2a1b310
Changes
131
Show whitespace changes
Inline
Side-by-side
Showing
20 changed files
with
1928 additions
and
397 deletions
+1928
-397
csrc/flash_attn_hg/include/fwd/utils.h
csrc/flash_attn_hg/include/fwd/utils.h
+28
-8
csrc/flash_attn_hg/include/intrinsic.h
csrc/flash_attn_hg/include/intrinsic.h
+74
-168
csrc/flash_attn_hg/include/intrinsic_mls_ds.h
csrc/flash_attn_hg/include/intrinsic_mls_ds.h
+213
-26
csrc/flash_attn_hg/include/intrinsic_mls_ds_b8.h
csrc/flash_attn_hg/include/intrinsic_mls_ds_b8.h
+34
-44
csrc/flash_attn_hg/include/kernel_traits.h
csrc/flash_attn_hg/include/kernel_traits.h
+0
-5
csrc/flash_attn_hg/include/kvcache/gfx92a/f16_kvcache_gfx92a.h
...flash_attn_hg/include/kvcache/gfx92a/f16_kvcache_gfx92a.h
+760
-0
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_epilogue_gfx938.h
..._attn_hg/include/kvcache/gfx938/kvcache_epilogue_gfx938.h
+102
-5
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_pv_gemm_prefetch_k_gfx938.h
...nclude/kvcache/gfx938/kvcache_pv_gemm_prefetch_k_gfx938.h
+8
-4
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_pv_gemm_utils_gfx938.h
..._hg/include/kvcache/gfx938/kvcache_pv_gemm_utils_gfx938.h
+371
-4
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_qk_gemm_prefetch_v_gfx938.h
...nclude/kvcache/gfx938/kvcache_qk_gemm_prefetch_v_gfx938.h
+2
-4
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_qk_gemm_utils_gfx938.h
..._hg/include/kvcache/gfx938/kvcache_qk_gemm_utils_gfx938.h
+170
-8
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_softmax_gfx938.h
...h_attn_hg/include/kvcache/gfx938/kvcache_softmax_gfx938.h
+61
-0
csrc/flash_attn_hg/include/kvcache/int8_kvcache_acco_reduce.h
.../flash_attn_hg/include/kvcache/int8_kvcache_acco_reduce.h
+30
-33
csrc/flash_attn_hg/include/kvcache/int8_kvcache_qk_gemm_prefetch_v_3stage.h
.../include/kvcache/int8_kvcache_qk_gemm_prefetch_v_3stage.h
+1
-1
csrc/flash_attn_hg/include/kvcache/int8_kvcache_softmax.h
csrc/flash_attn_hg/include/kvcache/int8_kvcache_softmax.h
+12
-13
csrc/flash_attn_hg/include/kvcache/kvcache_acco_reduce.h
csrc/flash_attn_hg/include/kvcache/kvcache_acco_reduce.h
+5
-8
csrc/flash_attn_hg/include/kvcache/kvcache_acco_reduce_tile16x32.h
...h_attn_hg/include/kvcache/kvcache_acco_reduce_tile16x32.h
+46
-45
csrc/flash_attn_hg/include/kvcache/kvcache_epilogue.h
csrc/flash_attn_hg/include/kvcache/kvcache_epilogue.h
+6
-16
csrc/flash_attn_hg/include/kvcache/kvcache_pv_gemm_prefetch_k_tile16x32.h
...hg/include/kvcache/kvcache_pv_gemm_prefetch_k_tile16x32.h
+4
-4
csrc/flash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v.h
...lash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v.h
+1
-1
No files found.
csrc/flash_attn_hg/include/fwd/utils.h
View file @
518a5f4d
...
...
@@ -109,13 +109,13 @@ struct Allreduce {
static
__device__
inline
union_vec2_fp32
run
(
union_vec2_fp32
x
,
Operator
&
op
)
{
union_vec2_fp32
res
;
if
constexpr
(
std
::
is_same
<
Operator
,
SumOp
<
float
>
>::
value
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
res
.
f32
[
0
]
=
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
res
.
f32
[
1
]
=
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
);
x
.
u64
=
hcu_pk_add_f32
(
x
.
u64
,
res
.
u64
);
x
.
u64
=
__builtin_
hcu_pk_add_f32
(
x
.
u64
,
res
.
u64
);
res
.
f32
[
0
]
=
__shfl_swap16
(
x
.
f32
[
0
]);
// __shfl_xor_tmp(x.f32[0], 16);
res
.
f32
[
1
]
=
__shfl_swap16
(
x
.
f32
[
1
]);
// __shfl_xor_tmp(x.f32[1], 16);
res
.
u64
=
hcu_pk_add_f32
(
res
.
u64
,
x
.
u64
);
res
.
u64
=
__builtin_
hcu_pk_add_f32
(
res
.
u64
,
x
.
u64
);
#else
x
.
f32
[
0
]
=
x
.
f32
[
0
]
+
__shfl_xor_tmp
(
x
.
f32
[
0
],
32
);
x
.
f32
[
1
]
=
x
.
f32
[
1
]
+
__shfl_xor_tmp
(
x
.
f32
[
1
],
32
);
...
...
@@ -141,10 +141,7 @@ struct Allreduce {
template
<
const
int
kHeadDim
,
typename
T
,
bool
Do_CacheSwizzle
=
true
>
__device__
__forceinline__
vec4_uint
prepare_for_buffer_load
(
T
*
ptr
)
{
vec4_uint
res
;
struct
{
uint32_t
lo
,
hi
;
}
parts
;
*
(
uint64_t
*
)
&
parts
=
reinterpret_cast
<
uint64_t
>
(
ptr
);
res
[
0
]
=
__builtin_amdgcn_readfirstlane
(
parts
.
lo
);
res
[
1
]
=
__builtin_amdgcn_readfirstlane
(
parts
.
hi
);
*
(
uint64_t
*
)
&
res
=
reinterpret_cast
<
uint64_t
>
(
ptr
);
if
constexpr
(
Do_CacheSwizzle
)
{
if
constexpr
(
kHeadDim
==
128
)
{
res
[
1
]
+=
0x41000000
;
// 62 bit: cache swizzle; 48~61: Stride
...
...
@@ -194,7 +191,7 @@ __forceinline__ __device__ void attention_initialize(
#if defined(__gfx936__)
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
__builtin_hcu_mov_b64
(
pk_zero
);
#elif defined(__gfx938__)
#elif defined(__gfx938__)
|| defined(__gfx946__)
asm
volatile
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
acc_o
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
])
:
);
...
...
@@ -213,4 +210,27 @@ __forceinline__ __device__ void attention_initialize(
}
template
<
int
kHeadDim
,
int
WARP_M
,
int
WARP_N
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_attention_initialize
(
ElementAccum
scores_max
[
WARP_M
/
16
],
ElementAccum
scores_sum
[
WARP_M
/
16
],
vec4_Accum
<
ElementAccum
>
acc_o
[
kHeadDim
/
32
][
WARP_M
/
16
][
WARP_N
/
16
]
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
scores_max
[
m_idx
]
=
-
INFINITY
;
scores_sum
[
m_idx
]
=
0
;
}
#pragma unroll
for
(
int
pv_loop
=
0
;
pv_loop
<
kHeadDim
/
32
;
++
pv_loop
)
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
WARP_M
/
16
;
++
m_idx
)
{
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
WARP_N
/
16
;
++
n_idx
)
{
inline_vgpr4_init_zero
(
acc_o
[
pv_loop
][
m_idx
][
n_idx
]);
}
}
}
}
}
// namespace flash
csrc/flash_attn_hg/include/intrinsic.h
100644 → 100755
View file @
518a5f4d
...
...
@@ -4,28 +4,43 @@
#include "hip/hip_fp16.h"
#include "numeric_types.h"
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
#define USE_BUFFER_LOAD_DWORDX4
// #define USE_BUFFER_LOAD_DWORDX2
#endif
// DTK ds_read_matrix builtins (DS_READ_MATRIX_FORMAT / _TRANS_FORMAT): para1 is LDS base
// typed per element kind — e.g. *_f16 → half*3, *_bf16 / *_u16 / *_i16 → short*3, *_f32 → float*3,
// 4/8-bit and tf32/u32/i32 variants → int*3 (vendor builtin table).
// HIP may use __half for fp16 LDS while builtins expect __fp16*3; use f16 helper below.
// Probe: FA_PROBE_FAMILY_DS (lds_f16_as3, lds_bf16_as3).
template
<
typename
T
>
__forceinline__
__device__
__attribute__
((
address_space
(
3
)))
__fp16
*
hcu_ds_read_matrix_f16_lds_base
(
T
*
const
p
)
{
return
(
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
p
);
template
<
typename
VEC
,
typename
pointerType
>
__forceinline__
__device__
void
inline_global_load_dwordx1
(
VEC
&
v_data
,
const
int
v_offset
,
const
pointerType
*
s_addr
)
{
const
int
v_offset_bytes
=
v_offset
*
sizeof
(
pointerType
);
asm
volatile
(
"global_load_dword %0, %1, %2
\n
"
:
"=v"
(
v_data
)
:
"v"
(
v_offset_bytes
),
"s"
(
s_addr
)
:
);
}
template
<
typename
T
>
__forceinline__
__device__
__attribute__
((
address_space
(
3
)))
short
*
hcu_ds_read_matrix_bf16_lds_base
(
T
*
const
p
)
{
return
(
__attribute__
((
address_space
(
3
)))
short
*
)(
p
);
template
<
const
int
shfl_count
,
bool
bypass
,
class
DataType
>
__forceinline__
__device__
void
inline_buffer_load_dwordx1
(
DataType
&
v_data
,
const
int
v_offset
,
const
vec4_uint
global_addr
)
{
int
v_offset_bytes
=
v_offset
<<
shfl_count
;
if
constexpr
(
bypass
)
{
asm
volatile
(
"buffer_load_dword %0, %1, %2, 0, offen offset:0 glc slc
\n
"
:
"=v"
(
v_data
)
:
"v"
(
v_offset_bytes
),
"s"
(
global_addr
)
:
);
}
else
{
asm
volatile
(
"buffer_load_dword %0, %1, %2, 0, offen offset:0
\n
"
:
"=v"
(
v_data
)
:
"v"
(
v_offset_bytes
),
"s"
(
global_addr
)
:
);
}
}
template
<
class
DataType
>
__forceinline__
__device__
void
inline_utcl2_warmup_dword
(
DataType
buffer_resource
)
{
int
container
;
...
...
@@ -34,6 +49,7 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour
asm
volatile
(
"s_nop 4
\n\t
"
"buffer_load_dword %0, %1, %2, 0, offen offset:0 glc slc
\n\t
"
"s_waitcnt vmcnt(0)
\n
"
:
"=v"
(
container
)
:
"v"
(
offset
),
"s"
(
buffer_resource
)
);
...
...
@@ -44,177 +60,99 @@ __forceinline__ __device__ void inline_utcl2_warmup_dword(DataType buffer_resour
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
inline_buffer_load_dword_lds
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
gvOffset_s
,
const
int
&
gvOffset_v
)
{
int
ldsAddrPerWave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
offset_s
=
gvOffset_s
<<
shfl_count
;
int
offset_v
=
gvOffset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dword %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
scalar_rsrc
),
"s"
(
offset_s
)
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
inline_buffer_load_dwordx2_lds
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
gvOffset_s
,
const
int
&
gvOffset_v
)
{
int
ldsAddrPerWave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
offset_s
=
gvOffset_s
<<
shfl_count
;
int
offset_v
=
gvOffset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx2 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
scalar_rsrc
),
"s"
(
offset_s
)
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
inline_buffer_load_dwordx4_lds
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
gvOffset_s
,
const
int
&
gvOffset_v
)
{
int
ldsAddrPerWave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
offset_s
=
gvOffset_s
<<
shfl_count
;
int
offset_v
=
gvOffset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
scalar_rsrc
),
"s"
(
offset_s
)
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
safe_inline_buffer_load_dwordx4_lds
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
offset_s
,
const
int
&
offset_v
)
{
int
lds_addr_per_wave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
__offset_s
=
offset_s
<<
shfl_count
;
int
__offset_v
=
offset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_nop 3
\n\t
"
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
__offset_v
),
"s"
(
lds_addr_per_wave
),
"s"
(
scalar_rsrc
),
"s"
(
__offset_s
)
::
"v"
(
__offset_v
),
"s"
(
lds_addr_per_wave
),
"s"
(
global_addr
),
"s"
(
__offset_s
)
:
);
}
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
inline_buffer_load_dword_lds_bypass_glc_slc
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
gvOffset_s
,
const
int
&
gvOffset_v
)
{
int
ldsAddrPerWave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
offset_s
=
gvOffset_s
<<
shfl_count
;
int
offset_v
=
gvOffset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc slc lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
scalar_rsrc
),
"s"
(
offset_s
)
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
inline_buffer_load_dword_lds_bypass_l1_glc
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
gvOffset_s
,
const
int
&
gvOffset_v
)
{
int
ldsAddrPerWave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
offset_s
=
gvOffset_s
<<
shfl_count
;
int
offset_v
=
gvOffset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dword %0, %2, %3 ,offen offset:0 glc lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
scalar_rsrc
),
"s"
(
offset_s
)
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
template
<
class
DataType
,
const
int
shfl_count
=
2
>
__forceinline__
__device__
void
inline_buffer_load_dword_lds_bypass_l2_slc
(
DataType
*
const
shared_addr
,
const
vec4_uint
global_addr
,
const
int
&
lds_offset
,
const
int
&
gvOffset_s
,
const
int
&
gvOffset_v
)
{
int
ldsAddrPerWave
=
__builtin_amdgcn_readfirstlane
(
(
int
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
)));
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
<<
shfl_count
);
int
offset_s
=
gvOffset_s
<<
shfl_count
;
int
offset_v
=
gvOffset_v
<<
shfl_count
;
vec4_uint
scalar_rsrc
;
scalar_rsrc
[
0
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
0
]);
scalar_rsrc
[
1
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
1
]);
scalar_rsrc
[
2
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
2
]);
scalar_rsrc
[
3
]
=
__builtin_amdgcn_readfirstlane
(
global_addr
[
3
]);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dword %0, %2, %3 ,offen offset:0 slc lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
scalar_rsrc
),
"s"
(
offset_s
)
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
template
<
typename
src_type
=
half_t
,
typename
dst_type
=
float
,
const
int
dword_count
=
1
,
const
int
auxilariy
=
0
>
__forceinline__
__device__
void
builtin_buffer_load_dword_lds
(
src_type
*
const
shared_addr
,
const
vec4_uint
rsrc
,
const
int
&
lds_offset
,
const
int
gvOffset_s
,
const
int
&
gvOffset_v
)
{
#if defined(__gfx936__) || defined(__gfx938__)
static_assert
(
dword_count
==
1
||
dword_count
==
2
||
dword_count
==
4
,
"unsupported buffer_load_dword LDS width"
);
// DTK currently accepts the mature asm buffer_load_* -> lds shape more reliably than
// the raw_buffer_load_lds wrapper instantiated through generic LDS pointers.
if
constexpr
(
auxilariy
==
0
)
{
if
constexpr
(
dword_count
==
1
)
{
inline_buffer_load_dword_lds
<
src_type
,
2
>
(
shared_addr
,
rsrc
,
lds_offset
,
gvOffset_s
,
gvOffset_v
);
}
else
if
constexpr
(
dword_count
==
2
)
{
inline_buffer_load_dwordx2_lds
<
src_type
,
2
>
(
shared_addr
,
rsrc
,
lds_offset
,
gvOffset_s
,
gvOffset_v
);
}
else
{
inline_buffer_load_dwordx4_lds
<
src_type
,
2
>
(
shared_addr
,
rsrc
,
lds_offset
,
gvOffset_s
,
gvOffset_v
);
}
}
else
if
constexpr
(
auxilariy
==
11
&&
dword_count
==
1
)
{
inline_buffer_load_dword_lds_bypass_glc_slc
<
src_type
,
2
>
(
shared_addr
,
rsrc
,
lds_offset
,
gvOffset_s
,
gvOffset_v
);
}
else
{
constexpr
int
bytes_per_element
=
sizeof
(
dst_type
);
auto
*
ptr
=
(
__attribute__
((
address_space
(
3
)))
int
*
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
static_cast
<
size_t
>
(
lds_offset
)
*
bytes_per_element
);
__builtin_hcu_raw_buffer_load_lds
(
rsrc
,
ptr
,
dword_count
*
4
,
gvOffset_v
*
bytes_per_element
,
gvOffset_s
*
bytes_per_element
,
0
,
/* immediate offset, instruction offset */
auxilariy
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
}
#else
constexpr
int
bytes_per_element
=
sizeof
(
dst_type
);
auto
*
ptr
=
(
__attribute__
((
address_space
(
3
)))
int
*
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
static_cast
<
size_t
>
(
lds_offset
)
*
bytes_per_element
)
;
dst_type
*
ptr
=
reinterpret_cast
<
dst_type
*
>
(
shared_addr
)
+
lds_offset
;
__builtin_hcu_raw_buffer_load_lds
(
rsrc
,
ptr
,
...
...
@@ -224,16 +162,12 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds(src_type *const sh
0
,
/* immediate offset, instruction offset */
auxilariy
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#endif
}
template
<
typename
src_type
=
half_t
,
typename
dst_type
=
float
>
__forceinline__
__device__
void
builtin_buffer_load_dword_lds_bypass_glc_slc
(
src_type
*
const
shared_addr
,
const
vec4_uint
rsrc
,
const
int
&
lds_offset
,
const
int
gvOffset_s
,
const
int
&
gvOffset_v
)
{
#if defined(__gfx936__) || defined(__gfx938__)
inline_buffer_load_dword_lds_bypass_glc_slc
<
src_type
,
2
>
(
shared_addr
,
rsrc
,
lds_offset
,
gvOffset_s
,
gvOffset_v
);
#else
constexpr
int
bytes_per_element
=
sizeof
(
dst_type
);
auto
*
ptr
=
(
__attribute__
((
address_space
(
3
)))
int
*
)(
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
static_cast
<
size_t
>
(
lds_offset
)
*
bytes_per_element
)
;
dst_type
*
ptr
=
reinterpret_cast
<
dst_type
*
>
(
shared_addr
)
+
lds_offset
;
__builtin_hcu_raw_buffer_load_lds
(
rsrc
,
ptr
,
...
...
@@ -243,7 +177,6 @@ __forceinline__ __device__ void builtin_buffer_load_dword_lds_bypass_glc_slc(src
0
,
/* immediate offset, instruction offset */
11
/* auxilariy data| bit 0: glc, bit 1: slc, bit 2: dlc, bit 3: cache swizzle */
);
#endif
}
template
<
class
DataType
,
const
int
shfl_count
>
...
...
@@ -335,6 +268,16 @@ __forceinline__ __device__ void inline_ds_read2_b32_no_wait_bytes(const int &l
}
template
<
typename
DataType
>
__forceinline__
__device__
void
inline_ds_read2_b64
(
const
int
lds_offset
,
DataType
&
reg_val
,
const
int
offset0
,
const
int
offset1
)
{
asm
volatile
(
"ds_read2_b64 %0, %1, offset0:%2, offset1:%3
\n
"
:
"=v"
(
reg_val
)
:
"s"
(
lds_offset
),
"B"
(
offset0
),
"B"
(
offset1
)
:
);
}
template
<
typename
dwordx2
>
__forceinline__
__device__
void
inlineasm_fa_ds_read2_b32
(
float
*
shared_addr
,
const
int
&
lds_offset
,
dwordx2
&
reg_val
,
const
int
offset0
,
const
int
offset1
)
{
int
lds_addr
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
lds_offset
*
4
;
...
...
@@ -364,14 +307,14 @@ template<typename VEC>
__forceinline__
__device__
void
inlineasm_ds_read_b128
(
int
lds_offset
,
VEC
&
data
)
{
asm
volatile
(
"ds_read_b128 %0, %1
\n
"
:
"=v"
(
data
)
:
"
s
"
(
lds_offset
)
:
"
v
"
(
lds_offset
)
:
);
}
template
<
typename
VEC
>
__forceinline__
__device__
void
inlineasm_ds_write_b128
(
int
lds_offset
,
VEC
&
data
)
{
asm
volatile
(
"ds_write_b128 %0, %1
\n
"
::
"
s
"
(
lds_offset
),
"v"
(
data
)
::
"
v
"
(
lds_offset
),
"v"
(
data
)
:
);
}
...
...
@@ -385,7 +328,7 @@ __forceinline__ __device__ void inline_vgpr_init_zero(VEC &dst, const int idx)
template
<
typename
VEC
>
__forceinline__
__device__
void
inline_vgpr2_init_zero
(
VEC
&
dst
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
(
"v_mov_b64 %0, 0x0"
:
"=v"
(
dst
)
:
);
...
...
@@ -396,7 +339,7 @@ __forceinline__ __device__ void inline_vgpr2_init_zero(VEC &dst) {
template
<
typename
VEC
>
__forceinline__
__device__
void
inline_vgpr4_init_zero
(
VEC
&
dst
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
:
"=v"
(
dst
.
u64
[
0
]),
"=v"
(
dst
.
u64
[
1
])
...
...
@@ -413,7 +356,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero(VEC &dst) {
template
<
typename
VEC
>
__forceinline__
__device__
void
inline_vgpr4_init_zero_4x4x4
(
VEC
s_reg
[
4
][
4
])
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
"v_mov_b64 %2, 0x0
\n\t
"
...
...
@@ -463,7 +406,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_4x4x4(VEC s_reg[4][4]) {
template
<
typename
VEC
>
__forceinline__
__device__
void
inline_vgpr4_init_zero_4x2x4
(
VEC
s_reg
[
4
][
4
])
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
"v_mov_b64 %2, 0x0
\n\t
"
...
...
@@ -498,7 +441,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_4x2x4(VEC s_reg[4][4]) {
template
<
typename
VEC
>
__forceinline__
__device__
void
inline_vgpr4_init_zero_1x4x4
(
VEC
s_reg
[
1
][
4
])
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
"v_mov_b64 %2, 0x0
\n\t
"
...
...
@@ -514,7 +457,7 @@ __forceinline__ __device__ void inline_vgpr4_init_zero_1x4x4(VEC s_reg[1][4]) {
template
<
typename
VEC
>
__forceinline__
__device__
void
inline_vgpr4_init_zero_1x2x4
(
VEC
s_reg
[
1
][
4
])
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
"v_mov_b64 %2, 0x0
\n\t
"
...
...
@@ -570,43 +513,6 @@ inline __HOST_DEVICE__ unsigned short inlineasm_float2bfloat16_ushort_nonan(cons
// DTK-compatible pk helpers (replace __builtin_hcu_pk_*_f32)
inline
__device__
__float2
hcu_pk_add_f32
(
__float2
a
,
__float2
b
)
{
__float2
o
;
asm
volatile
(
"v_pk_add_f32 %0, %1, %2"
:
"=v"
(
o
)
:
"v"
(
a
),
"v"
(
b
));
return
o
;
}
inline
__device__
__float2
hcu_pk_mul_f32
(
__float2
a
,
__float2
b
)
{
__float2
o
;
asm
volatile
(
"v_pk_mul_f32 %0, %1, %2"
:
"=v"
(
o
)
:
"v"
(
a
),
"v"
(
b
));
return
o
;
}
inline
__device__
__float2
hcu_pk_fma_f32
(
__float2
x
,
__float2
m
,
__float2
a
)
{
__float2
d
;
asm
volatile
(
"v_pk_fma_f32 %0, %1, %2, %3"
:
"=v"
(
d
)
:
"v"
(
x
),
"v"
(
m
),
"v"
(
a
));
return
d
;
}
// DTK requires these control operands to remain compile-time constants.
template
<
bool
Clamp
=
false
,
int
OModifier
=
0
>
inline
__device__
auto
hcu_cvt_pk_f16_f32
(
float
src0
,
float
src1
)
{
static_assert
(
OModifier
==
0
,
"Only o_modifier=0 is currently validated in HG DTK migration"
);
return
__builtin_hcu_cvt_pk_f16_f32
(
0
,
src0
,
0
,
src1
,
Clamp
,
OModifier
);
}
template
<
bool
Clamp
=
false
>
inline
__device__
auto
hcu_cvt_pk_bf16_f32
(
float
src0
,
float
src1
)
{
return
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
src0
,
0
,
src1
,
Clamp
);
}
template
<
int
ByteSel
>
inline
__device__
vec2_fp32
hcu_cvt_pk_f32_fp8
(
int
src0
)
{
static_assert
(
ByteSel
==
0
||
ByteSel
==
2
,
"ByteSel must select the low or high packed fp8 pair"
);
return
__builtin_hcu_cvt_pk_f32_fp8
(
src0
,
false
,
0
,
ByteSel
);
}
// d = a * b + c
inline
__device__
__float2
inlineasm_fa_v_pk_fma_f32
(
__float2
&
a
,
const
__float2
&
b
,
const
__float2
&
c
)
{
__float2
d
;
...
...
@@ -637,7 +543,7 @@ inline __device__ void inlineasm_fa_v_pk_mul_f32(__float2 &dst, const __float2 &
// c = a + b
inline
__device__
void
inline_v_pk_add_f32
(
__float2
&
c
,
const
__float2
&
a
,
const
__float2
&
b
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
asm
volatile
(
"v_pk_add_f32 %0, %1, %2 ; inline_v_pk_add_f32"
:
"=v"
(
c
)
:
"v"
(
a
),
"v"
(
b
)
...
...
@@ -873,8 +779,8 @@ inline __host__ __device__ auto DownCastPair(const vec2_Element<FromType>& sourc
template
<
>
inline
__host__
__device__
auto
DownCastPair
<
float
,
half_t
>
(
const
vec2_Element
<
float
>&
source
)
{
#if defined(__gfx938__)
auto
result
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
source
[
0
],
source
[
1
]);
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
auto
result
=
__builtin_
hcu_cvt_pk_f16_f32
(
source
[
0
],
source
[
1
]
,
false
/*clamp*/
,
0
/*o_modifier*/
);
return
*
(
vec2_Element
<
half_t
>*
)(
&
result
);
#else
return
__builtin_amdgcn_cvt_pkrtz
(
source
[
0
],
source
[
1
]);
...
...
@@ -883,8 +789,8 @@ inline __host__ __device__ auto DownCastPair<float, half_t>(const vec2_Element<f
template
<
>
inline
__host__
__device__
auto
DownCastPair
<
float
,
bhalf_t
>
(
const
vec2_Element
<
float
>&
source
)
{
#if defined(__gfx938__)
auto
result
=
hcu_cvt_pk_bf16_f32
<
false
>
(
source
[
0
],
source
[
1
]);
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
auto
result
=
__builtin_
hcu_cvt_pk_bf16_f32
(
source
[
0
],
source
[
1
]
,
false
/*clamp*/
);
return
*
(
vec2_Element
<
bhalf_t
>*
)(
&
result
);
#else
vec2_Element
<
bhalf_t
>
result
;
...
...
@@ -903,8 +809,8 @@ inline __host__ __device__ auto DownCastPairNoPack(const FromType src0, const Fr
template
<
>
inline
__host__
__device__
auto
DownCastPairNoPack
<
float
,
half_t
>
(
const
float
src0
,
const
float
src1
)
{
#if defined(__gfx938__)
auto
result
=
hcu_cvt_pk_f16_f32
<
false
,
0
>
(
src0
,
src1
);
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
auto
result
=
__builtin_
hcu_cvt_pk_f16_f32
(
src0
,
src1
,
false
/*clamp*/
,
0
/*o_modifier*/
);
return
*
(
vec2_Element
<
half_t
>*
)(
&
result
);
#else
return
__builtin_amdgcn_cvt_pkrtz
(
src0
,
src1
);
...
...
@@ -913,8 +819,8 @@ inline __host__ __device__ auto DownCastPairNoPack<float, half_t>(const float sr
template
<
>
inline
__host__
__device__
auto
DownCastPairNoPack
<
float
,
bhalf_t
>
(
const
float
src0
,
const
float
src1
)
{
#if defined(__gfx938__)
auto
result
=
hcu_cvt_pk_bf16_f32
<
false
>
(
src0
,
src1
);
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
auto
result
=
__builtin_
hcu_cvt_pk_bf16_f32
(
src0
,
src1
,
false
/*clamp*/
);
return
*
(
vec2_Element
<
bhalf_t
>*
)(
&
result
);
#else
vec2_Element
<
bhalf_t
>
result
;
...
...
@@ -954,7 +860,7 @@ __host__ __device__ float splitkv_upcast_to_f32(const FromType &from_var) {
template
<
typename
output_dtype
>
__forceinline__
__device__
void
__builtin_hcu_cvt_pk4_fp8_f32
(
const
vec4_fp32
&
source
,
int32_t
&
container
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
if
constexpr
(
std
::
is_same
<
output_dtype
,
fp8_e4m3
>::
value
)
{
container
=
__builtin_hcu_cvt_pk_fp8_f32
(
source
[
0
],
source
[
1
],
container
,
false
/*op_sel:[0,0,0,0]*/
);
container
=
__builtin_hcu_cvt_pk_fp8_f32
(
source
[
2
],
source
[
3
],
container
,
true
/*op_sel:[0,0,0,1]*/
);
...
...
csrc/flash_attn_hg/include/intrinsic_mls_ds.h
100644 → 100755
View file @
518a5f4d
...
...
@@ -7,35 +7,196 @@
#define VA_LIMIT_BITS(x) (0xffffffffffff & x)
template
<
int
INSTM
,
int
INSTNM
,
int
T
,
int
R
>
__forceinline__
__device__
void
matrix_load_b16_lds_trans_builtin
(
size_t
lds_addr_warp
,
vec4_int
rsrc
,
int
moffset
)
{
#define MATRIX_LOAD_32X32_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template
<
int
r
,
int
t
,
class
DataType
>
__forceinline__
__device__
void
inline_matrix_load_32x32_b16_lds_trans
(
DataType
*
shared_addr
,
vec4_uint
srsrc
,
int
&
lds_offset
,
const
int
offset
)
{
#if defined(__gfx938__)
int
soffset
=
lds_addr_warp
+
0x80000000
;
if
constexpr
(
INSTM
==
32
&&
INSTNM
==
16
)
{
__builtin_hcu_matrix_load_32x16_b16
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
short
*
)(
soffset
),
0
,
T
,
R
,
0
,
0
);
}
else
if
constexpr
(
INSTM
==
32
&&
INSTNM
==
32
)
{
__builtin_hcu_matrix_load_32x32_b16
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
short
*
)(
soffset
),
0
,
T
,
R
,
0
,
0
);
}
else
if
constexpr
(
INSTM
==
64
&&
INSTNM
==
16
)
{
__builtin_hcu_matrix_load_64x16_b16
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
short
*
)(
soffset
),
0
,
T
,
R
,
0
,
0
);
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X32_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X32_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X32_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,,);
}
#endif
}
#define MATRIX_LOAD_32X32_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X32_B16_LDS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x32_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template
<
int
r
,
int
t
,
class
DataType
>
__forceinline__
__device__
void
inline_matrix_load_32x32_b16_lds
(
DataType
*
shared_addr
,
vec4_uint
srsrc
,
int
&
lds_offset
,
const
int
offset
)
{
#if defined(__gfx938__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X32_B16_LDS
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X32_B16_LDS
(
lds_addr_per_wave
,
srsrc
,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X32_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X32_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X32_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,,);
}
#endif
}
// ======================================================= MLS32x16 ===========================================================
#define MATRIX_LOAD_32X16_B16_LDS_TRANS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template
<
int
r
,
int
t
,
class
DataType
>
__forceinline__
__device__
void
inline_matrix_load_32x16_b16_lds_trans
(
DataType
*
shared_addr
,
vec4_uint
srsrc
,
int
&
lds_offset
,
const
int
offset
)
{
#if defined(__gfx938__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X16_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X16_B16_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X16_B16_LDS_TRANS_GFX946
(
lds_addr_per_wave
,
srsrc
,,);
}
(
void
)
moffset
;
#endif
}
template
<
int
INSTM
,
int
INSTNM
,
int
T
,
int
R
>
__forceinline__
__device__
void
matrix_load_b16_lds_builtin
(
size_t
lds_addr_warp
,
vec4_int
rsrc
,
int
moffset
)
{
#define MATRIX_LOAD_32X16_B16_LDS(LDSADDR, SRSRC, R, T) \
int soffset = LDSADDR + 0x00000000; \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(0) \
:);
#define MATRIX_LOAD_32X16_B16_LDS_GFX946(LDSADDR, SRSRC, R, T) \
asm volatile("s_nop 0\n\t" \
"matrix_load_32x16_b16 %0, %1 moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(0) \
:);
template
<
int
r
,
int
t
,
class
DataType
>
__forceinline__
__device__
void
inline_matrix_load_32x16_b16_lds
(
DataType
*
shared_addr
,
vec4_uint
srsrc
,
int
&
lds_offset
,
const
int
offset
)
{
#if defined(__gfx938__)
int
soffset
=
lds_addr_warp
+
0x00000000
;
if
constexpr
(
INSTM
==
32
&&
INSTNM
==
16
)
{
__builtin_hcu_matrix_load_32x16_b16
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
short
*
)(
soffset
),
0
,
T
,
R
,
0
,
0
);
}
else
if
constexpr
(
INSTM
==
32
&&
INSTNM
==
32
)
{
__builtin_hcu_matrix_load_32x32_b16
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
short
*
)(
soffset
),
0
,
T
,
R
,
0
,
0
);
}
else
if
constexpr
(
INSTM
==
64
&&
INSTNM
==
16
)
{
__builtin_hcu_matrix_load_64x16_b16
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
short
*
)(
soffset
),
0
,
T
,
R
,
0
,
0
);
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
/*
matrix_load_32x32_b16 VDATA, SRSRC, m0 moffset:8 r:1 t:1 lds:1 glc:1 slc:1
VDATA:DST
SRSRC: {sgpr[SRSRC+1], sgpr[SRSRC]}: global base address
sgpr[SRSRC+2]: stride
sgpr[SRSRC+3]: m/nm_filter, cache swizzle, interleave
*/
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X16_B16_LDS
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X16_B16_LDS
(
lds_addr_per_wave
,
srsrc
,,);
}
#elif defined(__gfx946__) || defined(__gfx92a__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_32X16_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_32X16_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,,
t
);
}
else
{
MATRIX_LOAD_32X16_B16_LDS_GFX946
(
lds_addr_per_wave
,
srsrc
,,);
}
(
void
)
moffset
;
#endif
}
...
...
@@ -62,6 +223,25 @@ __forceinline__ __device__ void matrix_load_b16_lds_builtin(size_t lds_addr_warp
:); \
}
#define DS_READ_MATRIX_32X32_B16_GFX946(OFFSET, REG, REG1, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_trans_format %0, %2 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_trans_format %1, %2 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
} else { \
asm volatile( \
"s_nop 0\n\t" \
"ds_read_matrix_format %0, %2 offset:0 element:0x2 row:0x2 col:0x1 alt:0x0\n\t" \
"ds_read_matrix_format %1, %2 offset:1024 element:0x2 row:0x2 col:0x1 alt:0x0\n" \
: "=v"(REG), "=v"(REG1) \
: "s"(OFFSET) \
:); \
}
#define DS_READ_MATRIX_32X16_B16(OFFSET, REG, TRANS) \
if constexpr (TRANS) { \
asm volatile( \
...
...
@@ -141,15 +321,22 @@ __forceinline__ __device__ int inline_min_max(int source) {
}
// ======================================================= def ===========================================================
#define YY_USE_MPERMUTE
template
<
typename
VEC
>
__forceinline__
__device__
void
ds_mpermute_kdim_for_mmac
(
VEC
&
data
)
{
asm
volatile
(
"ds_mpermute_dwordx2 %0, %0 offset:6
\n
"
::
"v"
(
data
));
asm
volatile
(
"ds_mpermute_dwordx2 %0, %0 offset:6
\n
"
:
"+v"
(
data
)
:
);
}
template
<
typename
VEC
>
__forceinline__
__device__
void
ds_mpermute_kdim_for_mmac_wait
(
VEC
&
data
)
{
asm
volatile
(
"ds_mpermute_dwordx2 %0, %0 offset:6
\n\t
s_waitcnt lgkmcnt(0)"
::
"v"
(
data
));
asm
volatile
(
"ds_mpermute_dwordx2 %0, %0 offset:6
\n\t
s_waitcnt lgkmcnt(0)
\n
"
:
"+v"
(
data
)
:
);
}
...
...
@@ -163,7 +350,7 @@ inline __device__ vec4_fp32 mmac_4interleave(const vec4_Element<T> &v1, const ve
template
<
>
inline
__device__
vec4_fp32
mmac_4interleave
<
half_t
,
float
>
(
const
vec4_fp16
&
v1
,
const
vec4_fp16
&
v2
,
const
vec4_fp32
&
v3
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
return
__builtin_hcu_mmac_f32_16x16x16_f16_lit_lts
(
v1
,
v2
,
v3
,
1
,
0
);
#else
return
__builtin_hcu_mmac_f32_16x16x16_f16
(
v1
,
v2
,
v3
);
...
...
@@ -173,7 +360,7 @@ inline __device__ vec4_fp32 mmac_4interleave<half_t, float>(const vec4_fp16 &v1,
template
<
>
inline
__device__
vec4_fp32
mmac_4interleave
<
bhalf_t
,
float
>
(
const
vec4_bf16
&
v1
,
const
vec4_bf16
&
v2
,
const
vec4_fp32
&
v3
)
{
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__)
return
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
v1
,
v2
,
v3
,
1
,
0
);
#else
return
__builtin_hcu_mmac_f32_16x16x16_bf16
(
v1
,
v2
,
v3
);
...
...
csrc/flash_attn_hg/include/intrinsic_mls_ds_b8.h
View file @
518a5f4d
#pragma once
#include "numeric_types.h"
// DTK: __builtin_hcu_matrix_load_*_b8 第二参为 addrspace(3) char*(本机 clang 报错 short* 与 char* 不匹配);b16 用 short* 见 intrinsic_mls_ds.h。
// 改法、soffset(trans +0x80000000)、调用方式与验证:见仓库根目录 ROCM指令迁移到DTK.md §4。
// Inline asm with "s"(vec4_uint) can lower srsrc to VGPR and fail with invalid operand; builtins keep srsrc in the correct class.
template
<
int
r
,
int
t
>
__forceinline__
__device__
void
matrix_load_128x16_b8_lds_trans_builtin
(
size_t
lds_addr_warp
,
vec4_int
rsrc
,
int
/*matrix_offset*/
)
{
#if defined(__gfx938__)
int
soffset
=
static_cast
<
int
>
(
lds_addr_warp
)
+
0x80000000
;
// Third arg must be compile-time constant (same pattern as matrix_load_b16); call sites use matrix_offset==0.
__builtin_hcu_matrix_load_128x16_b8
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
soffset
),
0
,
t
,
r
,
0
,
0
);
#endif
}
#define MATRIX_LOAD_128X16_B8_LDS_TRANS(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
int soffset = LDSADDR + 0x80000000; \
asm volatile("s_nop 4\n\t" \
"matrix_load_128x16_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(soffset), "n"(MATRIX_OFFSET) \
:);
template
<
int
r
,
int
t
,
class
DataType
>
__forceinline__
__device__
void
inline_matrix_load_128x16_b8_lds_trans
(
DataType
*
shared_addr
,
vec4_uint
srsrc
,
int
lds_offset
,
const
int
matrix_offset
)
{
#if defined(__gfx938__)
union
union_vec4_uint
u
;
u
.
v32
=
srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
static_cast
<
size_t
>
(
lds_offset
);
matrix_load_128x16_b8_lds_trans_builtin
<
r
,
t
>
(
lds_addr_warp
,
u
.
i32
,
matrix_offset
);
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_128X16_B8_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_128X16_B8_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_128X16_B8_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,,
t
);
}
else
{
MATRIX_LOAD_128X16_B8_LDS_TRANS
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,,);
}
#endif
}
...
...
@@ -50,28 +43,25 @@ __forceinline__ __device__ void inline_matrix_load_128x16_b8_lds_trans(DataType
}
template
<
int
r
,
int
t
>
__forceinline__
__device__
void
matrix_load_64x32_b8_lds_rearrange_builtin
(
size_t
lds_addr_warp
,
vec4_int
rsrc
,
int
/*matrix_offset*/
)
{
#if defined(__gfx938__)
int
soffset
=
static_cast
<
int
>
(
lds_addr_warp
);
__builtin_hcu_matrix_load_64x32_b8
(
rsrc
,
(
__attribute__
((
address_space
(
3
)))
char
*
)(
soffset
),
0
,
t
,
r
,
0
,
0
);
#endif
}
#define MATRIX_LOAD_64x32_B8_LDS_REARRANGE(LDSADDR, SRSRC, MATRIX_OFFSET, R, T) \
asm volatile("s_nop 4\n\t" \
"matrix_load_64x32_b8 %0, %1, moffset:%2 "#R #T" lds\n" \
:: "s"(SRSRC), "s"(LDSADDR), "n"(MATRIX_OFFSET) \
:);
template
<
int
r
,
int
t
,
class
DataType
>
__forceinline__
__device__
void
inline_matrix_load_64x32_b8_lds_rearrange
(
DataType
*
shared_addr
,
vec4_uint
srsrc
,
int
lds_offset
,
const
int
matrix_offset
)
{
#if defined(__gfx938__)
union
union_vec4_uint
u
;
u
.
v32
=
srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
static_cast
<
size_t
>
(
lds_offset
);
matrix_load_64x32_b8_lds_rearrange_builtin
<
r
,
t
>
(
lds_addr_warp
,
u
.
i32
,
matrix_offset
);
#if defined(__gfx938__) || defined(__gfx946__) || defined(__gfx92a__)
int
lds_addr_per_wave
=
reinterpret_cast
<
size_t
>
(
shared_addr
)
+
(
lds_offset
);
if
constexpr
(
r
&&
t
)
{
MATRIX_LOAD_64x32_B8_LDS_REARRANGE
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,
r
,
t
);
}
else
if
constexpr
(
r
&&
!
t
)
{
MATRIX_LOAD_64x32_B8_LDS_REARRANGE
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,
r
,);
}
else
if
constexpr
(
!
r
&&
t
)
{
MATRIX_LOAD_64x32_B8_LDS_REARRANGE
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,,
t
);
}
else
{
MATRIX_LOAD_64x32_B8_LDS_REARRANGE
(
lds_addr_per_wave
,
srsrc
,
matrix_offset
,,);
}
#endif
}
...
...
csrc/flash_attn_hg/include/kernel_traits.h
View file @
518a5f4d
...
...
@@ -60,11 +60,6 @@ struct Flash_fwd_kernel_traits : public Base {
static
constexpr
size_t
k_smem_size
=
(
STAGES
*
(
kWaveN
/
32
)
*
(
kBlockK
/
32
)
*
(
32
*
34
))
*
sizeof
(
Element
);
static
constexpr
size_t
v_smem_size
=
(
STAGES
*
kBlockK
*
32
/*WARP_K*/
)
*
sizeof
(
Element
);
#if (TARGET == 928)
static
constexpr
int
kSmemSize
=
std
::
max
(
q_smem_size
,
v_smem_size
)
+
k_smem_size
*
2
;
#else
static
constexpr
int
kSmemSize
=
std
::
max
(
std
::
max
(
q_smem_size
,
v_smem_size
),
k_smem_size
*
2
);
#endif
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
...
...
csrc/flash_attn_hg/include/kvcache/gfx92a/f16_kvcache_gfx92a.h
0 → 100755
View file @
518a5f4d
#pragma once
#include "numeric_types.h"
namespace
gfx92a
{
template
<
bool
Is_Varlen
,
int
kHeadDim
,
int
kBlockK
,
int
WARP_M
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
Element
>
__forceinline__
__device__
void
kvcache_prefetch_q_to_vgpr
(
Element
*
q_ptr
,
Element
*
q_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
int
warp_id
,
int
query_seqlen_stride
,
int
query_ngroup_stride
,
int
ngroups
,
int
max_seq_q_offset
=
0
)
{
constexpr
int
elementBytes
=
sizeof
(
Element
);
// resource regs
auto
q_addr
=
prepare_for_buffer_load
<
kHeadDim
,
Element
,
false
>
(
q_ptr
);
if
constexpr
(
Is_Varlen
)
{
int
lane_id
=
int
(
threadIdx
.
x
)
&
63
;
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
if
constexpr
(
kHeadDim
==
128
and
WARP_NUM
==
4
)
{
for
(
int
load
=
0
;
load
<
M_MMAC_COUNT
;
++
load
)
{
int
q_row
=
min
(
load
*
16
+
(
lane_id
>>
2
),
max_seq_q_offset
-
1
);
int
q_col
=
warp_id
*
32
+
(
lane_id
&
3
)
*
8
;
int
q_row_seq
=
q_row
/
ngroups
;
int
q_row_regroup
=
q_row
-
q_row_seq
*
ngroups
;
int
q_load_offset
=
q_row_seq
*
ngroups
*
query_seqlen_stride
+
q_row_regroup
*
query_ngroup_stride
+
q_col
;
int
q_lds_write_offset
=
(
load
*
4
+
warp_id
)
*
16
*
32
;
inline_buffer_load_dwordx4_lds
<
Element
,
1
>
(
q_lds
,
q_addr
,
q_lds_write_offset
,
0
,
q_load_offset
);
}
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
for
(
int
load
=
0
;
load
<
M_MMAC_COUNT
;
++
load
)
{
for
(
int
neighbor
=
0
;
neighbor
<
WARP_NUM
;
++
neighbor
)
{
int
q_lds_load_offset
=
(
load
*
4
+
neighbor
)
*
16
*
32
+
(
lane_id
&
15
)
*
32
+
(
lane_id
>>
4
)
*
4
;
int
q_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
q_lds
+
q_lds_load_offset
);
inline_ds_read2_b64
(
q_lds_load_bytes
,
q_reg
[
neighbor
*
2
+
load
].
f32
,
0
,
4
);
}
}
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
}
else
{
#pragma unroll
for
(
int
neighbor
=
0
;
neighbor
<
WARP_NUM
;
++
neighbor
)
{
#pragma unroll
for
(
int
load
=
0
;
load
<
M_MMAC_COUNT
;
++
load
)
{
int
q_row
=
min
(
load
*
16
+
(
lane_id
&
15
),
max_seq_q_offset
-
1
);
int
q_col
=
neighbor
*
32
+
(
lane_id
>>
4
)
*
4
;
int
q_row_seq
=
q_row
/
ngroups
;
int
q_row_regroup
=
q_row
-
q_row_seq
*
ngroups
;
int
q_load_offset
=
q_row_seq
*
ngroups
*
query_seqlen_stride
+
q_row_regroup
*
query_ngroup_stride
+
q_col
;
q_reg
[
neighbor
*
2
+
load
].
data
[
0
]
=
*
(
double
*
)(
q_ptr
+
q_load_offset
);
q_reg
[
neighbor
*
2
+
load
].
data
[
1
]
=
*
(
double
*
)(
q_ptr
+
q_load_offset
+
16
);
}
}
}
}
else
{
if
constexpr
(
kHeadDim
==
128
and
WARP_NUM
==
4
)
{
// prepare mls resource regs
vec4_uint
q_srsrc
;
q_srsrc
[
1
]
=
q_addr
[
1
];
q_srsrc
[
2
]
=
query_seqlen_stride
;
// global offset along seqlen_q
int
q_loop
=
0
;
int
q_seq_offset
=
q_loop
*
kBlockK
;
// global offset along headdim
int
q_dim_offset
=
warp_id
*
kBlockK
;
// global bytes
q_srsrc
[
0
]
=
q_addr
[
0
]
+
(
q_seq_offset
+
q_dim_offset
)
*
elementBytes
;
if
constexpr
(
true
)
{
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
32
-
max_seq_q_offset
);
q_srsrc
[
3
]
=
max_seq_q_offset
%
32
==
0
?
0
:
nm_filter
<<
8
;
}
// compute lds write offset, each warp occupy 32 * 32 * sizeof(f16) = 2KB
int
q_lds_write_offset
=
warp_id
*
(
WARP_M
/
32
)
*
(
kBlockK
/
32
)
*
(
32
*
32
);
int
q_lds_offset_bytes
=
q_lds_write_offset
*
elementBytes
;
// flash::wait_lds_data_arrived<true>(0);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
q_lds
,
q_srsrc
,
q_lds_offset_bytes
,
0
);
// wait q data arrived
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
// lds -> vgprs
if
constexpr
(
M_MMAC_COUNT
==
1
)
{
DS_READ_MATRIX_32X16_B16
(
0
*
32
*
32
*
2
,
q_reg
[
0
*
2
].
f16
,
true
);
DS_READ_MATRIX_32X16_B16
(
1
*
32
*
32
*
2
,
q_reg
[
1
*
2
].
f16
,
true
);
DS_READ_MATRIX_32X16_B16
(
2
*
32
*
32
*
2
,
q_reg
[
2
*
2
].
f16
,
true
);
DS_READ_MATRIX_32X16_B16
(
3
*
32
*
32
*
2
,
q_reg
[
3
*
2
].
f16
,
true
);
}
else
{
DS_READ_MATRIX_32X32_B16
(
0
*
32
*
32
*
2
,
q_reg
[
0
*
2
].
f16
,
q_reg
[
0
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
1
*
32
*
32
*
2
,
q_reg
[
1
*
2
].
f16
,
q_reg
[
1
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
2
*
32
*
32
*
2
,
q_reg
[
2
*
2
].
f16
,
q_reg
[
2
*
2
+
1
].
f16
,
true
);
DS_READ_MATRIX_32X32_B16
(
3
*
32
*
32
*
2
,
q_reg
[
3
*
2
].
f16
,
q_reg
[
3
*
2
+
1
].
f16
,
true
);
}
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
}
else
{
// TODO
}
}
}
template
<
int
kBlockK
,
int
WARP_N
,
int
prefetchKLevel
,
typename
Element
>
__forceinline__
__device__
void
kvcache_prefetch_k_to_lds
(
vec4_uint
k_addr
,
Element
*
k_lds
,
int
warp_id
,
int
k_seq_stride
,
int
max_seq_k_offset
=
0
)
{
constexpr
int
elementBytes
=
sizeof
(
Element
);
// prepare mls resource regs
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
k_seq_stride
;
// pingpong buffer stage
int
stage_id
=
0
;
// tile id along headdim dimension
int
k_loop
=
0
;
// occupy 4 * 2 * 2 * 32 * 32 * sizeof(f16) = 32 KB, in total
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
prefetchKLevel
;
++
prefetch_id
)
{
// global bytes along headdim
int
k_dim_bytes
=
(
k_loop
+
prefetch_id
)
*
kBlockK
*
elementBytes
;
// global bytes along seqlen
int
k_seq_bytes
;
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_N
+
32
-
max_seq_k_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
k_seq_bytes
=
real_mls_warp_id
*
WARP_N
*
k_seq_stride
*
elementBytes
;
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_N
+
32
-
max_seq_k_offset
);
k_srsrc
[
3
]
=
nm_filter
<<
8
;
}
// acquire buffer address
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
k_dim_bytes
+
k_seq_bytes
);
// compute lds offset / bytes
int
k_lds_stage_offset
=
(
warp_id
*
prefetchKLevel
+
prefetch_id
)
*
(
WARP_N
/
32
)
*
(
kBlockK
/
32
)
*
(
32
*
32
);
int
lds_offset_bytes
=
k_lds_stage_offset
*
elementBytes
;
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset_bytes
,
0
);
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kHeadDim
,
int
kBlockN
,
int
WARP_K
,
int
STAGES
,
int
prefetchVLevel
,
typename
Element
>
__forceinline__
__device__
void
kvcache_prefetch_v_to_lds
(
vec4_uint
v_addr
,
Element
*
v_lds
,
int
warp_id
,
int
v_seq_stride
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
V_LOAD_REQUESTS
=
(
WARP_K
*
kBlockN
)
/
(
32
*
32
);
constexpr
int
elementBytes
=
2
;
// prepare mls resource regs
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_seq_stride
;
if
constexpr
(
prefetchVLevel
==
2
)
{
// tile loop
int
n_loop
=
0
;
// ping-ping stage
int
stage_id
=
0
;
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
prefetchVLevel
;
++
prefetch_id
)
{
// global bytes along headdim dimension
int
v_dim_bytes
=
(
n_loop
+
prefetch_id
)
*
kBlockN
*
elementBytes
;
// global bytes along seq dimension
int
v_seq_bytes
;
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
v_seq_bytes
=
real_mls_warp_id
*
WARP_K
*
v_seq_stride
*
elementBytes
;
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
);
v_srsrc
[
3
]
=
max_seq_kv_offset
%
kBlockN
==
0
?
0
:
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_seq_bytes
+
v_dim_bytes
);
// lds bytes
int
v_lds_write_offset
=
(
warp_id
*
STAGES
*
prefetchVLevel
+
stage_id
*
prefetchVLevel
+
prefetch_id
)
*
(
V_LOAD_REQUESTS
*
32
*
32
);
int
v_lds_write_bytes
=
v_lds_write_offset
*
elementBytes
;
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_lds_write_bytes
,
0
);
}
}
else
if
(
prefetchVLevel
==
4
)
{
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
prefetchVLevel
;
++
prefetch_id
)
{
// global bytes along headdim dimension
int
v_dim_bytes
=
prefetch_id
*
kBlockN
*
elementBytes
;
// global bytes along seq dimension
int
v_seq_bytes
;
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
v_seq_bytes
=
real_mls_warp_id
*
WARP_K
*
v_seq_stride
*
elementBytes
;
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
);
v_srsrc
[
3
]
=
max_seq_kv_offset
%
kBlockN
==
0
?
0
:
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_seq_bytes
+
v_dim_bytes
);
// lds bytes
int
v_lds_write_offset
=
(
warp_id
*
prefetchVLevel
+
prefetch_id
)
*
(
V_LOAD_REQUESTS
*
32
*
32
);
int
v_lds_write_bytes
=
v_lds_write_offset
*
elementBytes
;
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_lds_write_bytes
,
0
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
kHeadDim
,
int
kHeadDimV
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
prefetchKLevel
,
int
prefetchVLevel
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
kvcache_qk_gemm_prefetch_v
(
vec4_uint
k_addr
,
vec4_uint
v_addr
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec4_f16x2
<
Element
>
q_reg
[(
kHeadDim
/
kBlockK
)
*
(
WARP_M
*
kBlockK
)
/
(
32
*
32
)
*
2
],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
int
warp_id
,
int
k_seq_stride
,
int
v_seq_stride
,
int
max_seq_kv_offset
=
0
)
{
static_assert
(
WARP_M
==
32
and
WARP_N
==
32
and
kBlockK
==
32
and
"To simplify, only WARP_M = WARP_N = kBlockK = 32 is supported!"
);
static_assert
(
prefetchKLevel
==
4
and
"To simplify, only prefetchKLevel = 4 is supported"
);
constexpr
int
K_LOAD_REQUESTS
=
(
WARP_N
/
32
)
*
(
kBlockK
/
32
);
constexpr
int
elementBytes
=
2
;
// alloc k_regs, 32x32 f16 per warp, and thus 16 f16 for each threads
union_vec4_f16x2
<
Element
>
k_reg
[
1
*
(
WARP_N
*
kBlockK
)
/
(
32
*
32
)
*
2
];
// s_reg initialize
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
b64
[
0
]
=
__builtin_hcu_mov_b64
(
0x0
);
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
b64
[
1
]
=
__builtin_hcu_mov_b64
(
0x0
);
}
}
// qk gemm main loop, along kheaddim dimension
for
(
int
k_loop
=
0
;
k_loop
<
(
kHeadDim
/
kBlockK
);
k_loop
+=
1
)
{
flash
::
wait_buffer_data_arrived
<
false
>
((
kHeadDim
/
kBlockK
)
-
1
-
k_loop
);
// lds -> vgprs
int
k_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
(
warp_id
*
prefetchKLevel
+
k_loop
)
*
K_LOAD_REQUESTS
*
(
32
*
32
)
*
elementBytes
;
DS_READ_MATRIX_32X32_B16
(
k_lds_load_bytes
,
k_reg
[
0
].
f16
,
k_reg
[
1
].
f16
,
true
);
// mmac flow
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
2
-
1
-
min_tile_n
);
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
q_tile_id
=
k_loop
*
2
+
min_tile_m
;
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
Element
,
ElementAccum
>
(
q_reg
[
q_tile_id
].
f16x4
[
min_tile_k
],
k_reg
[
min_tile_n
].
f16x4
[
min_tile_k
],
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
// need to reduce results on scores_max and prefetch V, and thus sync
// can be simplified as flash::wait_all_warp_arrived()
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
// prefetch v
// can be rearranged while qk doing mmac
// gfx92a::kvcache_prefetch_v_to_lds<kHeadDimV, kBlockK, kBlockK, 2/*STAGES*/, prefetchVLevel, Element>(v_addr, v_lds, warp_id, v_seq_stride, max_seq_kv_offset);
}
template
<
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
DataType
>
__forceinline__
__device__
void
kvcache_apply_mask
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
max_seqlen_k
,
const
int
col_idx_offset_
=
0
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
*
4
;
if
(
col_idx
>=
max_seqlen_k
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
-
INFINITY
;
}
}
}
}
}
}
}
template
<
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
,
bool
Is_Varlen
,
typename
DataType
>
__forceinline__
__device__
void
kvcache_apply_mask_causal
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
ngroups
,
const
int
mtp
,
const
int
layout
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
int
col_idx_limit_right
;
if
constexpr
(
Is_Varlen
)
{
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
(
row_idx
/
ngroups
)
/*only for layout 1: bshd*/
+
max_seqlen_k
-
(
max_seqlen_q
/
ngroups
));
}
else
{
const
int
row_in_mtp
=
layout
==
0
?
(
row_idx
%
mtp
)
:
(
row_idx
/
ngroups
);
col_idx_limit_right
=
std
::
min
(
max_seqlen_k
,
row_in_mtp
+
max_seqlen_k
-
mtp
);
}
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
*
4
;
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
convert_attn_f32_to_f16
(
union_vec4_fp32
s_reg
[
1
][
4
],
union_vec2_f16x2
<
Element
>
p_reg
[
1
][
4
])
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
p_reg
[
0
][
0
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
0
][
0
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
p_reg
[
0
][
1
*
2
+
min_tile_m
].
f16x2
[
min_tile_k
]
=
DownCastPair
<
float
,
Element
>
(
s_reg
[
0
][
1
*
2
+
min_tile_m
].
f32x2
[
min_tile_k
]);
}
}
}
template
<
bool
prefetchK
,
int
K_LOOP_COUNT
,
int
kBlockN
,
int
kBlockK
,
int
M_WARP_COUNT
,
int
PV_N_WARP_COUNT
,
int
PV_K_WARP_COUNT
,
int
STAGES
,
int
prefetchKLevel
,
int
prefetchVLevel
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
kvcache_pv_gemm_prefetch_k
(
vec4_uint
v_addr
,
vec4_uint
k_addr
,
Element
*
v_lds
,
Element
*
k_lds
,
union_vec2_f16x2
<
Element
>
p_reg
[
M_WARP_COUNT
*
PV_K_WARP_COUNT
][
4
],
vec4_Accum
<
ElementAccum
>
pv_reg
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
(
kBlockN
/
32
)][
4
],
int
warp_id
,
int
v_seq_stride
,
int
k_seq_stride
,
int
max_seq_kv_offset
=
0
)
{
constexpr
int
WARP_K
=
PV_K_WARP_COUNT
*
32
;
static_assert
(
kBlockK
>=
32
,
"Error: pv gemm kBlockK must be equal or greater than 32"
);
static_assert
(
kBlockN
==
PV_N_WARP_COUNT
*
32
,
"Error: kBlockN in kvcache_pv_gemm_prefetch_k must be WARP_N * 32"
);
static_assert
(
M_WARP_COUNT
==
1
,
"for gfx938, only WARP_M = 32 is supported yet!"
);
static_assert
(
PV_N_WARP_COUNT
==
1
,
"for gfx938, only WARP_N = 32 is supported yet!"
);
static_assert
(
PV_K_WARP_COUNT
==
1
,
"for gfx938, only WARP_K = 32 is supported yet!"
);
constexpr
int
V_LOAD_REQUESTS
=
(
WARP_K
*
kBlockN
)
/
(
32
*
32
);
constexpr
int
elementBytes
=
2
;
// sync lds usage for reducing max/sum
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
// __syncthreads();
if
constexpr
(
prefetchVLevel
==
2
)
{
// hold v regs
union_vec4_f16x2
<
Element
>
v_reg
[
1
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
*
2
];
// prepare v resource regs
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_seq_stride
;
// pingpong stage
int
stage_id
=
(
STAGES
==
2
)
?
1
:
0
;
// make p 4-interleave layout for pv gemm
// strange: delete wait, results are wrong even if flash::wait_lds_data_arrived<false>(0);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
0
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
1
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
2
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
3
].
f16x4
);
// pv gemm main loop
constexpr
int
N_LOOP_STEP
=
(
STAGES
==
2
)
?
prefetchVLevel
:
1
;
constexpr
int
N_LOOP_START
=
(
STAGES
==
2
)
?
N_LOOP_STEP
:
1
;
for
(
int
n_loop
=
N_LOOP_START
;
n_loop
<
K_LOOP_COUNT
;
n_loop
+=
N_LOOP_STEP
)
{
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
prefetchVLevel
;
++
prefetch_id
)
{
// global bytes along headdim dimension
int
v_dim_bytes
=
(
n_loop
+
prefetch_id
)
*
kBlockN
*
elementBytes
;
// global bytes along seq dimension
int
v_seq_bytes
;
if
constexpr
(
true
)
{
int
nm_filter_max
=
warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
// can be simplified after gfx938
v_seq_bytes
=
real_mls_warp_id
*
WARP_K
*
v_seq_stride
*
elementBytes
;
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
WARP_K
+
32
-
max_seq_kv_offset
);
v_srsrc
[
3
]
=
max_seq_kv_offset
%
kBlockN
==
0
?
0
:
nm_filter
<<
8
;
v_srsrc
[
3
]
+=
0x20000
;
}
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_seq_bytes
+
v_dim_bytes
);
// lds write bytes
int
v_lds_write_offset
=
(
warp_id
*
STAGES
*
prefetchVLevel
+
stage_id
*
prefetchVLevel
+
prefetch_id
)
*
(
V_LOAD_REQUESTS
*
32
*
32
);
int
v_lds_write_bytes
=
v_lds_write_offset
*
elementBytes
;
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_lds_write_bytes
,
0
);
}
// wait v data stored in lds
if
constexpr
(
N_LOOP_STEP
==
2
)
{
flash
::
wait_buffer_data_arrived
<
false
>
((
prefetchVLevel
+
prefetchVLevel
-
1
)
*
V_LOAD_REQUESTS
);
}
else
if
constexpr
(
N_LOOP_STEP
==
1
and
STAGES
==
2
)
{
flash
::
wait_buffer_data_arrived
<
false
>
(
1
*
V_LOAD_REQUESTS
);
}
else
if
constexpr
(
N_LOOP_STEP
==
1
and
STAGES
==
1
)
{
flash
::
wait_buffer_data_arrived
<
false
>
(
0
);
}
// roll stage
if
constexpr
(
STAGES
==
2
)
{
stage_id
^=
1
;
}
// lds -> vgprs
int
v_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
prefetchVLevel
+
stage_id
*
prefetchVLevel
+
0
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
elementBytes
;
DS_READ_MATRIX_32X32_B16_ALT2
(
v_lds_load_bytes
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// pv mmac flow
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
2
-
1
-
min_tile_k
);
int
pv_tile_id
=
(
STAGES
==
2
)
?
n_loop
-
2
:
n_loop
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
// process second tile of pv gemm
if
constexpr
(
prefetchVLevel
==
2
)
{
flash
::
wait_buffer_data_arrived
<
false
>
(
prefetchVLevel
*
V_LOAD_REQUESTS
);
// lds -> vgprs
int
v_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
prefetchVLevel
+
stage_id
*
prefetchVLevel
+
1
/*prefetch_id*/
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
elementBytes
;
DS_READ_MATRIX_32X32_B16_ALT2
(
v_lds_load_bytes
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// pv gemm mmac flow
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
2
-
1
-
min_tile_k
);
int
pv_tile_id
=
(
STAGES
==
2
)
?
n_loop
-
1
:
n_loop
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
if
constexpr
(
STAGES
==
2
)
{
int
n_loop
=
K_LOOP_COUNT
;
// wait v stored in lds
flash
::
wait_buffer_data_arrived
<
false
>
((
prefetchVLevel
-
1
)
*
V_LOAD_REQUESTS
);
// roll stage
stage_id
^=
1
;
// lds -> vgprs
int
v_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
prefetchVLevel
+
stage_id
*
prefetchVLevel
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
elementBytes
;
DS_READ_MATRIX_32X32_B16_ALT2
(
v_lds_load_bytes
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// pv gemm mmac flow
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
2
-
1
-
min_tile_k
);
int
pv_tile_id
=
n_loop
-
2
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
// process second tile of pv gemm
if
constexpr
(
N_LOOP_STEP
==
2
)
{
flash
::
wait_buffer_data_arrived
<
false
>
(
0
);
// lds -> vgprs
int
v_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
STAGES
*
prefetchVLevel
+
stage_id
*
prefetchVLevel
+
1
/*prefetch_id*/
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
elementBytes
;
DS_READ_MATRIX_32X32_B16_ALT2
(
v_lds_load_bytes
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// pv gemm mmac flow
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
2
-
1
-
min_tile_k
);
int
pv_tile_id
=
n_loop
-
1
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
}
else
if
constexpr
(
prefetchVLevel
==
4
)
{
bool
can_prefetch_k
=
max_seq_kv_offset
>
kBlockK
;
if
constexpr
(
prefetchK
)
{
if
(
can_prefetch_k
)
{
gfx92a
::
kvcache_prefetch_k_to_lds
<
kBlockN
,
PV_N_WARP_COUNT
*
32
,
prefetchKLevel
,
Element
>
(
k_addr
,
k_lds
,
warp_id
,
k_seq_stride
,
max_seq_kv_offset
-
kBlockK
);
}
}
// hold v regs
union_vec4_f16x2
<
Element
>
v_reg
[
1
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
*
2
];
// prepare v resource regs
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_seq_stride
;
// make p 4-interleave layout for pv gemm
// strange: delete wait, results are wrong even if flash::wait_lds_data_arrived<false>(0);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
0
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
1
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
2
].
f16x4
);
ds_mpermute_kdim_for_mmac
(
p_reg
[
0
][
3
].
f16x4
);
// wait v data stored in lds
if
constexpr
(
prefetchK
)
{
if
(
can_prefetch_k
)
{
flash
::
wait_buffer_data_arrived
<
false
>
(
prefetchKLevel
/*4 for hdim 128*/
);
}
else
{
flash
::
wait_buffer_data_arrived
<
false
>
(
0
);
}
}
else
{
flash
::
wait_buffer_data_arrived
<
false
>
(
0
);
}
// pv gemm main loop
for
(
int
n_loop
=
0
;
n_loop
<
K_LOOP_COUNT
;
n_loop
+=
1
)
{
// lds -> vgprs
int
v_lds_load_bytes
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
(
warp_id
*
prefetchVLevel
+
n_loop
)
*
(
V_LOAD_REQUESTS
*
32
*
32
)
*
elementBytes
;
DS_READ_MATRIX_32X32_B16_ALT2
(
v_lds_load_bytes
,
v_reg
[
0
].
f16
,
v_reg
[
1
].
f16
,
false
);
// pv mmac flow
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
flash
::
wait_lds_data_arrived
<
false
>
(
2
-
1
-
min_tile_k
);
int
pv_tile_id
=
n_loop
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
flash
::
mmac_4interleave
<
Element
,
ElementAccum
>
(
p_reg
[
0
][
min_tile_k
*
2
+
min_tile_m
].
f16x4
,
v_reg
[
min_tile_k
].
f16x4
[
min_tile_n
],
pv_reg
[
pv_tile_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
}
}
}
// sync lds usage
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
}
template
<
int
K_LOOP_COUNT
,
int
K_WARP_COUNT
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
int
WARP_NUM
,
typename
ElementAccum
>
__forceinline__
__device__
void
kvcache_acco_reduce_tile16x32
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
ElementAccum
*
acc_o_lds
,
int
seqlen_q
,
int
warp_id
,
int
lane_id
)
{
if
constexpr
(
K_LOOP_COUNT
==
4
and
WARP_NUM
==
4
and
K_WARP_COUNT
==
1
)
{
constexpr
int
mmacVgprs
=
4
;
constexpr
int
tile16x32Vgprs
=
64
*
mmacVgprs
;
constexpr
int
tile32x32Vgprs
=
2
*
tile16x32Vgprs
;
constexpr
int
warpVgprs
=
WARP_NUM
*
tile32x32Vgprs
;
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
#pragma unroll
for
(
int
h_idx
=
0
;
h_idx
<
K_LOOP_COUNT
;
++
h_idx
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
int
lds_offset
=
warp_id
*
warpVgprs
+
h_idx
*
tile32x32Vgprs
+
min_tile_m
*
tile16x32Vgprs
+
lane_id
*
mmacVgprs
;
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
)
=
acc_o
[
h_idx
][
min_tile_n
*
2
+
min_tile_m
].
f32
;
}
}
__syncthreads
();
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// lds base
ElementAccum
*
acc_o_lds_ptr
=
acc_o_lds
+
0
*
warpVgprs
+
warp_id
/*h_idx*/
*
tile32x32Vgprs
+
min_tile_m
*
tile16x32Vgprs
+
lane_id
*
mmacVgprs
;
// load data of warp0 as accum base
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
*
(
vec4_fp32
*
)(
acc_o_lds_ptr
+
0
*
warpVgprs
);
// load warp 1, 2, 3
auto
neighbor1
=
*
(
union_vec4_fp32
*
)(
acc_o_lds_ptr
+
1
*
warpVgprs
);
auto
neighbor2
=
*
(
union_vec4_fp32
*
)(
acc_o_lds_ptr
+
2
*
warpVgprs
);
auto
neighbor3
=
*
(
union_vec4_fp32
*
)(
acc_o_lds_ptr
+
3
*
warpVgprs
);
// accumulate acc_o of all warps
#pragma unroll
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_id
]
=
__builtin_hcu_pk_add_f32
(
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_id
],
neighbor1
.
u64
[
vec_id
]);
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_id
]
=
__builtin_hcu_pk_add_f32
(
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_id
],
neighbor2
.
u64
[
vec_id
]);
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_id
]
=
__builtin_hcu_pk_add_f32
(
acc_o
[
0
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_id
],
neighbor3
.
u64
[
vec_id
]);
}
}
__syncthreads
();
}
}
else
{
// To be inplemented
}
}
template
<
bool
Is_Varlen
,
bool
Split
,
int
kBlockK
,
int
WARP_NUM
,
int
K_LOOP_COUNT
,
int
M_MMAC_COUNT
,
typename
SplitkvAccumType
,
typename
ElementAccum
,
typename
Params
>
__forceinline__
__device__
void
kvcache_varlen_epilogue_store_output
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
][
4
],
Params
params
,
int64_t
row_offset_o
,
int
seqlen_q_limit
,
int
warp_id
,
int
lane_id
)
{
int
o_mmac_row
=
lane_id
&
15
;
int
o_mmac_col
=
lane_id
>>
4
;
int
o_seq_stride
=
params
.
o_row_stride
;
SplitkvAccumType
*
o_ptr
=
reinterpret_cast
<
SplitkvAccumType
*>
(
Split
?
params
.
oaccum_ptr
:
params
.
o_ptr
)
+
row_offset_o
;
if
constexpr
(
K_LOOP_COUNT
==
4
and
WARP_NUM
==
4
)
{
// each warp output serveral tiles separately
#pragma unroll
for
(
int
k_loop
=
0
;
k_loop
<
K_LOOP_COUNT
;
k_loop
+=
WARP_NUM
/*1*/
)
{
int
tile_32x32_id
=
0
/*k_loop*/
;
union_vec4_f16x2
<
SplitkvAccumType
>
o_data
[
M_MMAC_COUNT
];
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// 2-interleave
o_data
[
min_tile_m
].
f16x2
[
0
+
0
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
0
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
1
]);
o_data
[
min_tile_m
].
f16x2
[
1
+
0
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
2
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
3
]);
o_data
[
min_tile_m
].
f16x2
[
0
+
1
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
0
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
1
]);
o_data
[
min_tile_m
].
f16x2
[
1
+
1
*
2
]
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
2
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
3
]);
// make 4-interleave
ds_mpermute_kdim_for_mmac
(
o_data
[
min_tile_m
].
f16x4
[
0
]);
ds_mpermute_kdim_for_mmac
(
o_data
[
min_tile_m
].
f16x4
[
1
]);
}
union_vec4_f16x2
<
SplitkvAccumType
>
o_dwordx4
[
M_MMAC_COUNT
];
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
flash
::
wait_lds_data_arrived
<
false
>
((
M_MMAC_COUNT
-
1
-
min_tile_m
)
*
2
);
o_dwordx4
[
min_tile_m
].
f16
[
0
]
=
o_data
[
min_tile_m
].
f16
[
0
];
o_dwordx4
[
min_tile_m
].
f16
[
1
]
=
o_data
[
min_tile_m
].
f16
[
4
];
o_dwordx4
[
min_tile_m
].
f16
[
2
]
=
o_data
[
min_tile_m
].
f16
[
1
];
o_dwordx4
[
min_tile_m
].
f16
[
3
]
=
o_data
[
min_tile_m
].
f16
[
5
];
o_dwordx4
[
min_tile_m
].
f16
[
4
]
=
o_data
[
min_tile_m
].
f16
[
2
];
o_dwordx4
[
min_tile_m
].
f16
[
5
]
=
o_data
[
min_tile_m
].
f16
[
6
];
o_dwordx4
[
min_tile_m
].
f16
[
6
]
=
o_data
[
min_tile_m
].
f16
[
3
];
o_dwordx4
[
min_tile_m
].
f16
[
7
]
=
o_data
[
min_tile_m
].
f16
[
7
];
}
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
// store 4 dwords into global memory
int
seqlen_q_idx
=
o_mmac_row
+
min_tile_m
*
16
;
if
(
seqlen_q_idx
<
seqlen_q_limit
)
{
int
pv_global_addr
;
if
constexpr
(
Is_Varlen
)
{
int
true_seqlen_q
=
seqlen_q_idx
/
params
.
ngroups
;
int
true_group_id
=
seqlen_q_idx
%
params
.
ngroups
;
pv_global_addr
=
true_seqlen_q
*
params
.
ngroups
*
o_seq_stride
+
true_group_id
*
params
.
o_head_stride
+
(
warp_id
+
0
)
*
kBlockK
+
o_mmac_col
*
8
;
}
else
{
pv_global_addr
=
seqlen_q_idx
*
o_seq_stride
+
(
warp_id
+
0
)
*
kBlockK
+
o_mmac_col
*
8
;
}
*
(
vec4_fp32
*
)(
o_ptr
+
pv_global_addr
)
=
o_dwordx4
[
min_tile_m
].
f32
;
}
}
}
}
else
{
// To be inplemented
}
}
}
// end of namespace gfx92a
\ No newline at end of file
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_epilogue_gfx938.h
View file @
518a5f4d
#pragma once
#include "numeric_types.h"
#include "intrinsic.h"
__forceinline__
__device__
float
fp8_kvcache_attention_sink_load
(
const
void
*
s_aux_ptr
,
int
s_aux_type
,
int
head_idx
)
{
if
(
s_aux_type
==
1
)
{
return
reinterpret_cast
<
const
float
*>
(
s_aux_ptr
)[
head_idx
];
}
else
if
(
s_aux_type
==
2
)
{
return
UpCast
<
half_t
,
float
>
(
reinterpret_cast
<
const
half_t
*>
(
s_aux_ptr
)[
head_idx
]);
}
else
{
return
UpCast
<
BFloat16
,
float
>
(
reinterpret_cast
<
const
BFloat16
*>
(
s_aux_ptr
)[
head_idx
]);
}
}
template
<
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_kvcache_apply_attention_sink_gfx938
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
vec2_Accum
<
ElementAccum
>
scores_max
[
M_WARP_COUNT
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
],
const
void
*
s_aux_ptr
,
int
s_aux_type
,
int
bidh
,
int
reduced_num_heads
,
int
ngroups
,
int
m_block
,
int
kBlockM
,
int
lane_id
,
ElementAccum
scale_softmax
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row
=
m_block
*
kBlockM
+
mi
*
32
+
(
lane_id
&
15
)
+
min_tile_m
*
16
;
const
int
group_id
=
row
%
ngroups
;
const
int
sink_head
=
bidh
*
ngroups
+
group_id
;
const
ElementAccum
sink_value
=
fp8_kvcache_attention_sink_load
(
s_aux_ptr
,
s_aux_type
,
sink_head
);
const
ElementAccum
old_scaled_max
=
scores_max
[
mi
].
f32
[
min_tile_m
]
*
scale_softmax
;
const
ElementAccum
new_scaled_max
=
max
(
old_scaled_max
,
sink_value
);
const
ElementAccum
old_rescale
=
__expf
(
old_scaled_max
-
new_scaled_max
);
scores_sum
[
mi
].
f32
[
min_tile_m
]
=
scores_sum
[
mi
].
f32
[
min_tile_m
]
*
old_rescale
+
__expf
(
sink_value
-
new_scaled_max
);
scores_max
[
mi
].
f32
[
min_tile_m
]
=
new_scaled_max
/
scale_softmax
;
__float2
old_rescale_pair
=
{
old_rescale
,
old_rescale
};
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
K_LOOP_COUNT
;
++
pv_n_loop
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
K_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
const
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#pragma unroll
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
old_rescale_pair
);
}
}
}
}
}
}
}
template
<
typename
Params
,
int
kHeadDimV
,
int
kHeadDimVSplit
,
bool
Interleave2
,
bool
Split
,
typename
SplitkvAccumType
,
typename
ElementAccum
,
int
kBlockM
,
int
kBlockK
,
int
WARP_NUM
,
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
>
__forceinline__
__device__
void
kvcache_epilogue_store_output_gfx938
(
...
...
@@ -20,8 +79,9 @@ __forceinline__ __device__ void kvcache_epilogue_store_output_gfx938(
:
reinterpret_cast
<
SplitkvAccumType
*>
(
params
.
o_ptr
)
+
row_offset_o
;
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
// Specialized optimizatio for headdim 128
constexpr
int
OPT_FOR_HDIM128
=
bool
(
WARP_NUM
==
4
and
M_MMAC_COUNT
==
1
);
// Specialized optimization for headdim 128. Dim256 is split into two
// 128-column stores so it can use the same layout per split.
constexpr
int
OPT_FOR_HDIM128
=
bool
(
WARP_NUM
==
4
and
M_MMAC_COUNT
==
1
and
K_LOOP_COUNT
==
WARP_NUM
);
if
constexpr
(
not
OPT_FOR_HDIM128
)
{
if
(
warp_id
>
0
)
return
;
}
...
...
@@ -90,8 +150,9 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
auto
gO
=
prepare_for_buffer_load
<
kHeadDimV
,
SplitkvAccumType
,
false
/*USE_CACHE_SWIZZLE*/
>
(
o_ptr
);
int
pv_lane_seq_idx
=
lane_id
&
15
;
int
pv_lane_head_dim_idx
=
lane_id
>>
4
;
// Specialized optimizatio for headdim 128
constexpr
int
OPT_FOR_HDIM128
=
bool
(
WARP_NUM
==
4
and
M_MMAC_COUNT
==
1
);
// Specialized optimization for headdim 128. Dim256 is split into two
// 128-column stores so it can use the same layout per split.
constexpr
int
OPT_FOR_HDIM128
=
bool
(
WARP_NUM
==
4
and
M_MMAC_COUNT
==
1
and
K_LOOP_COUNT
==
WARP_NUM
);
if
constexpr
(
not
OPT_FOR_HDIM128
)
{
if
(
warp_id
>
0
)
return
;
}
...
...
@@ -124,3 +185,39 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output_gfx938(
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention epilogue helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
K_LOOP_COUNT
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
M_MMAC_COUNT
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_kvcache_epilogue_rescale_acco_gfx938
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
vec2_Accum
<
ElementAccum
>
scores_sum
[
M_WARP_COUNT
],
ElementAccum
v_descale
)
{
#pragma unroll
for
(
int
pv_n_loop
=
0
;
pv_n_loop
<
K_LOOP_COUNT
;
++
pv_n_loop
)
{
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
#pragma unroll
for
(
int
ni
=
0
;
ni
<
K_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
ElementAccum
sum
=
scores_sum
[
mi
].
f32
[
min_tile_m
];
ElementAccum
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
v_descale
:
v_descale
/
sum
;
__float2
scale_pair
=
{
inv_sum
,
inv_sum
};
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#pragma unroll
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
}
}
}
}
}
}
}
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_pv_gemm_prefetch_k_gfx938.h
View file @
518a5f4d
...
...
@@ -42,6 +42,13 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938(
constexpr
int
N_LOOP_END
=
0
;
for
(
int
n_loop
=
N_LOOP_START
;
n_loop
>=
N_LOOP_END
;
n_loop
-=
N_LOOP_STEP
)
{
#if defined(__gfx92a__)
ds_mpermute_kdim_for_mmac_wait
(
p_reg
[
0
][
0
*
2
+
0
].
f16x4
);
ds_mpermute_kdim_for_mmac_wait
(
p_reg
[
0
][
0
*
2
+
1
].
f16x4
);
ds_mpermute_kdim_for_mmac_wait
(
p_reg
[
0
][
1
*
2
+
0
].
f16x4
);
ds_mpermute_kdim_for_mmac_wait
(
p_reg
[
0
][
1
*
2
+
1
].
f16x4
);
#endif
#pragma unroll
for
(
int
prefetch_id
=
0
;
prefetch_id
<
N_LOOP_STEP
;
++
prefetch_id
)
{
...
...
@@ -66,10 +73,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_gfx938(
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_mls_loop_global_offset
+
v_mls_warp_global_offset
);
__builtin_amdgcn_sched_barrier
(
0
);
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
v_mls_lds_warp_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_mls_lds_warp_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
...
...
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_pv_gemm_utils_gfx938.h
View file @
518a5f4d
...
...
@@ -2,6 +2,7 @@
#include "intrinsic.h"
#include "fwd/utils.h"
#include "intrinsic_mls_ds.h"
#include "intrinsic_mls_ds_b8.h"
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockN
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
WARP_K
,
int
stage_id
,
int
WARP_NUM
,
typename
Element
,
int
STAGES
>
...
...
@@ -47,10 +48,376 @@ __forceinline__ __device__ void kvcache_prefetch_v_to_lds_gfx938(
// v_srsrc[0] = v_addr[0] + v_mls_loop_global_offset + v_mls_warp_global_offset;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
v_mls_loop_global_offset
+
v_mls_warp_global_offset
);
__builtin_amdgcn_sched_barrier
(
0
);
union
union_vec4_uint
v_rsrc_bits
;
v_rsrc_bits
.
v32
=
v_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
v_mls_lds_warp_offset
;
matrix_load_b16_lds_builtin
<
32
,
32
,
1
,
0
>
(
lds_addr_warp
,
v_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds
<
0
,
1
>
(
v_lds
,
v_srsrc
,
v_mls_lds_warp_offset
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention PV helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
fp8_kvcache_prefetch_k_gfx938
(
vec4_uint
k_addr
,
Element
*
k_lds
,
int
warp_id
,
int
k_row_stride
,
int
max_seq_k_offset
);
template
<
int
K_LOOP_COUNT
,
int
kBlockK
,
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
fp8_kvcache_prefetch_v_gfx938
(
vec4_uint
v_addr
,
Element
*
v_lds
,
int
warp_id
,
int
v_row_stride
,
int
max_seq_v_offset
)
{
static_assert
(
K_LOOP_COUNT
%
2
==
0
);
constexpr
int
PREFETCH
=
2
;
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
int
stage_id
=
0
;
constexpr
int
k_loop
=
K_LOOP_COUNT
-
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
int
warp_lds_write_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
int
warp_global_bytes
;
int
v_loop_global_bytes
=
(
k_loop
-
load_id
)
*
64
*
sizeof
(
Element
);
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_v_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
warp_global_bytes
=
real_mls_warp_id
*
32
*
v_row_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_v_offset
);
v_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x20000
;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
warp_global_bytes
+
v_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
v_lds
,
v_srsrc
,
warp_lds_write_bytes
,
0
);
}
}
template
<
int
K_LOOP_COUNT
,
int
kBlockK
,
int
WARP_NUM
,
typename
Element
,
int
load_id
>
__forceinline__
__device__
void
fp8_kvcache_prefetch_v_one_gfx938
(
vec4_uint
v_addr
,
Element
*
v_lds
,
int
warp_id
,
int
v_row_stride
,
int
max_seq_v_offset
)
{
static_assert
(
K_LOOP_COUNT
==
2
);
static_assert
(
load_id
==
0
||
load_id
==
1
);
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
constexpr
int
stage_id
=
0
;
constexpr
int
k_loop
=
K_LOOP_COUNT
-
1
;
const
int
warp_lds_write_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
const
int
v_loop_global_bytes
=
(
k_loop
-
load_id
)
*
64
*
sizeof
(
Element
);
const
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_v_offset
;
const
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
const
int
warp_global_bytes
=
real_mls_warp_id
*
32
*
v_row_stride
*
sizeof
(
Element
);
const
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_v_offset
);
v_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x20000
;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
warp_global_bytes
+
v_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
v_lds
,
v_srsrc
,
warp_lds_write_bytes
,
0
);
}
template
<
bool
PrefetchK
,
int
K_LOOP_COUNT
,
int
kBlockK
,
int
kBlockN
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
V_Element
,
typename
P_Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_kvcache_pv_gemm_prefetch_k_gfx938
(
vec4_uint
v_addr
,
vec4_uint
&
k_addr
,
V_Element
*
v_lds
,
V_Element
*
k_lds
,
union_vec2_f16x2
<
P_Element
>
p_reg
[
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
vec4_Accum
<
ElementAccum
>
pv_reg
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
int
warp_id
,
int
k_row_stride
,
int
v_row_stride
,
int
max_seq_v_offset
,
int64_t
k_addr_offset
)
{
static_assert
(
K_LOOP_COUNT
%
2
==
0
);
constexpr
int
PREFETCH
=
2
;
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
int
stage_id
=
1
;
#pragma unroll
for
(
int
k_loop
=
K_LOOP_COUNT
-
1
-
PREFETCH
;
k_loop
>=
1
;
k_loop
-=
PREFETCH
)
{
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
int
warp_lds_write_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
int
warp_global_bytes
;
int
v_loop_global_bytes
=
(
k_loop
-
load_id
)
*
64
*
sizeof
(
V_Element
);
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_v_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
warp_global_bytes
=
real_mls_warp_id
*
32
*
v_row_stride
*
sizeof
(
V_Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_v_offset
);
v_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x20000
;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
warp_global_bytes
+
v_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
v_lds
,
v_srsrc
,
warp_lds_write_bytes
,
0
);
}
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
PREFETCH
);
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
union_vec16_fp8
v_regs
[
2
];
int
lds_load_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
,
v_regs
[
0
].
i32x4
,
false
/*transpose*/
)
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
+
32
,
v_regs
[
1
].
i32x4
,
false
/*transpose*/
)
int
k_loop_inner
=
k_loop
-
load_id
+
PREFETCH
;
#pragma unroll
for
(
int
tile32x32_id
=
0
;
tile32x32_id
<
2
;
++
tile32x32_id
)
{
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
tile32x32_id
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
#pragma unroll
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
vec2_fp32
v_f32x2
[
4
];
v_f32x2
[
0
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
false
/*word_sel*/
);
v_f32x2
[
1
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
true
/*word_sel*/
);
v_f32x2
[
2
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
false
/*word_sel*/
);
v_f32x2
[
3
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
true
/*word_sel*/
);
union_vec4_f16x2
<
P_Element
>
v_f16x8
;
v_f16x8
.
f16x2
[
0
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
1
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
2
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
3
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
P_Element
,
ElementAccum
>
(
p_reg
[
0
][
mmac_id
*
2
+
min_tile_m
].
f16x4
,
v_f16x8
.
f16x4
[
mmac_id
],
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
);
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
constexpr
bool
PrefetchKInPV
=
PrefetchK
&&
K_LOOP_COUNT
==
2
;
{
constexpr
int
k_loop
=
1
-
PREFETCH
;
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
union_vec16_fp8
v_regs
[
2
];
int
lds_load_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
,
v_regs
[
0
].
i32x4
,
false
/*transpose*/
)
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
+
32
,
v_regs
[
1
].
i32x4
,
false
/*transpose*/
)
int
k_loop_inner
=
k_loop
-
load_id
+
PREFETCH
;
#pragma unroll
for
(
int
tile32x32_id
=
0
;
tile32x32_id
<
2
;
++
tile32x32_id
)
{
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
tile32x32_id
);
if
constexpr
(
PrefetchKInPV
)
{
if
(
load_id
==
0
&&
tile32x32_id
==
1
)
{
*
(
int64_t
*
)
&
k_addr
+=
k_addr_offset
;
fp8_kvcache_prefetch_k_gfx938
<
WARP_NUM
,
V_Element
>
(
k_addr
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_v_offset
-
kBlockN
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
#pragma unroll
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
vec2_fp32
v_f32x2
[
4
];
v_f32x2
[
0
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
false
/*word_sel*/
);
v_f32x2
[
1
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
0
],
true
/*word_sel*/
);
v_f32x2
[
2
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
false
/*word_sel*/
);
v_f32x2
[
3
]
=
__builtin_hcu_cvt_pk_f32_fp8
(
v_regs
[
tile32x32_id
].
i32
[
min_tile_dim
*
2
+
1
],
true
/*word_sel*/
);
union_vec4_f16x2
<
P_Element
>
v_f16x8
;
v_f16x8
.
f16x2
[
0
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
0
][
0
],
v_f32x2
[
0
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
1
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
1
][
0
],
v_f32x2
[
1
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
2
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
2
][
0
],
v_f32x2
[
2
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
v_f16x8
.
f16x2
[
3
]
=
__builtin_hcu_cvt_pk_f16_f32
(
v_f32x2
[
3
][
0
],
v_f32x2
[
3
][
1
],
false
/*clamp*/
,
0
/*o_modifier*/
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
mmac_id
=
0
;
mmac_id
<
2
;
++
mmac_id
)
{
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
=
mmac_4interleave
<
P_Element
,
ElementAccum
>
(
p_reg
[
0
][
mmac_id
*
2
+
min_tile_m
].
f16x4
,
v_f16x8
.
f16x4
[
mmac_id
],
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
);
}
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
if
constexpr
(
PrefetchK
&&
!
PrefetchKInPV
)
{
*
(
int64_t
*
)
&
k_addr
+=
k_addr_offset
;
fp8_kvcache_prefetch_k_gfx938
<
WARP_NUM
,
V_Element
>
(
k_addr
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_v_offset
-
kBlockN
);
}
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
template
<
bool
PrefetchK
,
int
K_LOOP_COUNT
,
int
kBlockK
,
int
kBlockN
,
int
M_WARP_COUNT
,
int
K_WARP_COUNT
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
V_Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_kvcache_pv_gemm_fp8_prefetch_k_gfx938
(
vec4_uint
v_addr
,
vec4_uint
&
k_addr
,
V_Element
*
v_lds
,
V_Element
*
k_lds
,
union_vec32_fp8
p_reg
[
M_MMAC_COUNT
],
vec4_Accum
<
ElementAccum
>
pv_reg
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
int
warp_id
,
int
k_row_stride
,
int
v_row_stride
,
int
max_seq_v_offset
,
int64_t
k_addr_offset
)
{
static_assert
(
K_LOOP_COUNT
%
2
==
0
);
static_assert
(
M_WARP_COUNT
==
1
);
static_assert
(
K_WARP_COUNT
==
2
);
constexpr
int
PREFETCH
=
2
;
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
vec4_uint
v_srsrc
;
v_srsrc
[
1
]
=
v_addr
[
1
];
v_srsrc
[
2
]
=
v_row_stride
;
int
stage_id
=
1
;
#pragma unroll
for
(
int
k_loop
=
K_LOOP_COUNT
-
1
-
PREFETCH
;
k_loop
>=
1
;
k_loop
-=
PREFETCH
)
{
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
int
warp_lds_write_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
int
warp_global_bytes
;
int
v_loop_global_bytes
=
(
k_loop
-
load_id
)
*
64
*
sizeof
(
V_Element
);
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_v_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
warp_global_bytes
=
real_mls_warp_id
*
32
*
v_row_stride
*
sizeof
(
V_Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_v_offset
);
v_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x20000
;
*
(
uint64_t
*
)
&
v_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
v_addr
+
warp_global_bytes
+
v_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
v_lds
,
v_srsrc
,
warp_lds_write_bytes
,
0
);
}
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
PREFETCH
);
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
union_vec16_fp8
v_regs
[
2
];
int
lds_load_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
,
v_regs
[
0
].
i32x4
,
false
/*transpose*/
)
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
+
32
,
v_regs
[
1
].
i32x4
,
false
/*transpose*/
)
int
k_loop_inner
=
k_loop
-
load_id
+
PREFETCH
;
#pragma unroll
for
(
int
tile32x32_id
=
0
;
tile32x32_id
<
2
;
++
tile32x32_id
)
{
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
tile32x32_id
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
#pragma unroll
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
p_reg
[
min_tile_m
].
i8x8
[
0
],
v_regs
[
tile32x32_id
].
i8x8
[
min_tile_dim
],
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
constexpr
bool
PrefetchKInPV
=
PrefetchK
&&
K_LOOP_COUNT
==
2
;
{
constexpr
int
k_loop
=
1
-
PREFETCH
;
stage_id
^=
1
;
#pragma unroll
for
(
int
load_id
=
0
;
load_id
<
PREFETCH
;
++
load_id
)
{
union_vec16_fp8
v_regs
[
2
];
int
lds_load_bytes
=
stage_id
*
16384
+
(
WARP_NUM
*
load_id
+
warp_id
)
*
32
*
64
*
sizeof
(
V_Element
);
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
,
v_regs
[
0
].
i32x4
,
false
/*transpose*/
)
DS_READ_MATRIX_32x32_B8_ALT2
(
lds_load_bytes
+
32
,
v_regs
[
1
].
i32x4
,
false
/*transpose*/
)
int
k_loop_inner
=
k_loop
-
load_id
+
PREFETCH
;
#pragma unroll
for
(
int
tile32x32_id
=
0
;
tile32x32_id
<
2
;
++
tile32x32_id
)
{
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
tile32x32_id
);
if
constexpr
(
PrefetchKInPV
)
{
if
(
load_id
==
0
&&
tile32x32_id
==
1
)
{
*
(
int64_t
*
)
&
k_addr
+=
k_addr_offset
;
fp8_kvcache_prefetch_k_gfx938
<
WARP_NUM
,
V_Element
>
(
k_addr
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_v_offset
-
kBlockN
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 2"
);
#pragma unroll
for
(
int
min_tile_dim
=
0
;
min_tile_dim
<
2
;
++
min_tile_dim
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
p_reg
[
min_tile_m
].
i8x8
[
0
],
v_regs
[
tile32x32_id
].
i8x8
[
min_tile_dim
],
pv_reg
[
k_loop_inner
*
2
+
tile32x32_id
][
min_tile_dim
*
2
+
min_tile_m
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
if
constexpr
(
PrefetchK
&&
!
PrefetchKInPV
)
{
*
(
int64_t
*
)
&
k_addr
+=
k_addr_offset
;
fp8_kvcache_prefetch_k_gfx938
<
WARP_NUM
,
V_Element
>
(
k_addr
,
k_lds
,
warp_id
,
k_row_stride
,
max_seq_v_offset
-
kBlockN
);
}
flash
::
wait_lds_data_arrived
<
true
/*sync*/
>
(
0
);
}
template
<
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
inline
__device__
void
fp8_kvcache_cvt_f32_to_fp8_gfx938
(
union_vec32_fp8
p_reg
[
M_MMAC_COUNT
],
vec4_Accum
<
ElementAccum
>
s_reg
[
1
][
4
])
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
__builtin_hcu_cvt_pk4_fp8_f32
<
Element
>
(
s_reg
[
0
][
0
*
2
+
min_tile_m
].
f32
,
p_reg
[
min_tile_m
].
i32
[
0
]);
__builtin_hcu_cvt_pk4_fp8_f32
<
Element
>
(
s_reg
[
0
][
1
*
2
+
min_tile_m
].
f32
,
p_reg
[
min_tile_m
].
i32
[
1
]);
}
}
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_qk_gemm_prefetch_v_gfx938.h
View file @
518a5f4d
...
...
@@ -73,10 +73,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
k_mls_loop_global_offset
+
k_mls_warp_global_offset
);
int
lds_offset_bytes
=
k_lds_stage_offset
*
2
/*half -> bytes*/
;
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset_bytes
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset_bytes
,
0
);
}
// 等待 MLS 数据回来
...
...
@@ -272,3 +269,4 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v_gfx938(
}
}
// qk_gemm
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_qk_gemm_utils_gfx938.h
View file @
518a5f4d
...
...
@@ -3,6 +3,7 @@
#include "hip/hip_fp16.h"
#include "static_switch.h"
#include "kvcache_pv_gemm_utils_gfx938.h"
#include "intrinsic_mls_ds_b8.h"
template
<
int
kHeadDim
,
int
kBlockM
,
int
kBlockK
,
int
WARP_M
,
int
WARP_NUM
,
typename
Element
,
int
STAGES
,
int
M_MMAC_COUNT
>
...
...
@@ -40,10 +41,7 @@ __forceinline__ __device__ void kvcache_prefetch_q_to_vgpr_gfx938(
}
int
lds_offset_bytes
=
k_lds_stage_offset
*
2
/*half -> bytes*/
;
flash
::
wait_lds_data_arrived
<
true
>
(
0
);
union
union_vec4_uint
q_rsrc_bits
;
q_rsrc_bits
.
v32
=
q_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
q_lds
)
+
lds_offset_bytes
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
q_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
q_lds
,
q_srsrc
,
lds_offset_bytes
,
0
);
flash
::
wait_buffer_data_arrived
<
true
>
(
0
);
...
...
@@ -199,11 +197,175 @@ __forceinline__ __device__ void kvcache_prefetch_k_to_lds_gfx938(
// k_srsrc[0] = k_addr[0] + k_mls_loop_global_offset + k_mls_warp_global_offset;
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
k_mls_loop_global_offset
+
k_mls_warp_global_offset
);
int
lds_offset_bytes
=
k_lds_stage_offset
*
2
/*half -> bytes*/
;
union
union_vec4_uint
k_rsrc_bits
;
k_rsrc_bits
.
v32
=
k_srsrc
;
size_t
lds_addr_warp
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
lds_offset_bytes
;
matrix_load_b16_lds_trans_builtin
<
32
,
32
,
0
,
0
>
(
lds_addr_warp
,
k_rsrc_bits
.
i32
,
0
);
inline_matrix_load_32x32_b16_lds_trans
<
0
,
0
>
(
k_lds
,
k_srsrc
,
lds_offset_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
int
WARP_NUM
,
typename
Element
>
__forceinline__
__device__
void
fp8_kvcache_prefetch_k_gfx938
(
vec4_uint
k_addr
,
Element
*
k_lds
,
int
warp_id
,
int
k_row_stride
,
int
max_seq_k_offset
)
{
int
stage_id
=
0
;
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
k_row_stride
;
constexpr
int
k_loop
=
0
;
int
warp_lds_write_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
int
warp_global_bytes
;
int
k_loop_global_bytes
=
k_loop
*
64
*
sizeof
(
Element
);
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_k_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
warp_global_bytes
=
real_mls_warp_id
*
32
*
k_row_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_k_offset
);
k_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x40000
;
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
warp_global_bytes
+
k_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
k_lds
,
k_srsrc
,
warp_lds_write_bytes
,
0
);
__builtin_amdgcn_sched_barrier
(
0
);
}
template
<
int
K_LOOP_COUNT
,
int
kBlockK
,
int
WARP_NUM
,
typename
Element
,
int
load_id
>
__forceinline__
__device__
void
fp8_kvcache_prefetch_v_one_gfx938
(
vec4_uint
v_addr
,
Element
*
v_lds
,
int
warp_id
,
int
v_row_stride
,
int
max_seq_v_offset
);
template
<
bool
PrefetchVInQK
,
int
K_LOOP_COUNT
,
int
kHeadDim
,
int
kBlockK
,
int
WARP_M
,
int
WARP_N
,
int
WARP_NUM
,
int
M_MMAC_COUNT
,
typename
Element
,
typename
ElementAccum
>
__forceinline__
__device__
void
fp8_kvcache_qk_gemm_gfx938
(
vec4_uint
k_addr
,
vec4_uint
v_addr
,
Element
*
k_lds
,
Element
*
v_lds
,
union_vec16_fp8
q_reg
[
M_MMAC_COUNT
][
kHeadDim
/
64
],
vec4_Accum
<
ElementAccum
>
s_reg
[(
WARP_M
/
32
)
*
(
WARP_N
/
32
)][
4
],
int
warp_id
,
int
k_row_stride
,
int
v_row_stride
,
int
max_seq_k_offset
=
0
)
{
static_assert
(
!
PrefetchVInQK
||
(
kHeadDim
==
128
&&
K_LOOP_COUNT
==
2
));
int
stage_id
=
0
;
vec4_uint
k_srsrc
;
k_srsrc
[
1
]
=
k_addr
[
1
];
k_srsrc
[
2
]
=
k_row_stride
;
#pragma unroll
for
(
int
i
=
0
;
i
<
(
WARP_N
/
WARP_N
)
*
(
WARP_M
/
32
);
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
asm
volatile
(
"v_mov_b64 %0, 0x0
\n\t
"
"v_mov_b64 %1, 0x0
\n\t
"
:
"=v"
(
s_reg
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]),
"=v"
(
s_reg
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
])
:
);
}
}
}
stage_id
^=
1
;
#pragma unroll
for
(
int
k_loop
=
1
;
k_loop
<
kHeadDim
/
64
;
++
k_loop
)
{
int
warp_lds_write_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
int
warp_global_bytes
;
int
k_loop_global_bytes
=
k_loop
*
64
*
sizeof
(
Element
);
int
nm_filter_max
=
warp_id
*
32
+
32
-
max_seq_k_offset
;
int
real_mls_warp_id
=
nm_filter_max
>=
32
?
0
:
warp_id
;
warp_global_bytes
=
real_mls_warp_id
*
32
*
k_row_stride
*
sizeof
(
Element
);
int
nm_filter
=
inline_min_max
<
0
,
32
>
(
real_mls_warp_id
*
32
+
32
-
max_seq_k_offset
);
k_srsrc
[
3
]
=
(
nm_filter
<<
8
)
+
0x40000
;
*
(
uint64_t
*
)
&
k_srsrc
=
VA_LIMIT_BITS
(
*
(
uint64_t
*
)
&
k_addr
+
warp_global_bytes
+
k_loop_global_bytes
);
inline_matrix_load_64x32_b8_lds_rearrange
<
0
,
1
>
(
k_lds
,
k_srsrc
,
warp_lds_write_bytes
,
0
);
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
1
);
stage_id
^=
1
;
union_vec16_fp8
k_regs
[
WARP_N
/
16
];
int
lds_load_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
,
k_regs
[
0
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
+
1024
,
k_regs
[
1
].
i32x4
,
true
/*transpose*/
)
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
min_tile_n
);
if
constexpr
(
PrefetchVInQK
)
{
if
(
min_tile_n
==
1
)
{
fp8_kvcache_prefetch_v_one_gfx938
<
K_LOOP_COUNT
,
kBlockK
,
WARP_NUM
,
Element
,
0
>
(
v_addr
,
v_lds
,
warp_id
,
v_row_stride
,
max_seq_k_offset
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_reg
[
min_tile_m
][
k_loop
-
1
].
i8x8
[
min_tile_k
],
k_regs
[
min_tile_n
].
i8x8
[
min_tile_k
],
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
{
constexpr
int
k_loop
=
kHeadDim
/
64
;
if
constexpr
(
PrefetchVInQK
)
{
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
1
);
}
else
{
flash
::
wait_buffer_data_arrived
<
false
/*sync*/
>
(
0
);
}
stage_id
^=
1
;
union_vec16_fp8
k_regs
[
WARP_N
/
16
];
int
lds_load_bytes
=
(
stage_id
*
WARP_NUM
+
warp_id
)
*
32
*
64
*
sizeof
(
Element
);
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
,
k_regs
[
0
].
i32x4
,
true
/*transpose*/
)
DS_READ_MATRIX_64x16_B8
(
lds_load_bytes
+
1024
,
k_regs
[
1
].
i32x4
,
true
/*transpose*/
)
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
flash
::
wait_lds_data_arrived
<
false
/*sync*/
>
(
1
-
min_tile_n
);
if
constexpr
(
PrefetchVInQK
)
{
if
(
min_tile_n
==
1
)
{
fp8_kvcache_prefetch_v_one_gfx938
<
K_LOOP_COUNT
,
kBlockK
,
WARP_NUM
,
Element
,
1
>
(
v_addr
,
v_lds
,
warp_id
,
v_row_stride
,
max_seq_k_offset
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 1"
);
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_k
=
0
;
min_tile_k
<
2
;
++
min_tile_k
)
{
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
mmac_4interleave_b8
<
int8_t
,
ElementAccum
>
(
q_reg
[
min_tile_m
][
k_loop
-
1
].
i8x8
[
min_tile_k
],
k_regs
[
min_tile_n
].
i8x8
[
min_tile_k
],
s_reg
[
0
][
min_tile_n
*
2
+
min_tile_m
].
f32
);
}
}
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_setprio 0"
);
}
}
}
csrc/flash_attn_hg/include/kvcache/gfx938/kvcache_softmax_gfx938.h
View file @
518a5f4d
...
...
@@ -65,6 +65,44 @@ inline __device__ void kvcache_apply_mask_causal_gfx938(DataType tensor[M_WARP_C
}
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
kvcache_apply_mask_local_causal_gfx938
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
const
int
max_seqlen_q
,
const
int
ngroups
,
const
int
window_size_left
,
const
int
window_size_right
)
{
const
int
lane_id
=
threadIdx
.
x
&
63
;
const
int
row_idx_offset
=
row_idx_offset_
+
(
lane_id
&
15
);
const
int
col_idx_offset
=
col_idx_offset_
+
(
lane_id
>>
4
)
*
4
;
#pragma unroll
for
(
int
mi
=
0
;
mi
<
M_WARP_COUNT
;
++
mi
)
{
const
int
row_idx_base
=
row_idx_offset
+
mi
*
32
;
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
const
int
row_idx
=
row_idx_base
+
min_tile_m
*
16
;
const
int
logical_row
=
row_idx
/
ngroups
;
const
int
logical_q
=
max_seqlen_q
/
ngroups
;
const
int
col_idx_limit_left
=
max
(
0
,
logical_row
+
max_seqlen_k
-
logical_q
-
window_size_left
);
const
int
col_idx_limit_right
=
min
(
max_seqlen_k
,
logical_row
+
max_seqlen_k
-
logical_q
+
window_size_right
);
#pragma unroll
for
(
int
ni
=
0
;
ni
<
N_WARP_COUNT
;
++
ni
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
const
int
col_idx_base
=
col_idx_offset
+
ni
*
32
+
min_tile_n
*
16
;
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
const
int
col_idx
=
col_idx_base
+
vec_idx
;
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
(
col_idx
<
col_idx_limit_left
||
col_idx
>
col_idx_limit_right
)
?
-
INFINITY
:
tensor
[
mi
+
ni
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
];
}
}
}
}
}
}
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
kvcache_apply_mask_causal_gfx938_mtp
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
int
col_idx_offset_
,
const
int
max_seqlen_k
,
const
int
row_idx_offset_
,
...
...
@@ -96,3 +134,26 @@ inline __device__ void kvcache_apply_mask_causal_gfx938_mtp(DataType tensor[M_WA
}
}
}
////////////////////////////////////////////////////////////////////////////////////////////////////
// FP8 MLS Paged Attention score helpers, >= gfx938
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
typename
DataType
,
int
M_WARP_COUNT
,
int
N_WARP_COUNT
,
int
M_MMAC_COUNT
>
inline
__device__
void
fp8_kvcache_apply_descale_gfx938
(
DataType
tensor
[
M_WARP_COUNT
*
N_WARP_COUNT
][
4
],
const
__float2
qk_descale
)
{
#pragma unroll
for
(
int
i
=
0
;
i
<
M_WARP_COUNT
*
N_WARP_COUNT
;
++
i
)
{
#pragma unroll
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
#pragma unroll
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
]
=
__builtin_hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
0
],
qk_descale
);
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
]
=
__builtin_hcu_pk_mul_f32
(
tensor
[
i
][
min_tile_n
*
2
+
min_tile_m
].
u64
[
1
],
qk_descale
);
}
}
}
}
csrc/flash_attn_hg/include/kvcache/int8_kvcache_acco_reduce.h
View file @
518a5f4d
...
...
@@ -32,7 +32,7 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
// ####################################################################################################################################################
// 4 个 wave 共同参与 acc_o 在 LDS 中的相加
// 判断当前架构是否支持 pk_f32 指令
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
constexpr
bool
SUPPORT_PK_F32
=
true
;
#else
constexpr
bool
SUPPORT_PK_F32
=
false
;
...
...
@@ -73,13 +73,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
}
// asm volatile("s_nop 8\n");
{
...
...
@@ -92,13 +92,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
}
// asm volatile("s_nop 8\n");
{
...
...
@@ -111,13 +111,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
}
// asm volatile("s_nop 8\n");
{
...
...
@@ -130,13 +130,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
}
// asm volatile("s_nop 8\n");
{
...
...
@@ -149,13 +149,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
}
// asm volatile("s_nop 8\n");
{
...
...
@@ -168,13 +168,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
}
// asm volatile("s_nop 8\n");
{
...
...
@@ -187,13 +187,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(6)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(5)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
0
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(4)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
0
].
u64
);
}
// 先写一部分数据到 lds
for
(
int
loop_id
=
0
;
loop_id
<
7
;
++
loop_id
)
{
...
...
@@ -206,13 +206,13 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
asm
volatile
(
"s_waitcnt lgkmcnt(2)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave1
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(1)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave2
[
1
].
u64
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_tmp_wave0
[
loop_id
].
u64
=
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
acc_tmp_wave0
[
loop_id
].
u64
=
__builtin_
hcu_pk_add_f32
(
acc_tmp_wave0
[
loop_id
].
u64
,
acc_tmp_wave3
[
1
].
u64
);
__builtin_amdgcn_sched_barrier
(
0
);
acc_o_lds
[
lds_offset
[
loop_id
]]
=
acc_tmp_wave0
[
loop_id
].
f32
[
0
];
acc_o_lds
[
lds_offset
[
loop_id
]
+
16
]
=
acc_tmp_wave0
[
loop_id
].
f32
[
1
];
...
...
@@ -233,26 +233,23 @@ __forceinline__ __device__ void int8_kvcache_acco_reduce(
union_vec2_fp32
acc_tmp
;
int
lds_offset0
=
min_tile_m
*
__kHeadDim
+
q_seq_idx
*
2
*
__kHeadDim
+
h_idx
*
kBlockK
+
k_idx
*
32
+
0
*
16
+
(
lane_id
>>
4
)
*
4
+
WARP_ID
;
int
lds_offset1
=
min_tile_m
*
__kHeadDim
+
q_seq_idx
*
2
*
__kHeadDim
+
h_idx
*
kBlockK
+
k_idx
*
32
+
1
*
16
+
(
lane_id
>>
4
)
*
4
+
WARP_ID
;
inlineasm_fa
_ds_read2_
b
32
(
acc_o_lds
,
lds_offset0
,
acc_tmp
.
u64
,
0
,
16
);
acc_tmp
.
u64
=
__builtin_hcu
_ds_read2_
f
32
(
(
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
,
0
,
16
,
false
);
// acc_tmp.f32[0] = acc_o_lds[lds_offset0];
// acc_tmp.f32[1] = acc_o_lds[lds_offset1];
union_vec2_fp32
acc_tmp_wave1
;
inlineasm_fa_ds_read2_b32
(
acc_o_lds
,
lds_offset0
+
1
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
acc_tmp_wave1
.
u64
,
0
,
16
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
acc_tmp_wave1
.
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
+
1
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
0
,
16
,
false
);
// acc_tmp_wave1.f32[0] = acc_o_lds[lds_offset0 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave1.f32[1] = acc_o_lds[lds_offset1 + 1*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp
.
f32
[
0
]
+=
acc_tmp_wave1
.
f32
[
0
];
acc_tmp
.
f32
[
1
]
+=
acc_tmp_wave1
.
f32
[
1
];
union_vec2_fp32
acc_tmp_wave2
;
inlineasm_fa_ds_read2_b32
(
acc_o_lds
,
lds_offset0
+
2
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
acc_tmp_wave2
.
u64
,
0
,
16
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
acc_tmp_wave2
.
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
+
2
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
0
,
16
,
false
);
// acc_tmp_wave2.f32[0] = acc_o_lds[lds_offset0 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave2.f32[1] = acc_o_lds[lds_offset1 + 2*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp
.
f32
[
0
]
+=
acc_tmp_wave2
.
f32
[
0
];
acc_tmp
.
f32
[
1
]
+=
acc_tmp_wave2
.
f32
[
1
];
union_vec2_fp32
acc_tmp_wave3
;
inlineasm_fa_ds_read2_b32
(
acc_o_lds
,
lds_offset0
+
3
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
acc_tmp_wave3
.
u64
,
0
,
16
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
acc_tmp_wave3
.
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
+
3
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
0
,
16
,
false
);
// acc_tmp_wave3.f32[0] = acc_o_lds[lds_offset0 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
// acc_tmp_wave3.f32[1] = acc_o_lds[lds_offset1 + 3*EVEN_REUSE_KV_TIMES*__kHeadDim];
acc_tmp
.
f32
[
0
]
+=
acc_tmp_wave3
.
f32
[
0
];
...
...
csrc/flash_attn_hg/include/kvcache/int8_kvcache_qk_gemm_prefetch_v_3stage.h
View file @
518a5f4d
...
...
@@ -69,7 +69,7 @@ __forceinline__ __device__ void int8_kvcache_qk_gemm_prefetch_v_3stage(
auto
BUFFER_LOAD_FUNC
=
&
inline_buffer_load_dword_lds
<
Element_q
,
2
>
;
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
if
constexpr
(
M_MMAC_COUNT
==
1
)
{
inline_vgpr4_init_zero_1x2x4
(
s_reg
);
}
else
{
...
...
csrc/flash_attn_hg/include/kvcache/int8_kvcache_softmax.h
View file @
518a5f4d
...
...
@@ -228,7 +228,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
m_idx
++
)
{
// 对于 gfx936 及以上的架构, 可以使用 v_pk_add_f32
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
summary
[
m_idx
*
2
].
u64
=
0x0
;
// 可以更狠一点, 直接初始化成第一个 additem_pair, 但是貌似容易导致编译器出问题, 影响不大, 可以不加
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
n_idx
++
)
{
...
...
@@ -236,7 +236,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -262,7 +262,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
}
else
{
#pragma unroll
for
(
int
m_idx
=
0
;
m_idx
<
(
WARP_M
/
32
);
m_idx
++
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
summary_cur
[
m_idx
*
2
].
u64
=
summary
[
m_idx
*
2
].
u64
;
#pragma unroll
for
(
int
n_idx
=
0
;
n_idx
<
(
WARP_N
/
32
);
n_idx
++
)
{
...
...
@@ -270,7 +270,7 @@ __device__ inline void int8_kvcache_thread_reduce_sum(const DataType0 tensor[(WA
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
// mmac min_tile is 16*16, a warp is 64 thread
__float2
additem_pair
=
{
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
].
f32
[
vec_idx
],
tensor
[
m_idx
+
n_idx
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
1
].
f32
[
vec_idx
]};
summary_cur
[
m_idx
*
2
].
u64
=
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
=
__builtin_
hcu_pk_add_f32
(
summary_cur
[
m_idx
*
2
].
u64
,
additem_pair
);
...
...
@@ -370,15 +370,14 @@ inline __device__ void int8_kvcache_scale_apply_exp2(DataType0 tensor[(WARP_M/32
// instruction instead of fadd and fmul separately.
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
vec_idx
++
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_fma_f32
(
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_fma_f32
(
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
],
scale_pair
,
neg_max_scaled_pair
);
}
asm
volatile
(
"s_nop 0"
:::
"memory"
);
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
vec_idx
++
)
{
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]
=
__llvm_exp2_f32
(
tensor
[
mi
+
ni
*
(
WARP_M
/
32
)][
min_tile_n
*
2
+
min_tile_m
].
f32
[
vec_idx
]);
}
...
...
@@ -479,10 +478,10 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
// min tile is 32*32, mmac size is 16x16x16,so min_tile_n=32/16, min_tile_m=32/16
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
min_tile_n
++
)
{
// 936 及之后的架构有 pk_mul 指令
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
#pragma unroll
for
(
int
vec_idx
=
0
;
vec_idx
<
2
;
vec_idx
++
)
{
acc_o
[
pv_n_loop
*
((
WARP_M
/
32
)
*
(
kBlockK
/
32
))
+
(
mi
+
ni
*
(
WARP_M
/
32
))][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
hcu_pk_mul_f32
(
acc_o
[
pv_n_loop
*
((
WARP_M
/
32
)
*
(
kBlockK
/
32
))
+
(
mi
+
ni
*
(
WARP_M
/
32
))][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
pv_n_loop
*
((
WARP_M
/
32
)
*
(
kBlockK
/
32
))
+
(
mi
+
ni
*
(
WARP_M
/
32
))][
min_tile_n
*
2
+
min_tile_m
].
u64
[
vec_idx
],
scores_scale_pair
);
...
...
@@ -534,8 +533,8 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
#pragma unroll
for
(
int
warp_loop
=
1
;
warp_loop
<
WARP_NUM
;
warp_loop
++
)
{
__float2
other_warp_sum
=
*
(
__float2
*
)(
sum_lds
+
warp_loop
*
WARP_M
+
mi
*
32
+
lane_id
*
2
);
#if defined(__gfx936__) || defined(__gfx938__)
cur_wave_sum
=
hcu_pk_add_f32
(
cur_wave_sum
,
other_warp_sum
);
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
cur_wave_sum
=
__builtin_
hcu_pk_add_f32
(
cur_wave_sum
,
other_warp_sum
);
#else
cur_wave_sum
[
0
]
+=
other_warp_sum
[
0
];
cur_wave_sum
[
1
]
+=
other_warp_sum
[
1
];
...
...
@@ -559,8 +558,8 @@ inline __device__ void int8_kvcache_softmax_rescale_o(DataType0 scores[(WARP_N/3
}
for
(
int
mi
=
0
;
mi
<
(
WARP_M
/
32
);
++
mi
)
{
#if defined(__gfx936__) || defined(__gfx938__)
scores_sum
[
mi
].
u64
=
hcu_pk_add_f32
(
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
scores_sum
[
mi
].
u64
=
__builtin_
hcu_pk_add_f32
(
scores_sum
[
mi
].
u64
,
scores_sum_cur
[
mi
].
u64
);
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_acco_reduce.h
View file @
518a5f4d
...
...
@@ -38,20 +38,17 @@ __forceinline__ __device__ void kvcache_acco_reduce(
union_vec2_fp32
acc_tmp
;
int
lds_offset0
=
min_tile_m
*
__kHeadDim
+
q_seq_idx
*
2
*
__kHeadDim
+
h_idx
*
kBlockK
+
k_idx
*
32
+
0
*
16
+
(
lane_id
>>
4
)
*
4
+
WARP_ID
;
int
lds_offset1
=
min_tile_m
*
__kHeadDim
+
q_seq_idx
*
2
*
__kHeadDim
+
h_idx
*
kBlockK
+
k_idx
*
32
+
1
*
16
+
(
lane_id
>>
4
)
*
4
+
WARP_ID
;
inlineasm_fa
_ds_read2_
b
32
(
acc_o_lds
,
lds_offset0
,
acc_tmp
.
u64
,
0
,
16
);
acc_tmp
.
u64
=
__builtin_hcu
_ds_read2_
f
32
(
(
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
,
0
,
16
,
false
);
union_vec2_fp32
acc_tmp_wave1
;
inlineasm_fa_ds_read2_b32
(
acc_o_lds
,
lds_offset0
+
1
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
acc_tmp_wave1
.
u64
,
0
,
16
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
acc_tmp_wave1
.
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
+
1
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
0
,
16
,
false
);
acc_tmp
.
f32
[
0
]
+=
acc_tmp_wave1
.
f32
[
0
];
acc_tmp
.
f32
[
1
]
+=
acc_tmp_wave1
.
f32
[
1
];
union_vec2_fp32
acc_tmp_wave2
;
inlineasm_fa_ds_read2_b32
(
acc_o_lds
,
lds_offset0
+
2
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
acc_tmp_wave2
.
u64
,
0
,
16
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
acc_tmp_wave2
.
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
+
2
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
0
,
16
,
false
);
acc_tmp
.
f32
[
0
]
+=
acc_tmp_wave2
.
f32
[
0
];
acc_tmp
.
f32
[
1
]
+=
acc_tmp_wave2
.
f32
[
1
];
union_vec2_fp32
acc_tmp_wave3
;
inlineasm_fa_ds_read2_b32
(
acc_o_lds
,
lds_offset0
+
3
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
acc_tmp_wave3
.
u64
,
0
,
16
);
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
"
);
acc_tmp_wave3
.
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
acc_o_lds
+
lds_offset0
+
3
*
EVEN_REUSE_KV_TIMES
*
__kHeadDim
,
0
,
16
,
false
);
acc_tmp
.
f32
[
0
]
+=
acc_tmp_wave3
.
f32
[
0
];
acc_tmp
.
f32
[
1
]
+=
acc_tmp_wave3
.
f32
[
1
];
// ds_write2_b32
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_acco_reduce_tile16x32.h
View file @
518a5f4d
...
...
@@ -2,15 +2,15 @@
#include "numeric_types.h"
template
<
int
REUSE_KV_TIMES
,
int
K_LOOP_COUNT
,
int
K_WARP_COUNT
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
int
WARP_NUM
,
int
Padding
,
typename
ElementAccum
>
template
<
int
K_LOOP_COUNT
,
int
K_WARP_COUNT
,
int
M_WARP_COUNT
,
int
M_MMAC_COUNT
,
int
WARP_NUM
,
int
Padding
,
typename
ElementAccum
>
__forceinline__
__device__
void
kvcache_acco_reduce_tile16x32
(
vec4_Accum
<
ElementAccum
>
acc_o
[
K_LOOP_COUNT
*
M_WARP_COUNT
*
K_WARP_COUNT
][
4
],
ElementAccum
*
acc_o_lds
,
int
seqlen_q
,
int
warp_id
,
int
lane_id
)
{
#if defined(__gfx938__)
constexpr
int
OPT_FOR_HDIM128
=
bool
(
WARP_NUM
==
4
and
M_MMAC_COUNT
==
1
and
Padding
==
0
);
// Specialized optimizatio for headdim 128
#if defined(__gfx938__)
|| defined(__gfx946__)
constexpr
int
OPT_FOR_HDIM128
=
bool
(
WARP_NUM
==
4
and
M_MMAC_COUNT
==
1
and
Padding
==
0
and
K_LOOP_COUNT
==
WARP_NUM
and
K_WARP_COUNT
==
1
and
M_WARP_COUNT
==
1
);
// Specialized optimizatio
n
for headdim 128
#else
constexpr
int
OPT_FOR_HDIM128
=
false
;
// keep same as origin for archs <= gfx936
#endif
...
...
@@ -78,8 +78,7 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
}
else
{
constexpr
int
kBlockK
=
K_WARP_COUNT
*
32
+
Padding
;
// when REUSE_KV not in templated, compute max reuse times
int
EVEN_REUSE_KV_TIMES
=
(
REUSE_KV_TIMES
>
0
)
?
((
REUSE_KV_TIMES
+
1
)
/
2
)
*
2
:
((
seqlen_q
+
1
)
/
2
)
*
2
;
int
EVEN_REUSE_KV_TIMES
=
((
seqlen_q
+
1
)
/
2
)
*
2
;
int
q_seq_idx
=
(
lane_id
&
15
);
if
(
q_seq_idx
<
EVEN_REUSE_KV_TIMES
)
{
...
...
@@ -90,7 +89,8 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
// 一个 wave 共同持有 seqlen_q x kHeadDim 个 Half, 但为了节省 lds 用量, 每次只 reduce seqlen_q x kBlockK 个 Half
int
lds_offset
=
(
warp_id
*
EVEN_REUSE_KV_TIMES
*
M_MMAC_COUNT
+
q_seq_idx
+
min_tile_m
*
16
)
*
kBlockK
+
k_idx
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
/*0~3*/
)
*
4
/*0~15*/
;
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
)
=
acc_o
[
h_idx
*
(
K_WARP_COUNT
+
k_idx
)
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
;
int
tile_32x32_id
=
h_idx
*
M_WARP_COUNT
*
K_WARP_COUNT
+
k_idx
*
M_WARP_COUNT
;
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
)
=
acc_o
[
tile_32x32_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
;
}
}
}
...
...
@@ -135,7 +135,8 @@ __forceinline__ __device__ void kvcache_acco_reduce_tile16x32(
for
(
int
min_tile_m
=
0
;
min_tile_m
<
M_MMAC_COUNT
;
++
min_tile_m
)
{
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
lds_offset
=
(
q_seq_idx
+
min_tile_m
*
16
)
*
kBlockK
+
k_idx
*
32
+
min_tile_n
*
16
+
(
lane_id
>>
4
)
*
4
;
acc_o
[
h_idx
*
(
K_WARP_COUNT
+
k_idx
)
*
M_WARP_COUNT
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
);
int
tile_32x32_id
=
h_idx
*
M_WARP_COUNT
*
K_WARP_COUNT
+
k_idx
*
M_WARP_COUNT
;
acc_o
[
tile_32x32_id
][
min_tile_n
*
2
+
min_tile_m
].
f32
=
*
(
vec4_fp32
*
)(
acc_o_lds
+
lds_offset
);
}
}
}
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_epilogue.h
View file @
518a5f4d
...
...
@@ -21,9 +21,9 @@ __forceinline__ __device__ void kvcache_epilugue_rescale_acco(
for
(
int
min_tile_n
=
0
;
min_tile_n
<
2
;
++
min_tile_n
)
{
int
mmac_id
=
min_tile_n
*
2
+
min_tile_m
;
int
tile_32x32_id
=
pv_n_loop
*
M_WARP_COUNT
*
K_WARP_COUNT
+
(
ni
*
M_WARP_COUNT
+
mi
);
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
for
(
int
vec_id
=
0
;
vec_id
<
2
;
++
vec_id
)
{
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
]
=
__builtin_
hcu_pk_mul_f32
(
acc_o
[
tile_32x32_id
][
mmac_id
].
u64
[
vec_id
],
scale_pair
);
...
...
@@ -54,12 +54,7 @@ __forceinline__ __device__ void kvcache_epilogue_store_max_sum(
int
headdim_split_id
,
int
seqlen_q_limit
)
{
#ifdef FA_DEBUG_SUM_MAX
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
true
;
#else
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
false
;
#endif
if
constexpr
(
Split
or
ALLOW_WRITE_SUM_MAX
)
{
if
constexpr
(
Split
)
{
bool
write_ok
=
Is_16x32
?
(
thread_id
<
16
and
headdim_split_id
==
0
)
:
thread_id
<
16
;
if
(
write_ok
)
{
// 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
...
...
@@ -96,12 +91,7 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_max_sum(
int
total_q
,
int
ngroups
)
{
#ifdef FA_DEBUG_SUM_MAX
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
true
;
#else
constexpr
bool
ALLOW_WRITE_SUM_MAX
=
false
;
#endif
if
constexpr
(
Split
or
ALLOW_WRITE_SUM_MAX
)
{
if
constexpr
(
Split
)
{
bool
write_ok
=
Is_16x32
?
(
thread_id
<
16
and
headdim_split_id
==
0
)
:
thread_id
<
16
;
if
(
write_ok
)
{
// 0-15 号线程储存有 max/sum 的数据, 16~31/32~47/48~63 号线程也含有, 但只需要写一次即可
#pragma unroll
...
...
@@ -191,7 +181,7 @@ __forceinline__ __device__ void kvcache_epilogue_store_output(
int
seqlen_q_idx
=
m_block
*
kBlockM
+
warp_m_idx
*
32
+
(
Is_16x32
?
pv_lane_seq_idx
+
min_tile_m
*
16
:
pv_lane_seq_idx
*
2
+
min_tile_m
);
if
(
seqlen_q_idx
<
params
.
seqlen_q
)
{
if
constexpr
(
WARP_NUM
==
4
)
{
// for 4 waves, storation can be done togather, performance 4%
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
int
vec_index
=
warp_id
;
int64_t
pv_global_addr
=
seqlen_q_idx
*
output_seqlen_stride
+
k_loop
*
kBlockK
+
k_tile_idx
*
32
+
vec_index
*
8
+
pv_lane_head_dim_idx
*
2
;
vec2_Element
<
SplitkvAccumType
>
result
=
DownCastPairNoPack
<
ElementAccum
,
SplitkvAccumType
>
(
acc_o
[
tile_32x32_id
][
min_tile_m
+
0
*
2
].
f32
[
vec_index
],
acc_o
[
tile_32x32_id
][
min_tile_m
+
1
*
2
].
f32
[
vec_index
]);
...
...
@@ -264,7 +254,7 @@ __forceinline__ __device__ void kvcache_varlen_epilogue_store_output(
int
seqlen_q_idx
=
m_block
*
kBlockM
+
warp_m_idx
*
32
+
(
Is_16x32
?
pv_lane_seq_idx
+
min_tile_m
*
16
:
pv_lane_seq_idx
*
2
+
min_tile_m
);
if
(
seqlen_q_idx
<
actual_seqlen_q
)
{
if
constexpr
(
WARP_NUM
==
4
)
{
// for 4 waves, storation can be done togather, performance 4%
#if defined(__gfx938__)
#if defined(__gfx938__)
|| defined(__gfx946__) || defined(__gfx92a__)
int
vec_index
=
warp_id
;
int
true_seqlen_q
=
seqlen_q_idx
/
params
.
ngroups
;
int
true_group_id
=
seqlen_q_idx
%
params
.
ngroups
;
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_pv_gemm_prefetch_k_tile16x32.h
View file @
518a5f4d
...
...
@@ -74,7 +74,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
precompute_v_lds_offset
[
vec_idx
]
=
reinterpret_cast
<
size_t
>
(
v_lds_v2fp16
)
+
(
(
stage_id
*
WARP_K
*
kBlockN
+
seq_idx
*
32
*
kBlockN
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
)
*
4
;
precompute_v_lds_offset
[
vec_idx
]
=
(
stage_id
*
WARP_K
*
kBlockN
+
seq_idx
*
32
*
kBlockN
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
;
}
}
}
...
...
@@ -97,7 +97,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
inline_ds_read2_b32_no_wait_bytes
(
precompute_v_lds_offset
[
vec_idx
],
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
,
NEXT_DWORD_OFFSET
);
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
v_lds_v2fp16
+
precompute_v_lds_offset
[
vec_idx
],
0
,
NEXT_DWORD_OFFSET
,
false
);
}
}
}
...
...
@@ -168,7 +168,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for
(
int
vec_idx
=
0
;
vec_idx
<
4
;
++
vec_idx
)
{
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
precompute_v_lds_offset
[
vec_idx
]
=
reinterpret_cast
<
size_t
>
(
v_lds_v2fp16
)
+
(
(
stage_id
*
WARP_K
*
kBlockN
+
(
seq_idx
*
32
*
kBlockN
)
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
)
*
4
;
precompute_v_lds_offset
[
vec_idx
]
=
(
stage_id
*
WARP_K
*
kBlockN
+
(
seq_idx
*
32
*
kBlockN
)
+
head_dim_idx
*
32
*
32
+
vec_idx
*
8
*
32
+
v_ds_read_offset
)
/
2
;
}
}
}
...
...
@@ -191,7 +191,7 @@ __forceinline__ __device__ void kvcache_pv_gemm_prefetch_k_tile16x32(
for
(
int
seq_idx
=
0
;
seq_idx
<
PV_K_WARP_COUNT
;
++
seq_idx
)
{
#pragma unroll
for
(
int
head_dim_idx
=
0
;
head_dim_idx
<
PV_N_WARP_COUNT
;
++
head_dim_idx
)
{
inline_ds_read2_b32_no_wait_bytes
(
precompute_v_lds_offset
[
vec_idx
],
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
,
NEXT_DWORD_OFFSET
);
v_reg
[
stage_id
*
PV_K_WARP_COUNT
*
PV_N_WARP_COUNT
+
(
head_dim_idx
*
PV_K_WARP_COUNT
+
seq_idx
)][
vec_idx
].
u64
=
__builtin_hcu_ds_read2_f32
((
__attribute__
((
address_space
(
3
)))
float
*
)
v_lds_v2fp16
+
precompute_v_lds_offset
[
vec_idx
],
0
,
NEXT_DWORD_OFFSET
,
false
);
}
}
}
...
...
csrc/flash_attn_hg/include/kvcache/kvcache_qk_gemm_prefetch_v.h
View file @
518a5f4d
...
...
@@ -39,7 +39,7 @@ __forceinline__ __device__ void kvcache_qk_gemm_prefetch_v(
int
stage_id
=
0
;
// load 指令发下去之后, 先做一些初始化运算
#if defined(__gfx936__) || defined(__gfx938__)
#if defined(__gfx936__) || defined(__gfx938__)
|| defined(__gfx946__)
if
constexpr
(
M_MMAC_COUNT
==
1
)
{
inline_vgpr4_init_zero_1x2x4
(
s_reg
);
}
else
{
...
...
Prev
1
2
3
4
5
6
7
Next
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment