Unverified Commit cabaf4ef authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] MLA decode optimizations (#12528)


Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Signed-off-by: default avatarsimon-mo <xmo@berkeley.edu>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarsimon-mo <simon.mo@hey.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarAlexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
Co-authored-by: default avatarsimon-mo <xmo@berkeley.edu>
parent a1fc18c0
......@@ -28,7 +28,7 @@ from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.compilation.decorators import support_torch_compile
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -326,12 +326,156 @@ class DeepseekV2Attention(nn.Module):
return output
class DeepseekV2MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
attn_metadata)
class DeepseekV2DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
......@@ -344,7 +488,11 @@ class DeepseekV2DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV2Attention(
if model_config.use_mla:
attn_cls = DeepseekV2MLAAttention
else:
attn_cls = DeepseekV2Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -421,6 +569,7 @@ class DeepseekV2Model(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
......@@ -440,6 +589,7 @@ class DeepseekV2Model(nn.Module):
lambda prefix: DeepseekV2DecoderLayer(
config,
prefix,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
),
......
......@@ -31,7 +31,8 @@ class CpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.TORCH_SDPA:
logger.info("Cannot use %s backend on CPU.", selected_backend)
logger.info("Using Torch SDPA backend.")
......
......@@ -157,10 +157,14 @@ class CudaPlatformBase(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
if use_v1:
logger.info("Using Flash Attention backend on V1 engine.")
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend"
if use_mla:
logger.info("Using Triton MLA backend.")
return "vllm.attention.backends.triton_mla.TritonMLABackend"
if selected_backend == _Backend.FLASHINFER:
logger.info("Using FlashInfer backend.")
return "vllm.attention.backends.flashinfer.FlashInferBackend"
......@@ -171,7 +175,8 @@ class CudaPlatformBase(Platform):
pass
elif selected_backend:
raise ValueError(
f"Invalid attention backend for {cls.device_name}")
f"Invalid attention backend for {cls.device_name}, "
f"with use_v1: {use_v1} use_mla: {use_mla}")
target_backend = _Backend.FLASH_ATTN
if not cls.has_device_capability(80):
......
......@@ -27,7 +27,8 @@ class HpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
logger.info("Using HPUAttention backend.")
return "vllm.attention.backends.hpu_attn.HPUAttentionBackend"
......
......@@ -30,6 +30,7 @@ class _Backend(enum.Enum):
TORCH_SDPA = enum.auto()
OPENVINO = enum.auto()
FLASHINFER = enum.auto()
TRITON_MLA = enum.auto()
HPU_ATTN = enum.auto()
PALLAS = enum.auto()
IPEX = enum.auto()
......@@ -139,7 +140,8 @@ class Platform:
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
"""Get the attention backend class of a device."""
return ""
......
......@@ -30,7 +30,8 @@ class OpenVinoPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.OPENVINO:
logger.info("Cannot use %s backend on OpenVINO.", selected_backend)
logger.info("Using OpenVINO Attention backend.")
......
......@@ -75,7 +75,8 @@ class RocmPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend, head_size, dtype,
kv_cache_dtype, block_size, use_v1) -> str:
kv_cache_dtype, block_size, use_v1,
use_mla) -> str:
selected_backend = (_Backend.ROCM_FLASH if selected_backend
== _Backend.FLASH_ATTN else selected_backend)
if selected_backend == _Backend.ROCM_FLASH:
......
......@@ -29,7 +29,8 @@ class TpuPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.PALLAS:
logger.info("Cannot use %s backend on TPU.", selected_backend)
logger.info("Using Pallas backend.")
......
......@@ -27,7 +27,8 @@ class XPUPlatform(Platform):
@classmethod
def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
dtype: torch.dtype, kv_cache_dtype: Optional[str],
block_size: int, use_v1: bool) -> str:
block_size: int, use_v1: bool,
use_mla: bool) -> str:
if selected_backend != _Backend.IPEX:
logger.info("Cannot use %s backend on XPU.", selected_backend)
logger.info("Using IPEX attention backend.")
......
......@@ -56,7 +56,8 @@ class CacheEngine:
model_config.dtype,
cache_config.cache_dtype,
self.block_size,
model_config.is_attention_free)
model_config.is_attention_free,
use_mla=model_config.use_mla)
# Initialize the cache.
self.gpu_cache = self._allocate_kv_cache(
......
......@@ -1066,6 +1066,7 @@ class GPUModelRunnerBase(ModelRunnerBase[TModelInputForGPU]):
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
) if needs_attn_backend else None
if self.attn_backend:
self.attn_state = self.attn_backend.get_state_cls()(
......@@ -1973,7 +1974,8 @@ class CUDAGraphRunner(nn.Module):
# Copy the input tensors to the input buffers.
self.input_buffers["input_ids"].copy_(input_ids, non_blocking=True)
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if positions is not None:
self.input_buffers["positions"].copy_(positions, non_blocking=True)
if self.backend_name != "NO_ATTENTION":
self.input_buffers["slot_mapping"].copy_(
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment