Commit 4b00d1ba authored by zhuwenwen's avatar zhuwenwen
Browse files

add VLLM_USE_CAT_MLA to use fused cat and mla

parent be22412f
...@@ -7,6 +7,7 @@ import torch ...@@ -7,6 +7,7 @@ import torch
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm import envs
logger = init_logger(__name__) logger = init_logger(__name__)
...@@ -144,6 +145,57 @@ def flash_mla_with_kvcache( ...@@ -144,6 +145,57 @@ def flash_mla_with_kvcache(
return out, softmax_lse return out, softmax_lse
def flash_mla_with_kvcache_q_nope_pe(
q_nope: torch.Tensor,
q_pe: torch.Tensor,
k_cache: torch.Tensor,
block_table: torch.Tensor,
cache_seqlens: torch.Tensor,
head_dim_v: int,
tile_scheduler_metadata: torch.Tensor,
num_splits: torch.Tensor,
softmax_scale: Optional[float] = None,
causal: bool = False,
k_scale = None,
kv_cache_dtype = "auto",
) -> Tuple[torch.Tensor, torch.Tensor]:
if softmax_scale is None:
softmax_scale = (q_nope.shape[-1] + q_pe.shape[-1]) ** (-0.5)
if current_platform.is_rocm():
if kv_cache_dtype == "fp8" or kv_cache_dtype == "fp8_e4m3" or kv_cache_dtype == "fp8_e5m2":
kv_dtype = "fp8_e4m3" if kv_cache_dtype == "fp8" else kv_cache_dtype
out, softmax_lse = flash_mla_cuda.fwd_kvcache_quantization_q_nope_pe_mla(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
k_scale,
kv_dtype,
)
return out, softmax_lse
out, softmax_lse = flash_mla_cuda.fwd_kvcache_mla_nope_pe(
q_nope,
q_pe,
k_cache,
None,
head_dim_v,
cache_seqlens,
block_table,
softmax_scale,
causal,
tile_scheduler_metadata,
num_splits,
)
return out, softmax_lse
# #
# TODO: Add fake functions # TODO: Add fake functions
# #
......
...@@ -184,6 +184,7 @@ if TYPE_CHECKING: ...@@ -184,6 +184,7 @@ if TYPE_CHECKING:
VLLM_USE_PP_BALANCE: bool = False VLLM_USE_PP_BALANCE: bool = False
VLLM_USE_ZERO_MTP: bool = False VLLM_USE_ZERO_MTP: bool = False
VLLM_USE_CUDA_GRAPH_SIZES: bool = False VLLM_USE_CUDA_GRAPH_SIZES: bool = False
VLLM_USE_CAT_MLA: bool = False
def get_default_cache_root(): def get_default_cache_root():
return os.getenv( return os.getenv(
...@@ -1194,6 +1195,11 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -1194,6 +1195,11 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_USE_CUDA_GRAPH_SIZES": "VLLM_USE_CUDA_GRAPH_SIZES":
lambda: (os.getenv('VLLM_USE_CUDA_GRAPH_SIZES', 'True').lower() in lambda: (os.getenv('VLLM_USE_CUDA_GRAPH_SIZES', 'True').lower() in
("true", "1")), ("true", "1")),
# vllm will use fused cat and mla
"VLLM_USE_CAT_MLA":
lambda: (os.getenv('VLLM_USE_CAT_MLA', 'True').lower() in
("true", "1")),
} }
# --8<-- [end:env-vars-definition] # --8<-- [end:env-vars-definition]
......
...@@ -9,6 +9,7 @@ import torch ...@@ -9,6 +9,7 @@ import torch
from vllm.attention.backends.abstract import (AttentionType, from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache) is_quantized_kv_cache)
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache, from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
flash_mla_with_kvcache_q_nope_pe,
get_mla_metadata, get_mla_metadata,
is_flashmla_supported) is_flashmla_supported)
from vllm.logger import init_logger from vllm.logger import init_logger
...@@ -167,6 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -167,6 +168,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
assert kv_c_and_k_pe_cache.numel() > 0 assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
if not envs.VLLM_USE_CAT_MLA:
if envs.VLLM_USE_OPT_CAT: if envs.VLLM_USE_OPT_CAT:
if q_nope.shape[0] < 1024: if q_nope.shape[0] < 1024:
from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode from vllm.v1.attention.backends.mla.test_concat import concat_helper_decode
...@@ -179,6 +181,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -179,6 +181,7 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
q = torch.cat([q_nope, q_pe], dim=-1)\ q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode) .unsqueeze(1) # Add seqlen dim of 1 (decode)
if not envs.VLLM_USE_CAT_MLA:
o, _ = flash_mla_with_kvcache( o, _ = flash_mla_with_kvcache(
q=q, q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1 k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
...@@ -193,5 +196,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]): ...@@ -193,5 +196,21 @@ class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
k_scale = k_scale, k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype, kv_cache_dtype = kv_cache_dtype,
) )
else:
o, _ = flash_mla_with_kvcache_q_nope_pe(
q_nope=q_nope.unsqueeze(1),
q_pe=q_pe.unsqueeze(1),
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.decode.block_table,
cache_seqlens=attn_metadata.decode.seq_lens,
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.decode.
tile_scheduler_metadata,
num_splits=attn_metadata.decode.num_splits,
softmax_scale=self.scale,
causal=True,
k_scale = k_scale,
kv_cache_dtype = kv_cache_dtype,
)
return self._v_up_proj(o) return self._v_up_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment