Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in
Toggle navigation
Menu
Open sidebar
OpenDAS
FlashMLA
Commits
fd2b2d8f
Commit
fd2b2d8f
authored
Mar 20, 2026
by
zhanghj2
Browse files
fix fp8 e5m2融合问题
parent
2ff5a773
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
12 deletions
+12
-12
csrc/extension/flash_fwd_mla_kernel.h
csrc/extension/flash_fwd_mla_kernel.h
+12
-12
No files found.
csrc/extension/flash_fwd_mla_kernel.h
View file @
fd2b2d8f
...
@@ -1666,17 +1666,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
...
@@ -1666,17 +1666,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
__fp16
fp16_v
;
__fp16
fp16_v
;
uint16_t
tmp
;
uint16_t
tmp
;
};
};
if
constexpr
(
std
::
is_same_v
<
Element_O
,
cutlass
::
bfloat16_t
>
)
{
//
if constexpr (std::is_same_v<Element_O, cutlass::bfloat16_t>) {
for
(
int
i
=
0
;
i
<
size
(
tSrQ_copy_view
);
i
++
)
//
for (int i = 0; i < size(tSrQ_copy_view); i++)
{
//
{
uint16_t
tmp
=
tSrQ_copy_view
(
i
).
storage
;
//
uint16_t tmp = tSrQ_copy_view(i).storage;
Fp32_storage
fp32
;
//
Fp32_storage fp32;
fp32
.
u32
=
tmp
<<
16
;
//
fp32.u32 = tmp << 16;
Fp16_storage
fp16_t
;
//
Fp16_storage fp16_t;
fp16_t
.
fp16_v
=
static_cast
<
__fp16
>
(
fp32
.
fp32
);
//
fp16_t.fp16_v = static_cast<__fp16>(fp32.fp32);
tSrQ_copy_view
(
i
)
=
cutlass
::
half_t
::
bitcast
(
fp16_t
.
tmp
);
//
tSrQ_copy_view(i) = cutlass::half_t::bitcast(fp16_t.tmp);
}
//
}
}
//
}
#else
#else
...
@@ -4303,7 +4303,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
...
@@ -4303,7 +4303,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params ¶ms, const std::string& kv
run_flash_splitkv_fwd_mla_q_nope_pe
<
Kernel_traits
,
flash
::
SharedStorageMLA
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kAuto
>
(
params
,
stream
);
run_flash_splitkv_fwd_mla_q_nope_pe
<
Kernel_traits
,
flash
::
SharedStorageMLA
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kAuto
>
(
params
,
stream
);
}
}
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
else
if
(
kv_cache_dtype
==
"fp8_e5m2"
)
{
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8
<
576
,
16
,
64
,
4
,
cutlass
::
half_t
,
512
,
T
>
;
using
Kernel_traits
=
Flash_fwd_kernel_traits_mla_kvfp8
<
576
,
16
,
64
,
4
,
T
,
512
,
T
>
;
run_flash_splitkv_fwd_mla_q_nope_pe
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E5M2
>
(
params
,
stream
);
run_flash_splitkv_fwd_mla_q_nope_pe
<
Kernel_traits
,
flash
::
SharedStorageMLAFp8
<
Kernel_traits
>
,
Fp8KVCacheDataType
::
kFp8E5M2
>
(
params
,
stream
);
}
else
{
}
else
{
printf
(
"is_q_nope_pe = %d Unsupported kv cache dtype
\n
"
,
is_q_nope_pe
);
printf
(
"is_q_nope_pe = %d Unsupported kv cache dtype
\n
"
,
is_q_nope_pe
);
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment