Commit 77210184 authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache_fp8 interface and k_cache

parent 347fc09c
...@@ -17,7 +17,6 @@ from vllm.attention.backends.mla.common import (MLACommonBackend, ...@@ -17,7 +17,6 @@ from vllm.attention.backends.mla.common import (MLACommonBackend,
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm import envs from vllm import envs
...@@ -239,7 +238,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -239,7 +238,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8: if torch.cuda.get_device_properties("cuda").gcnArchName.split(':')[0] == "gfx938" and kv_cache_dtype == "fp8_e4m3" and envs.VLLM_USE_FLASH_MLA_FP8:
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn), q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).to(torch.float8_e4m3fn), # Add head dim of 1
block_table=decode_meta.block_tables, block_table=decode_meta.block_tables,
cache_seqlens=decode_meta.seq_lens_tensor, cache_seqlens=decode_meta.seq_lens_tensor,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
......
...@@ -73,6 +73,7 @@ def get_mla_decoding_metadata_dense_fp8( ...@@ -73,6 +73,7 @@ def get_mla_decoding_metadata_dense_fp8(
cache_seqlens: torch.Tensor, cache_seqlens: torch.Tensor,
num_heads_per_head_k: int, num_heads_per_head_k: int,
num_heads_k: int, num_heads_k: int,
num_heads_q : int = 16,
) -> Tuple[torch.Tensor, torch.Tensor]: ) -> Tuple[torch.Tensor, torch.Tensor]:
""" """
Arguments: Arguments:
...@@ -87,7 +88,7 @@ def get_mla_decoding_metadata_dense_fp8( ...@@ -87,7 +88,7 @@ def get_mla_decoding_metadata_dense_fp8(
""" """
return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens, return flash_mla_cuda.get_mla_decoding_metadata_dense_fp8(cache_seqlens,
num_heads_per_head_k, num_heads_per_head_k,
num_heads_k) num_heads_k, num_heads_q)
def flash_mla_with_kvcache( def flash_mla_with_kvcache(
......
...@@ -12,7 +12,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, ...@@ -12,7 +12,6 @@ from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe, flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
flash_mla_with_kvcache_fp8, flash_mla_with_kvcache_fp8,
get_mla_decoding_metadata_dense_fp8,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend, from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
...@@ -183,10 +182,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -183,10 +182,9 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
else: else:
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache_fp8( o, _ = flash_mla_with_kvcache_fp8(
q=q.to(torch.float8_e4m3fn), q=q.to(torch.float8_e4m3fn),
k_cache=kv_c_and_k_pe_cache.view(torch.float8_e4m3fn).unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2).to(torch.float8_e4m3fn), # Add head dim of 1
block_table=attn_metadata.decode.block_table, block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens, cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank, head_dim_v=self.kv_lora_rank,
...@@ -213,7 +211,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -213,7 +211,6 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA or kv_cache_dtype == "fp8_e4m3":
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment