Unverified Commit 9532c498 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] MLA get rid of materialization (#14770)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 0c2af17c
...@@ -7,22 +7,22 @@ First we define: ...@@ -7,22 +7,22 @@ First we define:
Sq as Q sequence length Sq as Q sequence length
Skv as KV sequence length Skv as KV sequence length
MLA has two possible ways of computing, a data-movement friendly approach and a MLA has two possible ways of computing, a data-movement friendly approach and a
compute friendly approach, we generally want to use the compute friendly compute friendly approach, we generally want to use the compute friendly
approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1) approach for "prefill" (i.e. the ratio Sq / Skv is "small", is near 1)
and the data-movement friendly approach for "decode" (i.e. the ratio and the data-movement friendly approach for "decode" (i.e. the ratio
Sq / Skv is "large"). Sq / Skv is "large").
NOTE what we deem small and large is currently determined by if its labelled NOTE what we deem small and large is currently determined by if its labelled
prefill or decode by the scheduler, but this is something we should probably prefill or decode by the scheduler, but this is something we should probably
tune. tune.
Main reference: DeepseekV2 paper, and FlashInfer Implementation Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way: Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache. * Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a * For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention. multi-head attention, while the compute is similar to multi-query attention.
Below is example of both paths assuming batchsize = 1 Below is example of both paths assuming batchsize = 1
...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq] ...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P] W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R] W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv] W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N * P] W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, N * R] W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N * V] W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H] W_O project v to h_t shape [N * V, H]
...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV ...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R // MHA with QK headdim = P + R
// V headdim = V // V headdim = V
...@@ -90,20 +90,10 @@ NOTE: in the actual code, ...@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"): ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime Runtime
q_c = h_t @ W_DQ q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
...@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) ...@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P // NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA // but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention( spda_o = scaled_dot_product_attention(
torch.cat([q_latent, q_pe], dim=-1), torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1),
kv_c kv_c
) )
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill ## Chunked Prefill
...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P) ...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV // MHA between queries and new KV
// with QK headdim = P + R // with QK headdim = P + R
...@@ -171,17 +163,17 @@ for chunk_idx in range(cdiv(C, MCC)): ...@@ -171,17 +163,17 @@ for chunk_idx in range(cdiv(C, MCC)):
cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end] cache_k_pe_chunk = cache_k_pe[chunk_start:chunk_end]
cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P) cache_k_nope_chunk = (cache_kv_c_chunk @ W_UK).view(-1, N, P)
cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V) cache_v_chunk = (cache_kv_c_chunk @ W_UV).view(-1, N, V)
chunk_o, chunk_lse = scaled_dot_product_attention( chunk_o, chunk_lse = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1), torch.cat([q_nope, q_pe], dim=-1),
torch.cat([cache_k_nope_chunk, torch.cat([cache_k_nope_chunk,
cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)], cache_k_pe_chunk.unsqueeze(1).expand(-1, N, -1)],
dim=-1), dim=-1),
cache_v_chunk, cache_v_chunk,
casual=False, casual=False,
return_softmax_lse=True return_softmax_lse=True
) )
curr_o, curr_lse = merge_attn_states( curr_o, curr_lse = merge_attn_states(
suffix_output=curr_o, suffix_output=curr_o,
suffix_lse=curr_lse, suffix_lse=curr_lse,
...@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, ...@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar) Type, TypeVar)
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
...@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
get_flash_attn_version, get_flash_attn_version,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding) DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
...@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.triton_fa_func = triton_attention self.triton_fa_func = triton_attention
self.fp8_linear_generic = Fp8LinearGenericOp()
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
...@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: # Convert from (B, N, L) to (N, B, L)
if is_fp8(self.W_UV_O): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
output_parallel = self.fp8_linear_generic.apply( # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, x = torch.bmm(x, self.W_UV)
self.reqaunt_input_group_shape, # Convert from (N, B, V) to (B, N * V)
self.reqaunt_weight_group_shape) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: return self.o_proj(x)[0]
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O) # Return `ql_nope`, `q_pe`
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x): def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: q_nope, q_pe = self.q_proj(x)[0]\
if is_fp8(self.W_Q_UK): .view(-1, self.num_heads, self.qk_head_dim)\
return self.fp8_linear_generic.apply( .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype): # Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of def process_weights_after_loading(self, act_dtype: torch.dtype):
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_layer_weight(layer): def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed") WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
...@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight
weight_dtype = get_layer_weight(self.kv_b_proj).dtype # we currently do not have quantized bmm's which are needed for
assert get_layer_weight(self.o_proj).dtype == weight_dtype # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
assert get_layer_weight(self.q_proj).dtype == weight_dtype # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
...@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ # Convert from (L, N, V) to (N, L, V)
.view(-1, self.num_heads, self.qk_head_dim) self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
# can be W_Q or W_UQ depending q_lora_rank, the former if self.W_UK_T = W_UK.permute(1, 2, 0)
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
def _compute_prefill_context( def _compute_prefill_context(
self, self,
...@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, ql_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
...@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
prefill_k_c_normed = k_c_normed[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode: if has_decode:
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_ql_nope, decode_q_pe = \
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
.view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe) decode_input_positions, decode_q_pe, decode_k_pe)
...@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_decode: if has_decode:
output[num_prefill_tokens:] = self._forward_decode( output[num_prefill_tokens:] = self._forward_decode(
decode_q_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output return output
...@@ -84,8 +84,6 @@ if TYPE_CHECKING: ...@@ -84,8 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_PER_WORKER_GPUS: float = 1.0
...@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE": "VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))),
# If set, vLLM will use the Triton implementation of moe_align_block_size, # If set, vLLM will use the Triton implementation of moe_align_block_size,
# i.e. moe_align_block_size_triton in fused_moe.py. # i.e. moe_align_block_size_triton in fused_moe.py.
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
......
...@@ -13,10 +13,9 @@ import triton.language as tl ...@@ -13,10 +13,9 @@ import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize) scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported, CUTLASS_BLOCK_FP8_SUPPORTED)
cutlass_fp8_supported)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -101,60 +100,6 @@ direct_register_custom_op( ...@@ -101,60 +100,6 @@ direct_register_custom_op(
) )
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
# TODO(luka): unify this better
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearGenericOp:
def __init__(
self,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(),
):
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_supported)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
weight_group_shape = _normalize_quant_group_shape( \
weight, weight_group_shape)
input_group_shape = _normalize_quant_group_shape(
input, input_group_shape)
def is_dim_blocked(dim, shape, group_shape):
return group_shape < shape[dim] and group_shape > 1
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape),
weight_scale,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported)
else:
# Despite having linear in the name it doesn't conform to
# `torch.nn.functional.linear` which is defined as
# `input @ weight.T` so we explicitly transpose the weight matrix
return self.fp8_linear.apply(input, weight.T, weight_scale.T,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))
def input_to_float8( def input_to_float8(
x: torch.Tensor, x: torch.Tensor,
dtype: Optional[torch.dtype] = None dtype: Optional[torch.dtype] = None
......
...@@ -21,7 +21,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation ...@@ -21,7 +21,7 @@ Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551). (https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way: Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the per-token entry of the KV cache. * Use a single latent vector to represent the per-token entry of the KV cache.
* For decode (i.e. the memory friendly approach) the attention "simulates" a * For decode (i.e. the memory friendly approach) the attention "simulates" a
multi-head attention, while the compute is similar to multi-query attention. multi-head attention, while the compute is similar to multi-query attention.
...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq] ...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P] W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R] W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv] W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N * P] W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, N * R] W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N * V] W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H] W_O project v to h_t shape [N * V, H]
...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV ...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R // MHA with QK headdim = P + R
// V headdim = V // V headdim = V
...@@ -79,7 +79,7 @@ spda_o = scaled_dot_product_attention( ...@@ -79,7 +79,7 @@ spda_o = scaled_dot_product_attention(
torch.cat([q_nope, q_pe], dim=-1), torch.cat([q_nope, q_pe], dim=-1),
torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1), torch.cat([k_nope, k_pe.unsqueeze(1).expand(-1, N, -1)], dim=-1),
v v
) )
return spda_o @ W_O return spda_o @ W_O
NOTE: in the actual code, NOTE: in the actual code,
...@@ -90,20 +90,10 @@ NOTE: in the actual code, ...@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"): ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime Runtime
q_c = h_t @ W_DQ q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
...@@ -116,29 +106,31 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) ...@@ -116,29 +106,31 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P // NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA // but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention( spda_o = scaled_dot_product_attention(
torch.cat([q_latent, q_pe], dim=-1), torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1),
kv_c kv_c
) )
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill ## Chunked Prefill
For chunked prefill we want to use the compute friendly algorithm. We are For chunked prefill we want to use the compute friendly algorithm. We are
assuming sufficiently large Sq / Skv ratio, in the future may want to switch to assuming sufficiently large Sq / Skv ratio, in the future may want to switch to
the data-movement friendly approach if the chunk (i.e. `Sq`) is small. the data-movement friendly approach if the chunk (i.e. `Sq`) is small.
However, the compute-friendly approach can potentially run out of memory if Skv However, the compute-friendly approach can potentially run out of memory if Skv
is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)` is large due to: `k_nope = (kv_c @ W_UK).view(Skv, N, P)`
To mitigate this, we chunk the computation of attention with respect to the To mitigate this, we chunk the computation of attention with respect to the
current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a current context (i.e. `cache_kv_c` and `cache_k_pe`) so that we can used a
fixed workspace size. fixed workspace size.
The chunked prefill approach is as follows: The chunked prefill approach is as follows:
MCC Max chunk of context to process per iter, computed dynamically, MCC Max chunk of context to process per iter, computed dynamically,
used to bound the memory usage used to bound the memory usage
q_c = h_t @ W_DQ q_c = h_t @ W_DQ
...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P) ...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV // MHA between queries and new KV
// with QK headdim = P + R // with QK headdim = P + R
...@@ -160,7 +152,7 @@ curr_o, curr_lse = scaled_dot_product_attention( ...@@ -160,7 +152,7 @@ curr_o, curr_lse = scaled_dot_product_attention(
new_v, new_v,
casual=True, casual=True,
return_softmax_lse=True return_softmax_lse=True
) )
// Compute attention with the already existing context // Compute attention with the already existing context
for chunk_idx in range(cdiv(C, MCC)): for chunk_idx in range(cdiv(C, MCC)):
...@@ -198,30 +190,17 @@ from dataclasses import dataclass ...@@ -198,30 +190,17 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
...@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
self.fp8_linear_generic = Fp8LinearGenericOp()
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
...@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: # Convert from (B, N, L) to (N, B, L)
if is_fp8(self.W_UV_O): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
output_parallel = self.fp8_linear_generic.apply( # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, x = torch.bmm(x, self.W_UV)
self.reqaunt_input_group_shape, # Convert from (N, B, V) to (B, N * V)
self.reqaunt_weight_group_shape) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: return self.o_proj(x)[0]
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O) # Return `ql_nope`, `q_pe`
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x): def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: q_nope, q_pe = self.q_proj(x)[0]\
if is_fp8(self.W_Q_UK): .view(-1, self.num_heads, self.qk_head_dim)\
return self.fp8_linear_generic.apply( .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype): # Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of def process_weights_after_loading(self, act_dtype: torch.dtype):
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
tuple[tuple[int, int], tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_layer_weight(layer): def get_layer_weight(layer):
if hasattr(layer, "weight"): WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
return layer.weight for attr in WEIGHT_NAMES:
elif hasattr(layer, "qweight"): if hasattr(layer, attr):
return layer.qweight return getattr(layer, attr)
else: raise AttributeError(
raise AttributeError( f"Layer '{layer}' has no recognized weight attribute:"
f"Layer '{layer}' has neither weight nor qweight") f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase): def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod): if not isinstance(layer.quant_method, UnquantizedLinearMethod):
...@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight
weight_dtype = get_layer_weight(self.kv_b_proj).dtype # we currently do not have quantized bmm's which are needed for
assert get_layer_weight(self.o_proj).dtype == weight_dtype # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
assert get_layer_weight(self.q_proj).dtype == weight_dtype # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
...@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ # Convert from (L, N, V) to (N, L, V)
.view(-1, self.num_heads, self.qk_head_dim) self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
# can be W_Q or W_UQ depending q_lora_rank, the former if self.W_UK_T = W_UK.permute(1, 2, 0)
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
def _compute_prefill_context( def _compute_prefill_context(
self, self,
...@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, ql_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M, attn_metadata: M,
...@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_ql_nope, decode_q_pe = \
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
.view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(), attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe) decode_k_pe)
...@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode: if has_decode:
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_q_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded return output_padded
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment