Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
14b2cfc5
Commit
14b2cfc5
authored
Apr 07, 2026
by
zhanghj2
Browse files
优化 nmz和bmz dsa prefill,nhead=64
parent
a9ef79c6
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
994 additions
and
3 deletions
+994
-3
csrc/gfx93/prefill/sparse/config.h
csrc/gfx93/prefill/sparse/config.h
+30
-0
csrc/gfx93/prefill/sparse/phase1.cuh
csrc/gfx93/prefill/sparse/phase1.cuh
+786
-1
csrc/softmax.h
csrc/softmax.h
+89
-0
csrc/utils.h
csrc/utils.h
+85
-0
tests/test_flash_mla_sparse_prefill.py
tests/test_flash_mla_sparse_prefill.py
+4
-2
No files found.
csrc/gfx93/prefill/sparse/config.h
View file @
14b2cfc5
...
@@ -124,5 +124,35 @@ static void run(const SparseAttnFwdParams ¶ms);
...
@@ -124,5 +124,35 @@ static void run(const SparseAttnFwdParams ¶ms);
};
};
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
class
KernelTemplate_B_H_64
{
public:
static
constexpr
int
D_Q
=
D_QK
;
static
constexpr
int
D_K
=
D_QK
;
static
constexpr
int
D_V
=
512
;
static
constexpr
int
kNWarps
=
4
;
static
constexpr
int
B_H
=
64
;
static
constexpr
int
B_TOPK
=
64
;
// TopK block size
static
constexpr
int
NUM_THREADS
=
kNWarps
*
64
;
static
constexpr
float
MAX_INIT_VAL
=
-
1e30
;
// We use this number as the initial value for mi (max logits)
using
Element
=
cutlass
::
bfloat16_t
;
using
elem_type
=
Element
;
using
ElementAccum
=
float
;
using
index_t
=
int64_t
;
static
constexpr
int
kBlockM
=
B_H
;
static
constexpr
int
kBlockN
=
B_TOPK
;
static
constexpr
int
kHeadDim
=
D_QK
;
static
constexpr
int
kHeadDimV
=
D_V
;
static
__device__
__forceinline__
void
devfunc
(
const
SparseAttnFwdParams
&
params
);
static
void
run
(
const
SparseAttnFwdParams
&
params
);
};
};
};
csrc/gfx93/prefill/sparse/phase1.cuh
View file @
14b2cfc5
...
@@ -10,6 +10,762 @@ namespace gfx93::fwd {
...
@@ -10,6 +10,762 @@ namespace gfx93::fwd {
using
namespace
cute
;
using
namespace
cute
;
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
__device__
void
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
const
int
tidx
=
threadIdx
.
x
;
static
constexpr
int
kBlockM
=
B_H
;
static
constexpr
int
kBlockN
=
B_TOPK
;
static
constexpr
int
kHeadDim
=
D_QK
;
static
constexpr
int
kHeadDimV
=
D_V
;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
);
const
int
s_q_idx
=
blockIdx
.
y
;
const
int
bidh
=
blockIdx
.
x
;
const
int
lane_idx
=
tidx
%
64
;
extern
__shared__
Element
smem
[];
Element
*
q_lds
=
(
Element
*
)
&
(
smem
);
Element
*
k_lds
=
q_lds
;
Element
*
v_lds
=
q_lds
;
int
*
sIndices
=
(
int
*
)(
q_lds
+
8192
);
const
index_t
row_offset_q
=
s_q_idx
*
static_cast
<
index_t
>
(
params
.
stride_q_s_q
)
+
bidh
*
kBlockM
*
params
.
stride_q_h_q
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q
)
+
row_offset_q
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
stride_q_h_q
,
_1
{}));
const
index_t
row_offset_k
=
0
*
params
.
stride_kv_h_kv
;
Tensor
gK
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
kv
)
+
row_offset_k
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
stride_kv_s_kv
,
_1
{}));
const
index_t
row_offset_topk
=
s_q_idx
*
params
.
stride_indices_s_q
;
int
*
gIndices
=
reinterpret_cast
<
int
*>
(
params
.
indices
)
+
row_offset_topk
;
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
__bf16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
typedef
__bf16
__fp16x2_t
__attribute__
((
ext_vector_type
(
2
)));
union
Bf16_storage
{
__fp16x8_t
data_128
;
__fp16x4_t
data_64
[
2
];
__fp16x2_t
data_32
[
4
];
uint16_t
data_array
[
8
];
};
union
Bf16_storage_x4
{
__fp16x4_t
data_64
;
__fp16x2_t
data_32
[
2
];
uint16_t
data
[
4
];
};
const
int
topk_length
=
HAVE_TOPK_LENGTH
?
__ldg
(
params
.
topk_length
+
s_q_idx
)
:
params
.
topk
;
const
int
num_topk_blocks
=
IS_TOPK_2048
?
2048
/
B_TOPK
:
HAVE_TOPK_LENGTH
?
ku
::
ceil_div
(
topk_length
,
(
int
)
B_TOPK
)
:
(
int
)((
unsigned
int
)
params
.
topk
/
(
unsigned
int
)
B_TOPK
);
// TiledMMA tiled_mma = TiledMma{};
// auto thr_mma = tiled_mma.get_thread_slice(tidx);
flash
::
Softmax
<
1
>
softmax
;
// #if 1
// #if defined(__gfx938__)
// #else
int
virtual_row_
=
lane_idx
/
8
;
//0
int
virtual_col_
=
lane_idx
%
8
;
//0
int
swizzle_col_
=
virtual_row_
^
virtual_col_
;
int
row_
=
lane_idx
/
4
;
//0
// 8->9 9->8
// row_ = (row_ >= 8 ) ^ row_;
int
col_
=
swizzle_col_
%
4
;
// #endif
auto
calc_row_and_col_k
=
[
&
](
const
int
block_idx
)
->
std
::
tuple
<
int
,
int
>
{
constexpr
int
elements_per_thread
=
8
;
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// int row_offset = row * 4 + warp_idx + block_idx * kBlockN;
#if defined(__gfx938__)
// int row = lane_idx / 4;
// int col = lane_idx % 4;
// col = (col + (4 - (row / 2) % 4)) % 4;
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
// int col_offset = col * 8;
int
row_offset
=
row_
+
warp_idx
*
16
+
block_idx
*
kBlockN
;
int
col_offset
=
col_
*
8
;
#else
int
row_offset
=
row_
*
4
+
warp_idx
+
block_idx
*
kBlockN
;
int
col_offset
=
col_
*
8
;
#endif
// int row_offset = row + warp_idx * 16 + block_idx * kBlockN;
if
constexpr
(
IS_TOPK_2048
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
{
row_offset
=
gIndices
[
row_offset
];
}
return
{
row_offset
,
col_offset
};
};
auto
calc_row_and_col_v
=
[
&
](
const
int
block_idx
,
int
i
)
->
int
{
int
row
=
lane_idx
/
4
;
// int col = lane_idx % 4;
int
row_offset
=
row
+
i
*
16
+
block_idx
*
kBlockN
;;
// int col_offset = col * 8 + warp_idx * 32;
if
constexpr
(
IS_TOPK_2048
)
{
row_offset
=
sIndices
[
row_offset
%
1024
];
}
else
{
row_offset
=
gIndices
[
row_offset
];
}
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
return
row_offset
;
};
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr_q
;
*
(
uint64_t
*
)
&
glob_ptr_q
=
reinterpret_cast
<
uint64_t
>
(
gQ
.
data
().
get
());
glob_ptr_q
.
latter
|=
((
params
.
stride_q_h_q
*
2
)
<<
16
);
glob_ptr_q
.
latter
|=
0x40000000
;
uint32x4_t
global_addr_q
=
{
0
};
global_addr_q
[
0
]
=
(
glob_ptr_q
.
former
);
global_addr_q
[
1
]
=
(
glob_ptr_q
.
latter
);
global_addr_q
[
2
]
=
64
;
global_addr_q
[
3
]
=
0x00020000
;
PtrWrapper
glob_ptr_indices
;
*
(
uint64_t
*
)
&
glob_ptr_indices
=
reinterpret_cast
<
uint64_t
>
(
gIndices
);
// glob_ptr_indices.latter |= ((params.stride_indices_s_q * 4) << 16);
// *(uint64_t*)&glob_ptr_indices = reinterpret_cast<uint64_t>(params.indices);
// glob_ptr_indices.latter |= ((params.stride_indices_s_q * 4) << 16);
glob_ptr_indices
.
latter
|=
0x40000000
;
uint32x4_t
global_addr_indices
=
{
0
};
global_addr_indices
[
0
]
=
(
glob_ptr_indices
.
former
);
global_addr_indices
[
1
]
=
(
glob_ptr_indices
.
latter
);
global_addr_indices
[
2
]
=
0x80000000
;
global_addr_indices
[
3
]
=
0x00020000
;
auto
buffer_load_lds_indices
=
[
&
]
(
int
n
)
{
constexpr
int
element_size
=
4
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
sIndices
)
+
warp_idx
*
64
*
4
*
4
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
// uint32x2_t index_offset = {0};
// index_offset[0] = s_q_idx;
// index_offset[1] = lane_idx * 4 * 4 + warp_idx * 64 * 4 * 4;
const
int
offset_v
=
lane_idx
*
4
*
4
+
warp_idx
*
64
*
4
*
4
;
const
int
offset_s
=
n
*
1024
*
4
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_indices
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
};
buffer_load_lds_indices
(
0
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
PtrWrapper
glob_ptr_k
;
*
(
uint64_t
*
)
&
glob_ptr_k
=
reinterpret_cast
<
uint64_t
>
(
gK
.
data
().
get
());
glob_ptr_k
.
latter
|=
((
params
.
stride_kv_s_kv
*
2
)
<<
16
);
glob_ptr_k
.
latter
|=
0x40000000
;
uint32x4_t
global_addr_k
=
{
0
};
global_addr_k
[
0
]
=
(
glob_ptr_k
.
former
);
global_addr_k
[
1
]
=
(
glob_ptr_k
.
latter
);
global_addr_k
[
2
]
=
params
.
s_kv
;
global_addr_k
[
3
]
=
0x00020000
;
auto
buffer_load_lds_k
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
)
{
constexpr
int
element_size
=
2
;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
int
offset_v
=
col_offset
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
k_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
%
4
)
*
64
*
32
*
2
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
;
index_offset
[
1
]
=
offset_v
;
const
int
offset_s
=
k_idx
*
32
*
2
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_k
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
};
auto
buffer_load_lds_v
=
[
&
](
int
row_offset
,
int
col
,
int
k_idx
,
int
n_idx
)
{
constexpr
int
element_size
=
2
;
// int k_idx = __builtin_amdgcn_readfirstlane(k_idx_);
// struct PtrWrapper {
// uint32_t former;
// uint32_t latter;
// };
// PtrWrapper glob_ptr;
// *(uint64_t*)&glob_ptr = reinterpret_cast<uint64_t>(gK.data().get());
// glob_ptr.latter |= ((row_stride * 2) << 16);
// uint32x4_t global_addr = {0};
// global_addr[0] = (glob_ptr.former);
// global_addr[1] = (glob_ptr.latter);
// global_addr[2] = max_MN;
// global_addr[3] = 0x00020000;
constexpr
int
elements_per_thread
=
8
;
int
col_offset
=
col
;
// int v_idx = row_offset;
int
offset_v
=
col_offset
*
2
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
v_lds
)
+
warp_idx
*
16
*
32
*
2
+
(
k_idx
%
1
)
*
512
*
16
*
2
+
n_idx
*
128
*
16
*
2
;
typedef
uint32_t
uint32x2_t
__attribute__
((
ext_vector_type
(
2
)));
uint32x2_t
index_offset
=
{
0
};
index_offset
[
0
]
=
row_offset
;
index_offset
[
1
]
=
offset_v
;
const
int
offset_s
=
n_idx
*
128
*
2
;
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"s_nop 0
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 , idxen offen offset:0, lds
\n
"
::
"v"
(
index_offset
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr_k
),
"s"
(
offset_s
)
:
);
__builtin_amdgcn_sched_barrier
(
0
);
};
const
int
v_lds_read_ptr
=
reinterpret_cast
<
size_t
>
(
v_lds
+
lane_idx
*
8
);
auto
k_lds_read_offset
=
[
&
]
()
->
int
{
// #if defined(__gfx938__)
// int row = lane_idx % 16;
// int col = lane_idx / 16;
// col = (col + (row / 2) % 4) % 4;
// const auto lds_offset = row * 32 + col * 8;
// #else
int
row
=
lane_idx
%
16
;
int
col
=
lane_idx
/
16
;
col
=
(
row
/
2
)
^
col
;
col
=
col
%
4
;
// row = (row >= 8) ^ row;
const
auto
lds_offset
=
row
*
32
+
col
*
8
;
// #endif
return
lds_offset
;
};
Element
*
q_lds_read_ptr
=
(
q_lds
+
warp_idx
*
16
*
32
+
lane_idx
*
8
);
Element
*
k_lds_read_ptr
=
(
k_lds
+
k_lds_read_offset
());
Bf16_storage
q_reg
[
18
];
for
(
int
i
=
0
;
i
<
18
;
i
++
)
{
constexpr
int
elements_per_thread
=
8
;
int
row
=
lane_idx
%
16
;
int
col
=
lane_idx
/
16
;
int
row_offset
=
row
+
warp_idx
*
16
;
int
col_offset
=
col
*
8
;
int
offset_v
=
col_offset
*
2
+
i
*
32
*
2
;
q_reg
[
i
].
data_128
=
__builtin_amdgcn_buffer_load_dwordx4
(
global_addr_q
,
row_offset
,
offset_v
,
false
,
false
);
}
__syncthreads
();
v4f
acco_f32
[
32
];
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
acco_f32
[
i
].
x
=
0.0
f
;
acco_f32
[
i
].
y
=
0.0
f
;
acco_f32
[
i
].
z
=
0.0
f
;
acco_f32
[
i
].
w
=
0.0
f
;
}
int
col_offset_v
=
(
lane_idx
%
4
)
*
8
+
warp_idx
*
32
;
struct
IsFirstBlock
{};
struct
IsOtherBlock
{};
auto
float2bf16
=
[]
(
float
s
)
->
uint16_t
{
uint32_t
x32
=
reinterpret_cast
<
uint32_t
const
&>
(
s
);
#ifndef FLASH_MLA_BF16_TYPE
#define FLASH_MLA_BF16_TYPE 0
#endif
#if FLASH_MLA_BF16_TYPE == 1
x32
+=
0x8000u
;
#endif
return
uint16_t
(
x32
>>
16
);
};
auto
process_one_block
=
[
&
]
(
int
block_idx
,
auto
is_block_t
)
{
static
constexpr
bool
IS_FIRST_BLOCK
=
std
::
is_same_v
<
decltype
(
is_block_t
),
IsFirstBlock
>
;
static
constexpr
bool
IS_OTHER_BLOCK
=
std
::
is_same_v
<
decltype
(
is_block_t
),
IsOtherBlock
>
;
v4f
accs_f32
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
accs_f32
[
i
].
x
=
0.0
f
;
accs_f32
[
i
].
y
=
0.0
f
;
accs_f32
[
i
].
z
=
0.0
f
;
accs_f32
[
i
].
w
=
0.0
f
;
}
auto
[
row_offset
,
col
]
=
calc_row_and_col_k
(
block_idx
);
row_offset
=
row_offset
==
-
1
?
params
.
s_kv
:
row_offset
;
#if 1
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val - 3); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier
(
0
);
\
}
{
constexpr
int
k_val
=
(
17
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
1
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
2
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
-
3
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
>
(
q_reg
[
k_val
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
LOAD_K_AND_QK_GEMM
(
16
);
LOAD_K_AND_QK_GEMM
(
15
);
LOAD_K_AND_QK_GEMM
(
14
);
LOAD_K_AND_QK_GEMM
(
13
);
LOAD_K_AND_QK_GEMM
(
12
);
LOAD_K_AND_QK_GEMM
(
11
);
LOAD_K_AND_QK_GEMM
(
10
);
LOAD_K_AND_QK_GEMM
(
9
);
LOAD_K_AND_QK_GEMM
(
8
);
LOAD_K_AND_QK_GEMM
(
7
);
LOAD_K_AND_QK_GEMM
(
6
);
LOAD_K_AND_QK_GEMM
(
5
);
LOAD_K_AND_QK_GEMM
(
4
);
LOAD_K_AND_QK_GEMM
(
3
);
flash
::
qk_gemm
<
Element
,
k_val
-
15
>
(
q_reg
[
k_val
-
15
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
16
>
(
q_reg
[
k_val
-
16
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
-
17
>
(
q_reg
[
k_val
-
17
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
#else
#define LOAD_K_AND_QK_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_k(row_offset, col, k_val); \
buffer_load_lds_k(row_offset, col, k_val + 1); \
buffer_load_lds_k(row_offset, col, k_val + 2); \
buffer_load_lds_k(row_offset, col, k_val + 3); \
buffer_load_lds_k(row_offset, col, k_val + 4); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(4) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val>(q_reg[k_val].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 1>(q_reg[k_val + 1].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 2>(q_reg[k_val + 2].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 3>(q_reg[k_val + 3].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::qk_gemm<Element, k_val + 4>(q_reg[k_val + 4].data_128, k_lds_read_ptr, accs_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_barrier \n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
LOAD_K_AND_QK_GEMM
(
0
);
LOAD_K_AND_QK_GEMM
(
5
);
LOAD_K_AND_QK_GEMM
(
10
);
{
constexpr
int
k_val
=
(
15
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
+
1
);
buffer_load_lds_k
(
row_offset
,
col
,
k_val
+
2
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
>
(
q_reg
[
k_val
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
+
1
>
(
q_reg
[
k_val
+
1
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
qk_gemm
<
Element
,
k_val
+
2
>
(
q_reg
[
k_val
+
2
].
data_128
,
k_lds_read_ptr
,
accs_f32
);
\
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
#endif
auto
is_valid_token
=
[
&
](
const
int
idx
)
->
bool
{
const
int
n_idx
=
(
lane_idx
/
16
)
*
4
+
(
idx
%
4
)
+
(
idx
/
4
)
*
16
;
int
offs
=
n_idx
+
block_idx
*
kBlockN
;
int
t
;
if
constexpr
(
IS_TOPK_2048
)
{
t
=
sIndices
[
offs
%
1024
];
}
else
{
t
=
gIndices
[
offs
];
}
bool
is_cur_token_valid
=
t
>=
0
&&
t
<
params
.
s_kv
;
if
constexpr
(
HAVE_TOPK_LENGTH
)
{
is_cur_token_valid
=
is_cur_token_valid
&&
(
offs
<
topk_length
);
}
return
is_cur_token_valid
;
};
for
(
int
i
=
0
;
i
<
16
;
++
i
)
{
#if defined(__gfx938__)
if
(
!
is_valid_token
(
i
))
accs_f32
[
i
/
4
][
i
%
4
]
=
-
INFINITY
;
#else
if
(
!
is_valid_token
(
i
))
accs_f32
[
i
%
4
][
i
/
4
]
=
-
INFINITY
;
#endif
}
// Tensor acc_s = partition_fragment_C(tiled_mma, Shape<Int<kBlockM>, Int<kBlockN>>{});
// Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
Tensor
scores
=
make_tensor
<
float
>
(
Shape
<
_1
,
_16
>
{});
for
(
int
i
=
0
;
i
<
16
;
i
++
)
{
#if defined(__gfx938__)
scores
(
0
,
i
)
=
accs_f32
[
i
/
4
][
i
%
4
];
#else
scores
(
0
,
i
)
=
accs_f32
[
i
%
4
][
i
/
4
];
#endif
}
softmax
.
template
softmax_rescale_o_prefill_4x1
<
/*Is_first=*/
IS_FIRST_BLOCK
,
/*Check_inf=*//*Is_local=*/
false
>(
scores
,
acco_f32
,
params
.
sm_scale_div_log2
);
Bf16_storage_x4
p
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
#if defined(__gfx938__)
p
[
i
].
data_32
[
0
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
scores
(
0
,
i
*
4
),
0
,
scores
(
0
,
i
*
4
+
1
),
0
);
p
[
i
].
data_32
[
1
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
scores
(
0
,
i
*
4
+
2
),
0
,
scores
(
0
,
i
*
4
+
3
),
0
);
#else
p
[
i
].
data
[
0
]
=
float2bf16
(
scores
(
0
,
i
*
4
));
p
[
i
].
data
[
1
]
=
float2bf16
(
scores
(
0
,
i
*
4
+
1
));
p
[
i
].
data
[
2
]
=
float2bf16
(
scores
(
0
,
i
*
4
+
2
));
p
[
i
].
data
[
3
]
=
float2bf16
(
scores
(
0
,
i
*
4
+
3
));
#endif
}
int
row_offset_v
[
4
];
for
(
int
i
=
0
;
i
<
4
;
i
++
)
{
row_offset_v
[
i
]
=
calc_row_and_col_v
(
block_idx
,
i
);
}
__syncthreads
();
#if 1
{
constexpr
int
k_val
=
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
1
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
2
);
buffer_load_lds_v
(
row_offset_v
[
k_val
],
col_offset_v
,
k_val
,
3
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(3)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
2
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
0
);
flash
::
pv_gemm
<
k_val
,
4
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
5
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
6
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
7
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
1
);
flash
::
pv_gemm
<
k_val
,
8
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
9
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
10
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
11
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
2
);
flash
::
pv_gemm
<
k_val
,
12
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
13
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
14
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
15
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
buffer_load_lds_v
(
row_offset_v
[
k_val
+
1
],
col_offset_v
,
k_val
+
1
,
3
);
}
#define LOAD_V_AND_PV_GEMM(k) \
{ \
constexpr int k_val = (k); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 0); \
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 1); \
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 2); \
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
buffer_load_lds_v(row_offset_v[k_val + 1], col_offset_v, k_val + 1, 3); \
}
LOAD_V_AND_PV_GEMM
(
1
);
LOAD_V_AND_PV_GEMM
(
2
);
{
constexpr
int
k_val
=
(
3
);
flash
::
pv_gemm
<
k_val
,
0
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
1
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
2
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
3
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(2)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
4
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
5
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
6
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
7
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(1)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
8
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
9
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
10
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
11
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
flash
::
pv_gemm
<
k_val
,
12
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
13
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
14
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
flash
::
pv_gemm
<
k_val
,
15
>
(
p
[
k_val
+
0
].
data_64
,
v_lds_read_ptr
,
acco_f32
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
}
#else
#define LOAD_V_AND_PV_GEMM(k) \
{ \
constexpr int k_val = (k); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 0); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 1); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 2); \
buffer_load_lds_v(row_offset_v[k_val], col_offset_v, k_val, 3); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(3) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 0>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 1>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 2>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 3>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(2) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 4>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 5>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 6>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 7>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(1) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 8>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 9>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 10>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 11>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_waitcnt vmcnt(0) \n\t s_barrier\n\t"); \
__builtin_amdgcn_sched_barrier(0); \
flash::pv_gemm<k_val, 12>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 13>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 14>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
flash::pv_gemm<k_val, 15>(p[k_val + 0].data_64, v_lds_read_ptr, acco_f32); \
__builtin_amdgcn_sched_barrier(0); \
asm volatile("s_barrier \n\t"); \
__builtin_amdgcn_sched_barrier(0); \
}
LOAD_V_AND_PV_GEMM
(
0
);
LOAD_V_AND_PV_GEMM
(
1
);
LOAD_V_AND_PV_GEMM
(
2
);
LOAD_V_AND_PV_GEMM
(
3
);
#endif
};
if
constexpr
(
IS_TOPK_2048
)
{
process_one_block
(
0
,
IsFirstBlock
{});
for
(
int
block_idx
=
1
;
block_idx
<
1024
/
B_TOPK
;
block_idx
++
)
{
process_one_block
(
block_idx
,
IsOtherBlock
{});
}
buffer_load_lds_indices
(
1
);
__builtin_amdgcn_sched_barrier
(
0
);
asm
volatile
(
"s_waitcnt vmcnt(0)
\n\t
s_barrier
\n\t
"
);
__builtin_amdgcn_sched_barrier
(
0
);
for
(
int
block_idx
=
1024
/
B_TOPK
;
block_idx
<
2048
/
B_TOPK
;
block_idx
++
)
{
process_one_block
(
block_idx
,
IsOtherBlock
{});
}
}
else
{
process_one_block
(
0
,
IsFirstBlock
{});
for
(
int
block_idx
=
1
;
block_idx
<
num_topk_blocks
;
block_idx
++
)
{
process_one_block
(
block_idx
,
IsOtherBlock
{});
}
}
Tensor
lse
=
softmax
.
template
normalize_softmax_lse_prefill_4x1
<
false
>(
acco_f32
,
params
.
sm_scale
);
// if (block0())
// {
// printf(" threadIdx.x %d %.3f %.3f %.3f %.3f \n", threadIdx.x,
// acco_f32[0].x,
// acco_f32[0].y,
// acco_f32[0].z,
// acco_f32[0].w
// );
// }
const
index_t
row_offset_o
=
s_q_idx
*
static_cast
<
index_t
>
(
params
.
h_q
*
params
.
d_v
)
+
bidh
*
kBlockM
*
params
.
d_v
;
Tensor
gO
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
out
)
+
row_offset_o
),
Shape
<
Int
<
kBlockM
>
,
Int
<
kHeadDimV
>>
{},
make_stride
(
params
.
d_v
,
_1
{}));
const
index_t
row_offset_lse
=
s_q_idx
*
params
.
h_q
+
bidh
*
kBlockM
;
float
*
gLSE
=
reinterpret_cast
<
float
*>
(
params
.
lse
)
+
row_offset_lse
;
// const index_t row_offset_lse = m_block * params.h_q;
float
*
gMax_logits
=
reinterpret_cast
<
float
*>
(
params
.
max_logits
)
+
row_offset_lse
;
{
// store O and gLSE
// auto rO = flash::convert_type<Element>(acc_o);
int
row
,
col
;
// const int warpId = tidx / 64;
// const int laneId = tidx % 64;
for
(
int
mi
=
0
;
mi
<
1
;
++
mi
)
{
row
=
mi
*
kBlockM
+
lane_idx
%
16
+
warp_idx
*
16
;
// if (row < params.h_q)
{
for
(
int
ni
=
0
;
ni
<
16
;
++
ni
)
{
#if defined(__gfx938__)
Bf16_storage
res
;
col
=
(
lane_idx
/
16
)
*
8
+
ni
*
32
;
res
.
data_32
[
0
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
0
],
0
,
acco_f32
[
ni
*
2
+
1
][
0
],
0
);
res
.
data_32
[
1
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
1
],
0
,
acco_f32
[
ni
*
2
+
1
][
1
],
0
);
res
.
data_32
[
2
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
2
],
0
,
acco_f32
[
ni
*
2
+
1
][
2
],
0
);
res
.
data_32
[
3
]
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
acco_f32
[
ni
*
2
][
3
],
0
,
acco_f32
[
ni
*
2
+
1
][
3
],
0
);
*
(
__fp16x8_t
*
)(
&
gO
(
row
,
col
))
=
res
.
data_128
;
#else
col
=
(
lane_idx
/
16
)
*
2
+
ni
*
32
;
using
result_type
=
cutlass
::
Array
<
Element
,
2
>
;
for
(
int
ei
=
0
;
ei
<
4
;
ei
++
)
{
result_type
res
;
Element
e0
,
e1
;
e0
.
storage
=
float2bf16
(
acco_f32
[
ni
*
2
][
ei
]);
e1
.
storage
=
float2bf16
(
acco_f32
[
ni
*
2
+
1
][
ei
]);
res
[
0
]
=
e0
;
res
[
1
]
=
e1
;
// gO(row, col) = res[0];
// gO(row, col + 1) = res[1];
*
(
result_type
*
)(
&
gO
(
row
,
col
))
=
res
;
col
+=
8
;
}
#endif
}
gLSE
[
row
]
=
lse
(
mi
);
if
constexpr
(
HAVE_TOPK_LENGTH
)
{
gMax_logits
[
row
]
=
topk_length
==
0
?
-
INFINITY
:
softmax
.
row_max
(
mi
)
*
params
.
sm_scale
;
}
else
{
gMax_logits
[
row
]
=
softmax
.
row_max
(
mi
)
*
params
.
sm_scale
;
}
}
}
}
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
__device__
void
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
__device__
void
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
devfunc
(
const
SparseAttnFwdParams
&
params
)
{
extern
__shared__
char
smem_
[];
extern
__shared__
char
smem_
[];
...
@@ -529,7 +1285,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
...
@@ -529,7 +1285,9 @@ __device__ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::devfunc(const SparseAttn
template
<
typename
Kernel
>
template
<
typename
Kernel
>
__global__
void
__launch_bounds__
(
Kernel
::
NUM_THREADS
,
1
)
__global__
void
__launch_bounds__
(
Kernel
::
NUM_THREADS
,
1
)
sparse_attn_fwd_kernel
(
const
SparseAttnFwdParams
params
)
{
sparse_attn_fwd_kernel
(
const
SparseAttnFwdParams
params
)
{
// #if defined(__gfx936__)
Kernel
::
devfunc
(
params
);
Kernel
::
devfunc
(
params
);
// #endif
}
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
...
@@ -545,9 +1303,36 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
...
@@ -545,9 +1303,36 @@ void KernelTemplate<D_QK, HAVE_TOPK_LENGTH>::run(const SparseAttnFwdParams ¶
KU_CHECK_KERNEL_LAUNCH
();
KU_CHECK_KERNEL_LAUNCH
();
}
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
,
bool
IS_TOPK_2048
>
void
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
>::
run
(
const
SparseAttnFwdParams
&
params
)
{
KU_ASSERT
(
params
.
h_kv
==
1
);
// KU_ASSERT(params.topk % (2*B_TOPK) == 0); // To save some boundry checkings
KU_ASSERT
(
params
.
topk
>
0
);
// KU_ASSERT(params.h_q % B_H == 0);
auto
kernel
=
&
sparse_attn_fwd_kernel
<
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
IS_TOPK_2048
>>
;
constexpr
size_t
smem_size
=
16384
+
4096
;
// 做了lds复用
dim3
grid
((
params
.
h_q
+
B_H
-
1
)
/
B_H
,
params
.
s_q
,
1
);
kernel
<<<
grid
,
NUM_THREADS
,
smem_size
,
params
.
stream
>>>
(
params
);
KU_CHECK_KERNEL_LAUNCH
();
}
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
template
<
int
D_QK
,
bool
HAVE_TOPK_LENGTH
>
void
run_fwd_phase1_kernel
(
const
SparseAttnFwdParams
&
params
)
{
void
run_fwd_phase1_kernel
(
const
SparseAttnFwdParams
&
params
)
{
if
(
params
.
h_q
==
64
&&
!
HAVE_TOPK_LENGTH
&&
D_QK
==
576
&&
!
params
.
attn_sink
)
{
if
(
params
.
topk
==
2048
)
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
true
>::
run
(
params
);
}
else
{
KernelTemplate_B_H_64
<
D_QK
,
HAVE_TOPK_LENGTH
,
false
>::
run
(
params
);
}
}
else
{
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
run
(
params
);
KernelTemplate
<
D_QK
,
HAVE_TOPK_LENGTH
>::
run
(
params
);
}
}
}
}
}
csrc/softmax.h
View file @
14b2cfc5
...
@@ -602,6 +602,95 @@ struct Softmax {
...
@@ -602,6 +602,95 @@ struct Softmax {
}
}
return
lse
;
return
lse
;
};
};
template
<
bool
Is_first
,
bool
Check_inf
=
false
,
typename
Tensor0
>
__forceinline__
__device__
void
softmax_rescale_o_prefill_4x1
(
Tensor0
&
scores
,
v4f
*
acc_o
,
float
softmax_scale_log2
)
{
// Reshape acc_s from (MMA=4, MMA_M, MMA_N) to (nrow=(2, MMA_M), ncol=(2, MMA_N))
MaxOp
<
float
>
max_op
;
// Tensor scores = make_tensor(acc_s.data(), flash::convert_layout_acc_rowcol(acc_s.layout()));
static_assert
(
decltype
(
size
<
0
>
(
scores
))
::
value
==
kNRows
);
if
constexpr
(
Is_first
)
{
flash
::
template
reduce_max
<
/*zero_init=*/
true
>(
scores
,
row_max
);
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
flash
::
reduce_sum
<
/*zero_init=*/
true
>
(
scores
,
row_sum
);
}
else
{
Tensor
scores_max_prev
=
make_fragment_like
(
row_max
);
cute
::
copy
(
row_max
,
scores_max_prev
);
flash
::
template
reduce_max
<
/*zero_init=*/
false
>(
scores
,
row_max
);
// Reshape acc_o from (MMA=4, MMA_M, MMA_K) to (nrow=(2, MMA_M), ncol=(2, MMA_K))
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
#pragma unroll
for
(
int
mi
=
0
;
mi
<
size
(
row_max
);
++
mi
)
{
float
scores_max_cur
=
!
true
?
row_max
(
mi
)
:
(
row_max
(
mi
)
==
-
INFINITY
?
0.0
f
:
row_max
(
mi
));
#if 0
float scores_scale = exp2f((scores_max_prev(mi) - scores_max_cur) * softmax_scale_log2);
#else
float
scores_scale
=
__builtin_amdgcn_exp2f
((
scores_max_prev
(
mi
)
-
scores_max_cur
)
*
softmax_scale_log2
);
#endif
// if (blockIdx.x == 0 && threadIdx.x == 0)
// {
// printf("threadIdx.x %.2f, scores_scale = %.4f\n",row_sum(mi), scores_scale );
// }
row_sum
(
mi
)
*=
scores_scale
;
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
acc_o
[
i
].
x
*=
scores_scale
;
acc_o
[
i
].
y
*=
scores_scale
;
acc_o
[
i
].
z
*=
scores_scale
;
acc_o
[
i
].
w
*=
scores_scale
;
}
}
// if (blockIdx.x == 2)
// {
// printf("threadIdx.x %.2f \n",row_sum(mi) );
// }
flash
::
scale_apply_exp2
(
scores
,
row_max
,
softmax_scale_log2
);
// We don't do the reduce across threads here since we don't need to use the row_sum.
// We do that reduce at the end when we need to normalize the softmax.
flash
::
reduce_sum
<
/*zero_init=*/
false
>
(
scores
,
row_sum
);
}
// if (thread0())
// {
// printf("max sum %.3f %.3f \n", row_max(0), row_sum(0));
// }
};
template
<
bool
Is_dropout
=
false
,
bool
Split
=
false
>
__forceinline__
__device__
TensorT
normalize_softmax_lse_prefill_4x1
(
v4f
*
acc_o
,
float
softmax_scale
,
float
rp_dropout
=
1.0
)
{
SumOp
<
float
>
sum_op
;
quad_allreduce_
(
row_sum
,
row_sum
,
sum_op
);
// flash::template warp_allreduce_(row_sum, sRow_sum_reduce_buffer, sum_op);
TensorT
lse
=
make_fragment_like
(
row_sum
);
// Tensor acc_o_rowcol = make_tensor(acc_o.data(), flash::convert_layout_acc_rowcol(acc_o.layout()));
// static_assert(decltype(size<0>(acc_o_rowcol))::value == kNRows);
// if (thread0())
// {
// printf(" %.3f %.3f \n", row_max(0), row_sum(0));
// }
#pragma unroll
for
(
int
mi
=
0
;
mi
<
1
;
++
mi
)
{
float
sum
=
row_sum
(
mi
);
float
inv_sum
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
1.
f
:
1.
f
/
sum
;
lse
(
mi
)
=
(
sum
==
0.
f
||
sum
!=
sum
)
?
(
Split
?
-
INFINITY
:
INFINITY
)
:
row_max
(
mi
)
*
softmax_scale
+
__logf
(
sum
);
float
scale
=
!
Is_dropout
?
inv_sum
:
inv_sum
*
rp_dropout
;
for
(
int
i
=
0
;
i
<
32
;
i
++
)
{
acc_o
[
i
].
x
*=
scale
;
acc_o
[
i
].
y
*=
scale
;
acc_o
[
i
].
z
*=
scale
;
acc_o
[
i
].
w
*=
scale
;
}
}
return
lse
;
};
};
};
...
...
csrc/utils.h
View file @
14b2cfc5
...
@@ -1523,6 +1523,91 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
...
@@ -1523,6 +1523,91 @@ __forceinline__ __device__ void gemm1_rs_fp8(Tensor0 &acc, Tensor1 &tCrA, Tensor
}
}
#endif
#endif
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
template
<
typename
Element
,
int
k_idx
>
__forceinline__
__device__
void
qk_gemm
(
const
__fp16x8_t
&
q_data
,
Element
*
k_lds_read_ptr
,
v4f
*
accs_f32
)
{
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
typedef
__bf16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
union
Bf16_storage
{
__fp16x8_t
data_128
;
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_idx_even
=
k_idx
%
4
;
constexpr
int
n_offset
=
16
*
32
;
constexpr
int
k_offset
=
k_idx_even
*
64
*
32
;
Bf16_storage
q_reg
;
Bf16_storage
k_reg
;
q_reg
.
data_128
=
q_data
;
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
);
// q_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(q_lds_read_ptr), k_offset, 2, 1, 0);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 0 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
0
],
true
,
false
);
accs_f32
[
0
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
0
],
true
,
false
);
#else
accs_f32
[
0
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
0
]);
accs_f32
[
0
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
0
]);
#endif
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
+
1
*
n_offset
);
#if defined(__gfx938__)
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
1
],
true
,
false
);
accs_f32
[
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
1
],
true
,
false
);
#else
accs_f32
[
1
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
1
]);
accs_f32
[
1
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
1
]);
#endif
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 1 * n_offset + k_offset, 2, 1, 0);
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
+
2
*
n_offset
);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 2 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32
[
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
2
],
true
,
false
);
accs_f32
[
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
2
],
true
,
false
);
#else
accs_f32
[
2
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
2
]);
accs_f32
[
2
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
2
]);
#endif
k_reg
.
data_128
=
*
reinterpret_cast
<
__fp16x8_t
*>
(
k_lds_read_ptr
+
k_offset
+
3
*
n_offset
);
// k_reg.data_128 = __builtin_hcu_ds_read_matrix_trans_format_bf16((__attribute__((address_space(3))) short*)(k_lds_read_ptr), 3 * n_offset + k_offset, 2, 1, 0);
#if defined(__gfx938__)
accs_f32
[
3
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
3
],
true
,
false
);
accs_f32
[
3
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
3
],
true
,
false
);
#else
accs_f32
[
3
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
0
],
k_reg
.
data_64
[
0
],
accs_f32
[
3
]);
accs_f32
[
3
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
q_reg
.
data_64
[
1
],
k_reg
.
data_64
[
1
],
accs_f32
[
3
]);
#endif
}
typedef
__bf16
__fp16x4_t
__attribute__
((
ext_vector_type
(
4
)));
template
<
int
k_idx
,
int
n_idx_val
>
__forceinline__
__device__
void
pv_gemm
(
const
__fp16x4_t
&
p
,
int
v_lds_read_ptr
,
v4f
*
acco_f32
)
{
constexpr
int
k_idx_even
=
k_idx
%
1
;
constexpr
int
n_offset
=
16
*
32
*
2
;
typedef
__bf16
__fp16x8_t
__attribute__
((
ext_vector_type
(
8
)));
union
Bf16_storage
{
__fp16x8_t
data_128
;
__fp16x4_t
data_64
[
2
];
uint16_t
data_array
[
8
];
};
constexpr
int
k_offset
=
k_idx_even
*
16
*
512
*
2
;
// #if 1
Bf16_storage
v_reg
;
v_reg
.
data_128
=
__builtin_amdgcn_ds_read_m32x16f16_alt
((
__attribute__
((
address_space
(
3
)))
__fp16
*
)(
v_lds_read_ptr
),
k_offset
+
n_idx_val
*
n_offset
);
#if defined(__gfx938__)
acco_f32
[
n_idx_val
*
2
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
0
],
acco_f32
[
n_idx_val
*
2
],
true
,
false
);
acco_f32
[
n_idx_val
*
2
+
1
]
=
__builtin_hcu_mmac_f32_16x16x16_bf16_lit_lts
(
p
,
v_reg
.
data_64
[
1
],
acco_f32
[
n_idx_val
*
2
+
1
],
true
,
false
);
#else
acco_f32
[
n_idx_val
*
2
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
p
,
v_reg
.
data_64
[
0
],
acco_f32
[
n_idx_val
*
2
]);
acco_f32
[
n_idx_val
*
2
+
1
]
=
__builtin_amdgcn_mmac_f32_16x16x16bf16
(
p
,
v_reg
.
data_64
[
1
],
acco_f32
[
n_idx_val
*
2
+
1
]);
#endif
}
}
}
\ No newline at end of file
tests/test_flash_mla_sparse_prefill.py
View file @
14b2cfc5
...
@@ -77,7 +77,7 @@ if __name__ == '__main__':
...
@@ -77,7 +77,7 @@ if __name__ == '__main__':
(
1840
,
256
),
(
1840
,
256
),
(
1592
,
384
),
(
1592
,
384
),
(
1521
,
512
),
(
1521
,
512
),
(
3000
,
2048
),
# Irregular shapes with OOB TopK
# Irregular shapes with OOB TopK
(
95
,
128
),
(
95
,
128
),
(
153
,
256
),
(
153
,
256
),
...
@@ -146,6 +146,7 @@ if __name__ == '__main__':
...
@@ -146,6 +146,7 @@ if __name__ == '__main__':
performance_case_templates
=
[
performance_case_templates
=
[
# V3.2
# V3.2
(
576
,
128
,
2048
,
[
8192
,
32768
,
65536
,
98304
,
131072
]),
(
576
,
128
,
2048
,
[
8192
,
32768
,
65536
,
98304
,
131072
]),
(
576
,
64
,
2048
,
[
8192
,
32768
,
65536
,
98304
,
131072
]),
# MODEL1 CONFIG1
# MODEL1 CONFIG1
(
512
,
64
,
512
,
[
8192
,
32768
,
49152
,
65536
]),
(
512
,
64
,
512
,
[
8192
,
32768
,
49152
,
65536
]),
# MODEL1 CONFIG2
# MODEL1 CONFIG2
...
@@ -154,9 +155,10 @@ if __name__ == '__main__':
...
@@ -154,9 +155,10 @@ if __name__ == '__main__':
]
]
performance_cases
=
[
performance_cases
=
[
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
d_qk
=
d_qk
,
have_attn_sink
=
True
)
TestParam
(
s_q
,
s_kv
,
topk
,
h_q
=
h_q
,
d_qk
=
d_qk
,
have_attn_sink
=
have_attn_sink
)
for
(
d_qk
,
h_q
,
topk
,
s_kv_list
)
in
performance_case_templates
for
(
d_qk
,
h_q
,
topk
,
s_kv_list
)
in
performance_case_templates
for
s_q
in
[
4096
]
for
s_q
in
[
4096
]
for
have_attn_sink
in
[
False
,
True
]
for
s_kv
in
s_kv_list
for
s_kv
in
s_kv_list
]
]
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment