Unverified Commit baeded25 authored by Lucas Wilkinson's avatar Lucas Wilkinson Committed by GitHub
Browse files

[Attention] Deepseek v3 MLA support with FP8 compute (#12601)



This PR implements the Deepseek V3 support by performing matrix absorption the fp8 weights 

---------
Signed-off-by: default avatarLucas Wilkinson <lwilkinson@neuralmagic.com>
Co-authored-by: default avatarWoosuk Kwon <woosuk.kwon@berkeley.edu>
Co-authored-by: default avatarsimon-mo <simon.mo@hey.com>
Co-authored-by: default avatarMichael Goin <mgoin64@gmail.com>
Co-authored-by: default avatarZhuohan Li <zhuohan123@gmail.com>
Co-authored-by: default avatarTyler Michael Smith <tysmith@redhat.com>
Co-authored-by: default avatarAlexander Matveev <59768536+alexm-neuralmagic@users.noreply.github.com>
parent 3e1c76cf
from abc import abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, Generic, List, Optional
from typing import Any, Dict, Generic, List, Optional, Tuple
import torch
from compressed_tensors.quantization import QuantizationStrategy
from vllm import _custom_ops as ops
from vllm import envs
from vllm.attention.backends.abstract import (AttentionLayer,
AttentionMetadata,
MLAAttentionImpl, T)
from vllm.distributed import get_tensor_model_parallel_world_size
from vllm.distributed import (get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
from vllm.model_executor.layers.linear import (ColumnParallelLinear,
RowParallelLinear)
LinearBase, RowParallelLinear,
UnquantizedLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import ( # noqa: E501
CompressedTensorsLinearMethod)
from vllm.model_executor.layers.quantization.compressed_tensors.schemes import (
CompressedTensorsW8A8Fp8)
from vllm.model_executor.layers.quantization.fp8 import Fp8LinearMethod
from vllm.model_executor.layers.quantization.utils.fp8_utils import (
apply_fp8_linear_generic, current_platform_fp8_dtype, is_fp8)
from vllm.model_executor.layers.quantization.utils.quant_utils import (
scaled_dequantize, scaled_quantize)
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
from vllm.vllm_flash_attn import flash_attn_varlen_func
......@@ -25,11 +37,11 @@ class MLACommonMetadata(AttentionMetadata):
class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
"""
Common class for implementing repeated parts
Common class for implementing repeated parts
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
Deepseek's MLA attention works the following way:
* Use a single latent vector to represent the entire KV cache.
* The attention "simulates" a multi-head attention, while the compute is
......@@ -46,7 +58,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
* V: V head dim.
* kv_c: latent/compressed KV
* q_c: latent/compressed Q
#
# Outside the MLA attention backend
#
......@@ -55,21 +67,21 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
kv_c_k_pe (B, Lkv+R).
2. The kv_c_k_pe is split into kv_c (B, Lkv) and k_pe (B, R). cq
and kv_c are normalized.
#
# Inside the MLA attention backend
#
* if prefill:
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
3. The q_c is then projected up into the multi-head version.
* q_c goes from (B, Lq) to (B, N, (P+R)), which is split into q_nope
(B, N, P) and q_pe (B, N, R).
4. q_pe, k_pe are then passed through rotary embeddings.
5. kv_c and k_pe are concatenated and inserted into the cache
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
6. The kv_c is then projected up into the multi-head version.
* kv_c goes from (B, Lkv) to (B, N, (P+V)) which has the nope
dimensions for K and V, which is split into k_nope (B, N, P)
and v (B, N, V).
7. q (B, N, (P+R)) and k (B, N, (P+R)) matrices are assembled from
q_nope, q_pe, k_nope, k_pe.
......@@ -112,7 +124,7 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
From @tsu-bin's calculation, we only want to use the absorption technique
for decode. The prefill algorithm should still use the up-projected MHA
for less flops and memory usage.
"""
def __init__(
......@@ -162,8 +174,19 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def _v_up_proj_and_o_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
return self.o_proj_absorbed(
x.reshape(-1, self.num_heads * self.kv_lora_rank))[0]
if is_fp8(self.W_UV_O):
output_parallel = apply_fp8_linear_generic(
x.flatten(start_dim=1), self.W_UV_O, self.W_UV_O_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape)
else:
output_parallel = torch.matmul(x.flatten(start_dim=1),
self.W_UV_O)
if self.tp_size > 1:
output = tensor_model_parallel_all_reduce(output_parallel)
else:
output = output_parallel
return output
else:
x = torch.einsum("bnl,lnv->bnv", x, self.W_UV)
return self.o_proj(x.reshape(-1,
......@@ -171,6 +194,12 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
def _q_proj_and_k_up_proj(self, x):
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
if is_fp8(self.W_Q_UK):
return apply_fp8_linear_generic(
x, self.W_Q_UK, self.W_Q_UK_scales,
self.reqaunt_input_group_shape,
self.reqaunt_weight_group_shape).view(
-1, self.num_heads, self.kv_lora_rank)
return torch.matmul(x, self.W_Q_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
else:
......@@ -179,8 +208,91 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
return torch.einsum("bnp,lnp->bnl", x, self.W_UK)\
.view(-1, self.num_heads, self.kv_lora_rank)
def process_weights_after_loading(self):
kv_b_proj_weight = self.kv_b_proj.weight.T
def process_weights_after_loading(self, act_dtype: torch.dtype):
def is_layer_fp8(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, Fp8LinearMethod) or\
(isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8))
def quantization_scheme_supported(layer: LinearBase) -> bool:
return isinstance(layer.quant_method, UnquantizedLinearMethod) or \
is_layer_fp8(layer)
# TODO(lucas) This is very gross, we need a more wide scale refactor of
# all the FP8 code with a more standard way of
# defining schemes/group-shapes, we should also potentially force
# quant_methods to support a decompress function
#
# returns input_group_shape, weight_group_shape
def get_scale_group_shapes_for_fp8(layer: LinearBase) -> \
Tuple[Tuple[int, int], Tuple[int, int]]:
if isinstance(layer.quant_method, Fp8LinearMethod):
if layer.quant_method.block_quant is not None:
weight_block_size = \
layer.quant_method.quant_config.weight_block_size
# per-token-group (1, X), block-quantized (X, Y)
return (1, weight_block_size[-1]), weight_block_size
else:
return (-1, -1), (-1, -1) # per-tensor, per-tensor
elif isinstance(layer.quant_method, CompressedTensorsLinearMethod)\
and isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# this is hacky but we always assume the for
# CompressedTensorsW8A8Fp8 the input is dynamic per-token
# we ignore if it is static-per-tensor since we are going to
# requantize after later anyways
strategy = layer.scheme.strategy
if strategy == QuantizationStrategy.TENSOR:
return (1, -1), (-1, -1) # per-token, per-tensor
elif strategy == QuantizationStrategy.CHANNEL:
return (1, -1), (-1, 1) # per-token, per-channel
else:
raise NotImplementedError(
f"QuantizationStrategy.{strategy} is not supported for "
"fp8 MLA, please run with VLLM_MLA_DISABLE=1")
else:
raise NotImplementedError(
"Can't determine scale group shapes for "
f"{layer.quant_method}, please run with VLLM_MLA_DISABLE=1"
)
def get_scales(layer: LinearBase) -> torch.Tensor:
if hasattr(layer, "weight_scale_inv"):
return layer.weight_scale_inv
return layer.weight_scale
def get_and_maybe_dequant_weights(layer: LinearBase):
if is_layer_fp8(layer):
if isinstance(layer.quant_method, \
CompressedTensorsLinearMethod) and \
isinstance(layer.scheme, CompressedTensorsW8A8Fp8):
# NOTE(lucas): note sure why but `CompressedTensorsW8A8Fp8`
# seems to store weights as (input, output) instead of
# (output, input) so we need to transpose
weight = layer.weight.T # standardize to (output, input)
else:
weight = layer.weight
_, weight_scale_group_shape = \
get_scale_group_shapes_for_fp8(layer)
scales = get_scales(layer)
return scaled_dequantize(weight, scales,
weight_scale_group_shape)
else:
return layer.weight
if not (quantization_scheme_supported(self.kv_b_proj) and\
quantization_scheme_supported(self.q_proj) and\
quantization_scheme_supported(self.o_proj)):
raise NotImplementedError(
"Only FP8 and UnquantizedLinearMethod are supported for MLA"
", please run with VLLM_MLA_DISABLE=1")
weight_dtype = self.kv_b_proj.weight.dtype
assert self.o_proj.weight.dtype == weight_dtype
assert self.q_proj.weight.dtype == weight_dtype
kv_b_proj_weight = get_and_maybe_dequant_weights(self.kv_b_proj).T
assert kv_b_proj_weight.shape == (
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim)), (
......@@ -198,18 +310,35 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
W_UK, W_UV = kv_b_proj_weight.split(
[self.qk_nope_head_dim, self.v_head_dim], dim=-1)
q_proj = self.q_proj.weight.T\
q_proj_weight = get_and_maybe_dequant_weights(self.q_proj).T\
.view(-1, self.num_heads, self.qk_head_dim)
# can be W_Q or W_UQ depending q_lora_rank, the former if
# q_lora_rank is None, the latter otherwise. From the Attention backend
# perspective though we call these both W_Q and rely on the layer
# to pass in the correct matrix
W_Q = q_proj[..., :self.qk_nope_head_dim]
self.W_QR = q_proj[..., self.qk_nope_head_dim:]\
W_Q = q_proj_weight[..., :self.qk_nope_head_dim]
self.W_QR = q_proj_weight[..., self.qk_nope_head_dim:]\
.flatten(start_dim=1).contiguous()
# W_QR is small so for simplicity we dont bother requantizing it
self.W_QR = self.W_QR.to(act_dtype)
if envs.VLLM_MLA_PERFORM_MATRIX_ABSORPTION:
requantization_enabled = not envs.VLLM_MLA_DISABLE_REQUANTIZATION
if is_fp8(weight_dtype) and requantization_enabled:
# This assumes it wise to requantize using the same group shapes
# (i.e. strategy, per-tensor, per-channel, block etc.) that the
# weights were originally quantized
requant_input_group_shape, requant_weight_group_shape = \
get_scale_group_shapes_for_fp8(self.q_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.kv_b_proj)
assert (requant_input_group_shape, requant_weight_group_shape)\
== get_scale_group_shapes_for_fp8(self.o_proj)
self.reqaunt_input_group_shape = requant_input_group_shape
self.reqaunt_weight_group_shape = requant_weight_group_shape
#
# Perform matrix-absorption following
# https://github.com/flashinfer-ai/flashinfer/pull/551
......@@ -223,25 +352,44 @@ class MLACommonImpl(MLAAttentionImpl[T], Generic[T]):
# latter otherwise
# basically if q_lora_rank is none we are absorbing into q_proj
# instead of UQ
self.W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
W_Q_UK = torch.einsum("qnd,lnd -> qnl", W_Q, W_UK)\
.flatten(start_dim=1).contiguous()
W_O = self.o_proj.weight\
if is_fp8(weight_dtype) and requantization_enabled:
W_Q_UK, W_Q_UK_scales = scaled_quantize(
W_Q_UK,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_Q_UK = W_Q_UK.T.contiguous()
self.W_Q_UK_scales = W_Q_UK_scales.T.contiguous()
else:
self.W_Q_UK = W_Q_UK.to(act_dtype)
W_O = get_and_maybe_dequant_weights(self.o_proj)\
.view(-1, self.num_heads, self.v_head_dim)
self.W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
W_UV_O = torch.einsum("lnd,hnd -> nlh", W_UV, W_O)\
.flatten(start_dim=0, end_dim=1).contiguous()
tp_size = get_tensor_model_parallel_world_size()
self.o_proj_absorbed = RowParallelLinear(
self.W_UV_O.shape[0] * tp_size,
self.W_UV_O.shape[1],
bias=False,
# TODO(lucas) figure out how to properly forward quant_method
#quant_config=self.o_proj.quant_method,
)
self.o_proj_absorbed.weight = torch.nn.Parameter(self.W_UV_O.T)
if is_fp8(weight_dtype) and requantization_enabled:
W_UV_O, W_UV_O_scales = scaled_quantize(
W_UV_O,
self.reqaunt_weight_group_shape,
quant_dtype=current_platform_fp8_dtype)
# For FP8 save the transpose so we can use
# `apply_w8a8_block_fp8_linear` directly
self.W_UV_O = W_UV_O.T.contiguous()
self.W_UV_O_scales = W_UV_O_scales.T.contiguous()
else:
self.W_UV_O = W_UV_O.to(act_dtype)
self.tp_size = get_tensor_model_parallel_world_size()
else:
if is_fp8(weight_dtype):
raise NotImplementedError(
"Currently fp8 requires matrix absorption")
self.W_UV = W_UV
self.W_UK = W_UK
self.W_Q = W_Q.flatten(start_dim=1)
......
......@@ -57,14 +57,12 @@ class TritonMLABackend(AttentionBackend):
@staticmethod
def get_kv_cache_shape(
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
kv_lora_rank: int, # passed via head_size
num_blocks: int,
block_size: int,
num_kv_heads: int, # assumed to be 1 for MLA
head_size: int,
) -> Tuple[int, ...]:
# TODO(lucas): remove hardcoding k_pe size as 1/8th of kv_lora_rank
k_pe_size = kv_lora_rank // 8
return (num_blocks, block_size, kv_lora_rank + k_pe_size)
return (num_blocks, block_size, head_size)
@staticmethod
def swap_blocks(
......@@ -83,7 +81,7 @@ class TritonMLABackend(AttentionBackend):
@staticmethod
def get_supported_head_sizes() -> List[int]:
return [512]
return [576]
class TritonMLAState(AttentionState):
......@@ -624,8 +622,6 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
self.multimodal_placeholder_maps.items()
}
num_kv_splits = 8
return TritonMLAMetadata(
num_prefills=self.num_prefills,
slot_mapping=slot_mapping_tensor,
......@@ -645,7 +641,7 @@ class TritonMLAMetadataBuilder(AttentionMetadataBuilder[TritonMLAMetadata]):
context_lens_tensor=context_lens_tensor,
block_tables=block_tables,
use_cuda_graph=use_captured_graph,
num_kv_splits=num_kv_splits,
num_kv_splits=4, # TODO(lucas) add heuristic
head_dim=self.runner.model_config.get_head_size(),
)
......
......@@ -200,9 +200,9 @@ class Attention(nn.Module):
s += f", backend={self.impl.__class__.__name__}"
return s
def process_weights_after_loading(self):
def process_weights_after_loading(self, act_dtype: torch.dtype):
if hasattr(self.impl, "process_weights_after_loading"):
self.impl.process_weights_after_loading()
self.impl.process_weights_after_loading(act_dtype)
class MultiHeadAttention(nn.Module):
......
......@@ -739,18 +739,19 @@ class ModelConfig:
@property
def is_deepseek_mla(self) -> bool:
# TODO add deepseek_v3
return hasattr(self.hf_text_config,
"model_type") and (self.hf_text_config.model_type
in ('deepseek_v2'))
return (hasattr(self.hf_text_config, "model_type")) \
and (self.hf_text_config.model_type in \
('deepseek_v2', 'deepseek_v3'))\
and (self.hf_text_config.kv_lora_rank is not None)
def get_head_size(self) -> int:
# TODO remove hard code
if self.is_deepseek_mla:
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim",
0)
if self.use_mla:
return self.hf_text_config.kv_lora_rank
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
else:
qk_rope_head_dim = getattr(self.hf_text_config,
"qk_rope_head_dim", 0)
qk_nope_head_dim = getattr(self.hf_text_config,
"qk_nope_head_dim", 0)
if qk_rope_head_dim and qk_nope_head_dim:
......@@ -969,6 +970,32 @@ class ModelConfig:
@property
def use_mla(self) -> bool:
if self.quantization is not None and self.quantization not in [\
"fp8", "compressed-tensors"]:
logger.warning(
"MLA is not supported with %s quantization. "
"Disabling MLA.", self.quantization)
return False
# If using a "compressed-tensors" checkpoint, check that all groups
# have fp8 for both weights and activations.
if self.quantization == "compressed-tensors":
quant_config = self._parse_quant_hf_config()
for group_name, cfg in quant_config.get("config_groups",
("", {})).items():
act_cfg = cfg.get("input_activations", {})
act_type = None if act_cfg is None else act_cfg.get("type", "")
w_cfg = cfg.get("weights", {})
w_type = None if w_cfg is None else w_cfg.get("type", "")
if act_type != "fp8" or w_type != "fp8":
logger.warning(
"compressed-tensors MLA support requires fp8 "
"activations and weights in group '%s', but got "
"activations type '%s' and weights type '%s'.\n "
"Full config: %s", group_name, act_type, w_type,
quant_config)
return False
use_mla = (self.is_deepseek_mla and not envs.VLLM_MLA_DISABLE)
return use_mla
......
......@@ -79,6 +79,7 @@ if TYPE_CHECKING:
VLLM_V1_OUTPUT_PROC_CHUNK_SIZE: int = 128
VLLM_MLA_DISABLE: bool = False
VLLM_MLA_PERFORM_MATRIX_ABSORPTION: bool = True
VLLM_MLA_DISABLE_REQUANTIZATION: bool = False
def get_default_cache_root():
......@@ -519,7 +520,16 @@ environment_variables: Dict[str, Callable[[], Any]] = {
# storing more weights, W_Q_UK and W_UV_O, so can increase memory usage,
# the is enabled by default
"VLLM_MLA_PERFORM_MATRIX_ABSORPTION":
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1")))
lambda: bool(int(os.getenv("VLLM_MLA_PERFORM_MATRIX_ABSORPTION", "1"))),
# When running MLA with matrix-absorption enabled and fp8 quantized weights
# we perform the matrix-absorption in float32 precision, after the matrices
# are absorbed we requantize the weights back to fp8, this flag can be used
# to disable the requantization step, and instead convert the absorbed
# matrices to match the activation type. This can lead to higher memory and
# compute usage but better preserves the accuracy of the original model.
"VLLM_MLA_DISABLE_REQUANTIZATION":
lambda: bool(int(os.getenv("VLLM_MLA_DISABLE_REQUANTIZATION", "0")))
}
# end-env-vars-definition
......
......@@ -2,7 +2,7 @@
import functools
import json
import os
from typing import Any, Dict, List, Optional, Tuple
from typing import Any, Dict, List, Optional, Tuple, Union
import torch
import triton
......@@ -10,10 +10,24 @@ import triton.language as tl
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.quantization.utils.quant_utils import (
_normalize_quant_group_shape, scaled_dequantize)
from vllm.model_executor.layers.quantization.utils.w8a8_utils import (
apply_fp8_linear)
from vllm.platforms import current_platform
logger = init_logger(__name__)
current_platform_fp8_dtype = (torch.float8_e4m3fnuz
if current_platform.is_rocm() else
torch.float8_e4m3fn)
def is_fp8(x: Union[torch.dtype, torch.Tensor]) -> bool:
if isinstance(x, torch.Tensor):
x = x.dtype
return x == torch.float8_e4m3fn or x == torch.float8_e4m3fnuz
def apply_w8a8_block_fp8_linear(
input: torch.Tensor,
......@@ -55,6 +69,42 @@ def apply_w8a8_block_fp8_linear(
return output.to(dtype=input.dtype).view(*output_shape)
# Unify the interface between `apply_w8a8_block_fp8_linear` and
# `apply_fp8_linear`
# NOTE(lucas): this is quite messy, we should think through this more formally
def apply_fp8_linear_generic(
input: torch.Tensor,
weight: torch.Tensor,
weight_scale: torch.Tensor,
input_group_shape: Tuple[int, int],
weight_group_shape: Tuple[int, int],
input_scale: Optional[torch.Tensor] = None, # static scale if one
) -> torch.Tensor:
# View input as 2D matrix for fp8 methods
input = input.view(-1, input.shape[-1])
weight_group_shape = _normalize_quant_group_shape(\
weight, weight_group_shape)
input_group_shape = _normalize_quant_group_shape(input, input_group_shape)
def is_dim_blocked(dim, shape, group_shape):
return group_shape < shape[dim] and group_shape > 1
if is_dim_blocked(0, weight.shape, weight_group_shape[0])\
and is_dim_blocked(1, weight.shape, weight_group_shape[1]) and\
input_group_shape == (1, weight_group_shape[1]):
return apply_w8a8_block_fp8_linear(input, weight,
list(weight_group_shape),
weight_scale)
else:
# Despite having linear in the it doesn't conform to
# `torch.nn.functional.linear` which is defined as `input @ weight.T`
# so we explicitly transpose the weight matrix here
return apply_fp8_linear(input, weight.T, weight_scale.T,
use_per_token_if_dynamic=\
(input_group_shape == (1, input.shape[1])))
def input_to_float8(
x: torch.Tensor,
dtype: Optional[torch.dtype] = None
......@@ -75,7 +125,6 @@ def input_to_float8(
def block_quant_to_tensor_quant(
x_q_block: torch.Tensor,
x_s: torch.Tensor,
block_size: List[int],
) -> Tuple[torch.Tensor, torch.Tensor]:
"""This function converts block-wise quantization to tensor-wise
quantization. The inputs are block-wise quantization tensor `x_q_block`,
......@@ -83,26 +132,7 @@ def block_quant_to_tensor_quant(
The outputs are tensor-wise quantization tensor and tensor-wise
quantization scale. Note only float8 is supported for now.
"""
block_n, block_k = block_size[0], block_size[1]
n, k = x_q_block.shape
n_tiles = (n + block_n - 1) // block_n
k_tiles = (k + block_k - 1) // block_k
assert n_tiles == x_s.shape[0]
assert k_tiles == x_s.shape[1]
x_dq_block = x_q_block.to(torch.float32)
x_dq_block_tiles = [[
x_dq_block[
j * block_n:min((j + 1) * block_n, n),
i * block_k:min((i + 1) * block_k, k),
] for i in range(k_tiles)
] for j in range(n_tiles)]
for i in range(k_tiles):
for j in range(n_tiles):
x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
x_dq_block = scaled_dequantize(x_q_block, x_s)
x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
return x_q_tensor, scale
......
"""This file is used for /tests and /benchmarks"""
from typing import List, Optional
from typing import List, Optional, Tuple
import numpy
import torch
......@@ -20,6 +20,120 @@ FUSED_LAYER_NAME_MAPPING = {
}
# Normalize the group_shape to the full extent for any dims that are -1
def _normalize_quant_group_shape(x: torch.Tensor, group_shape: Tuple[int,
int]):
# -1 means full extent
return (group_shape[0] if group_shape[0] > 0 else x.shape[-2],
group_shape[1] if group_shape[1] > 0 else x.shape[-1])
# Useful when treating N-dimensional group scaling as extended numpy-style
# broadcasting in numpy simply stretches dimensions with an extent of 1 to match
# the target shape by repeating the data along that dimension (broadcasting)
# , we extend these semantics to say if the extent of a dimension in the
# source shape is not 1 and does not match the target shape we repeat each
# element along that dimension src_shape[dim] // target_shape[dim] times
# example if we have:
# a = [[1, 2], and target_shape = (2, 4)
# [3, 4]]
# then we would expand a to:
# a = [[1, 1, 2, 2],
# [3, 3, 4, 4]]
# NOTE this function this function does not explicitly broadcast dimensions
# with an extent of 1, since this can be done implicitly by pytorch
def group_broadcast(t, shape):
for i, s in enumerate(shape):
if t.shape[i] != s and t.shape[i] != 1:
assert s % t.shape[i] == 0
t = t.unsqueeze(i + 1)\
.expand(*t.shape[:i+1], s // t.shape[i], *t.shape[i+1:])\
.flatten(i, i + 1)
return t
# Quantize assuming once scale per group of elements with shape group_shape,
# example group shapes:
# * (-1, -1) for per-tensor quantization
# * (1, -1) for per-row quantization
# * (-1, 1) for per-column quantization
# * (128, 128) for 128x128 deepseek style block quantization
# * (1, 128) for deepseek style activation quantization
# (i.e. per-token-per-group)
def scaled_quantize(
x: torch.Tensor,
group_shape: Tuple[int, int],
quant_dtype: torch.dtype,
) -> Tuple[torch.Tensor, torch.Tensor]:
group_shape = _normalize_quant_group_shape(x, group_shape)
assert quant_dtype.is_floating_point, \
"currently `scaled_quantize` only supports floating point dtypes " \
"but could be extended to support other dtypes"
finfo = torch.finfo(quant_dtype)
# Reshape (M, N) into (BLK_M, BLOCK_SIZE_M, BLK_N, BLOCK_SIZE_N)
assert x.ndim == 2
assert x.shape[0] % group_shape[0] == 0 and x.shape[1] % group_shape[1] == 0
blk_m, blk_n = x.shape[0] // group_shape[0], x.shape[1] // group_shape[1]
x_blkd = x.reshape(blk_m, group_shape[0], blk_n, group_shape[1])
# Permute to (BLK_M, BLK_N, BLOCK_SIZE_M, BLOCK_SIZE_N)
x_blkd_permd = x_blkd.permute(0, 2, 1, 3)
# Flatten to (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N)
x_blkd_permd = x_blkd_permd.flatten(start_dim=2)
# Compute scales
min_val, max_val = x_blkd_permd.aminmax(dim=-1)
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
scale = finfo.max / amax
# Apply scale and convert form:
# (BLK_M, BLK_N, BLOCK_SIZE_M * BLOCK_SIZE_N) to (M, N)
x_scl_sat = (x_blkd_permd * scale.unsqueeze(-1))\
.clamp(min=finfo.min, max=finfo.max)\
.reshape(blk_m, blk_n, group_shape[0], group_shape[1])\
.permute(0, 2, 1, 3)\
.reshape(x.shape)
return x_scl_sat.to(quant_dtype).contiguous(), scale.float().reciprocal()
# inverses `scaled_quantize`
def scaled_dequantize(
x_q: torch.Tensor,
x_s: torch.Tensor,
group_shape: Optional[Tuple[int, int]] = None,
out_dtype: torch.dtype = torch.float32,
) -> Tuple[torch.Tensor, torch.Tensor]:
if group_shape is not None:
group_shape = _normalize_quant_group_shape(x_q, group_shape)
if x_s.ndim == 0: # scalar
x_s = x_s.unsqueeze(-1).unsqueeze(-1) # convert to (1, 1) tensor
if x_s.ndim == 1:
if group_shape is None:
raise AssertionError(
"if x_s is 1D tensor, group_shape must be provided otherwise "
"its ambiguous which dimension to broadcast x_s to")
# unsqueeze the scales for the dimension where we want to broadcast
# across the full extent
if group_shape[0] == x_q.shape[-2]:
x_s = x_s.unsqueeze(-2)
elif group_shape[1] == x_q.shape[-1]:
x_s = x_s.unsqueeze(-1)
else:
raise AssertionError(
"if x_s is a vector we should be broadcasting it to the full "
"extent of one of the dimensions")
if group_shape is not None:
assert x_s.shape[-1] == x_q.shape[-1] // group_shape[1]
assert x_s.shape[-2] == x_q.shape[-2] // group_shape[0]
x_s = group_broadcast(x_s.to(torch.float32), x_q.shape)
return (x_q.to(torch.float32) * x_s).to(out_dtype)
def pack_quantized_values_into_int32(w_q: torch.Tensor,
wtype: ScalarType,
packed_dim: int = 0):
......
......@@ -398,11 +398,13 @@ class DefaultModelLoader(BaseModelLoader):
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
elif isinstance(module, Attention) and \
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading()
# TODO(lucas): see if there is a way to unify the signatures
# of process_weights_after_loading
module.process_weights_after_loading(model_config.dtype)
return model.eval()
......@@ -439,6 +441,11 @@ class DummyModelLoader(BaseModelLoader):
with device_loading_context(
module, torch.device(device_config.device)):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
return model.eval()
......@@ -633,6 +640,12 @@ class ShardedStateLoader(BaseModelLoader):
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(
model_config.dtype)
rank = get_tensor_model_parallel_rank()
pattern = os.path.join(
local_model_path,
......@@ -1272,7 +1285,7 @@ class GGUFModelLoader(BaseModelLoader):
class RunaiModelStreamerLoader(BaseModelLoader):
"""
Model loader that can load safetensors
Model loader that can load safetensors
files from local FS or S3 bucket.
"""
......@@ -1369,6 +1382,11 @@ class RunaiModelStreamerLoader(BaseModelLoader):
if quant_method is not None:
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if isinstance(module, Attention) and \
hasattr(module, "process_weights_after_loading"):
# When attention modules need to process weights after
# currently only used by MLA
module.process_weights_after_loading(model_config.dtype)
return model.eval()
......
......@@ -27,7 +27,7 @@ from torch import nn
from transformers import PretrainedConfig
from vllm.attention import Attention, AttentionMetadata
from vllm.config import CacheConfig, VllmConfig
from vllm.config import CacheConfig, ModelConfig, VllmConfig
from vllm.distributed import (get_pp_group,
get_tensor_model_parallel_world_size,
tensor_model_parallel_all_reduce)
......@@ -333,12 +333,156 @@ class DeepseekV3Attention(nn.Module):
return output
class DeepseekV3MLAAttention(nn.Module):
"""
Main reference: DeepseekV2 paper, and FlashInfer Implementation
(https://arxiv.org/abs/2405.04434 and https://github.com/flashinfer-ai/flashinfer/pull/551).
For more info see MLACommonImpl in: vllm/attention/backends/mla/utils.py
"""
def __init__(
self,
config: PretrainedConfig,
hidden_size: int,
num_heads: int,
qk_nope_head_dim: int,
qk_rope_head_dim: int,
v_head_dim: int,
q_lora_rank: Optional[int],
kv_lora_rank: int,
rope_theta: float = 10000,
rope_scaling: Optional[Dict[str, Any]] = None,
max_position_embeddings: int = 8192,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
prefix: str = "",
) -> None:
super().__init__()
self.hidden_size = hidden_size
self.qk_nope_head_dim = qk_nope_head_dim
self.qk_rope_head_dim = qk_rope_head_dim
self.qk_head_dim = qk_nope_head_dim + qk_rope_head_dim
self.v_head_dim = v_head_dim
self.q_lora_rank = q_lora_rank
self.kv_lora_rank = kv_lora_rank
self.num_heads = num_heads
tp_size = get_tensor_model_parallel_world_size()
assert num_heads % tp_size == 0
self.num_local_heads = num_heads // tp_size
self.scaling = self.qk_head_dim**-0.5
self.rope_theta = rope_theta
self.max_position_embeddings = max_position_embeddings
if self.q_lora_rank is not None:
self.q_a_proj = ReplicatedLinear(self.hidden_size,
self.q_lora_rank,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_a_proj")
self.q_a_layernorm = RMSNorm(self.q_lora_rank,
eps=config.rms_norm_eps)
self.q_b_proj = ColumnParallelLinear(q_lora_rank,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_b_proj")
else:
self.q_proj = ColumnParallelLinear(self.hidden_size,
self.num_heads *
self.qk_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.q_proj")
self.kv_a_proj_with_mqa = ReplicatedLinear(
self.hidden_size,
self.kv_lora_rank + self.qk_rope_head_dim,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_a_proj_with_mqa")
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank,
eps=config.rms_norm_eps)
self.kv_b_proj = ColumnParallelLinear(
self.kv_lora_rank,
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.kv_b_proj")
self.o_proj = RowParallelLinear(self.num_heads * self.v_head_dim,
self.hidden_size,
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.o_proj")
rope_scaling["rope_type"] = 'deepseek_yarn'
self.rotary_emb = get_rope(qk_rope_head_dim,
rotary_dim=qk_rope_head_dim,
max_position=max_position_embeddings,
base=rope_theta,
rope_scaling=rope_scaling,
is_neox_style=False)
if rope_scaling:
mscale_all_dim = rope_scaling.get("mscale_all_dim", False)
scaling_factor = rope_scaling["factor"]
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
self.scaling = self.scaling * mscale * mscale
self.mla_attn = Attention(
num_heads=self.num_local_heads,
head_size=self.kv_lora_rank,
scale=self.scaling,
num_kv_heads=1,
cache_config=cache_config,
quant_config=quant_config,
prefix=f"{prefix}.attn",
use_mla=True,
# MLA Args
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.kv_lora_rank,
qk_nope_head_dim=self.qk_nope_head_dim,
qk_rope_head_dim=self.qk_rope_head_dim,
qk_head_dim=self.qk_head_dim,
v_head_dim=self.v_head_dim,
rotary_emb=self.rotary_emb,
q_proj=self.q_proj if self.q_lora_rank is None else self.q_b_proj,
kv_b_proj=self.kv_b_proj,
o_proj=self.o_proj,
)
self.prefix = prefix
self.debug_layer_idx = int(self.prefix.split(".")[-2])
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
kv_cache: torch.Tensor,
attn_metadata: AttentionMetadata,
) -> torch.Tensor:
if self.q_lora_rank is not None:
ckq = self.q_a_proj(hidden_states)[0]
hidden_states_or_q_c = self.q_a_layernorm(ckq)
else:
hidden_states_or_q_c = hidden_states
kv_c, k_pe = self.kv_a_proj_with_mqa(hidden_states)[0].split(
[self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
kv_c_normed = self.kv_a_layernorm(kv_c.contiguous())
return self.mla_attn(hidden_states_or_q_c, kv_c_normed, k_pe, kv_cache,
attn_metadata)
class DeepseekV3DecoderLayer(nn.Module):
def __init__(
self,
config: PretrainedConfig,
prefix: str,
model_config: ModelConfig,
cache_config: Optional[CacheConfig] = None,
quant_config: Optional[QuantizationConfig] = None,
) -> None:
......@@ -351,7 +495,11 @@ class DeepseekV3DecoderLayer(nn.Module):
# DecoderLayers are created with `make_layers` which passes the prefix
# with the layer's index.
layer_idx = int(prefix.split(sep='.')[-1])
self.self_attn = DeepseekV3Attention(
if model_config.use_mla:
attn_cls = DeepseekV3MLAAttention
else:
attn_cls = DeepseekV3Attention
self.self_attn = attn_cls(
config=config,
hidden_size=self.hidden_size,
num_heads=config.num_attention_heads,
......@@ -428,6 +576,7 @@ class DeepseekV3Model(nn.Module):
super().__init__()
config = vllm_config.model_config.hf_config
model_config = vllm_config.model_config
cache_config = vllm_config.cache_config
quant_config = vllm_config.quant_config
......@@ -447,6 +596,7 @@ class DeepseekV3Model(nn.Module):
lambda prefix: DeepseekV3DecoderLayer(
config,
prefix,
model_config=model_config,
cache_config=cache_config,
quant_config=quant_config,
),
......
......@@ -110,7 +110,9 @@ class CacheEngine:
parallel_config, LayerBlockType.attention)
key_cache_block = cache_config.block_size * num_heads * head_size
value_cache_block = key_cache_block
# For MLA there is no value cache, since the latent vector
# is joint keys and values.
value_cache_block = key_cache_block if not model_config.use_mla else 0
total = num_attention_layers * (key_cache_block + value_cache_block)
if cache_config.cache_dtype == "auto":
dtype = model_config.dtype
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment