Commit 2461ea9d authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent 8bfd0bde
......@@ -33,6 +33,7 @@ from vllm.v1.attention.backends.utils import (
reshape_query_for_spec_decode,
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -298,6 +299,19 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
if current_platform.is_rocm():
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
)
else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment