Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
fd2b2d8f
"tests_mpi/test_internode.py" did not exist on "6b17f4fa6e2fb093d7dc73563f52e3b32d088a6b"
Commit
fd2b2d8f
authored
Mar 20, 2026
by
zhanghj2
Browse files
fix fp8 e5m2融合问题
parent
2ff5a773
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
csrc/extension/flash_fwd_mla_kernel.h
csrc/extension/flash_fwd_mla_kernel.h
+12
-12
No files found.
csrc/extension/flash_fwd_mla_kernel.h
View file @
fd2b2d8f
...
...
@@ -1666,17 +1666,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
__fp16
fp16_v
;
uint16_t
tmp
;
};
if
constexpr
(
std
::
is_same_v
<
Element_O
,
cutlass
::
bfloat16_t
>
)
{
for
(
int
i
=
0
;
i
<
size
(
tSrQ_copy_view
);
i
++
)
{
uint16_t
tmp
=
tSrQ_copy_view
(
i
).
storage
;
Fp32_storage
fp32
;
fp32
.
u32
=
tmp
<<
16
;
Fp16_storage
fp16_t
;
fp16_t
.
fp16_v
=
static_cast
<
__fp16
>
(
fp32
.
fp32
);
tSrQ_copy_view
(
i
)
=
cutlass
::
half_t
::
bitcast
(
fp16_t
.
tmp
);
}
}
//
if constexpr (std::is_same_v<Element_O, cutlass::bfloat16_t>) {
//
for (int i = 0; i < size(tSrQ_copy_view); i++)
//
{
//
uint16_t tmp = tSrQ_copy_view(i).storage;
//
Fp32_storage fp32;
//
fp32.u32 = tmp << 16;
//
Fp16_storage fp16_t;
//
fp16_t.fp16_v = static_cast<__fp16>(fp32.fp32);
//
tSrQ_copy_view(i) = cutlass::half_t::bitcast(fp16_t.tmp);
//
}
//
}
#else
...
...
@@ -4303,7 +4303,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
run_flash_splitkv_fwd_mla_q_nope_pe
<
Kernel_traits
,
flash
::
SharedStorageMLA
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kAuto
>
(
params
,
stream
);
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8
<
576
,
16
,
64
,
4
,
cutlass
::
half_t
,
512
,
T
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8
<
576
,
16
,
64
,
4
,
T
,
512
,
T
>
;
run_flash_splitkv_fwd_mla_q_nope_pe
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E5M2
>
(
params
,
stream
);
}
else
{
printf
(
"is_q_nope_pe = %d Unsupported kv cache dtype
\n
"
,
is_q_nope_pe
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment