Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
6fb681fc
Commit
6fb681fc
authored
Jan 27, 2026
by
zhanghj2
Browse files
lambda函数优化代码结构
parent
75f8262c
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
10 additions
and
3 deletions
+10
-3
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
+10
-3
No files found.
csrc/sm90/decode/sparse_fp8/splitkv_mla.cuh
View file @
6fb681fc
...
@@ -136,10 +136,12 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -136,10 +136,12 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{});
Tensor
acc_o
=
partition_fragment_C
(
tiled_mma_o
,
Shape
<
Int
<
BLOCK_M
>
,
Int
<
HEAD_DIM_V
>>
{});
clear
(
acc_o
);
clear
(
acc_o
);
flash
::
Softmax
<
size
<
1
>
(
acc_o
)
>
softmax
;
flash
::
Softmax
<
size
<
1
>
(
acc_o
)
>
softmax
;
MainloopArgs
args
=
get_cur_req_info
(
batch_idx
);
for
(
int
block_idx
=
args
.
start_block_idx
;
block_idx
<
args
.
end_block_idx
;
block_idx
++
)
{
struct
IsOrigBlock
{};
struct
IsExtraBlock
{};
auto
process_one_block
=
[
&
](
int
block_idx
,
auto
is_extra_block_t
)
{
static
constexpr
bool
IS_EXTRA_BLOCK
=
std
::
is_same_v
<
decltype
(
is_extra_block_t
),
IsExtraBlock
>
;
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
BLOCK_M
>
,
Int
<
TOPK_BLOCK_SIZE
>>
{});
Tensor
acc_s
=
partition_fragment_C
(
tiled_mma
,
Shape
<
Int
<
BLOCK_M
>
,
Int
<
TOPK_BLOCK_SIZE
>>
{});
clear
(
acc_s
);
clear
(
acc_s
);
int
col_idx
=
lane_idx
/
16
;
int
col_idx
=
lane_idx
/
16
;
...
@@ -395,6 +397,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
...
@@ -395,6 +397,11 @@ __device__ void KernelTemplate<MODEL_TYPE, NUM_HEADS>::compute_attn_1rowblock_sp
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
2
),
tOrVt
(
_
,
_
,
2
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
2
),
tOrVt
(
_
,
_
,
2
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
3
),
tOrVt
(
_
,
_
,
3
),
acc_o
);
cute
::
gemm
(
tiled_mma_o
,
tOrP
(
_
,
_
,
3
),
tOrVt
(
_
,
_
,
3
),
acc_o
);
}
}
};
MainloopArgs
args
=
get_cur_req_info
(
batch_idx
);
for
(
int
block_idx
=
args
.
start_block_idx
;
block_idx
<
args
.
end_block_idx
;
block_idx
++
)
{
process_one_block
(
block_idx
,
IsOrigBlock
{});
}
}
if
(
args
.
is_no_split
)
{
if
(
args
.
is_no_split
)
{
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment