"vscode:/vscode.git/clone" did not exist on "e387ee74fb9666a6a2338901792d45cb7ee8b263"
Commit fd2b2d8f authored by zhanghj2's avatar zhanghj2
Browse files

fix fp8 e5m2融合问题

parent 2ff5a773
...@@ -1666,17 +1666,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g ...@@ -1666,17 +1666,17 @@ __forceinline__ __device__ void compute_attn_1rowblock_splitkv_mla_fp8_nope_pe_g
__fp16 fp16_v; __fp16 fp16_v;
uint16_t tmp; uint16_t tmp;
}; };
if constexpr (std::is_same_v<Element_O, cutlass::bfloat16_t>) { // if constexpr (std::is_same_v<Element_O, cutlass::bfloat16_t>) {
for (int i = 0; i < size(tSrQ_copy_view); i++) // for (int i = 0; i < size(tSrQ_copy_view); i++)
{ // {
uint16_t tmp = tSrQ_copy_view(i).storage; // uint16_t tmp = tSrQ_copy_view(i).storage;
Fp32_storage fp32; // Fp32_storage fp32;
fp32.u32 = tmp << 16; // fp32.u32 = tmp << 16;
Fp16_storage fp16_t; // Fp16_storage fp16_t;
fp16_t.fp16_v = static_cast<__fp16>(fp32.fp32); // fp16_t.fp16_v = static_cast<__fp16>(fp32.fp32);
tSrQ_copy_view(i) = cutlass::half_t::bitcast(fp16_t.tmp); // tSrQ_copy_view(i) = cutlass::half_t::bitcast(fp16_t.tmp);
} // }
} // }
#else #else
...@@ -4303,7 +4303,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, const std::string& kv ...@@ -4303,7 +4303,7 @@ void run_mha_fwd_splitkv_mla(Flash_fwd_mla_params &params, const std::string& kv
run_flash_splitkv_fwd_mla_q_nope_pe<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>, Fp8KVCacheDataType::kAuto>(params, stream); run_flash_splitkv_fwd_mla_q_nope_pe<Kernel_traits, flash::SharedStorageMLA<Kernel_traits>, Fp8KVCacheDataType::kAuto>(params, stream);
} }
else if (kv_cache_dtype == "fp8_e5m2") { else if (kv_cache_dtype == "fp8_e5m2") {
using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8<576, 16, 64, 4, cutlass::half_t, 512, T>; using Kernel_traits = Flash_fwd_kernel_traits_mla_kvfp8<576, 16, 64, 4, T, 512, T>;
run_flash_splitkv_fwd_mla_q_nope_pe<Kernel_traits, flash::SharedStorageMLAFp8<Kernel_traits>, Fp8KVCacheDataType::kFp8E5M2>(params, stream); run_flash_splitkv_fwd_mla_q_nope_pe<Kernel_traits, flash::SharedStorageMLAFp8<Kernel_traits>, Fp8KVCacheDataType::kFp8E5M2>(params, stream);
} else { } else {
printf("is_q_nope_pe = %d Unsupported kv cache dtype \n", is_q_nope_pe); printf("is_q_nope_pe = %d Unsupported kv cache dtype \n", is_q_nope_pe);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment