Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
bdf0140b
Commit
bdf0140b
authored
Jan 30, 2026
by
zhanghj2
Browse files
使用buffer load lds读取q, 优化了vgpr溢出
parent
515dbd44
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
220 additions
and
31 deletions
+220
-31
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
+220
-31
No files found.
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
View file @
bdf0140b
...
@@ -30,40 +30,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -30,40 +30,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
using
index_t
=
int64_t
;
using
index_t
=
int64_t
;
const
int
tidx
=
threadIdx
.
x
;
const
int
tidx
=
threadIdx
.
x
;
const
int
lane_idx
=
tidx
%
64
;
const
int
lane_idx
=
tidx
%
64
;
const
int
warp_idx
=
tidx
/
64
;
const
int
warp_idx
=
__builtin_amdgcn_readfirstlane
(
tidx
/
64
)
;
const
int
head_block_idx
=
NUM_M_BLOCKS
==
1
?
0
:
blockIdx
.
x
;
const
int
head_block_idx
=
NUM_M_BLOCKS
==
1
?
0
:
blockIdx
.
x
;
const
int
s_q_idx
=
blockIdx
.
y
;
const
int
s_q_idx
=
blockIdx
.
y
;
extern
__shared__
char
shared_memory
[];
extern
__shared__
char
shared_memory
[];
SharedMemoryPlan
&
plan
=
*
reinterpret_cast
<
SharedMemoryPlan
*>
(
shared_memory
);
SharedMemoryPlan
&
plan
=
*
reinterpret_cast
<
SharedMemoryPlan
*>
(
shared_memory
);
struct
MainloopArgs
{
int
start_block_idx
,
end_block_idx
;
bool
is_no_split
;
// The following fields are only valid for MODEL1
int
topk_length
,
extra_topk_length
,
num_orig_kv_blocks
;
};
auto
get_cur_req_info
=
[
&
](
int
batch_idx
)
->
MainloopArgs
{
MainloopArgs
args
;
int
total_topk_padded
;
if
constexpr
(
MODEL_TYPE
==
ModelType
::
V32
)
{
total_topk_padded
=
params
.
topk
;
}
else
{
int
topk_length
=
params
.
topk_length
?
__ldg
(
params
.
topk_length
+
batch_idx
)
:
params
.
topk
;
int
orig_topk_padded
=
max
(
ku
::
ceil
(
topk_length
,
(
int
)
TOPK_BLOCK_SIZE
),
(
int
)
TOPK_BLOCK_SIZE
);
int
extra_topk_length
=
params
.
extra_topk_length
?
__ldg
(
params
.
extra_topk_length
+
batch_idx
)
:
params
.
extra_topk
;
total_topk_padded
=
orig_topk_padded
+
ku
::
ceil
(
extra_topk_length
,
(
int
)
TOPK_BLOCK_SIZE
);
args
.
topk_length
=
topk_length
;
args
.
extra_topk_length
=
extra_topk_length
;
args
.
num_orig_kv_blocks
=
orig_topk_padded
/
TOPK_BLOCK_SIZE
;
}
args
.
start_block_idx
=
batch_idx
==
sched_meta
.
begin_req_idx
?
sched_meta
.
begin_block_idx
:
0
;
args
.
end_block_idx
=
batch_idx
==
sched_meta
.
end_req_idx
?
sched_meta
.
end_block_idx
:
total_topk_padded
/
TOPK_BLOCK_SIZE
;
args
.
is_no_split
=
batch_idx
==
sched_meta
.
begin_req_idx
?
!
sched_meta
.
is_first_req_splitted
:
(
batch_idx
==
sched_meta
.
end_req_idx
?
!
sched_meta
.
is_last_req_splitted
:
true
);
return
args
;
};
const
index_t
row_offset_q
=
batch_idx
*
params
.
stride_q_b
+
head_block_idx
*
BLOCK_M
*
params
.
stride_q_h_q
+
s_q_idx
*
params
.
stride_q_s_q
;
const
index_t
row_offset_q
=
batch_idx
*
params
.
stride_q_b
+
head_block_idx
*
BLOCK_M
*
params
.
stride_q_h_q
+
s_q_idx
*
params
.
stride_q_s_q
;
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q
)
+
row_offset_q
),
Tensor
gQ
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
q
)
+
row_offset_q
),
...
@@ -90,7 +61,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -90,7 +61,7 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
auto
thr_mma_16x16x32
=
tiled_mma_16x16x32
.
get_thread_slice
(
tidx
);
auto
thr_mma_16x16x32
=
tiled_mma_16x16x32
.
get_thread_slice
(
tidx
);
TiledMMA
tiled_mma_o
=
TiledMma_O
{};
TiledMMA
tiled_mma_o
=
TiledMma_O
{};
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
auto
thr_mma_o
=
tiled_mma_o
.
get_thread_slice
(
tidx
);
#if 0
// load Q
// load Q
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_tiled_copy_Q = make_tiled_copy_A(Copy_Atom<DefaultCopy, Element>{}, tiled_mma);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
auto gmem_thr_copy_Q = gmem_tiled_copy_Q.get_thread_slice(tidx);
...
@@ -101,6 +72,196 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -101,6 +72,196 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tSgQ)));
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.h_q - head_block_idx * BLOCK_M);
flash::copy</*Is_even_MN=*/false, /*Is_even_K=*/true>(gmem_tiled_copy_Q, tSgQ, tSrQ, tQcQ, tQpQ, params.h_q - head_block_idx * BLOCK_M);
__syncthreads();
__syncthreads();
#else
Tensor
tSrQ
=
thr_mma
.
partition_fragment_A
(
gQ
);
// 需要的最大空间为 16 * 576 * 2
Element
*
s_q
=
reinterpret_cast
<
Element
*>
(
shared_memory
);
auto
lds_direct_copy_q
=
[
&
](
const
int
k_idx
,
const
int
offset_k
)
{
// static_assert(offset_k == 0 || offset_k == 1);
// static_assert(k_idx < 3);
struct
PtrWrapper
{
uint32_t
former
;
uint32_t
latter
;
};
PtrWrapper
glob_ptr
;
*
(
uint64_t
*
)
&
glob_ptr
=
reinterpret_cast
<
uint64_t
>
(
gQ
.
data
().
get
());
uint32x4_t
global_addr
=
{
0
};
global_addr
[
0
]
=
(
glob_ptr
.
former
);
global_addr
[
1
]
=
(
glob_ptr
.
latter
);
global_addr
[
2
]
=
0x80000000
;
global_addr
[
3
]
=
0x00020000
;
constexpr
int
elements_per_thread
=
8
;
constexpr
int
bytes_per_warp
=
64
*
8
*
2
;
constexpr
int
bytes_per_block
=
bytes_per_warp
*
4
;
const
int
row_idx
=
lane_idx
%
16
;
const
int
col_idx
=
lane_idx
/
16
;
const
int
row_offset
=
row_idx
;
if
constexpr
(
MODEL_TYPE
==
ModelType
::
V32
)
{
int
col_offset
;
if
(
k_idx
==
2
)
{
col_offset
=
k_idx
*
256
+
warp_idx
*
8
+
col_idx
*
16
;
}
else
{
col_offset
=
k_idx
*
256
+
warp_idx
*
64
+
col_idx
*
16
+
offset_k
*
8
;
}
int
offset_v
=
(
row_offset
*
params
.
stride_q_h_q
+
col_offset
)
*
2
;
if
(
head_block_idx
*
BLOCK_M
+
row_idx
>=
params
.
h_q
)
{
offset_v
=
-
1
;
}
if
(
k_idx
==
2
&&
warp_idx
>=
2
)
{
offset_v
=
-
1
;
}
const
int
offset_s
=
0
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
s_q
)
+
warp_idx
*
bytes_per_warp
+
k_idx
*
bytes_per_block
+
offset_k
*
3
*
bytes_per_block
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
else
{
const
int
col_offset
=
k_idx
*
256
+
warp_idx
*
64
+
col_idx
*
16
+
offset_k
*
8
;
int
offset_v
=
(
row_offset
*
params
.
stride_q_h_q
+
col_offset
)
*
2
;
if
(
head_block_idx
*
BLOCK_M
+
row_idx
>=
params
.
h_q
)
{
offset_v
=
-
1
;
}
const
int
offset_s
=
0
;
int
ldsAddrPerWave
=
reinterpret_cast
<
size_t
>
(
s_q
)
+
warp_idx
*
bytes_per_warp
+
k_idx
*
bytes_per_block
+
offset_k
*
2
*
bytes_per_block
;
asm
volatile
(
"s_mov_b32 m0, %1
\n\t
"
"buffer_load_dwordx4 %0, %2, %3 ,offen offset:0, lds
\n
"
::
"v"
(
offset_v
),
"s"
(
ldsAddrPerWave
),
"s"
(
global_addr
),
"s"
(
offset_s
)
:
);
}
};
if
constexpr
(
MODEL_TYPE
==
ModelType
::
V32
)
{
// __builtin_amdgcn_sched_barrier(0);
lds_direct_copy_q
(
0
,
0
);
lds_direct_copy_q
(
1
,
0
);
lds_direct_copy_q
(
0
,
1
);
lds_direct_copy_q
(
1
,
1
);
lds_direct_copy_q
(
2
,
0
);
Element
*
s_q_read_ptr
=
s_q
+
lane_idx
*
8
;
asm
volatile
(
"s_waitcnt vmcnt(4)
\n
s_barrier"
);
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(3)
\n
s_barrier"
);
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n
s_barrier"
);
s_q_read_ptr
=
s_q
+
lane_idx
*
8
+
3
*
4
*
16
*
4
*
8
;
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
for
(
int
i
=
8
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
-
8
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
for
(
int
i
=
8
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
-
8
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
s_q_read_ptr
=
s_q
+
lane_idx
*
8
+
2
*
4
*
16
*
4
*
8
;
for
(
int
k
=
8
;
k
<
9
;
k
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
];
}
s_q_read_ptr
+=
16
*
32
;
}
for
(
int
k
=
8
;
k
<
9
;
k
++
)
{
for
(
int
i
=
8
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
-
8
];
}
s_q_read_ptr
+=
16
*
32
;
}
// __syncthreads();
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
s_barrier"
);
// __builtin_amdgcn_sched_barrier(0);
}
else
{
// __builtin_amdgcn_sched_barrier(0);
lds_direct_copy_q
(
0
,
0
);
lds_direct_copy_q
(
1
,
0
);
lds_direct_copy_q
(
0
,
1
);
lds_direct_copy_q
(
1
,
1
);
Element
*
s_q_read_ptr
=
s_q
+
lane_idx
*
8
;
asm
volatile
(
"s_waitcnt vmcnt(3)
\n
s_barrier"
);
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(2)
\n
s_barrier"
);
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
for
(
int
i
=
0
;
i
<
8
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(1)
\n
s_barrier"
);
for
(
int
k
=
0
;
k
<
4
;
k
++
)
{
for
(
int
i
=
8
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
-
8
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt vmcnt(0)
\n
s_barrier"
);
for
(
int
k
=
4
;
k
<
8
;
k
++
)
{
for
(
int
i
=
8
;
i
<
16
;
i
++
)
{
tSrQ
(
i
,
0
,
k
)
=
s_q_read_ptr
[
i
-
8
];
}
s_q_read_ptr
+=
16
*
32
;
}
asm
volatile
(
"s_waitcnt lgkmcnt(0)
\n
s_barrier"
);
// __builtin_amdgcn_sched_barrier(0);
}
#endif
// zhj debug
// zhj debug
// if (head_block_idx == 0)
// if (head_block_idx == 0)
// {
// {
...
@@ -133,7 +294,35 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -133,7 +294,35 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
uint32x4_t
data_128
;
uint32x4_t
data_128
;
uint16_t
data_array
[
8
];
uint16_t
data_array
[
8
];
};
};
struct
MainloopArgs
{
int
start_block_idx
,
end_block_idx
;
bool
is_no_split
;
// The following fields are only valid for MODEL1
int
topk_length
,
extra_topk_length
,
num_orig_kv_blocks
;
};
auto
get_cur_req_info
=
[
&
](
int
batch_idx
)
->
MainloopArgs
{
MainloopArgs
args
;
int
total_topk_padded
;
if
constexpr
(
MODEL_TYPE
==
ModelType
::
V32
)
{
total_topk_padded
=
params
.
topk
;
}
else
{
int
topk_length
=
params
.
topk_length
?
__ldg
(
params
.
topk_length
+
batch_idx
)
:
params
.
topk
;
int
orig_topk_padded
=
max
(
ku
::
ceil
(
topk_length
,
(
int
)
TOPK_BLOCK_SIZE
),
(
int
)
TOPK_BLOCK_SIZE
);
int
extra_topk_length
=
params
.
extra_topk_length
?
__ldg
(
params
.
extra_topk_length
+
batch_idx
)
:
params
.
extra_topk
;
total_topk_padded
=
orig_topk_padded
+
ku
::
ceil
(
extra_topk_length
,
(
int
)
TOPK_BLOCK_SIZE
);
args
.
topk_length
=
topk_length
;
args
.
extra_topk_length
=
extra_topk_length
;
args
.
num_orig_kv_blocks
=
orig_topk_padded
/
TOPK_BLOCK_SIZE
;
}
args
.
start_block_idx
=
batch_idx
==
sched_meta
.
begin_req_idx
?
sched_meta
.
begin_block_idx
:
0
;
args
.
end_block_idx
=
batch_idx
==
sched_meta
.
end_req_idx
?
sched_meta
.
end_block_idx
:
total_topk_padded
/
TOPK_BLOCK_SIZE
;
args
.
is_no_split
=
batch_idx
==
sched_meta
.
begin_req_idx
?
!
sched_meta
.
is_first_req_splitted
:
(
batch_idx
==
sched_meta
.
end_req_idx
?
!
sched_meta
.
is_last_req_splitted
:
true
);
return
args
;
};
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{});
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{});
clear
(
acc_o
);
clear
(
acc_o
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment