Unverified Commit 58d1b2aa authored by Yang Chen's avatar Yang Chen Committed by GitHub
Browse files

[Attention] MLA support for V1 (#13789)


Signed-off-by: default avatarYang Chen <yangche@fb.com>
parent f1579b22
...@@ -89,6 +89,7 @@ class Attention(nn.Module): ...@@ -89,6 +89,7 @@ class Attention(nn.Module):
self._k_scale_float = 1.0 self._k_scale_float = 1.0
self._v_scale_float = 1.0 self._v_scale_float = 1.0
self.use_mla = use_mla
self.num_heads = num_heads self.num_heads = num_heads
self.head_size = head_size self.head_size = head_size
self.num_kv_heads = num_kv_heads self.num_kv_heads = num_kv_heads
...@@ -158,6 +159,10 @@ class Attention(nn.Module): ...@@ -158,6 +159,10 @@ class Attention(nn.Module):
query: torch.Tensor, query: torch.Tensor,
key: torch.Tensor, key: torch.Tensor,
value: torch.Tensor, value: torch.Tensor,
# For some alternate attention backends like MLA the attention output
# shape does not match the query shape, so we optionally let the model
# definition specify the output tensor shape.
output_shape: Optional[torch.Size] = None,
) -> torch.Tensor: ) -> torch.Tensor:
""" """
The KV cache is stored inside this class and is accessed via The KV cache is stored inside this class and is accessed via
...@@ -173,17 +178,25 @@ class Attention(nn.Module): ...@@ -173,17 +178,25 @@ class Attention(nn.Module):
if attn_metadata.enable_kv_scales_calculation: if attn_metadata.enable_kv_scales_calculation:
self.calc_kv_scales(key, value) self.calc_kv_scales(key, value)
if self.use_output: if self.use_output:
output = torch.empty_like(query) output_shape = (output_shape
hidden_size = query.size(-1) if output_shape is not None else query.shape)
# Reshape the query, key, and value tensors. output = torch.empty(output_shape,
# NOTE(woosuk): We do this outside the custom op to minimize the dtype=query.dtype,
# CPU overheads from the non-CUDA-graph regions. device=query.device)
query = query.view(-1, self.num_heads, self.head_size) hidden_size = output_shape[-1]
output = output.view(-1, self.num_heads, self.head_size) # We skip reshaping query, key and value tensors for the MLA
if key is not None: # backend since these tensors have different semantics and are
key = key.view(-1, self.num_kv_heads, self.head_size) # processed differently.
if value is not None: if not self.use_mla:
value = value.view(-1, self.num_kv_heads, self.head_size) # Reshape the query, key, and value tensors.
# NOTE(woosuk): We do this outside the custom op to minimize the
# CPU overheads from the non-CUDA-graph regions.
query = query.view(-1, self.num_heads, self.head_size)
output = output.view(-1, self.num_heads, self.head_size)
if key is not None:
key = key.view(-1, self.num_kv_heads, self.head_size)
if value is not None:
value = value.view(-1, self.num_kv_heads, self.head_size)
if self.use_direct_call: if self.use_direct_call:
forward_context: ForwardContext = get_forward_context() forward_context: ForwardContext = get_forward_context()
attn_metadata = forward_context.attn_metadata attn_metadata = forward_context.attn_metadata
......
...@@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -420,9 +420,15 @@ class DeepseekV2MLAAttention(nn.Module):
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim)) mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale self.scaling = self.scaling * mscale * mscale
# In the MLA backend, kv_cache includes both k_c and
# pe (i.e. decoupled position embeddings). In particular,
# the concat_and_cache_mla op requires
# k_c.size(1) + k_pe.size(1) == kv_cache.size(2)
# i.e.
# kv_lora_rank + qk_rope_head_dim == head_size
self.mla_attn = Attention( self.mla_attn = Attention(
num_heads=self.num_local_heads, num_heads=self.num_local_heads,
head_size=self.kv_lora_rank, head_size=self.kv_lora_rank + self.qk_rope_head_dim,
scale=self.scaling, scale=self.scaling,
num_kv_heads=1, num_kv_heads=1,
cache_config=cache_config, cache_config=cache_config,
...@@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module): ...@@ -458,7 +464,10 @@ class DeepseekV2MLAAttention(nn.Module):
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split( kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1) [self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous()) kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe) return self.mla_attn(hidden_states_or_q_c,
kv_c_normed,
k_pe,
output_shape=hidden_states.shape)
class DeepseekV2DecoderLayer(nn.Module): class DeepseekV2DecoderLayer(nn.Module):
......
...@@ -162,8 +162,13 @@ class CudaPlatformBase(Platform): ...@@ -162,8 +162,13 @@ class CudaPlatformBase(Platform):
kv_cache_dtype, block_size, use_v1, kv_cache_dtype, block_size, use_v1,
use_mla) -> str: use_mla) -> str:
if use_v1: if use_v1:
logger.info("Using Flash Attention backend on V1 engine.") if use_mla:
return "vllm.v1.attention.backends.flash_attn.FlashAttentionBackend" logger.info("Using Triton MLA backend on V1 engine.")
return "vllm.v1.attention.backends.triton_mla.TritonMLABackend"
else:
logger.info("Using Flash Attention backend on V1 engine.")
return ("vllm.v1.attention.backends.flash_attn."
"FlashAttentionBackend")
if use_mla: if use_mla:
if selected_backend == _Backend.FLASHMLA: if selected_backend == _Backend.FLASHMLA:
from vllm.attention.backends.flashmla import ( from vllm.attention.backends.flashmla import (
......
...@@ -35,6 +35,7 @@ class _Backend(enum.Enum): ...@@ -35,6 +35,7 @@ class _Backend(enum.Enum):
OPENVINO = enum.auto() OPENVINO = enum.auto()
FLASHINFER = enum.auto() FLASHINFER = enum.auto()
TRITON_MLA = enum.auto() TRITON_MLA = enum.auto()
TRITON_MLA_VLLM_V1 = enum.auto()
FLASHMLA = enum.auto() FLASHMLA = enum.auto()
HPU_ATTN = enum.auto() HPU_ATTN = enum.auto()
PALLAS = enum.auto() PALLAS = enum.auto()
......
# SPDX-License-Identifier: Apache-2.0 # SPDX-License-Identifier: Apache-2.0
"""Attention layer with FlashAttention.""" """Attention layer with FlashAttention."""
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Optional, Tuple, Type from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Type
import numpy as np import numpy as np
import torch import torch
...@@ -14,6 +14,11 @@ from vllm.logger import init_logger ...@@ -14,6 +14,11 @@ from vllm.logger import init_logger
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv from vllm.utils import cdiv
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
if current_platform.is_cuda(): if current_platform.is_cuda():
from vllm.vllm_flash_attn import flash_attn_varlen_func from vllm.vllm_flash_attn import flash_attn_varlen_func
...@@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend): ...@@ -40,6 +45,10 @@ class FlashAttentionBackend(AttentionBackend):
def get_metadata_cls() -> Type["AttentionMetadata"]: def get_metadata_cls() -> Type["AttentionMetadata"]:
return FlashAttentionMetadata return FlashAttentionMetadata
@staticmethod
def get_builder_cls() -> Type["FlashAttentionMetadataBuilder"]:
return FlashAttentionMetadataBuilder
@staticmethod @staticmethod
def get_kv_cache_shape( def get_kv_cache_shape(
num_blocks: int, num_blocks: int,
...@@ -85,6 +94,62 @@ class FlashAttentionMetadata: ...@@ -85,6 +94,62 @@ class FlashAttentionMetadata:
num_input_tokens: int = 0 # Number of tokens including padding. num_input_tokens: int = 0 # Number of tokens including padding.
class FlashAttentionMetadataBuilder:
def __init__(self, runner: "GPUModelRunner"):
self.runner = runner
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput"):
pass
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
self.runner.device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(self.runner.device,
non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
self.runner.device, non_blocking=True).long()
use_cascade = common_prefix_len > 0
if use_cascade:
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor([0, num_actual_tokens],
dtype=torch.int32,
device=self.runner.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.runner.device)
suffix_kv_lens = (self.runner.seq_lens_np[:num_reqs] -
common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(
self.runner.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
)
return attn_metadata
class FlashAttentionImpl(AttentionImpl): class FlashAttentionImpl(AttentionImpl):
def __init__( def __init__(
...@@ -371,4 +436,4 @@ def cascade_attention( ...@@ -371,4 +436,4 @@ def cascade_attention(
# Merge prefix and suffix outputs, and store the result in output. # Merge prefix and suffix outputs, and store the result in output.
merge_attn_states(output, prefix_output, prefix_lse, suffix_output, merge_attn_states(output, prefix_output, prefix_lse, suffix_output,
suffix_lse) suffix_lse)
\ No newline at end of file
# SPDX-License-Identifier: Apache-2.0
"""
This file implements common components for MLA implementations.
First we define:
Sq as Q sequence length
Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably
tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1
## More Extent Definitions:
C Context length, `Skv - Sq`
H hidden size
N number of attention heads
Lq latent dimension for Q 1536 in DSV3
Lkv latent dimension for K/V 512 in DSV3
P nope dimension, no rope. 128 in DSV3
R rope dimension, goes through rope. 64 in DSV3
V V head dim. 128 in DSV3
## Vector/Matrix Definitions
h_t hidden states (input to attention) shape [Sq, H]
q_c latent/compressed Q shape [Sq, Lq]
q_nope uncompressed Q (no-rope) shape [Sq, N, P]
q_pe uncompressed Q (rope) shape [Sq, N, R]
kv_c latent/compressed KV shape [Skv, Lkv]
k_pe decoupled k position embeddings shape [Skv, R]
new_kv_c new kv_c from current iter shape [Sq, Lkv]
new_k_pe new k_pe from current iter shape [Sq, R]
cache_kv_c cached k_c from previous iters shape [C, Lkv]
cache_k_pe cached k_pe from previous iters shape [C, R]
W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N * P]
W_KR project h_t to k_pe shape [H, N * R]
W_UV project kv_c to v shape [Lkv, N * V]
W_O project v to h_t shape [N * V, H]
## Compute Friendly Approach (i.e. "_forward_prefill"):
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V)
// MHA with QK headdim = P + R
// V headdim = V
// spda_o shape [Sq, N, V]
spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v
)
return spda_o @ W_O
NOTE: in the actual code,
`kv_b_proj` is [W_UK; W_UV] concatnated per head
`q_b_proj` is [W_UQ; W_QR] concatnated per head
`out_proj` is W_O
## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime
q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// MQA with QK headdim = Lkv + R
// V headdim = Lkv
// spda_o shape [Sq, N, Lkv]
// NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention(
torch.cat([q_latent, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1),
kv_c
)
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size.
The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage
q_c = h_t @ W_DQ
q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V)
// MHA between queries and new KV
// with QK headdim = P + R
// V headdim = V
// curr_o shape [Sq, N, V]
// curr_lse shape [N, Sq], this is just order FA returns
curr_o, curr_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([new_k_nope, new_k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
new_v,
casual=True,
return_softmax_lse=True
)
// Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)):
chunk_start = chunk_idx * MCC
chunk_end = min(chunk_start + MCC, C)
Sc = chunk_end - chunk_start
cache_kv_c_chunk = cache_kv_c[chunk_start:chunk_end]
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1),
cache_v_chunk,
casual=False,
return_softmax_lse=True
)
curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o,
suffix_lse=curr_lse,
prefix_output=chunk_o,
prefix_lse=chunk_lse,
)
return curr_o @ W_O
"""
import functools
from abc import abstractmethod
from dataclasses import dataclass
from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar)
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata,
MLAAttentionImpl)
from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.utils import cdiv, round_down
try:
from vllm.vllm_flash_attn import flash_attn_varlen_func
except ImportError:
# For rocm use upstream flash attention
from flash_attn import flash_attn_varlen_func
if TYPE_CHECKING:
from vllm.v1.core.scheduler_output import SchedulerOutput
from vllm.v1.worker.gpu_input_batch import InputBatch
from vllm.v1.worker.gpu_model_runner import GPUModelRunner
logger = init_logger(__name__)
class MLACommonBackend(AttentionBackend):
accept_output_buffer: bool = True
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_metadata_cls() -> Type["AttentionMetadata"]:
return MLACommonMetadata
@staticmethod
def get_builder_cls() -> Type["MLACommonMetadataBuilder"]:
return MLACommonMetadataBuilder
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
return (num_blocks, block_size, head_size)
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [576]
@staticmethod
def use_cascade_attention(*args, **kwargs) -> bool:
return False
@dataclass
class MLACommonMetadata:
"""Metadata for MLACommon.
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
# New for MLA (compared to FlashAttention)
# Input positions for rotrary embeddings since for MLA the rotary
# position embeddings are applied inside the attention backend
input_positions: torch.Tensor
# NOTE(sang): Definition of context_len, query_len, and seq_len.
# |---------- N-1 iteration --------|
# |---------------- N iteration ---------------------|
# |- tokenA -|......................|-- newTokens ---|
# |---------- context_len ----------|
# |-------------------- seq_len ---------------------|
# |-- query_len ---|
num_actual_tokens: int # Number of tokens excluding padding.
max_query_len: int
query_start_loc: torch.Tensor
max_seq_len: int
seq_lens: torch.Tensor
block_table: torch.Tensor
slot_mapping: torch.Tensor
# For logging.
num_input_tokens: int = 0 # Number of tokens including padding.
# The dimension of the attention heads
head_dim: Optional[int] = None
# New for MLA (compared to FlashAttention)
# For chunked prefill
num_decodes: Optional[int] = None
num_decode_tokens: Optional[int] = None
num_prefills: Optional[int] = None
has_context: bool = False
context_chunk_cu_seq_lens: Optional[torch.Tensor] = None
context_chunk_starts: Optional[torch.Tensor] = None
context_chunk_seq_tot: Optional[List[int]] = None
context_chunk_max_seq_lens: Optional[List[int]] = None
chunked_prefill_workspace: Optional[torch.Tensor] = None
def __post_init__(self):
supported_head_sizes = MLACommonBackend.get_supported_head_sizes()
if self.head_dim is not None and self.head_dim \
not in supported_head_sizes:
raise ValueError(
f"Only {supported_head_sizes} are supported for head_dim,",
f"received {self.head_dim}.")
T = TypeVar("T", bound=MLACommonMetadata)
class MLACommonMetadataBuilder:
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(self, runner: "GPUModelRunner"):
self.runner = runner
scheduler_config = runner.scheduler_config
model_config = runner.model_config
cache_config = runner.cache_config
self.chunked_prefill_enabled = scheduler_config.chunked_prefill_enabled
if self.chunked_prefill_enabled:
self.chunked_prefill_workspace_size = min(
# Max sure there is enough for 8 full length request or at least
# 4 pages of cache per request
max(
8 * model_config.max_model_len, 4 *
scheduler_config.max_num_seqs * cache_config.block_size),
# For long-context models try not to over-allocate limiting
# kv-cache space, limiting it to 64k tokens,
# which would result in the workspace being:
# 2*(576)*(64*1024) = 144mb
# (assuming 576 MLA head dim, and fp16)
# which would result in up-projected context being
# 2*(192*128)*(64*1024) = 3gb
# (assuming 192 QK head dim, 128 heads, and fp16)
128 * 1024)
assert self.chunked_prefill_workspace_size >= \
scheduler_config.max_num_seqs * cache_config.block_size
self.chunked_prefill_workspace = torch.empty(
(self.chunked_prefill_workspace_size,
model_config.get_head_size()),
dtype=model_config.dtype,
device=runner.device,
)
self.page_size = self.runner.block_size
def reorder_batch(self, input_batch: "InputBatch",
scheduler_output: "SchedulerOutput"):
# We now want to reorder the batch so that the "decode" requests are and
# the front and the "prefill" requests are at the using the least amount
# swaps possible. (NOTE for now we loosely use "decode" to mean requests
# where attention is likely memory-bound and "prefill" to mean requests
# where attention is likely compute-bound, TODO(lucas): figure out a
# better naming here)
decodes = []
prefills = []
num_decode_tokens = 0
num_prefill_tokens = 0
for i, req_id in enumerate(input_batch.req_ids):
num_tokens = scheduler_output.num_scheduled_tokens[req_id]
# for now treat 1 scheduled token as "decode" even if its not,
# we should update this to something like < 8 in the future but
# currently the TritonMLA._forward_decode only supports
# num_tokens = 1
if num_tokens == 1:
decodes.append(i)
num_decode_tokens += num_tokens
else:
prefills.append(i)
num_prefill_tokens += num_tokens
# We hope that this is fairly minimal since decodes
# should be around for a number of iterations so hopefully they are
# relatively stationary (and new request are generally appended to the
# persistent batch so already should be at the back)
# To achieve this we loop over the decodes in descending order and
# the prefills in ascending order. We swap decodes from the "back"
# i.e. past where the last decode should be in the reodorered with
# prefills from the front of the batch.
# `decodes` and `prefills` are already in ascending order just based on
# the above loop
num_decodes = len(decodes)
num_prefills = len(prefills)
first_prefill = 0
for i in range(1, min(num_decodes, num_prefills) + 1):
# If the decode is at the "back" of the batch, i, we can swap it
# with the prefill closest to the front of the batch
if decodes[num_decodes - i] >= num_decodes:
input_batch.swap_states(prefills[first_prefill],
decodes[num_decodes - i])
first_prefill += 1
else:
break
# Save for next `build` call
# TODO(lucas): this is a bit of a hack, we should probably have a
# better way of doing this
self._num_decodes = num_decodes
self._num_prefills = num_prefills
self._num_decode_tokens = num_decode_tokens
self._num_prefill_tokens = num_prefill_tokens
def build(self, num_reqs: int, num_actual_tokens: int, max_query_len: int,
common_prefix_len: int):
device = self.runner.device
max_seq_len = self.runner.seq_lens_np[:num_reqs].max()
query_start_loc = self.runner.query_start_loc_cpu[:num_reqs + 1].to(
device, non_blocking=True)
seq_lens = self.runner.seq_lens_cpu[:num_reqs].to(device,
non_blocking=True)
block_table = (
self.runner.input_batch.block_table.get_device_tensor()[:num_reqs])
slot_mapping = self.runner.slot_mapping_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
input_positions = self.runner.positions_cpu[:num_actual_tokens].to(
device, non_blocking=True).long()
context_chunk_cu_seq_lens = None
context_chunk_starts = None
context_chunk_seq_tot = None
context_chunk_max_seq_lens = None
num_computed_tokens_cpu_tensor = \
self.runner.input_batch.num_computed_tokens_cpu_tensor[:num_reqs]
context_lens_tensor = \
num_computed_tokens_cpu_tensor.to(device, non_blocking=True)
if self.chunked_prefill_enabled and self._num_prefills > 0 \
and context_lens_tensor[self._num_decodes:].max() > 0:
# NOTE: it is recommend you read the `Chunked Prefill` section in
# the comment at the top of the file before trying to understand
# the following code
self.has_context = True
num_prefills_with_context = \
(context_lens_tensor[self._num_decodes:] > 0).sum().item()
# currently we allocate an equal amount of workspace for each
# prefill in the batch, we could probably use a more advanced
# algorithm here and allocate more workspace to prefills with
# longer context lengths
max_context_chunk = \
self.chunked_prefill_workspace_size // num_prefills_with_context
# align max_context_chunk to page_size by rounding down,
# currently the `gather_cache` kernel cannot handle
# `context_chunk_starts` that are not aligned to page_size
max_context_chunk = round_down(max_context_chunk, self.page_size)
assert max_context_chunk > 0
num_chunks = cdiv(context_lens_tensor.max(), max_context_chunk)
# if `max_context_chunk = 256`, `num_chunks = 3`, and
# `num_prefills_with_context = 4`, create a tensor that looks like
# [[0, 0, 0, 0], [256, 256, 256, 256], [512, 512, 512, 512]]
context_chunk_starts = \
torch.arange(num_chunks, device=device, dtype=torch.int32) \
.unsqueeze(1).expand(-1, self._num_prefills) \
* max_context_chunk
chunk_ends = torch.min(context_lens_tensor[self._num_decodes:] \
.unsqueeze(0), context_chunk_starts + max_context_chunk)
chunk_seq_lens = (chunk_ends - context_chunk_starts).clamp(min=0)
_context_chunk_cu_seq_lens = chunk_seq_lens.cumsum(dim=1).to(
torch.int32)
zero = torch.zeros(num_chunks, dtype=torch.int32, device=device) \
.unsqueeze(-1)
context_chunk_cu_seq_lens = \
torch.cat([zero, _context_chunk_cu_seq_lens], dim=1)
context_chunk_max_seq_lens = \
chunk_seq_lens.max(dim=1).values.tolist()
context_chunk_seq_tot = chunk_seq_lens.sum(dim=1).tolist()
assert max(context_chunk_seq_tot) <= \
self.chunked_prefill_workspace_size
return MLACommonMetadata(
input_positions=input_positions,
num_actual_tokens=num_actual_tokens,
max_query_len=max_query_len,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=block_table,
slot_mapping=slot_mapping,
head_dim=self.runner.model_config.get_head_size(),
# MLACommonMetadata Chunk prefill specific
num_decodes=self._num_decodes,
num_decode_tokens=self._num_decode_tokens,
num_prefills=self._num_prefills,
context_chunk_cu_seq_lens=context_chunk_cu_seq_lens,
context_chunk_starts=context_chunk_starts,
context_chunk_seq_tot=context_chunk_seq_tot,
context_chunk_max_seq_lens=context_chunk_max_seq_lens,
)
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
NOTE: Please read the comment at the top of the file before trying to
understand this class
"""
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
q_lora_rank: Optional[int],
kv_lora_rank: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
qk_head_dim: int,
v_head_dim: int,
rotary_emb: RotaryEmbedding,
# q_proj should be q_b_proj if q_lora_rank is not None, but from an
# attention backend perspective we rely on the layer to pass in the
# correct matrix
q_proj: ColumnParallelLinear,
kv_b_proj: ColumnParallelLinear,
o_proj: RowParallelLinear,
) -> None:
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_kv_heads
self.kv_cache_dtype = kv_cache_dtype
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_head_dim
self.v_head_dim = v_head_dim
self.rotary_emb = rotary_emb
self.use_yarn_rope = isinstance(rotary_emb,
DeepseekScalingRotaryEmbedding)
self.q_proj = q_proj
self.kv_b_proj = kv_b_proj
self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version()
# Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the
# latter has an additional parameter to control FA2 vs FA3
self.flash_attn_varlen_func = flash_attn_varlen_func
if self.vllm_flash_attn_version is not None:
self.flash_attn_varlen_func = \
functools.partial(flash_attn_varlen_func,
fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_UV_O):
output_parallel = apply_fp8_linear_generic(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype):
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_layer_weight(layer):
if hasattr(layer, "weight"):
return layer.weight
elif hasattr(layer, "qweight"):
return layer.qweight
else:
raise AttributeError(
f"Layer '{layer}' has neither weight nor qweight")
def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod):
# NOTE: This should only be used offline, since it's O(N^3)
eye = torch.eye(layer.input_size_per_partition,
dtype=act_dtype,
device=get_layer_weight(layer).device)
dequant_weights = layer.quant_method.apply(layer,
eye,
bias=None)
del eye
# standardize to (output, input)
return dequant_weights.T
return layer.weight
weight_dtype = get_layer_weight(self.kv_b_proj).dtype
assert get_layer_weight(self.o_proj).dtype == weight_dtype
assert get_layer_weight(self.q_proj).dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
f"{kv_b_proj_weight.shape=}, "
f"{self.kv_lora_rank=}, "
f"{self.num_heads=}, "
f"{self.qk_nope_head_dim=}, "
f"{self.v_head_dim=}")
kv_b_proj_weight = kv_b_proj_weight.view(
self.kv_lora_rank,
self.num_heads,
self.qk_nope_head_dim + self.v_head_dim,
)
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
def _compute_prefill_context(
self,
q: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
):
assert attn_metadata.num_prefills is not None
assert attn_metadata.context_chunk_seq_tot is not None
assert attn_metadata.context_chunk_cu_seq_lens is not None
assert attn_metadata.context_chunk_starts is not None
assert attn_metadata.context_chunk_max_seq_lens is not None
output = None
iters = len(attn_metadata.context_chunk_seq_tot)
assert attn_metadata.chunked_prefill_workspace is not None
workspace = attn_metadata.chunked_prefill_workspace
for i in range(iters):
toks = attn_metadata.context_chunk_seq_tot[i]
ops.gather_cache(
src_cache=kv_c_and_k_pe_cache,
dst=workspace,
block_table=attn_metadata.block_table,
cu_seq_lens=attn_metadata.context_chunk_cu_seq_lens[i],
batch_size=attn_metadata.num_prefills,
seq_starts=attn_metadata.context_chunk_starts[i],
)
kv_c_normed = workspace[:toks]\
[..., :self.kv_lora_rank].unsqueeze(1)
k_pe = workspace[:toks]\
[..., self.kv_lora_rank:].unsqueeze(1)
kv_nope = self.kv_b_proj(kv_c_normed)[0].view( \
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))),
dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad
# out v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v,
[0, q.shape[-1] - v.shape[-1]],
value=0)
attn_output, attn_softmax_lse = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.context_chunk_cu_seq_lens[i],
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.context_chunk_max_seq_lens[i],
softmax_scale=self.scale,
causal=False, # Context is unmasked
return_softmax_lse=True,
)
if output is None:
output = attn_output
output_lse = attn_softmax_lse
else:
output_tmp = torch.empty_like(output)
output_lse_tmp = torch.empty_like(output_lse)
merge_attn_states(
output=output_tmp,
output_lse=output_lse_tmp,
prefix_output=output,
prefix_lse=output_lse,
suffix_output=attn_output,
suffix_lse=attn_softmax_lse,
)
output = output_tmp
output_lse = output_lse_tmp
return output, output_lse
def _forward_prefill(
self,
q: torch.Tensor,
kv_c_normed: torch.Tensor,
k_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
has_context = attn_metadata.has_context
kv_nope = self.kv_b_proj(kv_c_normed)[0].view(\
-1, self.num_heads, self.qk_nope_head_dim + self.v_head_dim)
k_nope, v = kv_nope\
.split([self.qk_nope_head_dim, self.v_head_dim], dim=-1)
k = torch.cat((k_nope, k_pe.expand((*k_nope.shape[:-1], -1))), dim=-1)
# For MLA the v head dim is smaller than qk head dim so we pad out
# v with 0s to match the qk head dim
v_padded = torch.nn.functional.pad(v, [0, q.shape[-1] - v.shape[-1]],
value=0)
output = self.flash_attn_varlen_func(
q=q,
k=k,
v=v_padded,
cu_seqlens_q=attn_metadata.query_start_loc,
cu_seqlens_k=attn_metadata.query_start_loc,
max_seqlen_q=attn_metadata.max_query_len,
max_seqlen_k=attn_metadata.max_seq_len,
softmax_scale=self.scale,
causal=True,
return_softmax_lse=has_context,
)
if has_context:
suffix_output, suffix_lse = output
context_output, context_lse = self._compute_prefill_context( \
q, kv_c_and_k_pe_cache, attn_metadata)
output = torch.empty_like(suffix_output)
merge_attn_states(
output=output,
prefix_output=context_output,
prefix_lse=context_lse,
suffix_output=suffix_output,
suffix_lse=suffix_lse,
)
# slice by `:v.shape[-1]` in order to remove v headdim padding
output = output\
.view(-1, self.num_heads, q.shape[-1])[..., :v.shape[-1]]\
.reshape(-1, self.num_heads * v.shape[-1])
return self.o_proj(output)[0]
@abstractmethod
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T,
) -> torch.Tensor:
raise NotImplementedError
def forward(
self,
layer: AttentionLayer,
hidden_states_or_q_c: torch.Tensor, # query in unified attn
k_c_normed: torch.Tensor, # key in unified attn
k_pe: torch.Tensor, # value in unified attn
kv_cache: torch.Tensor,
attn_metadata: T,
output: Optional[torch.Tensor] = None,
) -> torch.Tensor:
assert output is not None, "Output tensor must be provided."
if attn_metadata is None:
# Profiling run.
return output
num_actual_toks = attn_metadata.num_actual_tokens
# Inputs and outputs may be padded for CUDA graphs
output_padded = output
output = output[:num_actual_toks, ...]
hidden_states_or_q_c = hidden_states_or_q_c[:num_actual_toks, ...]
k_c_normed = k_c_normed[:num_actual_toks, ...]
k_pe = k_pe[:num_actual_toks, ...]
# Restore head dim (for rotary embedding)
k_pe = k_pe.unsqueeze(1)
assert hasattr(attn_metadata, "input_positions")
assert attn_metadata.num_decodes is not None and \
attn_metadata.num_prefills is not None and \
attn_metadata.num_decode_tokens is not None
has_decode = attn_metadata.num_decodes > 0
has_prefill = attn_metadata.num_prefills > 0
num_decode_tokens = attn_metadata.num_decode_tokens
decode_hs_or_q_c = hidden_states_or_q_c[:num_decode_tokens]
decode_k_pe = k_pe[:num_decode_tokens]
decode_input_positions = \
attn_metadata.input_positions[:num_decode_tokens]
prefill_hs_or_q_c = hidden_states_or_q_c[num_decode_tokens:]
prefill_k_pe = k_pe[num_decode_tokens:]
prefill_input_positions = \
attn_metadata.input_positions[num_decode_tokens:]
prefill_k_c_normed = k_c_normed[num_decode_tokens:]
if has_decode:
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c)
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\
.view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe)
if has_prefill:
prefill_q = self.q_proj(prefill_hs_or_q_c)[0]\
.view(-1, self.num_heads, self.qk_head_dim)
prefill_q_pe = prefill_q[..., self.qk_nope_head_dim:]
prefill_q_pe[...], prefill_k_pe[...] = self.rotary_emb(
prefill_input_positions, prefill_q_pe, prefill_k_pe)
# write the latent and rope to kv cache
if kv_cache.numel() > 0:
ops.concat_and_cache_mla(
k_c_normed,
k_pe.squeeze(1),
kv_cache,
attn_metadata.slot_mapping.flatten(),
kv_cache_dtype=self.kv_cache_dtype,
scale=layer._k_scale,
)
if has_prefill:
output[num_decode_tokens:] = self._forward_prefill(
prefill_q, prefill_k_c_normed, prefill_k_pe, kv_cache,
attn_metadata)
if has_decode:
output[:num_decode_tokens] = self._forward_decode(
decode_q_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded
# SPDX-License-Identifier: Apache-2.0
from typing import Any, Dict, List, Optional, Type
import torch
from vllm.attention.backends.abstract import AttentionType
from vllm.attention.ops.triton_decode_attention import decode_attention_fwd
from vllm.logger import init_logger
from vllm.v1.attention.backends.mla.common import (MLACommonBackend,
MLACommonImpl,
MLACommonMetadata)
logger = init_logger(__name__)
class TritonMLABackend(MLACommonBackend):
@staticmethod
def get_name() -> str:
return "TRITON_MLA_VLLM_V1"
@staticmethod
def get_impl_cls() -> Type["TritonMLAImpl"]:
return TritonMLAImpl
class TritonMLAImpl(MLACommonImpl[MLACommonMetadata]):
def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: int,
alibi_slopes: Optional[List[float]],
sliding_window: Optional[int],
kv_cache_dtype: str,
blocksparse_params: Optional[Dict[str, Any]],
logits_soft_cap: Optional[float],
attn_type: str,
# MLA Specific Arguments
**mla_args) -> None:
super().__init__(num_heads, head_size, scale, num_kv_heads,
alibi_slopes, sliding_window, kv_cache_dtype,
blocksparse_params, logits_soft_cap, attn_type,
**mla_args)
unsupported_features = [
alibi_slopes, sliding_window, blocksparse_params, logits_soft_cap
]
if any(unsupported_features):
raise NotImplementedError(
"TritonMLAImpl does not support one of the following: "
"alibi_slopes, sliding_window, blocksparse_params, "
"logits_soft_cap")
if attn_type != AttentionType.DECODER:
raise NotImplementedError("Encoder self-attention and "
"encoder/decoder cross-attention "
"are not implemented for "
"TritonMLAImpl")
def _forward_decode(
self,
q_nope: torch.Tensor,
q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: MLACommonMetadata,
) -> torch.Tensor:
assert kv_c_and_k_pe_cache.numel() > 0
if self.kv_cache_dtype.startswith("fp8"):
raise NotImplementedError("FP8 Triton MLA not yet supported")
B = q_nope.shape[0]
q = torch.cat([q_nope, q_pe], dim=-1)
o = torch.zeros(B,
self.num_heads,
self.kv_lora_rank,
dtype=q.dtype,
device=q.device)
num_kv_splits = 4 # TODO: heuristic
# TODO(lucas) Allocate ahead of time
attn_logits = torch.empty(
(
B,
self.num_heads,
num_kv_splits,
# NOTE(lucas) idk why the +1 is here but sglang has it so we
# just mirror that
self.kv_lora_rank + 1,
),
dtype=torch.float32,
device=q.device,
)
# Add a head dim of 1
kv_c_and_k_pe_cache = kv_c_and_k_pe_cache.unsqueeze(2)
kv_c_cache = kv_c_and_k_pe_cache[..., :self.kv_lora_rank]
PAGE_SIZE = kv_c_and_k_pe_cache.size(1)
# Run MQA
decode_attention_fwd(q, kv_c_and_k_pe_cache, kv_c_cache, o,
attn_metadata.block_table, attn_metadata.seq_lens,
attn_logits, num_kv_splits, self.scale, PAGE_SIZE)
return self._v_up_proj_and_o_proj(o)
...@@ -80,7 +80,14 @@ class InputBatch: ...@@ -80,7 +80,14 @@ class InputBatch:
self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32) self.num_tokens_no_spec = np.zeros(max_num_reqs, dtype=np.int32)
self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32) self.num_prompt_tokens = np.zeros(max_num_reqs, dtype=np.int32)
self.num_computed_tokens_cpu = np.empty(max_num_reqs, dtype=np.int32) self.num_computed_tokens_cpu_tensor = torch.zeros(
(max_num_reqs, ),
device="cpu",
dtype=torch.int32,
pin_memory=pin_memory,
)
self.num_computed_tokens_cpu = \
self.num_computed_tokens_cpu_tensor.numpy()
# Block table. # Block table.
self.block_table = BlockTable( self.block_table = BlockTable(
...@@ -356,6 +363,61 @@ class InputBatch: ...@@ -356,6 +363,61 @@ class InputBatch:
self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False) self.allowed_token_ids_mask_cpu_tensor[req_index].fill_(False)
return req_index return req_index
def swap_states(self, i1: int, i2: int) -> None:
old_id_i1 = self._req_ids[i1]
old_id_i2 = self._req_ids[i2]
self._req_ids[i1], self._req_ids[i2] =\
self._req_ids[i2], self._req_ids[i1] # noqa
self.req_output_token_ids[i1], self.req_output_token_ids[i2] =\
self.req_output_token_ids[i2], self.req_output_token_ids[i1]
assert old_id_i1 is not None and old_id_i2 is not None
self.req_id_to_index[old_id_i1], self.req_id_to_index[old_id_i2] =\
self.req_id_to_index[old_id_i2], self.req_id_to_index[old_id_i1]
self.num_tokens[i1], self.num_tokens[i2] =\
self.num_tokens[i2], self.num_tokens[i1]
self.token_ids_cpu[i1, ...], self.token_ids_cpu[i2, ...], =\
self.token_ids_cpu[i2, ...], self.token_ids_cpu[i1, ...]
self.num_tokens_no_spec[i1], self.num_tokens_no_spec[i2] =\
self.num_tokens_no_spec[i2], self.num_tokens_no_spec[i1]
self.num_prompt_tokens[i1], self.num_prompt_tokens[i2] =\
self.num_prompt_tokens[i2], self.num_prompt_tokens[i1]
self.num_computed_tokens_cpu[i1], self.num_computed_tokens_cpu[i2] =\
self.num_computed_tokens_cpu[i2], self.num_computed_tokens_cpu[i1]
self.temperature_cpu[i1], self.temperature_cpu[i2] =\
self.temperature_cpu[i2], self.temperature_cpu[i1]
self.top_p_cpu[i1], self.top_p_cpu[i2] =\
self.top_p_cpu[i2], self.top_p_cpu[i1]
self.top_k_cpu[i1], self.top_k_cpu[i2] =\
self.top_k_cpu[i2], self.top_k_cpu[i1]
self.frequency_penalties_cpu[i1], self.frequency_penalties_cpu[i2] =\
self.frequency_penalties_cpu[i2], self.frequency_penalties_cpu[i1]
self.presence_penalties_cpu[i1], self.presence_penalties_cpu[i2] =\
self.presence_penalties_cpu[i2], self.presence_penalties_cpu[i1]
self.repetition_penalties_cpu[i1], self.repetition_penalties_cpu[i2] =\
self.repetition_penalties_cpu[i2], self.repetition_penalties_cpu[i1]
self.min_p_cpu[i1], self.min_p_cpu[i2] =\
self.min_p_cpu[i2], self.min_p_cpu[i1]
g1 = self.generators.get(i1)
g2 = self.generators.get(i2)
if g1 is not None:
self.generators[i2] = g1
if g2 is not None:
self.generators[i1] = g2
t1 = self.min_tokens.get(i1)
t2 = self.min_tokens.get(i2)
if t1 is not None:
self.min_tokens[i2] = t1
if t2 is not None:
self.min_tokens[i1] = t2
self.request_lora_mapping[i1], self.request_lora_mapping[i2] =\
self.request_lora_mapping[i2], self.request_lora_mapping[i1]
self.logit_bias[i1], self.logit_bias[i2] =\
self.logit_bias[i2], self.logit_bias[i1]
self.block_table.swap_row(i1, i2)
def condense(self, empty_req_indices: List[int]) -> None: def condense(self, empty_req_indices: List[int]) -> None:
num_reqs = self.num_reqs num_reqs = self.num_reqs
if num_reqs == 0: if num_reqs == 0:
......
...@@ -2,6 +2,7 @@ ...@@ -2,6 +2,7 @@
import gc import gc
import time import time
import weakref
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
...@@ -9,7 +10,7 @@ import torch ...@@ -9,7 +10,7 @@ import torch
import torch.distributed import torch.distributed
import torch.nn as nn import torch.nn as nn
from vllm.attention.backends.abstract import AttentionType from vllm.attention import AttentionType, get_attn_backend
from vllm.attention.layer import Attention from vllm.attention.layer import Attention
from vllm.config import CompilationLevel, VllmConfig from vllm.config import CompilationLevel, VllmConfig
from vllm.distributed.parallel_state import get_pp_group, graph_capture from vllm.distributed.parallel_state import get_pp_group, graph_capture
...@@ -24,8 +25,7 @@ from vllm.sampling_params import SamplingType ...@@ -24,8 +25,7 @@ from vllm.sampling_params import SamplingType
from vllm.sequence import IntermediateTensors from vllm.sequence import IntermediateTensors
from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler, from vllm.utils import (STR_DTYPE_TO_TORCH_DTYPE, DeviceMemoryProfiler,
LayerBlockType, cdiv, is_pin_memory_available) LayerBlockType, cdiv, is_pin_memory_available)
from vllm.v1.attention.backends.flash_attn import (FlashAttentionBackend, from vllm.v1.attention.backends.flash_attn import FlashAttentionMetadata
FlashAttentionMetadata)
from vllm.v1.core.encoder_cache_manager import compute_encoder_budget from vllm.v1.core.encoder_cache_manager import compute_encoder_budget
from vllm.v1.engine.mm_input_cache import MMInputCacheClient from vllm.v1.engine.mm_input_cache import MMInputCacheClient
from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig, from vllm.v1.kv_cache_interface import (FullAttentionSpec, KVCacheConfig,
...@@ -92,6 +92,27 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -92,6 +92,27 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.head_size = model_config.get_head_size() self.head_size = model_config.get_head_size()
self.hidden_size = model_config.get_hidden_size() self.hidden_size = model_config.get_hidden_size()
self.attn_backend = get_attn_backend(
self.head_size,
self.dtype,
self.kv_cache_dtype,
self.block_size,
self.model_config.is_attention_free,
use_mla=self.model_config.use_mla,
)
if self.attn_backend is None:
error_msg = (
f"Error with get_att_backend: {self.head_size=}, "
f"{self.dtype=}, {self.kv_cache_dtype=}, {self.block_size=}, "
f"{self.model_config.is_attention_free=}, "
f"{self.model_config.use_mla=}")
logger.error(error_msg)
raise NotImplementedError(
"Non-Attention backend is not supported by V1 GPUModelRunner.")
self.attn_metadata_builder = self.attn_backend.get_builder_cls()(
weakref.proxy(self))
# Multi-modal data support # Multi-modal data support
self.input_registry = INPUT_REGISTRY self.input_registry = INPUT_REGISTRY
self.mm_registry = MULTIMODAL_REGISTRY self.mm_registry = MULTIMODAL_REGISTRY
...@@ -433,6 +454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -433,6 +454,12 @@ class GPUModelRunner(LoRAModelRunnerMixin):
num_reqs = self.input_batch.num_reqs num_reqs = self.input_batch.num_reqs
assert num_reqs > 0 assert num_reqs > 0
# Some attention backends (namely MLA) may want to separate requests
# based on if the attention computation will be compute-bound or
# memory-bound. This gives them a hook to do that.
self.attn_metadata_builder.reorder_batch(self.input_batch,
scheduler_output)
# OPTIMIZATION: Start copying the block table first. # OPTIMIZATION: Start copying the block table first.
# This way, we can overlap the copy with the following CPU operations. # This way, we can overlap the copy with the following CPU operations.
self.input_batch.block_table.commit(num_reqs) self.input_batch.block_table.commit(num_reqs)
...@@ -515,7 +542,6 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -515,7 +542,6 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.seq_lens_np[:num_reqs] = ( self.seq_lens_np[:num_reqs] = (
self.input_batch.num_computed_tokens_cpu[:num_reqs] + self.input_batch.num_computed_tokens_cpu[:num_reqs] +
num_scheduled_tokens) num_scheduled_tokens)
max_seq_len = self.seq_lens_np[:num_reqs].max()
# Copy the tensors to the GPU. # Copy the tensors to the GPU.
self.input_ids[:total_num_scheduled_tokens].copy_( self.input_ids[:total_num_scheduled_tokens].copy_(
...@@ -530,49 +556,17 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -530,49 +556,17 @@ class GPUModelRunner(LoRAModelRunnerMixin):
self.positions[:total_num_scheduled_tokens].copy_( self.positions[:total_num_scheduled_tokens].copy_(
self.positions_cpu[:total_num_scheduled_tokens], self.positions_cpu[:total_num_scheduled_tokens],
non_blocking=True) non_blocking=True)
query_start_loc = self.query_start_loc_cpu[:num_reqs + 1].to(
self.device, non_blocking=True)
seq_lens = self.seq_lens_cpu[:num_reqs].to(self.device,
non_blocking=True)
slot_mapping = self.slot_mapping_cpu[:total_num_scheduled_tokens].to(
self.device, non_blocking=True).long()
# Prepare for cascade attention if needed. # Prepare for cascade attention if needed.
common_prefix_len = self._compute_cascade_attn_prefix_len( common_prefix_len = self._compute_cascade_attn_prefix_len(
num_scheduled_tokens, num_scheduled_tokens,
scheduler_output.num_common_prefix_blocks, scheduler_output.num_common_prefix_blocks,
) )
use_cascade = common_prefix_len > 0 attn_metadata = self.attn_metadata_builder.build(
if use_cascade: num_reqs=num_reqs,
# TODO: Optimize.
cu_prefix_query_lens = torch.tensor(
[0, total_num_scheduled_tokens],
dtype=torch.int32,
device=self.device)
prefix_kv_lens = torch.tensor([common_prefix_len],
dtype=torch.int32,
device=self.device)
suffix_kv_lens = (self.seq_lens_np[:num_reqs] - common_prefix_len)
suffix_kv_lens = torch.from_numpy(suffix_kv_lens).to(self.device)
else:
cu_prefix_query_lens = None
prefix_kv_lens = None
suffix_kv_lens = None
attn_metadata = FlashAttentionMetadata(
num_actual_tokens=total_num_scheduled_tokens, num_actual_tokens=total_num_scheduled_tokens,
max_query_len=max_num_scheduled_tokens, max_query_len=max_num_scheduled_tokens,
query_start_loc=query_start_loc,
max_seq_len=max_seq_len,
seq_lens=seq_lens,
block_table=(
self.input_batch.block_table.get_device_tensor()[:num_reqs]),
slot_mapping=slot_mapping,
use_cascade=use_cascade,
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
cu_prefix_query_lens=cu_prefix_query_lens,
prefix_kv_lens=prefix_kv_lens,
suffix_kv_lens=suffix_kv_lens,
) )
use_spec_decode = len( use_spec_decode = len(
...@@ -586,7 +580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -586,7 +580,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# from these partial requests, we do so for simplicity. # from these partial requests, we do so for simplicity.
# We will ignore the sampled tokens from the partial requests. # We will ignore the sampled tokens from the partial requests.
# TODO: Support prompt logprobs. # TODO: Support prompt logprobs.
logits_indices = query_start_loc[1:] - 1 logits_indices = attn_metadata.query_start_loc[1:] - 1
# Hot-Swap lora model # Hot-Swap lora model
if self.lora_config: if self.lora_config:
...@@ -667,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -667,7 +661,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
# common_prefix_len should be a multiple of the block size. # common_prefix_len should be a multiple of the block size.
common_prefix_len = (common_prefix_len // self.block_size * common_prefix_len = (common_prefix_len // self.block_size *
self.block_size) self.block_size)
use_cascade = FlashAttentionBackend.use_cascade_attention( use_cascade = self.attn_backend.use_cascade_attention(
common_prefix_len=common_prefix_len, common_prefix_len=common_prefix_len,
query_lens=num_scheduled_tokens, query_lens=num_scheduled_tokens,
num_query_heads=self.num_query_heads, num_query_heads=self.num_query_heads,
...@@ -1379,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin): ...@@ -1379,7 +1373,7 @@ class GPUModelRunner(LoRAModelRunnerMixin):
assert tensor_config.size % layer_spec.page_size_bytes == 0 assert tensor_config.size % layer_spec.page_size_bytes == 0
num_blocks = tensor_config.size // layer_spec.page_size_bytes num_blocks = tensor_config.size // layer_spec.page_size_bytes
if isinstance(layer_spec, FullAttentionSpec): if isinstance(layer_spec, FullAttentionSpec):
kv_cache_shape = FlashAttentionBackend.get_kv_cache_shape( kv_cache_shape = self.attn_backend.get_kv_cache_shape(
num_blocks, layer_spec.block_size, layer_spec.num_kv_heads, num_blocks, layer_spec.block_size, layer_spec.num_kv_heads,
layer_spec.head_size) layer_spec.head_size)
dtype = layer_spec.dtype dtype = layer_spec.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment