Commit 9663a03f authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent d0e16bf5
......@@ -224,18 +224,16 @@ def flash_mla_with_kvcache(
else:
if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q,
k_cache,
block_table,
cache_seqlens,
head_dim_v,
tile_scheduler_metadata,
num_splits,
q,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
is_fp8_kvcache,
indices,
)
causal,
tile_scheduler_metadata,
num_splits)
else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q,
......
......@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (
)
from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs
from vllm.platforms import current_platform
logger = init_logger(__name__)
......@@ -310,19 +311,32 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
if current_platform.is_rocm():
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
)
else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
o = reshape_attn_output_for_spec_decode(o)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment