Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
91691124
Commit
91691124
authored
Feb 06, 2026
by
zhanghj2
Browse files
优化combine
parent
c4412432
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
48 additions
and
19 deletions
+48
-19
csrc/smxx/decode/combine/combine.cu
csrc/smxx/decode/combine/combine.cu
+47
-18
csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
...decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
+1
-1
No files found.
csrc/smxx/decode/combine/combine.cu
View file @
91691124
...
...
@@ -40,7 +40,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
return
;
}
FLASH_DEVICE_ASSERT
(
my_num_splits
<=
MAX_SPLITS
);
//
FLASH_DEVICE_ASSERT(my_num_splits <= MAX_SPLITS);
Tensor
gLseAccum
=
make_tensor
(
make_gmem_ptr
((
float
*
)
params
.
lse_accum
+
start_split_idx
*
params
.
stride_lse_accum_split
+
s_q_idx
*
params
.
stride_lse_accum_s_q
+
h_block_idx
*
BLOCK_SIZE_M
),
...
...
@@ -127,6 +127,7 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
__syncthreads
();
static_assert
(
HEAD_DIM_V
%
(
64
*
4
)
==
0
);
constexpr
int
ELEMS_PER_THREAD
=
HEAD_DIM_V
/
(
64
*
4
);
static_assert
(
ELEMS_PER_THREAD
==
2
);
float
*
oaccum_ptr
=
params
.
o_accum
+
start_split_idx
*
params
.
stride_o_accum_split
+
s_q_idx
*
params
.
stride_o_accum_s_q
+
(
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
)
*
params
.
stride_o_accum_h_q
;
float4
datas
[
ELEMS_PER_THREAD
];
CUTLASS_PRAGMA_UNROLL
...
...
@@ -165,24 +166,52 @@ flash_fwd_mla_combine_kernel(const CombineParams params) {
// printf(" %.3f \n", result[0].x);
// }
const
int
h_q_idx
=
h_block_idx
*
BLOCK_SIZE_M
+
warp_idx
;
ElementT
*
o_ptr
=
(
ElementT
*
)
params
.
out
+
batch_idx
*
params
.
stride_o_b
+
s_q_idx
*
params
.
stride_o_s_q
+
h_q_idx
*
params
.
stride_o_h_q
;
CUTLASS_PRAGMA_UNROLL
ElementT
*
o_ptr
=
(
ElementT
*
)
params
.
out
+
batch_idx
*
params
.
stride_o_b
+
s_q_idx
*
params
.
stride_o_s_q
+
h_q_idx
*
params
.
stride_o_h_q
+
lane_idx
*
8
;
ElementT
data_converted
[
8
];
using
result_type
=
cutlass
::
Array
<
ElementT
,
2
>
;
for
(
int
i
=
0
;
i
<
ELEMS_PER_THREAD
;
++
i
)
{
float4
data
=
result
[
i
];
ElementT
data_converted
[
4
];
// auto res = __builtin_hcu_cvt_pk_bf16_f32(0, data.x, 0, data.y, 0);
// data_converted[0].storage = res[0];
// data_converted[1].storage = res[1];
// res = __builtin_hcu_cvt_pk_bf16_f32(0, data.z, 0, data.w, 0);
// data_converted[2].storage = res[0];
// data_converted[3].storage = res[1];
data_converted
[
0
]
=
(
ElementT
)(
data
.
x
);
data_converted
[
1
]
=
(
ElementT
)(
data
.
y
);
data_converted
[
2
]
=
(
ElementT
)(
data
.
z
);
data_converted
[
3
]
=
(
ElementT
)(
data
.
w
);
static_assert
(
sizeof
(
ElementT
)
==
2
);
*
(
uint64_t
*
)(
o_ptr
+
lane_idx
*
8
+
i
*
4
)
=
*
(
uint64_t
*
)
data_converted
;
if
constexpr
(
std
::
is_same_v
<
cutlass
::
bfloat16_t
,
ElementT
>
)
{
#if defined(__gfx938__)
auto
d0
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
result
[
i
].
x
,
0
,
result
[
i
].
y
,
0
);
auto
d1
=
__builtin_hcu_cvt_pk_bf16_f32
(
0
,
result
[
i
].
z
,
0
,
result
[
i
].
w
,
0
);
auto
res0
=
reinterpret_cast
<
result_type
const
&>
(
d0
);
auto
res1
=
reinterpret_cast
<
result_type
const
&>
(
d1
);
o_ptr
[
i
*
4
]
=
res0
[
0
];
o_ptr
[
i
*
4
+
1
]
=
res0
[
1
];
o_ptr
[
i
*
4
+
2
]
=
res1
[
0
];
o_ptr
[
i
*
4
+
3
]
=
res1
[
1
];
#else
// auto float32_to_bfloat16 = [&](float v) -> ElementT {
// union {
// float fp32;
// uint32_t int32;
// } u = {v};
// ElementT res;
// res.storage = (u.int32 >> 16);
// return res;
// };
// float4 data = result[i];
// o_ptr[i * 4] = float32_to_bfloat16((data.x));
// o_ptr[i * 4 + 1] = float32_to_bfloat16((data.y));
// o_ptr[i * 4 + 2] = float32_to_bfloat16((data.z));
// o_ptr[i * 4 + 3] = float32_to_bfloat16((data.w));
data_converted
[
i
*
4
]
=
(
ElementT
)(
data
.
x
);
data_converted
[
i
*
4
+
1
]
=
(
ElementT
)(
data
.
y
);
data_converted
[
i
*
4
+
2
]
=
(
ElementT
)(
data
.
z
);
data_converted
[
i
*
4
+
3
]
=
(
ElementT
)(
data
.
w
);
#endif
}
else
{
auto
d0
=
__builtin_hcu_cvt_pkrtz
(
result
[
i
].
x
,
result
[
i
].
y
);
auto
d1
=
__builtin_hcu_cvt_pkrtz
(
result
[
i
].
z
,
result
[
i
].
w
);
auto
res0
=
reinterpret_cast
<
result_type
const
&>
(
d0
);
auto
res1
=
reinterpret_cast
<
result_type
const
&>
(
d1
);
o_ptr
[
i
*
4
]
=
res0
[
0
];
o_ptr
[
i
*
4
+
1
]
=
res0
[
1
];
o_ptr
[
i
*
4
+
2
]
=
res1
[
0
];
o_ptr
[
i
*
4
+
3
]
=
res1
[
1
];
}
}
}
}
...
...
csrc/smxx/decode/get_decoding_sched_meta/get_decoding_sched_meta.cu
View file @
91691124
...
...
@@ -96,7 +96,7 @@ get_mla_metadata_kernel(const GetDecodeSchedMetaParams params) {
}
tile_scheduler_metadata_ptr
[
i
]
=
cur_meta
;
}
FLASH_DEVICE_ASSERT
(
now_req_idx
==
batch_size
&&
now_block
==
0
&&
now_n_split_idx
==
0
);
//
FLASH_DEVICE_ASSERT(now_req_idx == batch_size && now_block == 0 && now_n_split_idx == 0);
}
__syncthreads
();
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment