Unverified Commit 41aa5784 authored by Kaixi Hou's avatar Kaixi Hou Committed by GitHub
Browse files

[NVIDIA] Add Cutlass MLA backend (#17625)

parent 8d646c2e
...@@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options( ...@@ -119,7 +119,7 @@ typename T::Fmha::Arguments args_from_options(
{static_cast<ElementOut*>(out.data_ptr()), stride_O, {static_cast<ElementOut*>(out.data_ptr()), stride_O,
static_cast<ElementAcc*>(nullptr), stride_LSE}, static_cast<ElementAcc*>(nullptr), stride_LSE},
hw_info, hw_info,
-1, // split_kv 1, // split_kv
nullptr, // is_var_split_kv nullptr, // is_var_split_kv
}; };
// TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute // TODO(kaixih@nvidia): When split_kv=-1 and is_var_split_kv=false, we compute
......
...@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int, ...@@ -76,7 +76,9 @@ def test_cutlass_mla_decode(dtype: torch.dtype, mean_seq_len: int, bs: int,
pack_factor = 128 // block_size pack_factor = 128 // block_size
block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor block_num = ((block_num + pack_factor - 1) // pack_factor) * pack_factor
q = torch.randn(bs, h_q, d) # Amplify input values to ensure test coverage of edge cases where CUTLASS
# kernel errors occur with split_k settings.
q = torch.randn(bs, h_q, d) * 100
block_table = torch.randint(0, block_table = torch.randint(0,
bs * block_num, (bs, block_num), bs * block_num, (bs, block_num),
dtype=torch.int32) dtype=torch.int32)
......
...@@ -1395,6 +1395,7 @@ class EngineArgs: ...@@ -1395,6 +1395,7 @@ class EngineArgs:
"PALLAS_VLLM_V1", "PALLAS_VLLM_V1",
"TRITON_ATTN_VLLM_V1", "TRITON_ATTN_VLLM_V1",
"TRITON_MLA", "TRITON_MLA",
"CUTLASS_MLA_VLLM_V1",
"FLASHMLA", "FLASHMLA",
"FLASHINFER", "FLASHINFER",
"FLASHINFER_VLLM_V1", "FLASHINFER_VLLM_V1",
......
...@@ -183,6 +183,14 @@ class CudaPlatformBase(Platform): ...@@ -183,6 +183,14 @@ class CudaPlatformBase(Platform):
if use_mla: if use_mla:
# TODO(lucas): refactor to be more concise # TODO(lucas): refactor to be more concise
# we should probably consider factoring out V1 here # we should probably consider factoring out V1 here
if selected_backend == _Backend.CUTLASS_MLA_VLLM_V1:
if use_v1:
logger.info_once("Using Cutlass MLA backend on V1 engine.")
return ("vllm.v1.attention.backends.mla."
"cutlass_mla.CutlassMLABackend")
else:
logger.warning(
"Cutlass MLA backend is only supported on V1 engine")
if selected_backend == _Backend.TRITON_MLA or block_size != 64: if selected_backend == _Backend.TRITON_MLA or block_size != 64:
if use_v1: if use_v1:
logger.info_once("Using Triton MLA backend on V1 engine.") logger.info_once("Using Triton MLA backend on V1 engine.")
......
...@@ -51,6 +51,7 @@ class _Backend(enum.Enum): ...@@ -51,6 +51,7 @@ class _Backend(enum.Enum):
TRITON_MLA_VLLM_V1 = enum.auto() TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA_VLLM_V1 = enum.auto() FLASHMLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() # Supported by V1 FLASHMLA = enum.auto() # Supported by V1
CUTLASS_MLA_VLLM_V1 = enum.auto()
HPU_ATTN = enum.auto() HPU_ATTN = enum.auto()
PALLAS = enum.auto() PALLAS = enum.auto()
PALLAS_VLLM_V1 = enum.auto() PALLAS_VLLM_V1 = enum.auto()
......
...@@ -350,7 +350,7 @@ class MLACommonMetadataBuilder(Generic[M]): ...@@ -350,7 +350,7 @@ class MLACommonMetadataBuilder(Generic[M]):
self.num_heads = model_config.get_num_attention_heads( self.num_heads = model_config.get_num_attention_heads(
runner.parallel_config) runner.parallel_config)
self.mla_dims = get_mla_dims(model_config) self.mla_dims = get_mla_dims(model_config)
self.aot_schedule = is_vllm_fa and (get_flash_attn_version() == 3) self.aot_schedule = current_platform.is_cuda()
self.kv_cache_spec = kv_cache_spec self.kv_cache_spec = kv_cache_spec
# Dont try to access the runner on AMD # Dont try to access the runner on AMD
......
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Optional
import torch
import vllm._custom_ops as ops
from vllm.attention.backends.abstract import (AttentionType,
is_quantized_kv_cache)
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class CutlassMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "CUTLASS_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> type["CutlassMLAImpl"]:
return CutlassMLAImpl
class CutlassMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[list[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"CutlassMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"CutlassMLAImpl")
if is_quantized_kv_cache(self.kv_cache_dtype):
raise NotImplementedError(
"CutlassMLA V1 with FP8 KV cache not yet supported")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
assert attn_metadata.decode is not None
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Cutlass MLA not yet supported")
B = q_nope.shape[0]
o = torch.empty((B, self.num_heads, self.kv_lora_rank),
dtype=q_nope.dtype,
device=q_nope.device)
# Run MLA
# Clone q_nope and q_pe to make sure strides computation is correct.
q_nope = q_nope.clone()
q_pe = q_pe.clone()
ops.cutlass_mla_decode(o, q_nope, q_pe, kv_c_and_k_pe_cache,
attn_metadata.decode.seq_lens,
attn_metadata.decode.block_table, self.scale)
return self._v_up_proj(o)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment