"vscode:/vscode.git/clone" did not exist on "b6087a6beead9165f4c77ceba592b3651bb37de9"
Commit 2461ea9d authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent 8bfd0bde
...@@ -33,6 +33,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -33,6 +33,7 @@ from vllm.v1.attention.backends.utils import (
reshape_query_for_spec_decode, reshape_query_for_spec_decode,
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -298,6 +299,19 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -298,6 +299,19 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# zeros of length B+1 # zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
if current_platform.is_rocm():
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
)
else:
o, lse = flash_mla_with_kvcache( o, lse = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment