Commit 9663a03f authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent d0e16bf5
...@@ -226,16 +226,14 @@ def flash_mla_with_kvcache( ...@@ -226,16 +226,14 @@ def flash_mla_with_kvcache(
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
block_table, None,
cache_seqlens,
head_dim_v, head_dim_v,
tile_scheduler_metadata, cache_seqlens,
num_splits, block_table,
softmax_scale, softmax_scale,
causal, causal,
is_fp8_kvcache, tile_scheduler_metadata,
indices, num_splits)
)
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, q,
......
...@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs from vllm import envs
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -310,6 +311,19 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -310,6 +311,19 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# zeros of length B+1 # zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
if current_platform.is_rocm():
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
)
else:
o, lse = flash_mla_with_kvcache( o, lse = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment