Commit 3e191138 authored by zhuwenwen's avatar zhuwenwen
Browse files

maintain consistency between k_cache type and q

parent da85feb7
......@@ -239,7 +239,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank,
......
......@@ -186,7 +186,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment