Commit 9663a03f authored by zhuwenwen's avatar zhuwenwen
Browse files

update flash_mla_with_kvcache

parent d0e16bf5
...@@ -224,18 +224,16 @@ def flash_mla_with_kvcache( ...@@ -224,18 +224,16 @@ def flash_mla_with_kvcache(
else: else:
if current_platform.is_rocm(): if current_platform.is_rocm():
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla( out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla(
q, q,
k_cache, k_cache,
block_table, None,
cache_seqlens, head_dim_v,
head_dim_v, cache_seqlens,
tile_scheduler_metadata, block_table,
num_splits,
softmax_scale, softmax_scale,
causal, causal,
is_fp8_kvcache, tile_scheduler_metadata,
indices, num_splits)
)
else: else:
out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla( out, softmax_lse = torch.ops._flashmla_C.fwd_kvcache_mla(
q, q,
......
...@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import ( ...@@ -34,6 +34,7 @@ from vllm.v1.attention.backends.utils import (
) )
from vllm.v1.kv_cache_interface import AttentionSpec from vllm.v1.kv_cache_interface import AttentionSpec
from vllm import envs from vllm import envs
from vllm.platforms import current_platform
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -310,19 +311,32 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -310,19 +311,32 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
# zeros of length B+1 # zeros of length B+1
num_splits = torch.zeros((B + 1,), dtype=dtype, device=device) num_splits = torch.zeros((B + 1,), dtype=dtype, device=device)
o, lse = flash_mla_with_kvcache( if current_platform.is_rocm():
q=q, o, lse = flash_mla_with_kvcache(
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 q=q,
block_table=attn_metadata.decode.block_table, k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
cache_seqlens=attn_metadata.decode.seq_lens, block_table=attn_metadata.decode.block_table,
head_dim_v=self.kv_lora_rank, cache_seqlens=attn_metadata.decode.seq_lens,
tile_scheduler_metadata=tile_scheduler_metadata, head_dim_v=self.kv_lora_rank,
num_splits=num_splits, tile_scheduler_metadata=tile_scheduler_metadata,
softmax_scale=self.scale, num_splits=num_splits,
causal=True, softmax_scale=self.scale,
descale_q=layer._q_scale.reshape(1), causal=True,
descale_k=layer._k_scale.reshape(1), )
) else:
o, lse = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=tile_scheduler_metadata,
num_splits=num_splits,
softmax_scale=self.scale,
causal=True,
descale_q=layer._q_scale.reshape(1),
descale_k=layer._k_scale.reshape(1),
)
o = reshape_attn_output_for_spec_decode(o) o = reshape_attn_output_for_spec_decode(o)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment