Unverified Commit 9532c498 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] MLA get rid of materialization (#14770)


Signed-off-by: default avatarLucas Wilkinson <lwilkins@redhat.com>
parent 0c2af17c
...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq] ...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P] W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R] W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv] W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N * P] W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, N * R] W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N * V] W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H] W_O project v to h_t shape [N * V, H]
...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV ...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R // MHA with QK headdim = P + R
// V headdim = V // V headdim = V
...@@ -90,20 +90,10 @@ NOTE: in the actual code, ...@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"): ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime Runtime
q_c = h_t @ W_DQ q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
...@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) ...@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P // NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA // but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention( spda_o = scaled_dot_product_attention(
torch.cat([q_latent, q_pe], dim=-1), torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1),
kv_c kv_c
) )
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill ## Chunked Prefill
...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P) ...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV // MHA between queries and new KV
// with QK headdim = P + R // with QK headdim = P + R
...@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple, ...@@ -202,7 +194,6 @@ from typing import (TYPE_CHECKING, Any, Dict, Generic, List, Optional, Tuple,
Type, TypeVar) Type, TypeVar)
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs from vllm import envs
...@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping, ...@@ -215,20 +206,9 @@ from vllm.attention.backends.utils import (PAD_SLOT_ID, compute_slot_mapping,
get_flash_attn_version, get_flash_attn_version,
is_block_tables_empty) is_block_tables_empty)
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import ( from vllm.model_executor.layers.rotary_embedding import (
DeepseekScalingRotaryEmbedding, RotaryEmbedding) DeepseekScalingRotaryEmbedding, RotaryEmbedding)
from vllm.multimodal import MultiModalPlaceholderMap from vllm.multimodal import MultiModalPlaceholderMap
...@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1057,7 +1037,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.triton_fa_func = triton_attention self.triton_fa_func = triton_attention
self.fp8_linear_generic = Fp8LinearGenericOp()
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
...@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1070,79 +1049,28 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: # Convert from (B, N, L) to (N, B, L)
if is_fp8(self.W_UV_O): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
output_parallel = self.fp8_linear_generic.apply( # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, x = torch.bmm(x, self.W_UV)
self.reqaunt_input_group_shape, # Convert from (N, B, V) to (B, N * V)
self.reqaunt_weight_group_shape) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: return self.o_proj(x)[0]
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O) # Return `ql_nope`, `q_pe`
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x): def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: q_nope, q_pe = self.q_proj(x)[0]\
if is_fp8(self.W_Q_UK): .view(-1, self.num_heads, self.qk_head_dim)\
return self.fp8_linear_generic.apply( .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype): # Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of def process_weights_after_loading(self, act_dtype: torch.dtype):
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_layer_weight(layer): def get_layer_weight(layer):
WEIGHT_NAMES = ("weight", "qweight", "weight_packed") WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
...@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1167,10 +1095,9 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight
weight_dtype = get_layer_weight(self.kv_b_proj).dtype # we currently do not have quantized bmm's which are needed for
assert get_layer_weight(self.o_proj).dtype == weight_dtype # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
assert get_layer_weight(self.q_proj).dtype == weight_dtype # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
...@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1189,89 +1116,10 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ # Convert from (L, N, V) to (N, L, V)
.view(-1, self.num_heads, self.qk_head_dim) self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
# can be W_Q or W_UQ depending q_lora_rank, the former if self.W_UK_T = W_UK.permute(1, 2, 0)
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
def _compute_prefill_context( def _compute_prefill_context(
self, self,
...@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1471,7 +1319,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, ql_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: T, attn_metadata: T,
...@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1525,9 +1373,8 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
prefill_k_c_normed = k_c_normed[:num_prefill_tokens] prefill_k_c_normed = k_c_normed[:num_prefill_tokens]
if has_decode: if has_decode:
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_ql_nope, decode_q_pe = \
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
.view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
decode_input_positions, decode_q_pe, decode_k_pe) decode_input_positions, decode_q_pe, decode_k_pe)
...@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]): ...@@ -1561,6 +1408,6 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
if has_decode: if has_decode:
output[num_prefill_tokens:] = self._forward_decode( output[num_prefill_tokens:] = self._forward_decode(
decode_q_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output return output
...@@ -84,8 +84,6 @@ if TYPE_CHECKING: ...@@ -84,8 +84,6 @@ if TYPE_CHECKING:
VLLM_SERVER_DEV_MODE: bool = False VLLM_SERVER_DEV_MODE: bool = False
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128 VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True VLLM_MLA_CUDA_MEM_ALIGN_KV_CACHE: bool = True
VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON: bool = False
VLLM_RAY_PER_WORKER_GPUS: float = 1.0 VLLM_RAY_PER_WORKER_GPUS: float = 1.0
...@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = { ...@@ -563,23 +561,6 @@ environment_variables: dict[str, Callable[[], Any]] = {
"VLLM_MLA_DISABLE": "VLLM_MLA_DISABLE":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))), lambda: bool(int(os.getenv("VLLM_MLA_DISABLE", "0"))),
# Flag that can control whether or not we perform matrix-absorption for MLA
# decode, i.e. absorb W_UK into W_Q/W_UK and W_UV into W_O, absorbing the
# matrices reduces the runtime FLOPs needed to compute MLA but requires
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0"))),
# If set, vLLM will use the Triton implementation of moe_align_block_size, # If set, vLLM will use the Triton implementation of moe_align_block_size,
# i.e. moe_align_block_size_triton in fused_moe.py. # i.e. moe_align_block_size_triton in fused_moe.py.
"VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON": "VLLM_ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON":
......
...@@ -13,10 +13,9 @@ import triton.language as tl ...@@ -13,10 +13,9 @@ import triton.language as tl
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import ( from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize) scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import ( from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
CUTLASS_BLOCK_FP8_SUPPORTED, Fp8LinearOp, cutlass_block_fp8_supported, CUTLASS_BLOCK_FP8_SUPPORTED)
cutlass_fp8_supported)
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import direct_register_custom_op from vllm.utils import direct_register_custom_op
...@@ -101,60 +100,6 @@ direct_register_custom_op( ...@@ -101,60 +100,6 @@ direct_register_custom_op(
) )
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
# TODO(luka): unify this better
# https://github.com/vllm-project/vllm/issues/14397
class Fp8LinearGenericOp:
def __init__(
self,
cutlass_fp8_supported: bool = cutlass_fp8_supported(),
cutlass_block_fp8_supported: bool = cutlass_block_fp8_supported(),
):
self.cutlass_block_fp8_supported = cutlass_block_fp8_supported
self.fp8_linear = Fp8LinearOp(
cutlass_fp8_supported=cutlass_fp8_supported)
def apply(
self,
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
weight_group_shape = _normalize_quant_group_shape( \
weight, weight_group_shape)
input_group_shape = _normalize_quant_group_shape(
input, input_group_shape)
def is_dim_blocked(dim, shape, group_shape):
return group_shape < shape[dim] and group_shape > 1
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(
input,
weight,
list(weight_group_shape),
weight_scale,
cutlass_block_fp8_supported=self.cutlass_block_fp8_supported)
else:
# Despite having linear in the name it doesn't conform to
# `torch.nn.functional.linear` which is defined as
# `input @ weight.T` so we explicitly transpose the weight matrix
return self.fp8_linear.apply(input, weight.T, weight_scale.T,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))
def input_to_float8( def input_to_float8(
x: torch.Tensor, x: torch.Tensor,
dtype: Optional[torch.dtype] = None dtype: Optional[torch.dtype] = None
......
...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq] ...@@ -54,9 +54,9 @@ W_DQ project h_t to q_c shape [H, Lq]
W_UQ project q_c to q_nope shape [Lq, N * P] W_UQ project q_c to q_nope shape [Lq, N * P]
W_QR project q_c to q_pe shape [Lq, N * R] W_QR project q_c to q_pe shape [Lq, N * R]
W_DKV project h_t to kv_c shape [H, Lkv] W_DKV project h_t to kv_c shape [H, Lkv]
W_UK project kv_c to k_nope shape [Lkv, N * P] W_UK project kv_c to k_nope shape [Lkv, N, P]
W_KR project h_t to k_pe shape [H, N * R] W_KR project h_t to k_pe shape [H, R]
W_UV project kv_c to v shape [Lkv, N * V] W_UV project kv_c to v shape [Lkv, N, V]
W_O project v to h_t shape [N * V, H] W_O project v to h_t shape [N * V, H]
...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV ...@@ -69,8 +69,8 @@ new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0) kv_c = torch.cat([new_kv_c, cache_kv_c], dim=0)
k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
k_nope = (kv_c @ W_UK).view(Skv, N, P) k_nope = (kv_c @ W_UK.view(Lkv, N * P)).view(Skv, N, P)
v = (kv_c @ W_UV).view(Skv, N, V) v = (kv_c @ W_UV.view(Lkv, N * V)).view(Skv, N, V)
// MHA with QK headdim = P + R // MHA with QK headdim = P + R
// V headdim = V // V headdim = V
...@@ -90,20 +90,10 @@ NOTE: in the actual code, ...@@ -90,20 +90,10 @@ NOTE: in the actual code,
## Data-Movement Friendly Approach (i.e. "_forward_decode"): ## Data-Movement Friendly Approach (i.e. "_forward_decode"):
Ahead of time, compute:
% this projects from q_c to [Sq, N * Lkv]
W_UQ_UK = einsum("qnp,knp -> qnk"
W_UQ.view(Lq, N, P), W_UK.view(Lkv, N, P)
).view(Lkv, N * Lkv)
% this projects from attn output [Sq, N * Lkv] to [Sq, H]
W_UV_O = einsum("knv,nvh -> nkh"
W_UV.view(Lkv, N, V), W_O.view(N, V, H)
).view(N * Lkv, H)
Runtime Runtime
q_c = h_t @ W_DQ q_c = h_t @ W_DQ
q_latent = q_c @ W_UQ_UK.view(Sq, N, Lkv) q_nope = (q_c @ W_UQ).view(-1, N, P)
ql_nope = einsum("snh,lnh->snl", q, W_UK)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
...@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0) ...@@ -116,11 +106,13 @@ k_pe = torch.cat([new_k_pe, cache_k_pe], dim=0)
// NOTE: this is less compute-friendly since Lkv > P // NOTE: this is less compute-friendly since Lkv > P
// but is more data-movement friendly since its MQA vs MHA // but is more data-movement friendly since its MQA vs MHA
spda_o = scaled_dot_product_attention( spda_o = scaled_dot_product_attention(
torch.cat([q_latent, q_pe], dim=-1), torch.cat([ql_nope, q_pe], dim=-1),
torch.cat([kv_c, k_pe], dim=-1), torch.cat([kv_c, k_pe], dim=-1),
kv_c kv_c
) )
return spda_o.reshape(-1, N * Lkv) @ W_UV_O
o = einsum("snl,lnv->snv", spda_o.reshape(-1, N, Lkv), W_UV)
return o.view(-1, N * V) @ self.num_heads @ W_O
## Chunked Prefill ## Chunked Prefill
...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P) ...@@ -146,8 +138,8 @@ q_nope = (q_c @ W_UQ).view(Sq, N, P)
q_pe = RoPE(q_c @ W_QR).view(Sq, N, R) q_pe = RoPE(q_c @ W_QR).view(Sq, N, R)
new_kv_c = h_t @ W_DKV new_kv_c = h_t @ W_DKV
new_k_pe = RoPE(h_t @ W_KR) new_k_pe = RoPE(h_t @ W_KR)
new_k_nope = (new_kv_c @ W_UK).view(Sq, N, P) new_k_nope = (new_kv_c @ W_UK.view(Lkv, N * P)).view(Sq, N, P)
new_v = (new_kv_c @ W_UV).view(Sq, N, V) new_v = (new_kv_c @ W_UV.view(Lkv, N * V)).view(Sq, N, V)
// MHA between queries and new KV // MHA between queries and new KV
// with QK headdim = P + R // with QK headdim = P + R
...@@ -198,30 +190,17 @@ from dataclasses import dataclass ...@@ -198,30 +190,17 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar from typing import TYPE_CHECKING, Any, Generic, Optional, TypeVar
import torch import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer, from vllm.attention.backends.abstract import (AttentionBackend, AttentionLayer,
AttentionMetadata, AttentionMetadata,
MLAAttentionImpl) MLAAttentionImpl)
from vllm.attention.backends.utils import get_flash_attn_version from vllm.attention.backends.utils import get_flash_attn_version
from vllm.attention.ops.triton_merge_attn_states import merge_attn_states from vllm.attention.ops.triton_merge_attn_states import merge_attn_states
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.logger import init_logger from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (ColumnParallelLinear, from vllm.model_executor.layers.linear import (ColumnParallelLinear,
LinearBase, RowParallelLinear, LinearBase, RowParallelLinear,
UnquantizedLinearMethod) UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
Fp8LinearGenericOp, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.platforms import current_platform from vllm.platforms import current_platform
from vllm.utils import cdiv, round_down from vllm.utils import cdiv, round_down
...@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -646,7 +625,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
self.kv_b_proj = kv_b_proj self.kv_b_proj = kv_b_proj
self.o_proj = o_proj self.o_proj = o_proj
self.vllm_flash_attn_version = get_flash_attn_version() self.vllm_flash_attn_version = get_flash_attn_version()
self.fp8_linear_generic = Fp8LinearGenericOp()
# Handle the differences between the flash_attn_varlen from flash_attn # Handle the differences between the flash_attn_varlen from flash_attn
# and the one from vllm_flash_attn. The former is used on RoCM and the # and the one from vllm_flash_attn. The former is used on RoCM and the
...@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -658,88 +636,37 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
fa_version=self.vllm_flash_attn_version) fa_version=self.vllm_flash_attn_version)
def _v_up_proj_and_o_proj(self, x): def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: # Convert from (B, N, L) to (N, B, L)
if is_fp8(self.W_UV_O): x = x.view(-1, self.num_heads, self.kv_lora_rank).transpose(0, 1)
output_parallel = self.fp8_linear_generic.apply( # Multiply (N, B, L) x (N, L, V) -> (N, B, V)
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales, x = torch.bmm(x, self.W_UV)
self.reqaunt_input_group_shape, # Convert from (N, B, V) to (B, N * V)
self.reqaunt_weight_group_shape) x = x.transpose(0, 1).reshape(-1, self.num_heads * self.v_head_dim)
else: return self.o_proj(x)[0]
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O) # Return `ql_nope`, `q_pe`
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
self.num_heads * self.v_head_dim))[0]
def _q_proj_and_k_up_proj(self, x): def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION: q_nope, q_pe = self.q_proj(x)[0]\
if is_fp8(self.W_Q_UK): .view(-1, self.num_heads, self.qk_head_dim)\
return self.fp8_linear_generic.apply( .split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
x = torch.matmul(x, self.W_Q)\
.view(-1, self.num_heads, self.qk_nope_head_dim)
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self, act_dtype: torch.dtype): # Convert from (B, N, P) to (N, B, P)
q_nope = q_nope.transpose(0, 1)
# Multiply (N, B, P) x (N, P, L) -> (N, B, L)
ql_nope = torch.bmm(q_nope, self.W_UK_T)
# Convert from (N, B, L) to (B, N, L)
return ql_nope.transpose(0, 1), q_pe
# TODO(lucas) This is very gross, we need a more wide scale refactor of def process_weights_after_loading(self, act_dtype: torch.dtype):
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
tuple[tuple[int, int], tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_layer_weight(layer): def get_layer_weight(layer):
if hasattr(layer, "weight"): WEIGHT_NAMES = ("weight", "qweight", "weight_packed")
return layer.weight for attr in WEIGHT_NAMES:
elif hasattr(layer, "qweight"): if hasattr(layer, attr):
return layer.qweight return getattr(layer, attr)
else:
raise AttributeError( raise AttributeError(
f"Layer '{layer}' has neither weight nor qweight") f"Layer '{layer}' has no recognized weight attribute:"
f" {WEIGHT_NAMES}.")
def get_and_maybe_dequant_weights(layer: LinearBase): def get_and_maybe_dequant_weights(layer: LinearBase):
if not isinstance(layer.quant_method, UnquantizedLinearMethod): if not isinstance(layer.quant_method, UnquantizedLinearMethod):
...@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -755,10 +682,9 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
return dequant_weights.T return dequant_weights.T
return layer.weight return layer.weight
weight_dtype = get_layer_weight(self.kv_b_proj).dtype # we currently do not have quantized bmm's which are needed for
assert get_layer_weight(self.o_proj).dtype == weight_dtype # `W_UV` and `W_UK_T`, we we just store fp16/bf16 copies and perform
assert get_layer_weight(self.q_proj).dtype == weight_dtype # the bmm's in 16-bit, the extra memory overhead of this is fairly low
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == ( assert kv_b_proj_weight.shape == (
self.kv_lora_rank, self.kv_lora_rank,
...@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -777,89 +703,10 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
W_UK, W_UV = kv_b_proj_weight.split( W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1) [self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\ # Convert from (L, N, V) to (N, L, V)
.view(-1, self.num_heads, self.qk_head_dim) self.W_UV = W_UV.transpose(0, 1)
# Convert from (L, N, P) to (N, P, L)
# can be W_Q or W_UQ depending q_lora_rank, the former if self.W_UK_T = W_UK.permute(1, 2, 0)
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
# for decode, as a result we end up with absorbed weights for decode
# and another copy of raw weights for prefill.
#
self.W_UK, self.W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
# We absorb `W_UK` into `W_Q` resulting in either W_Q_UK or W_UQ_UK
# depending q_lora_rank, the former if q_lora_rank is None, the
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform.fp8_dtype())
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
def _compute_prefill_context( def _compute_prefill_context(
self, self,
...@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -998,7 +845,7 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
@abstractmethod @abstractmethod
def _forward_decode( def _forward_decode(
self, self,
q_nope: torch.Tensor, ql_nope: torch.Tensor,
q_pe: torch.Tensor, q_pe: torch.Tensor,
kv_c_and_k_pe_cache: torch.Tensor, kv_c_and_k_pe_cache: torch.Tensor,
attn_metadata: M, attn_metadata: M,
...@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1051,10 +898,8 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode: if has_decode:
assert attn_metadata.decode is not None assert attn_metadata.decode is not None
decode_q_nope = self._q_proj_and_k_up_proj(decode_hs_or_q_c) decode_ql_nope, decode_q_pe = \
decode_q_pe = torch.matmul(decode_hs_or_q_c, self.W_QR)\ self._q_proj_and_k_up_proj(decode_hs_or_q_c)
.view(-1, self.num_heads, self.qk_rope_head_dim)
decode_q_pe[...], decode_k_pe[...] = self.rotary_emb( decode_q_pe[...], decode_k_pe[...] = self.rotary_emb(
attn_metadata.decode.input_positions, decode_q_pe.contiguous(), attn_metadata.decode.input_positions, decode_q_pe.contiguous(),
decode_k_pe) decode_k_pe)
...@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]): ...@@ -1087,6 +932,6 @@ class MLACommonImpl(MLAAttentionImpl[M], Generic[M]):
if has_decode: if has_decode:
output[:num_decode_tokens] = self._forward_decode( output[:num_decode_tokens] = self._forward_decode(
decode_q_nope, decode_q_pe, kv_cache, attn_metadata) decode_ql_nope, decode_q_pe, kv_cache, attn_metadata)
return output_padded return output_padded
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment