Unverified Commit 2e94b9cf authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Flash MLA for V1 (#13867)


Signed-off-by: default avatarYang Chen <yangche@fb.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
Co-authored-by: default avatarYang Chen <yangche@fb.com>
parent 8294773e
...@@ -161,15 +161,9 @@ class CudaPlatformBase(Platform): ...@@ -161,15 +161,9 @@ class CudaPlatformBase(Platform):
def get_attn_backend_cls(cls, selected_backend, head_size, dtype, def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_v1:
if use_mla:
logger.info("Using Triton MLA backend on V1 engine.")
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
else:
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if use_mla: if use_mla:
# TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here
if selected_backend == _Backend.FLASHMLA: if selected_backend == _Backend.FLASHMLA:
from vllm.attention.backends.flashmla import ( from vllm.attention.backends.flashmla import (
is_flashmla_supported) is_flashmla_supported)
...@@ -182,12 +176,27 @@ class CudaPlatformBase(Platform): ...@@ -182,12 +176,27 @@ class CudaPlatformBase(Platform):
"FlashMLA backend is not supported for block size %d" "FlashMLA backend is not supported for block size %d"
" (currently only supports block size 64).", " (currently only supports block size 64).",
block_size) block_size)
else:
if use_v1:
logger.info("Using FlashMLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"flashmla.FlashMLABackend")
else: else:
logger.info("Using FlashMLA backend.") logger.info("Using FlashMLA backend.")
return "vllm.attention.backends.flashmla.FlashMLABackend" return ("vllm.attention.backends."
"flashmla.FlashMLABackend")
if use_v1:
logger.info("Using Triton MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"triton_mla.TritonMLABackend")
else:
logger.info("Using Triton MLA backend.") logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend" return "vllm.attention.backends.triton_mla.TritonMLABackend"
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if selected_backend == _Backend.FLASHINFER: if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.") logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend" return "vllm.attention.backends.flashinfer.FlashInferBackend"
......
...@@ -34,9 +34,8 @@ class _Backend(enum.Enum): ...@@ -34,9 +34,8 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto() TORCH_SDPA = enum.auto()
OPENVINO = enum.auto() OPENVINO = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
TRITON_MLA = enum.auto() TRITON_MLA = enum.auto() # Supported by V1
TRITON_MLA_VLLM_V1 = enum.auto() FLASHMLA = enum.auto() # Supported by V1
FLASHMLA = enum.auto()
HPU_ATTN = enum.auto() HPU_ATTN = enum.auto()
PALLAS = enum.auto() PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto() PALLAS_VLLM_V1 = enum.auto()
......
...@@ -333,13 +333,16 @@ class MLACommonMetadata: ...@@ -333,13 +333,16 @@ class MLACommonMetadata:
T = TypeVar("T", bound=MLACommonMetadata) T = TypeVar("T", bound=MLACommonMetadata)
class MLACommonMetadataBuilder: class MLACommonMetadataBuilder(Generic[T]):
""" """
NOTE: Please read the comment at the top of the file before trying to NOTE: Please read the comment at the top of the file before trying to
understand this class understand this class
""" """
def __init__(self, runner: "GPUModelRunner"): def __init__(self,
runner: "GPUModelRunner",
cls: Optional[type[T]] = None):
self.cls = cls if cls is not None else MLACommonMetadata
self.runner = runner self.runner = runner
scheduler_config = runner.scheduler_config scheduler_config = runner.scheduler_config
model_config = runner.model_config model_config = runner.model_config
...@@ -431,7 +434,7 @@ class MLACommonMetadataBuilder: ...@@ -431,7 +434,7 @@ class MLACommonMetadataBuilder:
self._num_prefill_tokens = num_prefill_tokens self._num_prefill_tokens = num_prefill_tokens
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int, def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int): common_prefix_len: int) -> T:
device = self.runner.device device = self.runner.device
max_seq_len = self.runner.seq_lens_np[:num_reqs].max() max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to( query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
...@@ -502,7 +505,7 @@ class MLACommonMetadataBuilder: ...@@ -502,7 +505,7 @@ class MLACommonMetadataBuilder:
assert max(context_chunk_seq_tot) <= \ assert max(context_chunk_seq_tot) <= \
self.chunked_prefill_workspace_size self.chunked_prefill_workspace_size
return MLACommonMetadata( return self.cls(
input_positions=input_positions, input_positions=input_positions,
num_actual_tokens=num_actual_tokens, num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len, max_query_len=max_query_len,
......
# SPDX-License-Identifier: Apache-2.0
from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.ops.flashmla import (flash_mla_with_kvcache,
get_mla_metadata,
is_flashmla_supported)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata,
MLACommonMetadataBuilder)
logger = init_logger(__name__)
class FlashMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "FLASHMLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> Type["FlashMLAMetadata"]:
return FlashMLAMetadata
@staticmethod
def get_builder_cls() -> Type["FlashMLAMetadataBuilder"]:
return FlashMLAMetadataBuilder
@staticmethod
def get_impl_cls() -> Type["FlashMLAImpl"]:
return FlashMLAImpl
@dataclass
class FlashMLAMetadata(MLACommonMetadata):
decode_tile_scheduler_metadata: Optional[Tuple[torch.Tensor,
torch.Tensor]] = None
decode_num_splits: Optional[torch.Tensor] = None
class FlashMLAMetadataBuilder(MLACommonMetadataBuilder[FlashMLAMetadata]):
def __init__(self, runner):
super().__init__(runner, cls=FlashMLAMetadata)
self.num_q_heads = self.runner.model_config.get_num_attention_heads(
self.runner.parallel_config)
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
m = super().build(num_reqs, num_actual_tokens, max_query_len,
common_prefix_len)
if m.num_decode_tokens is not None and m.num_decode_tokens > 0:
m.decode_tile_scheduler_metadata, m.decode_num_splits = \
get_mla_metadata(
m.seq_lens[:m.num_decode_tokens],
self.num_q_heads,
1, # MQA for the decode path
)
return m
class FlashMLAImpl(MLACommonImpl[FlashMLAMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
assert is_flashmla_supported(), \
"FlashMLA is not supported on this device"
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"FlashMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"FlashMLAImpl")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: FlashMLAMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 FlashMLA not yet supported")
q = torch.cat([q_nope, q_pe], dim=-1)\
.unsqueeze(1) # Add seqlen dim of 1 (decode)
o, _ = flash_mla_with_kvcache(
q=q,
k_cache=kv_c_and_k_pe_cache.unsqueeze(-2), # Add head dim of 1
block_table=attn_metadata.block_table[:attn_metadata.num_decodes,
...],
cache_seqlens=attn_metadata.seq_lens[:attn_metadata.
num_decode_tokens],
head_dim_v=self.kv_lora_rank,
tile_scheduler_metadata=attn_metadata.
decode_tile_scheduler_metadata,
num_splits=attn_metadata.decode_num_splits,
softmax_scale=self.scale,
causal=True,
)
return self._v_up_proj_and_o_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment