Unverified Commit 4d51588e authored by Yifan Qiao's avatar Yifan Qiao Committed by GitHub
Browse files

[Feat] DeepSeek V4 Rebased (#40860)


Signed-off-by: default avatarYifan Qiao <yifanqiao@inferact.ai>
Signed-off-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Signed-off-by: default avatarqizixi <zixi@inferact.ai>
Signed-off-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Signed-off-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <zyy1102000@gmail.com>
Co-authored-by: default avatarYongye Zhu <yongye@inferact.ai>
Co-authored-by: default avatarSimon Mo <simon@inferact.ai>
Co-authored-by: default avatarBugen Zhao <i@bugenzhao.com>
Co-authored-by: default avatarGiancarlo Delfin <gdelfin@inferact.ai>
Co-authored-by: default avatarJee Jee Li <pandaleefree@gmail.com>
Co-authored-by: default avatarNick Hill <nickhill123@gmail.com>
Co-authored-by: default avatarRoger Wang <hey@rogerw.io>
Co-authored-by: default avatarRoy Wang <yasong.wang@inferact.ai>
Co-authored-by: default avatarWoosuk Kwon <woosuk@inferact.ai>
Co-authored-by: default avataryoukaichao <youkaichao@gmail.com>
Co-authored-by: default avatarZhewen Li <jerven.vllm@gmail.com>
Co-authored-by: default avatarZijing Liu <liuzijing2014@gmail.com>
Co-authored-by: default avatarkhluu <khluu000@gmail.com>
Co-authored-by: default avatarqizixi <zixi@inferact.ai>
Co-authored-by: default avatarZhewen Li <zhewenli@inferact.ai>
parent 32e45636
......@@ -132,10 +132,8 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
# NOTE(woosuk): Here we assume that the positions tensor has the
# shape [batch_size, seq_len].
cos = cos.repeat(1, 1, 2).unsqueeze(-2)
sin = sin.repeat(1, 1, 2).unsqueeze(-2)
cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
......@@ -197,3 +195,118 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbeddingBase):
return query, key
else:
return self.forward_native(positions, query, key, offsets)
class DeepseekV4ScalingRotaryEmbedding(DeepseekScalingRotaryEmbedding):
"""RotaryEmbedding extended with YaRN method.
Credits to Peng et al. github.com/jquesnelle/yarn
Compared to DeepseekScalingRotaryEmbedding:
- Applies RoPE to the last rotary_dim
- The forward method requires an inverse parameter to indicate
whether to negate the sin
- Supports applying RoPE to query only (without key)
- cos_sin_cache stored as fp32 for higher precision RoPE
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
cache_fp32 = self._compute_cos_sin_cache()
self.register_buffer("cos_sin_cache", cache_fp32, persistent=False)
def _compute_cos_sin_cache(self) -> torch.Tensor:
inv_freq = self._compute_inv_freq(self.scaling_factor)
t = torch.arange(
self.max_position_embeddings * self.scaling_factor,
device=current_platform.device_type,
dtype=torch.float32,
)
freqs = torch.einsum("i,j -> ij", t, inv_freq)
cos = freqs.cos() * self.mscale
sin = freqs.sin() * self.mscale
cache = torch.cat((cos, sin), dim=-1)
return cache
def forward_native(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
inverse: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
"""PyTorch-native implementation equivalent to forward()."""
head_size = query.size(-1)
query_rot = query[..., -self.rotary_dim :]
key_rot = key[..., -self.rotary_dim :] if key is not None else None
if self.rotary_dim < head_size:
query_pass = query[..., : -self.rotary_dim]
key_pass = key[..., : -self.rotary_dim] if key is not None else None
cos_sin = self.cos_sin_cache[
torch.add(positions, offsets) if offsets is not None else positions
]
cos, sin = cos_sin.chunk(2, dim=-1)
if self.is_neox_style:
cos = torch.cat((cos, cos), dim=-1).unsqueeze(-2)
sin = torch.cat((sin, sin), dim=-1).unsqueeze(-2)
else:
cos = cos.repeat_interleave(2, dim=-1).unsqueeze(-2)
sin = sin.repeat_interleave(2, dim=-1).unsqueeze(-2)
if inverse:
sin = -sin
rotate_fn = rotate_neox if self.is_neox_style else rotate_gptj
orig_dtype = query.dtype
query_rot = (query_rot * cos + rotate_fn(query_rot) * sin).to(orig_dtype)
if key_rot is not None:
key_rot = (key_rot * cos + rotate_fn(key_rot) * sin).to(orig_dtype)
if self.rotary_dim < head_size:
query = torch.cat((query_pass, query_rot), dim=-1)
key = torch.cat((key_pass, key_rot), dim=-1) if key is not None else None
else:
query = query_rot
key = key_rot
return query, key
def forward_hip(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
inverse: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
return self.forward_native(positions, query, key, offsets)
def forward_cuda(
self,
positions: torch.Tensor,
query: torch.Tensor,
key: torch.Tensor | None = None,
offsets: torch.Tensor | None = None,
inverse: bool = False,
) -> tuple[torch.Tensor, torch.Tensor | None]:
from vllm import _custom_ops as ops
# The indexer and attention have different head_dim,
# we obtain the corresponding head_dim via the query.
head_size = query.size(-1)
rope_dim_offset = head_size - self.rotary_dim
# ops.rotary_embedding() is an in-place operation
# that updates the query and key tensors.
ops.rotary_embedding(
torch.add(positions, offsets) if offsets is not None else positions,
query,
key,
head_size,
self.cos_sin_cache,
self.is_neox_style,
rope_dim_offset=rope_dim_offset,
inverse=inverse,
)
return query, key
......@@ -10,7 +10,11 @@ from vllm.forward_context import get_forward_context
from vllm.logger import init_logger
from vllm.model_executor.custom_op import CustomOp
from vllm.platforms import current_platform
from vllm.utils.deep_gemm import fp8_mqa_logits, fp8_paged_mqa_logits, has_deep_gemm
from vllm.utils.deep_gemm import (
fp8_fp4_mqa_logits,
fp8_fp4_paged_mqa_logits,
has_deep_gemm,
)
from vllm.utils.torch_utils import (
LayerNameType,
_encode_layer_name,
......@@ -32,12 +36,57 @@ logger = init_logger(__name__)
RADIX_TOPK_WORKSPACE_SIZE = 1024 * 1024
# MXFP4 layout: 2 values packed per byte, ue8m0 (1-byte) scale per block of 32.
MXFP4_BLOCK_SIZE = 32
def _gather_workspace_shapes(
total_seq_lens: int,
head_dim: int,
fp8_dtype: torch.dtype,
use_fp4_cache: bool,
) -> tuple[tuple[tuple[int, int], torch.dtype], tuple[tuple[int, int], torch.dtype]]:
"""Return ((values_shape, values_dtype), (scales_shape, scales_dtype)) for
the K-gather workspace. FP8 path: (T, head_dim) fp8 + (T, 4) uint8 fp32
scales. MXFP4 path: (T, head_dim // 2) uint8 packed mxfp4 +
(T, head_dim // MXFP4_BLOCK_SIZE) uint8 ue8m0 scales."""
if use_fp4_cache:
return (
((total_seq_lens, head_dim // 2), torch.uint8),
((total_seq_lens, head_dim // MXFP4_BLOCK_SIZE), torch.uint8),
)
return (
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
)
def kv_cache_as_quant_view(
kv_cache: torch.Tensor,
head_dim: int,
use_fp4_cache: bool,
) -> torch.Tensor:
"""4D ``[num_blocks, block_size, 1, head_width]`` view expected by
DeepGEMM, from the 3D indexer kv-cache allocation."""
if use_fp4_cache:
assert kv_cache.ndim == 3 and kv_cache.dtype == torch.uint8
num_blocks, block_size, _ = kv_cache.shape
page_bytes = int(kv_cache.stride(0))
fp4_bytes = head_dim // 2 + head_dim // MXFP4_BLOCK_SIZE
return torch.as_strided(
kv_cache,
size=(num_blocks, block_size, 1, fp4_bytes),
stride=(page_bytes, fp4_bytes, fp4_bytes, 1),
)
return kv_cache.unsqueeze(-2)
def sparse_attn_indexer(
hidden_states: torch.Tensor,
k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
q_quant: torch.Tensor,
q_scale: torch.Tensor | None,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
......@@ -47,6 +96,8 @@ def sparse_attn_indexer(
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor,
skip_k_cache_insert: bool,
use_fp4_cache: bool = False,
) -> torch.Tensor:
# careful! this will be None in dummy run
attn_metadata = get_forward_context().attn_metadata
......@@ -56,9 +107,12 @@ def sparse_attn_indexer(
# assert isinstance(attn_metadata, dict)
if not isinstance(attn_metadata, dict):
# Reserve workspace for indexer during profiling run
values_spec, scales_spec = _gather_workspace_shapes(
total_seq_lens, head_dim, fp8_dtype, use_fp4_cache
)
current_workspace_manager().get_simultaneous(
((total_seq_lens, head_dim), torch.float8_e4m3fn),
((total_seq_lens, 4), torch.uint8),
values_spec,
scales_spec,
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
)
......@@ -73,7 +127,8 @@ def sparse_attn_indexer(
hidden_states,
k_cache_prefix,
kv_cache,
q_fp8,
q_quant,
q_scale,
k,
weights,
quant_block_size,
......@@ -83,6 +138,8 @@ def sparse_attn_indexer(
max_model_len,
total_seq_lens,
topk_indices_buffer,
skip_k_cache_insert,
use_fp4_cache,
)
attn_metadata_narrowed = attn_metadata[k_cache_prefix]
assert isinstance(attn_metadata_narrowed, DeepseekV32IndexerMetadata)
......@@ -91,49 +148,81 @@ def sparse_attn_indexer(
has_prefill = attn_metadata_narrowed.num_prefills > 0
num_decode_tokens = attn_metadata_narrowed.num_decode_tokens
# q_scale is required iff the FP4 cache path is enabled; the FP8 path
# folds the Q scale into `weights` inside fused_indexer_q_rope_quant.
if use_fp4_cache:
assert q_scale is not None, "use_fp4_cache=True requires q_scale"
else:
assert q_scale is None, "q_scale must be None when use_fp4_cache=False"
# During speculative decoding, k may be padded to the CUDA graph batch
# size while slot_mapping only covers actual tokens. Truncate k to avoid
# out-of-bounds reads in the kernel.
num_tokens = slot_mapping.shape[0]
k = k[:num_tokens]
# scale_fmt can be None, but the function expects str
assert scale_fmt is not None
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
if k is not None:
k = k[:num_tokens]
if not skip_k_cache_insert:
# scale_fmt can be None, but the function expects str
assert scale_fmt is not None
assert not use_fp4_cache, "Unfused FP4 Insert is not supported yet"
ops.indexer_k_quant_and_cache(
k,
kv_cache,
slot_mapping,
quant_block_size,
scale_fmt,
)
topk_indices_buffer[: hidden_states.shape[0]] = -1
if has_prefill:
prefill_metadata = attn_metadata_narrowed.prefill
assert prefill_metadata is not None
# Get the full shared workspace buffers once (will allocate on first use)
# Get the full shared workspace buffers once (will allocate on first use).
# Layout switches between FP8 (head_dim bytes + 4-byte fp32 scale) and
# MXFP4 (head_dim/2 bytes packed + head_dim/MXFP4_BLOCK_SIZE ue8m0
# scales) based on use_fp4_cache.
workspace_manager = current_workspace_manager()
k_fp8_full, k_scale_full = workspace_manager.get_simultaneous(
((total_seq_lens, head_dim), fp8_dtype),
((total_seq_lens, 4), torch.uint8),
values_spec, scales_spec = _gather_workspace_shapes(
total_seq_lens, head_dim, fp8_dtype, use_fp4_cache
)
k_quant_full, k_scale_full = workspace_manager.get_simultaneous(
values_spec,
scales_spec,
)
for chunk in prefill_metadata.chunks:
k_fp8 = k_fp8_full[: chunk.total_seq_lens]
k_quant = k_quant_full[: chunk.total_seq_lens]
k_scale = k_scale_full[: chunk.total_seq_lens]
if not chunk.skip_kv_gather:
ops.cp_gather_indexer_k_quant_cache(
kv_cache,
k_fp8,
k_quant,
k_scale,
chunk.block_table,
chunk.cu_seq_lens,
)
logits = fp8_mqa_logits(
q_fp8[chunk.token_start : chunk.token_end],
(k_fp8, k_scale.view(torch.float32).flatten()),
q_slice = q_quant[chunk.token_start : chunk.token_end]
q_scale_slice = (
q_scale[chunk.token_start : chunk.token_end]
if q_scale is not None
else None
)
# DeepGEMM scalar-type tags (zero-copy): MXFP4 values → int8
# (kPackedFP4), scales → int32 squeezed to 1-D kv_sf / 2-D q_sf.
if use_fp4_cache:
q_slice_cast = q_slice.view(torch.int8)
k_quant_cast = k_quant.view(torch.int8)
k_scale_cast = k_scale.view(torch.int32).squeeze(-1)
else:
q_slice_cast = q_slice
k_quant_cast = k_quant
k_scale_cast = k_scale.view(torch.float32).squeeze(-1)
logits = fp8_fp4_mqa_logits(
(q_slice_cast, q_scale_slice),
(k_quant_cast, k_scale_cast),
weights[chunk.token_start : chunk.token_end],
chunk.cu_seqlen_ks,
chunk.cu_seqlen_ke,
......@@ -171,32 +260,55 @@ def sparse_attn_indexer(
if has_decode:
decode_metadata = attn_metadata_narrowed.decode
assert decode_metadata is not None
# kv_cache shape [
# kv_cache size requirement [num_block, block_size, n_head, head_dim],
# we only have [num_block, block_size, head_dim],
kv_cache = kv_cache.unsqueeze(-2)
kv_cache = kv_cache_as_quant_view(kv_cache, head_dim, use_fp4_cache)
decode_lens = decode_metadata.decode_lens
if decode_metadata.requires_padding:
# pad in edge case where we have short chunked prefill length <
# decode_threshold since we unstrictly split
# prefill and decode by decode_threshold
# (currently set to 1 + speculative tokens)
padded_q_fp8_decode_tokens = pack_seq_triton(
q_fp8[:num_decode_tokens], decode_lens
)
# (currently set to 1 + speculative tokens).
# FP8 Q is float8_e4m3fn (pack_seq_triton's fp32 pad path is OK —
# downstream context_lens masks stale slots). MXFP4 Q is two
# uint8 tensors (values + ue8m0 scales) — use the dedicated uint8
# packer with pad_byte=0 so padded slots dequantize to 0 and
# can't produce NaN/Inf in the logits kernel.
if q_scale is not None:
padded_q_quant_decode_tokens = pack_seq_triton(
q_quant[:num_decode_tokens], decode_lens, pad_value=0
)
padded_q_scale = pack_seq_triton(
q_scale[:num_decode_tokens], decode_lens, pad_value=0
)
else:
padded_q_quant_decode_tokens = pack_seq_triton(
q_quant[:num_decode_tokens], decode_lens
)
padded_q_scale = None
else:
padded_q_fp8_decode_tokens = q_fp8[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_fp8.shape[1:]
padded_q_quant_decode_tokens = q_quant[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_quant.shape[1:]
)
if q_scale is not None:
padded_q_scale = q_scale[:num_decode_tokens].reshape(
decode_lens.shape[0], -1, *q_scale.shape[1:]
)
else:
padded_q_scale = None
# TODO: move and optimize below logic with triton kernels
batch_size = padded_q_fp8_decode_tokens.shape[0]
next_n = padded_q_fp8_decode_tokens.shape[1]
batch_size = padded_q_quant_decode_tokens.shape[0]
next_n = padded_q_quant_decode_tokens.shape[1]
num_padded_tokens = batch_size * next_n
seq_lens = decode_metadata.seq_lens[:batch_size]
# seq_lens is (B, next_n) for native spec decode, (B,) otherwise.
# fp8_paged_mqa_logits and all topk kernels accept both shapes.
logits = fp8_paged_mqa_logits(
padded_q_fp8_decode_tokens,
# seq_lens is always 2D: (B, next_n) for native spec decode, (B, 1)
# otherwise. deep_gemm fp8_fp4_paged_mqa_logits requires 2D context_lens;
# the downstream topk kernels accept both 1D and 2D.
padded_q_quant_cast = (
padded_q_quant_decode_tokens.view(torch.int8)
if use_fp4_cache
else padded_q_quant_decode_tokens
)
logits = fp8_fp4_paged_mqa_logits(
(padded_q_quant_cast, padded_q_scale),
kv_cache,
weights[:num_padded_tokens],
seq_lens,
......@@ -208,7 +320,7 @@ def sparse_attn_indexer(
num_rows = logits.shape[0]
topk_indices = topk_indices_buffer[:num_padded_tokens, :topk_tokens]
if current_platform.is_cuda():
if current_platform.is_cuda() and topk_tokens in (512, 1024, 2048):
workspace_manager = current_workspace_manager()
(topk_workspace,) = workspace_manager.get_simultaneous(
((RADIX_TOPK_WORKSPACE_SIZE,), torch.uint8),
......@@ -263,7 +375,8 @@ def sparse_attn_indexer_fake(
hidden_states: torch.Tensor,
k_cache_prefix: LayerNameType,
kv_cache: torch.Tensor,
q_fp8: torch.Tensor,
q_quant: torch.Tensor,
q_scale: torch.Tensor | None,
k: torch.Tensor,
weights: torch.Tensor,
quant_block_size: int,
......@@ -273,6 +386,8 @@ def sparse_attn_indexer_fake(
max_model_len: int,
total_seq_lens: int,
topk_indices_buffer: torch.Tensor | None,
skip_k_cache_insert: bool,
use_fp4_cache: bool = False,
) -> torch.Tensor:
return topk_indices_buffer
......@@ -309,6 +424,8 @@ class SparseAttnIndexer(CustomOp):
max_model_len: int,
max_total_seq_len: int,
topk_indices_buffer: torch.Tensor,
skip_k_cache_insert: bool = False,
use_fp4_cache: bool = False,
):
super().__init__()
self.k_cache = k_cache
......@@ -319,6 +436,8 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len = max_model_len
self.max_total_seq_len = max_total_seq_len
self.topk_indices_buffer = topk_indices_buffer
self.skip_k_cache_insert = skip_k_cache_insert
self.use_fp4_cache = use_fp4_cache
if current_platform.is_cuda() and not has_deep_gemm():
raise RuntimeError(
"Sparse Attention Indexer CUDA op requires DeepGEMM to be installed."
......@@ -327,14 +446,14 @@ class SparseAttnIndexer(CustomOp):
def forward_native(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
weights: torch.Tensor,
):
if current_platform.is_cuda() or current_platform.is_xpu():
return self.forward_cuda(hidden_states, q_fp8, k, weights)
return self.forward_cuda(hidden_states, q_quant, k, weights)
elif current_platform.is_rocm():
return self.forward_hip(hidden_states, q_fp8, k, weights)
return self.forward_hip(hidden_states, q_quant, k, weights)
else:
raise NotImplementedError(
"SparseAttnIndexer native forward is only implemented for "
......@@ -344,15 +463,22 @@ class SparseAttnIndexer(CustomOp):
def forward_cuda(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
weights: torch.Tensor,
):
# FP8 path: single tensor (per-token scale is folded into `weights`).
# FP4 path: (values, scales) tuple with scales required by the kernel.
if isinstance(q_quant, tuple):
q_values, q_scale = q_quant
else:
q_values, q_scale = q_quant, None
return torch.ops.vllm.sparse_attn_indexer(
hidden_states,
_encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache,
q_fp8,
q_values,
q_scale,
k,
weights,
self.quant_block_size,
......@@ -362,21 +488,30 @@ class SparseAttnIndexer(CustomOp):
self.max_model_len,
self.max_total_seq_len,
self.topk_indices_buffer,
self.skip_k_cache_insert,
self.use_fp4_cache,
)
def forward_hip(
self,
hidden_states: torch.Tensor,
q_fp8: torch.Tensor,
q_quant: torch.Tensor | tuple[torch.Tensor, torch.Tensor],
k: torch.Tensor,
weights: torch.Tensor,
):
assert not self.skip_k_cache_insert, (
"AMD platform doesn't support skip cache insert yet"
)
assert not self.use_fp4_cache, "AMD platform doesn't support fp4 cache yet"
assert isinstance(q_quant, torch.Tensor), (
"AMD sparse_attn_indexer expects a single FP8 q_quant tensor"
)
if rocm_aiter_ops.is_enabled():
return torch.ops.vllm.rocm_aiter_sparse_attn_indexer(
hidden_states,
_encode_layer_name(self.k_cache.prefix),
self.k_cache.kv_cache,
q_fp8,
q_quant,
k,
weights,
self.quant_block_size,
......
......@@ -299,6 +299,13 @@ def cpu_unquantized_gemm(
return layer.cpu_linear(x, weight, bias)
def cublas_gemm_bf16_bf16_fp32(
x: torch.Tensor,
weight: torch.Tensor,
):
return ops.router_gemm_bf16_fp32(x, weight)
def dispatch_unquantized_gemm() -> Callable[..., torch.Tensor]:
if current_platform.is_rocm():
return rocm_unquantized_gemm
......
......@@ -107,6 +107,31 @@ class Gemma4Config(VerifyAndUpdateConfig):
)
class DeepseekV4ForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
quant_config = getattr(model_config.hf_config, "quantization_config", None)
if quant_config is not None and quant_config.get("quant_method") == "fp8":
model_type = getattr(model_config.hf_config, "model_type", None)
if model_type == "deepseek_v4":
model_config.hf_config.quantization_config["quant_method"] = (
"deepseek_v4_fp8"
)
hf_text_quant_config = getattr(
model_config.hf_text_config, "quantization_config", None
)
if (
hf_text_quant_config is not None
and hf_text_quant_config.get("quant_method") == "fp8"
):
model_type = getattr(model_config.hf_text_config, "model_type", None)
if model_type == "deepseek_v4":
model_config.hf_text_config.quantization_config["quant_method"] = (
"deepseek_v4_fp8"
)
class GptOssForCausalLMConfig(VerifyAndUpdateConfig):
@staticmethod
def verify_and_update_model_config(model_config: "ModelConfig") -> None:
......@@ -635,6 +660,7 @@ class VoyageQwen3BidirectionalEmbedModelConfig(VerifyAndUpdateConfig):
MODELS_CONFIG_MAP: dict[str, type[VerifyAndUpdateConfig]] = {
"ColBERTJinaRobertaModel": JinaRobertaModelConfig,
"ColQwen3_5": Qwen3_5ForConditionalGenerationConfig,
"DeepseekV4ForCausalLM": DeepseekV4ForCausalLMConfig,
"DeepseekV32ForCausalLM": DeepseekV32ForCausalLM,
"Ernie4_5_VLMoeForConditionalGeneration": Ernie4_5_VLMoeForConditionalGenerationConfig, # noqa: E501
"FalconMambaForCausalLM": MambaModelConfig,
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import typing
from collections.abc import Callable, Iterable
from itertools import islice
import regex as re
import torch
import torch.nn as nn
import torch.nn.functional as F
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import (
get_ep_group,
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.forward_context import get_forward_context
from vllm.model_executor.layers.deepseek_v4_attention import (
DeepseekV4Indexer,
DeepseekV4MLAModules,
DeepseekV4MultiHeadLatentAttentionWrapper,
)
from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear
from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod
from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import (
fused_topk_bias,
)
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import (
ColumnParallelLinear,
MergedColumnParallelLinear,
RowParallelLinear,
)
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.quantization import QuantizationMethods
from vllm.model_executor.layers.quantization.fp8 import Fp8Config
from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod
from vllm.model_executor.layers.quantization.utils.quant_utils import (
is_layer_skipped,
)
from vllm.model_executor.layers.rotary_embedding import get_rope
from vllm.model_executor.layers.vocab_parallel_embedding import (
ParallelLMHead,
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.model_executor.models.deepseek_v2 import DeepseekV2MLP
from vllm.model_executor.utils import set_weight_attrs
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.triton_utils import tl, triton
from vllm.utils.multi_stream_utils import AuxStreamType
from vllm.utils.torch_utils import direct_register_custom_op
from .utils import (
AutoWeightsLoader,
WeightsMapper,
extract_layer_index,
make_layers,
maybe_prefix,
)
class DeepseekV4FP8Config(Fp8Config):
"""FP8 config that routes MoE layers to MXFP4 quantization.
DeepSeek V4 checkpoints use FP8 for linear/attention layers but
MXFP4 for MoE expert weights. This config inherits standard FP8
behavior and overrides only the MoE dispatch.
"""
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.is_scale_e8m0: bool = True
@classmethod
def get_name(cls) -> QuantizationMethods:
return "deepseek_v4_fp8"
@classmethod
def override_quantization_method(
cls, hf_quant_cfg, user_quant, hf_config=None
) -> QuantizationMethods | None:
if not (
isinstance(hf_quant_cfg, dict)
and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8")
):
return None
model_type = getattr(hf_config, "model_type", None)
if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8":
return "deepseek_v4_fp8"
return None
def get_quant_method(self, layer, prefix):
if isinstance(layer, FusedMoE):
if is_layer_skipped(
prefix=prefix,
ignored_layers=self.ignored_layers,
fused_mapping=self.packed_modules_mapping,
):
return UnquantizedFusedMoEMethod(layer.moe_config)
return Mxfp4MoEMethod(layer.moe_config)
return super().get_quant_method(layer, prefix)
def is_mxfp4_quant(self, prefix, layer):
return isinstance(layer, FusedMoE)
@triton.jit
def _deepseek_v4_stage_mega_moe_inputs_kernel(
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_stride_m: tl.constexpr,
hidden_stride_k: tl.constexpr,
x_stride_m: tl.constexpr,
x_stride_k: tl.constexpr,
x_sf_stride_m: tl.constexpr,
x_sf_stride_k: tl.constexpr,
topk_ids_stride_m: tl.constexpr,
topk_ids_stride_k: tl.constexpr,
topk_weights_stride_m: tl.constexpr,
topk_weights_stride_k: tl.constexpr,
topk_idx_stride_m: tl.constexpr,
topk_idx_stride_k: tl.constexpr,
topk_weights_out_stride_m: tl.constexpr,
topk_weights_out_stride_k: tl.constexpr,
hidden_size: tl.constexpr,
top_k: tl.constexpr,
BLOCK_K: tl.constexpr,
GROUP_K: tl.constexpr,
BLOCK_TOPK: tl.constexpr,
) -> None:
token_id = tl.program_id(0)
k_block_id = tl.program_id(1)
k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K)
k_mask = k_offsets < hidden_size
hidden = tl.load(
hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k,
mask=k_mask,
other=0.0,
).to(tl.float32)
num_groups: tl.constexpr = BLOCK_K // GROUP_K
hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K])
amax = tl.max(hidden_groups, axis=1)
amax = tl.maximum(amax, 1.0e-4)
scale = amax / 448.0
scale_bits = scale.to(tl.uint32, bitcast=True)
scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to(
tl.uint32
)
scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254)
rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True)
hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K])
scaled = hidden_groups * (1.0 / rounded_scale)[:, None]
scaled = tl.reshape(scaled, [BLOCK_K])
fp8 = scaled.to(tl.float8e4nv)
tl.store(
x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k,
fp8,
mask=k_mask,
)
scale_offsets = tl.arange(0, num_groups)
packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32)
tl.store(
x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k,
packed_scale,
)
if k_block_id == 0:
topk_offsets = tl.arange(0, BLOCK_TOPK)
topk_mask = topk_offsets < top_k
ids = tl.load(
topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k,
mask=topk_mask,
other=0,
).to(tl.int64)
tl.store(
topk_idx_out
+ token_id * topk_idx_stride_m
+ topk_offsets * topk_idx_stride_k,
ids,
mask=topk_mask,
)
weights = tl.load(
topk_weights
+ token_id * topk_weights_stride_m
+ topk_offsets * topk_weights_stride_k,
mask=topk_mask,
other=0.0,
)
tl.store(
topk_weights_out
+ token_id * topk_weights_out_stride_m
+ topk_offsets * topk_weights_out_stride_k,
weights,
mask=topk_mask,
)
def _stage_deepseek_v4_mega_moe_inputs(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
x_fp8: torch.Tensor,
x_sf: torch.Tensor,
topk_idx_out: torch.Tensor,
topk_weights_out: torch.Tensor,
) -> None:
num_tokens, hidden_size = hidden_states.shape
if num_tokens == 0:
return
if hidden_size % 128 != 0:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires hidden_size to be "
"a multiple of 128."
)
top_k = topk_ids.shape[1]
if topk_weights.shape != topk_ids.shape:
raise ValueError(
"DeepSeek V4 MegaMoE input staging requires topk_weights and "
"topk_ids to have the same shape."
)
block_k = 128
grid = (num_tokens, triton.cdiv(hidden_size, block_k))
block_topk = triton.next_power_of_2(top_k)
_deepseek_v4_stage_mega_moe_inputs_kernel[grid](
hidden_states,
x_fp8,
x_sf,
topk_ids,
topk_weights,
topk_idx_out,
topk_weights_out,
hidden_states.stride(0),
hidden_states.stride(1),
x_fp8.stride(0),
x_fp8.stride(1),
x_sf.stride(0),
x_sf.stride(1),
topk_ids.stride(0),
topk_ids.stride(1),
topk_weights.stride(0),
topk_weights.stride(1),
topk_idx_out.stride(0),
topk_idx_out.stride(1),
topk_weights_out.stride(0),
topk_weights_out.stride(1),
hidden_size,
top_k,
BLOCK_K=block_k,
GROUP_K=32,
BLOCK_TOPK=block_topk,
num_warps=4,
)
def make_deepseek_v4_expert_params_mapping(
num_experts: int,
) -> list[tuple[str, str, int, str]]:
return [
(
"experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_",
f"experts.{expert_id}.{weight_name}.",
expert_id,
shard_id,
)
for expert_id in range(num_experts)
for shard_id, weight_name in [
("w1", "w1"),
("w2", "w2"),
("w3", "w3"),
]
]
class DeepseekV4MegaMoEExperts(nn.Module):
_symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {}
def __init__(
self,
vllm_config: VllmConfig,
*,
num_experts: int,
num_local_experts: int,
experts_start_idx: int,
top_k: int,
hidden_size: int,
intermediate_size: int,
prefix: str = "",
):
super().__init__()
self.prefix = prefix
self.num_experts = num_experts
self.num_local_experts = num_local_experts
self.experts_start_idx = experts_start_idx
self.experts_end_idx = experts_start_idx + num_local_experts
self.top_k = top_k
self.hidden_size = hidden_size
self.intermediate_size = intermediate_size
self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens
weight_attrs = {"weight_loader": self.weight_loader}
self.w13_weight = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight, weight_attrs)
self.w13_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
2 * intermediate_size,
hidden_size // 32,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w13_weight_scale, weight_attrs)
self.w13_weight_scale.quant_method = "block"
self.w2_weight = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 2,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight, weight_attrs)
self.w2_weight_scale = nn.Parameter(
torch.zeros(
num_local_experts,
hidden_size,
intermediate_size // 32,
dtype=torch.uint8,
),
requires_grad=False,
)
set_weight_attrs(self.w2_weight_scale, weight_attrs)
self.w2_weight_scale.quant_method = "block"
self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None
self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None
# Register in the static forward context so the custom-op wrapper
# can look up this module by name from within a torch.compile graph.
compilation_config = vllm_config.compilation_config
if prefix in compilation_config.static_forward_context:
raise ValueError(f"Duplicate layer name: {prefix}")
compilation_config.static_forward_context[prefix] = self
def _map_global_expert_id(self, expert_id: int) -> int:
if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx:
return -1
return expert_id - self.experts_start_idx
def weight_loader(
self,
param: nn.Parameter,
loaded_weight: torch.Tensor,
weight_name: str,
shard_id: str,
expert_id: int,
return_success: bool = False,
) -> bool | None:
local_expert_id = self._map_global_expert_id(expert_id)
if local_expert_id == -1:
return False if return_success else None
expert_data = param.data[local_expert_id]
if shard_id in ("w1", "w3"):
if "w13_" not in weight_name:
return False if return_success else None
shard_offset = 0 if shard_id == "w1" else self.intermediate_size
expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size)
elif shard_id == "w2":
if "w2_" not in weight_name:
return False if return_success else None
else:
raise ValueError(f"Unsupported expert shard id: {shard_id}")
if expert_data.shape != loaded_weight.shape:
raise ValueError(
f"DeepSeek V4 MegaMoE expert weight shape mismatch for "
f"{weight_name}: parameter shard {tuple(expert_data.shape)} "
f"vs checkpoint {tuple(loaded_weight.shape)}"
)
expert_data.copy_(loaded_weight)
return True if return_success else None
@staticmethod
def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor:
return (sf.to(torch.int32) << 23).view(torch.float32)
def _check_runtime_supported(self) -> None:
if not torch.cuda.is_available():
raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.")
device = self.w13_weight.device
if device.type != "cuda":
raise NotImplementedError(
"DeepSeek V4 MegaMoE expert weights must be loaded on CUDA."
)
if torch.cuda.get_device_capability(device)[0] != 10:
raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.")
if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0:
raise ValueError(
"DeepGEMM MegaMoE requires hidden and intermediate sizes "
"to be multiples of 128."
)
def finalize_weights(self) -> None:
if self._transformed_l1_weights is not None:
return
self._check_runtime_supported()
import vllm.third_party.deep_gemm as deep_gemm
w13_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(),
2 * self.intermediate_size,
self.hidden_size,
(1, 32),
self.num_local_experts,
)
w2_scale = deep_gemm.transform_sf_into_required_layout(
self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(),
self.hidden_size,
self.intermediate_size,
(1, 32),
self.num_local_experts,
)
self._transformed_l1_weights, self._transformed_l2_weights = (
deep_gemm.transform_weights_for_mega_moe(
(self.w13_weight.data.view(torch.int8).contiguous(), w13_scale),
(self.w2_weight.data.view(torch.int8).contiguous(), w2_scale),
)
)
# Drop the original loader-side parameters: the MegaMoE kernels only
# consume the transformed views above. transform_weights_for_mega_moe
# allocates a fresh tensor for the L1 weight (see _interleave_l1_weights)
# and fresh SF tensors for L1/L2; the L2 weight is the only tensor that
# aliases the original storage, and _transformed_l2_weights still holds
# it, so the storage stays live after we drop the Parameter.
self.w13_weight = None
self.w13_weight_scale = None
self.w2_weight = None
self.w2_weight_scale = None
def get_symm_buffer(self):
import vllm.third_party.deep_gemm as deep_gemm
group = get_ep_group().device_group
device = torch.accelerator.current_device_index()
key = (
id(group),
device,
self.num_experts,
self.max_num_tokens,
self.top_k,
self.hidden_size,
self.intermediate_size,
)
symm_buffer = self._symm_buffer_cache.get(key)
if symm_buffer is None:
symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe(
group,
self.num_experts,
self.max_num_tokens,
self.top_k,
self.hidden_size,
self.intermediate_size,
)
self._symm_buffer_cache[key] = symm_buffer
return symm_buffer
def forward(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
*,
activation_clamp: float | None,
fast_math: bool = True,
) -> torch.Tensor:
if hidden_states.shape[0] > self.max_num_tokens:
raise ValueError(
f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, "
f"but the symmetric buffer was sized for {self.max_num_tokens}."
)
y = torch.empty_like(hidden_states, dtype=torch.bfloat16)
torch.ops.vllm.deepseek_v4_mega_moe_experts(
hidden_states,
topk_weights,
topk_ids,
y,
self.prefix,
activation_clamp,
fast_math,
)
return y
def _run_mega_moe(
self,
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
y: torch.Tensor,
activation_clamp: float | None,
fast_math: bool,
) -> None:
import vllm.third_party.deep_gemm as deep_gemm
symm_buffer = self.get_symm_buffer()
num_tokens = hidden_states.shape[0]
_stage_deepseek_v4_mega_moe_inputs(
hidden_states,
topk_weights,
topk_ids,
symm_buffer.x[:num_tokens],
symm_buffer.x_sf[:num_tokens],
symm_buffer.topk_idx[:num_tokens],
symm_buffer.topk_weights[:num_tokens],
)
# This method must have been already called during the weight loading phase.
# We call it again here to cover the dummy weight loading case.
self.finalize_weights()
assert self._transformed_l1_weights is not None
assert self._transformed_l2_weights is not None
deep_gemm.fp8_fp4_mega_moe(
y,
self._transformed_l1_weights,
self._transformed_l2_weights,
symm_buffer,
activation_clamp=activation_clamp,
fast_math=fast_math,
)
DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined]
def _deepseek_v4_mega_moe_experts_op(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
out: torch.Tensor,
layer_name: str,
activation_clamp: float | None,
fast_math: bool,
) -> None:
self = get_forward_context().no_compile_layers[layer_name]
self._run_mega_moe(
hidden_states,
topk_weights,
topk_ids,
out,
activation_clamp,
fast_math,
)
def _deepseek_v4_mega_moe_experts_op_fake(
hidden_states: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
out: torch.Tensor,
layer_name: str,
activation_clamp: float | None,
fast_math: bool,
) -> None:
return None
direct_register_custom_op(
op_name="deepseek_v4_mega_moe_experts",
op_func=_deepseek_v4_mega_moe_experts_op,
mutates_args=["out"],
fake_impl=_deepseek_v4_mega_moe_experts_op_fake,
)
class DeepseekV4MoE(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str = "",
):
super().__init__()
self.tp_size = get_tensor_model_parallel_world_size()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.prefix = prefix
if vllm_config.parallel_config.enable_expert_parallel:
self.use_mega_moe = (
vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe"
)
else:
self.use_mega_moe = False
self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0)
self.hidden_size = config.hidden_size
self.n_routed_experts = config.n_routed_experts
self.n_activated_experts = config.num_experts_per_tok
self.moe_intermediate_size = config.moe_intermediate_size
self.swiglu_limit = config.swiglu_limit
self.renormalize = config.norm_topk_prob
self.scoring_func = getattr(config, "scoring_func", "sqrtsoftplus")
if self.use_mega_moe and self.scoring_func != "sqrtsoftplus":
raise NotImplementedError(
"DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only."
)
self.gate = GateLinear(
config.hidden_size,
config.n_routed_experts,
out_dtype=torch.float32,
bias=False,
prefix=f"{prefix}.gate",
)
self.gate.e_score_correction_bias = None
self.gate.tid2eid = None
is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers
self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32
if is_hash_moe:
# hash MoE doesn't use e_score_correction_bias
# Use randint instead of empty to avoid garbage values causing
# invalid memory access in dummy mode (--load-format="dummy")
self.gate.tid2eid = nn.Parameter(
torch.randint(
0,
config.n_routed_experts,
(config.vocab_size, config.num_experts_per_tok),
dtype=self.hash_indices_dtype,
),
requires_grad=False,
)
elif getattr(config, "topk_method", None) == "noaux_tc":
self.gate.e_score_correction_bias = nn.Parameter(
torch.empty(config.n_routed_experts, dtype=torch.float32),
requires_grad=False,
)
if config.n_shared_experts is None:
self.shared_experts = None
else:
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
self.shared_experts = DeepseekV2MLP(
hidden_size=config.hidden_size,
intermediate_size=intermediate_size,
hidden_act=config.hidden_act,
quant_config=quant_config,
reduce_results=self.use_mega_moe,
prefix=f"{prefix}.shared_experts",
)
if self.use_mega_moe:
self._init_mega_moe_experts(vllm_config, config, prefix)
else:
self._init_fused_moe_experts(config, quant_config, prefix)
def _init_mega_moe_experts(
self,
vllm_config: VllmConfig,
config,
prefix: str,
) -> None:
self.ep_group = get_ep_group()
self.ep_size = self.ep_group.world_size
self.ep_rank = self.ep_group.rank_in_group
assert config.n_routed_experts % self.ep_size == 0
self.n_local_experts = config.n_routed_experts // self.ep_size
self.experts_start_idx = self.ep_rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.experts = DeepseekV4MegaMoEExperts(
vllm_config,
num_experts=config.n_routed_experts,
num_local_experts=self.n_local_experts,
experts_start_idx=self.experts_start_idx,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
prefix=f"{prefix}.experts",
)
def _init_fused_moe_experts(
self,
config,
quant_config,
prefix: str,
) -> None:
self.tp_rank = get_tensor_model_parallel_rank()
assert config.n_routed_experts % self.tp_size == 0
self.n_local_experts = config.n_routed_experts // self.tp_size
self.experts_start_idx = self.tp_rank * self.n_local_experts
self.experts_end_idx = self.experts_start_idx + self.n_local_experts
self.experts = FusedMoE(
shared_experts=self.shared_experts,
gate=self.gate,
num_experts=config.n_routed_experts,
top_k=config.num_experts_per_tok,
hidden_size=config.hidden_size,
intermediate_size=config.moe_intermediate_size,
renormalize=config.norm_topk_prob,
quant_config=quant_config,
prefix=f"{prefix}.experts",
scoring_func=self.scoring_func,
routed_scaling_factor=self.routed_scaling_factor,
e_score_correction_bias=self.gate.e_score_correction_bias,
hash_indices_table=self.gate.tid2eid,
swiglu_limit=self.swiglu_limit,
router_logits_dtype=torch.float32,
)
def forward(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
) -> torch.Tensor:
if self.gate.tid2eid is not None:
if input_ids is None:
raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.")
input_ids = input_ids.to(dtype=self.hash_indices_dtype)
if not self.use_mega_moe:
return self._forward_fused_moe(hidden_states, input_ids)
org_shape = hidden_states.shape
router_logits, _ = self.gate(hidden_states)
topk_weights, topk_ids = fused_topk_bias(
hidden_states=hidden_states,
gating_output=router_logits,
scoring_func=self.scoring_func,
e_score_correction_bias=self.gate.e_score_correction_bias.data
if self.gate.e_score_correction_bias is not None
else None,
topk=self.n_activated_experts,
renormalize=self.renormalize,
indices_type=self.hash_indices_dtype,
input_tokens=input_ids,
hash_indices_table=self.gate.tid2eid,
routed_scaling_factor=self.routed_scaling_factor,
)
activation_clamp = (
float(self.swiglu_limit) if self.swiglu_limit is not None else None
)
final_hidden_states = self.experts(
hidden_states,
topk_weights,
topk_ids,
activation_clamp=activation_clamp,
)
if self.shared_experts is not None:
shared_output = self.shared_experts(hidden_states)
final_hidden_states += shared_output
return final_hidden_states.view(org_shape)
def _forward_fused_moe(
self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None
) -> torch.Tensor:
org_shape = hidden_states.shape
if self.experts.is_internal_router:
# In this case, the gate/router runs inside the FusedMoE class
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=hidden_states,
input_ids=input_ids,
)
else:
router_logits, _ = self.gate(hidden_states)
final_hidden_states = self.experts(
hidden_states=hidden_states,
router_logits=router_logits,
input_ids=input_ids,
)
return final_hidden_states.view(org_shape)
def finalize_mega_moe_weights(self) -> None:
if self.use_mega_moe:
self.experts.finalize_weights()
class DeepseekV4Attention(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
prefix: str,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream: torch.cuda.Stream | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
layer_id = extract_layer_index(prefix)
self.layer_id = layer_id
self.hidden_size = config.hidden_size
self.n_heads = config.num_attention_heads
tp_size = get_tensor_model_parallel_world_size()
assert self.n_heads % tp_size == 0
self.n_local_heads = self.n_heads // tp_size
self.q_lora_rank = config.q_lora_rank
self.o_lora_rank = config.o_lora_rank
self.head_dim = config.head_dim
self.rope_head_dim = config.qk_rope_head_dim
self.nope_head_dim = self.head_dim - self.rope_head_dim
self.n_groups = config.o_groups
self.n_local_groups = self.n_groups // tp_size
self.window_size = config.sliding_window
# NOTE(zyongye) Compress ratio can't be 0
# we do this for because MTP layer is not included
# in the compress ratio list
if layer_id < config.num_hidden_layers:
self.compress_ratio = max(1, config.compress_ratios[layer_id])
else:
self.compress_ratio = 1
self.eps = config.rms_norm_eps
self.max_position_embeddings = config.max_position_embeddings
# Padded to min 64 heads for FlashMLA, initialized to -inf
# (no sink effect). Weight loading fills the first n_local_heads slots.
padded_heads = max(self.n_local_heads, 64)
self.attn_sink = nn.Parameter(
torch.full((padded_heads,), -float("inf"), dtype=torch.float32),
requires_grad=False,
)
self.fused_wqa_wkv = MergedColumnParallelLinear(
self.hidden_size,
[self.q_lora_rank, self.head_dim],
bias=False,
quant_config=quant_config,
prefix=f"{prefix}.fused_wqa_wkv",
disable_tp=True, # fused ReplicatedLinear
)
self.q_norm = RMSNorm(self.q_lora_rank, self.eps)
self.wq_b = ColumnParallelLinear(
self.q_lora_rank,
self.n_heads * self.head_dim,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wq_b",
)
self.kv_norm = RMSNorm(self.head_dim, self.eps)
self.wo_a = ColumnParallelLinear(
self.n_heads * self.head_dim // self.n_groups,
self.n_groups * self.o_lora_rank,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_a",
)
self.wo_a.is_bmm = True
self.wo_a.bmm_batch_size = self.n_local_groups
self.wo_b = RowParallelLinear(
self.n_groups * self.o_lora_rank,
self.hidden_size,
bias=False,
quant_config=quant_config,
return_bias=False,
prefix=f"{prefix}.wo_b",
)
self.softmax_scale = self.head_dim**-0.5
self.scale_fmt = config.quantization_config["scale_fmt"]
self.rope_parameters = config.rope_scaling
# Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it)
rope_parameters = config.rope_parameters
rope_parameters["rope_theta"] = (
config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta
)
if config.rope_parameters["rope_type"] != "default":
config.rope_parameters["rope_type"] = (
"deepseek_yarn"
if config.rope_parameters.get("apply_yarn_scaling", True)
else "deepseek_llama_scaling"
)
rope_parameters["mscale"] = 0 # Disable mscale
rope_parameters["mscale_all_dim"] = 0 # Disable mscale
rope_parameters["is_deepseek_v4"] = True
rope_parameters["rope_dim"] = self.rope_head_dim
self.rotary_emb = get_rope(
self.head_dim,
max_position=self.max_position_embeddings,
rope_parameters=rope_parameters,
is_neox_style=False,
dtype=config.torch_dtype,
)
self.indexer = None
if self.compress_ratio == 4:
# Only C4A uses sparse attention and hence has indexer.
self.indexer = DeepseekV4Indexer(
vllm_config,
config=config,
hidden_size=self.hidden_size,
q_lora_rank=self.q_lora_rank,
quant_config=quant_config,
cache_config=vllm_config.cache_config,
topk_indices_buffer=topk_indices_buffer,
compress_ratio=self.compress_ratio,
prefix=f"{prefix}.indexer",
)
mla_modules = DeepseekV4MLAModules(
vllm_config=vllm_config,
fused_wqa_wkv=self.fused_wqa_wkv,
q_norm=self.q_norm,
wq_b=self.wq_b,
kv_norm=self.kv_norm,
wo_a=self.wo_a,
wo_b=self.wo_b,
attn_sink=self.attn_sink,
rotary_emb=self.rotary_emb,
indexer=self.indexer,
indexer_rotary_emb=self.rotary_emb,
topk_indices_buffer=topk_indices_buffer,
aux_stream=aux_stream,
)
self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper(
hidden_size=self.hidden_size,
num_heads=self.n_local_heads,
head_dim=self.head_dim,
scale=self.softmax_scale,
qk_nope_head_dim=self.nope_head_dim,
qk_rope_head_dim=self.rope_head_dim,
v_head_dim=self.head_dim,
q_lora_rank=self.q_lora_rank,
kv_lora_rank=self.head_dim,
o_lora_rank=self.o_lora_rank,
mla_modules=mla_modules,
window_size=self.window_size,
compress_ratio=self.compress_ratio,
cache_config=vllm_config.cache_config,
quant_config=quant_config,
prefix=prefix,
)
def forward(
self,
positions: torch.Tensor,
hidden_states: torch.Tensor,
llama_4_scaling: torch.Tensor | None,
):
return self.mla_attn(positions, hidden_states, llama_4_scaling)
class DeepseekV4DecoderLayer(nn.Module):
def __init__(
self,
vllm_config,
prefix,
topk_indices_buffer: torch.Tensor | None = None,
aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None,
):
super().__init__()
config = vllm_config.model_config.hf_config
self.hidden_size = config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
self.attn = DeepseekV4Attention(
vllm_config,
prefix=f"{prefix}.attn",
topk_indices_buffer=topk_indices_buffer,
aux_stream=aux_stream_dict.get(AuxStreamType.Attention)
if aux_stream_dict is not None
else None,
)
self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn")
self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps)
self.hc_mult = config.hc_mult
self.hc_sinkhorn_iters = config.hc_sinkhorn_iters
self.hc_eps = config.hc_eps
self.hc_post_alpha = 2.0
mix_hc = (2 + self.hc_mult) * self.hc_mult
hc_dim = self.hc_mult * self.hidden_size
self.hc_attn_fn = nn.Parameter(
torch.empty(
(mix_hc, hc_dim),
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_ffn_fn = nn.Parameter(
torch.empty(
(mix_hc, hc_dim),
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_attn_base = nn.Parameter(
torch.empty(
mix_hc,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_ffn_base = nn.Parameter(
torch.empty(
mix_hc,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_attn_scale = nn.Parameter(
torch.empty(
3,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_ffn_scale = nn.Parameter(
torch.empty(
3,
dtype=torch.float32,
),
requires_grad=False,
)
def hc_pre(
self,
x: torch.Tensor,
hc_fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
):
# Lazy import to avoid top-level tilelang dependency.
# Registers both torch.ops.vllm.mhc_pre and mhc_post,
# so hc_post() doesn't need its own import.
import vllm.model_executor.layers.mhc # noqa: F401
post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre(
residual=x,
fn=hc_fn,
hc_scale=hc_scale,
hc_base=hc_base,
rms_eps=self.rms_norm_eps,
hc_pre_eps=self.hc_eps,
hc_sinkhorn_eps=self.hc_eps,
hc_post_mult_value=self.hc_post_alpha,
sinkhorn_repeat=self.hc_sinkhorn_iters,
)
return layer_input, post_mix, res_mix
def hc_post(
self,
x: torch.Tensor,
residual: torch.Tensor,
post: torch.Tensor,
comb: torch.Tensor,
):
return torch.ops.vllm.mhc_post(x, residual, post, comb)
def forward(
self,
x: torch.Tensor,
positions: torch.Tensor,
input_ids: torch.Tensor | None,
) -> torch.Tensor:
residual = x
x, post, comb = self.hc_pre(
x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base
)
x = self.attn_norm(x)
x = self.attn(positions, x, None)
x = self.hc_post(x, residual, post, comb)
residual = x
x, post, comb = self.hc_pre(
x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base
)
x = self.ffn_norm(x)
x = self.ffn(x, input_ids)
x = self.hc_post(x, residual, post, comb)
return x
@support_torch_compile
class DeepseekV4Model(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
quant_config = vllm_config.quant_config
self.config = config
self.vocab_size = config.vocab_size
self.hc_eps = config.hc_eps
self.hc_mult = config.hc_mult
self.hc_dim = self.hc_mult * config.hidden_size
self.rms_norm_eps = config.rms_norm_eps
aux_stream_list = [torch.cuda.Stream() for _ in range(1)]
self.aux_stream_dict = {
AuxStreamType.Attention: aux_stream_list[0],
}
self.device = current_platform.device_type
# Reserved topk indices buffer for all Indexer layers to reuse.
self.topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
config.index_topk,
dtype=torch.int32,
device=self.device,
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
quant_config=quant_config,
prefix=f"{prefix}.embed_tokens",
)
self.start_layer, self.end_layer, self.layers = make_layers(
config.num_hidden_layers,
lambda prefix: DeepseekV4DecoderLayer(
vllm_config,
prefix=prefix,
topk_indices_buffer=self.topk_indices_buffer,
aux_stream_dict=self.aux_stream_dict,
),
prefix=f"{prefix}.layers",
)
self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps)
self.hc_head_fn = nn.Parameter(
torch.empty(
self.hc_mult,
self.hc_dim,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_head_base = nn.Parameter(
torch.empty(
self.hc_mult,
dtype=torch.float32,
),
requires_grad=False,
)
self.hc_head_scale = nn.Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
# Pre-hc_head residual stream buffer for the MTP draft. Stable
# address (outside the cudagraph pool) so the copy_ in forward()
# refreshes it correctly across captured shapes.
self._mtp_hidden_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
self.hc_dim,
dtype=vllm_config.model_config.dtype,
device=self.device,
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.embed_input_ids(input_ids)
hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1)
for layer in islice(self.layers, self.start_layer, self.end_layer):
hidden_states = layer(
hidden_states,
positions,
input_ids,
)
# Stash pre-hc_head residual for the MTP draft (captured copy_).
num_tokens = hidden_states.shape[0]
self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1))
hidden_states = hc_head(
hidden_states,
self.hc_head_fn,
self.hc_head_scale,
self.hc_head_base,
self.rms_norm_eps,
self.hc_eps,
)
hidden_states = self.norm(hidden_states)
return hidden_states
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
("attn.fused_wqa_wkv", "attn.wq_a", 0),
("attn.fused_wqa_wkv", "attn.wkv", 1),
("compressor.fused_wkv_wgate", "compressor.wkv", 0),
("compressor.fused_wkv_wgate", "compressor.wgate", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
# TP for attention
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
n_head = self.config.num_attention_heads
n_local_head = n_head // tp_size
head_rank_start = n_local_head * tp_rank
head_rank_end = n_local_head * (tp_rank + 1)
# Pre-compute expert mapping ONCE.
expert_mapping = self.get_expert_mapping()
for name, loaded_weight in weights:
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if ".experts." in name:
continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if ".experts." in name:
# E8M0 scales are stored as float8_e8m0fnu in
# checkpoints but the MoE param is uint8. copy_()
# would do a numeric conversion (e.g. 2^-7 → 0),
# destroying the raw exponent bytes.
if (
"weight_scale" in name
and loaded_weight.dtype == torch.float8_e8m0fnu
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
break
loaded_params.add(name_mapped)
continue
elif "attn_sink" in name:
narrow_weight = loaded_weight[head_rank_start:head_rank_end]
n = narrow_weight.shape[0]
params_dict[name][:n].copy_(narrow_weight)
loaded_params.add(name)
continue
else:
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer)))
if first_layer.ffn.use_mega_moe:
return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts)
# Params for weights, fp8 weight scales, fp8 activation scales
# (param_name, weight_name, expert_id, shard_id)
return FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.n_routed_experts,
)
def finalize_mega_moe_weights(self) -> None:
for layer in islice(self.layers, self.start_layer, self.end_layer):
layer.ffn.finalize_mega_moe_weights()
@torch.compile(backend=current_platform.simple_compile_backend)
def hc_head(
hidden_states: torch.Tensor,
hc_fn: torch.Tensor,
hc_scale: torch.Tensor,
hc_base: torch.Tensor,
rms_norm_eps: float,
hc_eps: float,
) -> torch.Tensor:
x = hidden_states
shape, dtype = x.size(), x.dtype
x = x.flatten(1).float()
rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + rms_norm_eps)
mixes = F.linear(x, hc_fn) * rsqrt
pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps
y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1)
return y.to(dtype)
class DeepseekV4ForCausalLM(nn.Module):
model_cls = DeepseekV4Model
hf_to_vllm_mapper = WeightsMapper(
orig_to_new_prefix={
"layers.": "model.layers.",
"embed.": "model.embed.",
"norm.": "model.norm.",
"hc_head": "model.hc_head",
"mtp.": "model.mtp.",
},
orig_to_new_regex={
# Routed MoE expert scales: experts.N.wX.scale -> .weight_scale
re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale",
# Everything else (FP8 linear + shared experts): .scale -> .weight_scale_inv
re.compile(r"\.scale$"): ".weight_scale_inv",
},
orig_to_new_suffix={
"head.weight": "lm_head.weight",
"embed.weight": "embed_tokens.weight",
".ffn.gate.bias": ".ffn.gate.e_score_correction_bias",
},
orig_to_new_substr={
".attn.compressor.": ".attn.mla_attn.compressor.",
".shared_experts.w2": ".shared_experts.down_proj",
},
)
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.config = config
self.model = self.model_cls(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
self.lm_head = ParallelLMHead(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "lm_head"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def compute_logits(
self,
hidden_states: torch.Tensor,
) -> torch.Tensor | None:
logits = self.logits_processor(self.lm_head, hidden_states)
return logits
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
) -> torch.Tensor | IntermediateTensors:
hidden_states = self.model(
input_ids, positions, intermediate_tensors, inputs_embeds
)
return hidden_states
def get_mtp_target_hidden_states(self) -> torch.Tensor | None:
"""Pre-hc_head residual stream buffer (max_num_batched_tokens,
hc_mult * hidden_size) for the MTP draft model. Populated by
forward(); valid after each target step."""
return getattr(self.model, "_mtp_hidden_buffer", None)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
loader = AutoWeightsLoader(self, skip_substrs=["mtp."])
loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper)
self.model.finalize_mega_moe_weights()
return loaded_params
def get_expert_mapping(self) -> list[tuple[str, str, int, str]]:
return self.model.get_expert_mapping()
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""MTP draft model for DeepSeek V4 (internal codename: DeepseekV4).
Split from ``deepseek_mtp.py`` because the V4 architecture introduces several
pieces that have no analogue in V3/V32:
* separate ``e_proj`` / ``h_proj`` with fp8 linear quantization (instead of
the fused ``eh_proj``);
* ``hc_head`` hypercompressed vocab projection applied in ``compute_logits``;
* ``DeepseekV4DecoderLayer`` with its own aux-stream management;
* V4-specific checkpoint weight-name remapping in ``load_weights``.
"""
import typing
from collections.abc import Callable, Iterable
import regex as re
import torch
import torch.nn as nn
from vllm.compilation.decorators import support_torch_compile
from vllm.config import VllmConfig
from vllm.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe import FusedMoE
from vllm.model_executor.layers.layernorm import RMSNorm
from vllm.model_executor.layers.linear import ReplicatedLinear
from vllm.model_executor.layers.logits_processor import LogitsProcessor
from vllm.model_executor.layers.vocab_parallel_embedding import (
VocabParallelEmbedding,
)
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
from vllm.platforms import current_platform
from vllm.sequence import IntermediateTensors
from vllm.utils.multi_stream_utils import AuxStreamType
from .deepseek_mtp import SharedHead
from .deepseek_v2 import get_spec_layer_idx_from_weight_name
from .deepseek_v4 import (
DeepseekV4DecoderLayer,
hc_head,
make_deepseek_v4_expert_params_mapping,
)
from .utils import maybe_prefix
logger = init_logger(__name__)
# MoE expert scales are fused into per-layer w13/w2 tensors; other FP8 linear
# scales use `.weight_scale_inv`. Mirrors the regex in
# DeepseekV4ForCausalLM.hf_to_vllm_mapper.
_EXPERT_SCALE_RE = re.compile(r"\.experts\.\d+\.w[123]\.scale$")
class DeepSeekV4MultiTokenPredictorLayer(nn.Module):
def __init__(
self,
vllm_config: VllmConfig,
topk_indices_buffer: torch.Tensor,
prefix: str,
) -> None:
super().__init__()
config = vllm_config.speculative_config.draft_model_config.hf_config
self.config = config
quant_config = vllm_config.quant_config
self.rms_norm_eps = config.rms_norm_eps
self.enorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
self.hnorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
# V4 keeps e_ and h_ proj separate (with fp8 linear quant) rather than
# fusing them the way V3 does with eh_proj.
self.e_proj = ReplicatedLinear(
config.hidden_size,
config.hidden_size,
bias=False,
return_bias=False,
quant_config=quant_config,
)
self.h_proj = ReplicatedLinear(
config.hidden_size,
config.hidden_size,
bias=False,
return_bias=False,
quant_config=quant_config,
)
self.hc_eps = config.hc_eps
self.hc_mult = config.hc_mult
self.hc_dim = self.hc_mult * config.hidden_size
self.hc_head_fn = nn.Parameter(
torch.empty(self.hc_mult, self.hc_dim, dtype=torch.float32),
requires_grad=False,
)
self.hc_head_base = nn.Parameter(
torch.empty(self.hc_mult, dtype=torch.float32),
requires_grad=False,
)
self.hc_head_scale = nn.Parameter(
torch.empty(1, dtype=torch.float32),
requires_grad=False,
)
self.shared_head = SharedHead(
config=config, prefix=prefix, quant_config=quant_config
)
self.aux_stream_dict = {
AuxStreamType.Attention: torch.cuda.Stream(),
}
self.mtp_block = DeepseekV4DecoderLayer(
vllm_config,
prefix,
topk_indices_buffer=topk_indices_buffer,
aux_stream_dict=self.aux_stream_dict,
)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_index: int = 0,
) -> torch.Tensor:
assert inputs_embeds is not None
# masking inputs at position 0, as not needed by MTP
inputs_embeds = torch.where(positions.unsqueeze(-1) == 0, 0, inputs_embeds)
inputs_embeds = self.enorm(inputs_embeds)
# Target stashes pre-hc_head residual as flat (T, hc_mult * D);
# reshape to (T, hc_mult, D) — the training-time layout.
previous_hidden_states = previous_hidden_states.view(
-1, self.hc_mult, self.config.hidden_size
)
previous_hidden_states = self.hnorm(previous_hidden_states)
hidden_states = self.h_proj(previous_hidden_states) + self.e_proj(
inputs_embeds
).unsqueeze(-2)
hidden_states = self.mtp_block(
positions=positions, x=hidden_states, input_ids=None
)
# Return the flat pre-hc_head residual so it can be re-fed as the
# next spec step's `previous_hidden_states` when
# num_speculative_tokens > 1. hc_head is deferred to compute_logits.
return hidden_states.flatten(1)
class DeepSeekV4MultiTokenPredictor(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
config = vllm_config.model_config.hf_config
self.mtp_start_layer_idx = config.num_hidden_layers
self.num_mtp_layers = config.num_nextn_predict_layers
self.device = current_platform.device_type
topk_tokens = config.index_topk
self.topk_indices_buffer = torch.empty(
vllm_config.scheduler_config.max_num_batched_tokens,
topk_tokens,
dtype=torch.int32,
device=self.device,
)
# to map the exact layer index from weights
self.layers = torch.nn.ModuleDict(
{
str(idx): DeepSeekV4MultiTokenPredictorLayer(
vllm_config,
self.topk_indices_buffer,
f"{prefix}.layers.{idx}",
)
for idx in range(
self.mtp_start_layer_idx,
self.mtp_start_layer_idx + self.num_mtp_layers,
)
}
)
self.embed_tokens = VocabParallelEmbedding(
config.vocab_size,
config.hidden_size,
prefix=maybe_prefix(prefix, "embed_tokens"),
)
self.logits_processor = LogitsProcessor(config.vocab_size)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.embed_tokens(input_ids)
def forward(
self,
input_ids: torch.Tensor,
positions: torch.Tensor,
previous_hidden_states: torch.Tensor,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
if inputs_embeds is None:
inputs_embeds = self.embed_tokens(input_ids)
current_step_idx = spec_step_idx % self.num_mtp_layers
return self.layers[str(self.mtp_start_layer_idx + current_step_idx)](
input_ids,
positions,
previous_hidden_states,
inputs_embeds,
current_step_idx,
)
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor:
current_step_idx = spec_step_idx % self.num_mtp_layers
mtp_layer = self.layers[str(self.mtp_start_layer_idx + current_step_idx)]
# MTP forward returns the pre-hc_head residual (T, hc_mult * D); apply
# hc_head here so logits are computed from the dense hidden state.
hidden_states = hidden_states.view(
-1, mtp_layer.hc_mult, mtp_layer.config.hidden_size
)
hidden_states = hc_head(
hidden_states,
mtp_layer.hc_head_fn,
mtp_layer.hc_head_scale,
mtp_layer.hc_head_base,
mtp_layer.rms_norm_eps,
mtp_layer.hc_eps,
)
logits = self.logits_processor(
mtp_layer.shared_head.head, mtp_layer.shared_head(hidden_states)
)
return logits
@support_torch_compile
class DeepSeekV4MTP(nn.Module):
def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""):
super().__init__()
self.config = vllm_config.model_config.hf_config
self.quant_config = vllm_config.quant_config
self.model = DeepSeekV4MultiTokenPredictor(
vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model")
)
def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor:
return self.model.embed_input_ids(input_ids)
def forward(
self,
input_ids: torch.Tensor | None,
positions: torch.Tensor,
hidden_states: torch.Tensor,
intermediate_tensors: IntermediateTensors | None = None,
inputs_embeds: torch.Tensor | None = None,
spec_step_idx: int = 0,
) -> torch.Tensor:
hidden_states = self.model(
input_ids, positions, hidden_states, inputs_embeds, spec_step_idx
)
return hidden_states
def compute_logits(
self,
hidden_states: torch.Tensor,
spec_step_idx: int = 0,
) -> torch.Tensor | None:
return self.model.compute_logits(hidden_states, spec_step_idx)
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]:
# Weight name remapping for checkpoint compatibility.
# Maps checkpoint weight paths to model parameter paths.
WEIGHT_NAME_REMAPPING: dict[str, str] = {
".emb.tok_emb.weight": ".embed_tokens.weight",
".head.weight": ".shared_head.head.weight",
".norm.weight": ".shared_head.norm.weight",
}
def _remap_weight_name(name: str) -> str:
"""Remap checkpoint weight names to model parameter names."""
for old_pattern, new_pattern in WEIGHT_NAME_REMAPPING.items():
if old_pattern in name:
name = name.replace(old_pattern, new_pattern)
return name
def _find_mtp_layer_idx(name: str) -> int:
subnames = name.split(".")
for subname in subnames:
try:
# we return the first encountered integer
return int(subname)
except ValueError:
continue
return 0
stacked_params_mapping = [
# (param_name, shard_name, shard_id)
("gate_up_proj", "w1", 0),
("gate_up_proj", "w3", 1),
("attn.fused_wqa_wkv", "attn.wq_a", 0),
("attn.fused_wqa_wkv", "attn.wkv", 1),
]
params_dict = dict(self.named_parameters())
loaded_params: set[str] = set()
# TP for attention
tp_size = get_tensor_model_parallel_world_size()
tp_rank = get_tensor_model_parallel_rank()
n_head = self.config.num_attention_heads
n_local_head = n_head // tp_size
head_rank_start = n_local_head * tp_rank
head_rank_end = n_local_head * (tp_rank + 1)
# Pre-compute expert mapping ONCE.
first_layer = next(iter(self.model.layers.values()))
if first_layer.mtp_block.ffn.use_mega_moe:
expert_mapping = make_deepseek_v4_expert_params_mapping(
self.config.n_routed_experts
)
else:
expert_mapping = FusedMoE.make_expert_params_mapping(
self,
ckpt_gate_proj_name="w1",
ckpt_down_proj_name="w2",
ckpt_up_proj_name="w3",
num_experts=self.config.n_routed_experts,
)
for name, loaded_weight in weights:
mtp_layer_idx = _find_mtp_layer_idx(name)
# V4 checkpoints store MTP weights as `mtp.{i}.*`; remap to
# `model.layers.{num_hidden_layers + i}.*` so that
# get_spec_layer_idx_from_weight_name can identify them.
name = name.replace(
f"mtp.{mtp_layer_idx}.",
f"model.layers.{self.config.num_hidden_layers + mtp_layer_idx}.",
)
spec_layer = get_spec_layer_idx_from_weight_name(self.config, name)
if spec_layer is None:
continue
name = _remap_weight_name(name)
name = self._rewrite_spec_layer_name(spec_layer, name)
if spec_layer != self.model.mtp_start_layer_idx and ".layers" not in name:
continue
if name.endswith(".scale"):
suffix = (
".weight_scale"
if _EXPERT_SCALE_RE.search(name)
else ".weight_scale_inv"
)
name = name.removesuffix(".scale") + suffix
for param_name, weight_name, shard_id in stacked_params_mapping:
# Skip non-stacked layers and experts (experts handled below).
if ".experts." in name:
continue
if weight_name not in name:
continue
name = name.replace(weight_name, param_name)
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
loaded_params.add(name)
break
else:
if ".experts." in name:
# Reinterpret E8M0 scales as uint8 to preserve raw
# exponent bytes; numeric copy_() would zero them.
# Mirrors the main DeepseekV4 loader.
if (
"weight_scale" in name
and loaded_weight.dtype == torch.float8_e8m0fnu
):
loaded_weight = loaded_weight.view(torch.uint8)
for mapping in expert_mapping:
param_name, weight_name, expert_id, shard_id = mapping
if weight_name not in name:
continue
name_mapped = name.replace(weight_name, param_name)
param = params_dict[name_mapped]
# We should ask the weight loader to return success or not
# here since otherwise we may skip experts with other
# available replicas.
weight_loader = typing.cast(
Callable[..., bool], param.weight_loader
)
success = weight_loader(
param,
loaded_weight,
name_mapped,
shard_id=shard_id,
expert_id=expert_id,
return_success=True,
)
if success:
name = name_mapped
loaded_params.add(name_mapped)
break
continue
elif "attn_sink" in name:
narrow_weight = loaded_weight[head_rank_start:head_rank_end]
n = narrow_weight.shape[0]
params_dict[name][:n].copy_(narrow_weight)
loaded_params.add(name)
continue
else:
if ".shared_experts.w2" in name:
name = name.replace(
".shared_experts.w2", ".shared_experts.down_proj"
)
if name.endswith(".ffn.gate.bias"):
name = name.replace(".bias", ".e_score_correction_bias")
param = params_dict[name]
weight_loader = getattr(
param, "weight_loader", default_weight_loader
)
weight_loader(param, loaded_weight)
loaded_params.add(name)
continue
loaded_layers: set[int] = set()
for param_name in loaded_params:
spec_layer = get_spec_layer_idx_from_weight_name(self.config, param_name)
if spec_layer is not None:
loaded_layers.add(spec_layer)
for layer_idx in range(
self.model.mtp_start_layer_idx,
self.model.mtp_start_layer_idx + self.model.num_mtp_layers,
):
if layer_idx not in loaded_layers:
raise ValueError(
f"MTP speculative decoding layer {layer_idx} weights "
f"missing from checkpoint. The checkpoint may have "
f"been quantized without including the MTP layers. "
f"Use a checkpoint that includes MTP layer weights, "
f"or disable speculative decoding."
)
self.finalize_mega_moe_weights()
logger.info_once("MTP draft model loaded: %d params", len(loaded_params))
return loaded_params
def finalize_mega_moe_weights(self) -> None:
for layer in self.model.layers.values():
layer.mtp_block.ffn.finalize_mega_moe_weights()
def _rewrite_spec_layer_name(self, spec_layer: int, name: str) -> str:
"""
Rewrite the weight name to match the format of the original model.
Add .mtp_block for modules in transformer layer block for spec layer
and rename shared layer weights to be top level.
"""
spec_layer_weight_names = [
"embed_tokens",
"enorm",
"hnorm",
"h_proj",
"e_proj",
"shared_head",
"hc_head_fn",
"hc_head_base",
"hc_head_scale",
]
shared_weight_names = ["embed_tokens"]
spec_layer_weight = False
shared_weight = False
for weight_name in spec_layer_weight_names:
if weight_name in name:
spec_layer_weight = True
if weight_name in shared_weight_names:
shared_weight = True
break
if not spec_layer_weight:
# treat rest weights as weights for transformer layer block
name = name.replace(
f"model.layers.{spec_layer}.", f"model.layers.{spec_layer}.mtp_block."
)
elif shared_weight:
# treat shared weights as top level weights
name = name.replace(f"model.layers.{spec_layer}.", "model.")
return name
......@@ -96,6 +96,7 @@ _TEXT_GENERATION_MODELS = {
"DeepseekV2ForCausalLM": ("deepseek_v2", "DeepseekV2ForCausalLM"),
"DeepseekV3ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV32ForCausalLM": ("deepseek_v2", "DeepseekV3ForCausalLM"),
"DeepseekV4ForCausalLM": ("deepseek_v4", "DeepseekV4ForCausalLM"),
"Dots1ForCausalLM": ("dots1", "Dots1ForCausalLM"),
"Ernie4_5ForCausalLM": ("ernie45", "Ernie4_5ForCausalLM"),
"Ernie4_5_MoeForCausalLM": ("ernie45_moe", "Ernie4_5_MoeForCausalLM"),
......@@ -586,6 +587,7 @@ _SPECULATIVE_DECODING_MODELS = {
"Eagle3DeepseekV3ForCausalLM": ("deepseek_eagle3", "Eagle3DeepseekV2ForCausalLM"),
"EagleDeepSeekMTPModel": ("deepseek_eagle", "EagleDeepseekV3ForCausalLM"),
"DeepSeekMTPModel": ("deepseek_mtp", "DeepSeekMTP"),
"DeepSeekV4MTPModel": ("deepseek_v4_mtp", "DeepSeekV4MTP"),
"ErnieMTPModel": ("ernie_mtp", "ErnieMTP"),
"ExaoneMoeMTP": ("exaone_moe_mtp", "ExaoneMoeMTP"),
"Exaone4_5_MTP": ("exaone4_5_mtp", "Exaone4_5_MTP"),
......
......@@ -28,6 +28,10 @@ _REASONING_PARSERS_TO_REGISTER = {
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningParser",
),
"deepseek_v4": (
"deepseek_v3_reasoning_parser",
"DeepSeekV3ReasoningParser",
),
"ernie45": (
"ernie45_reasoning_parser",
"Ernie45ReasoningParser",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.config import VllmConfig
from vllm.entrypoints.chat_utils import (
ChatCompletionMessageParam,
ConversationMessage,
parse_chat_messages,
parse_chat_messages_async,
)
from vllm.logger import init_logger
from vllm.tokenizers.deepseek_v4 import DeepseekV4Tokenizer
from vllm.utils.async_utils import make_async
from .base import BaseRenderer
from .inputs import DictPrompt
from .inputs.preprocess import parse_dec_only_prompt
from .params import ChatParams
logger = init_logger(__name__)
class DeepseekV4Renderer(BaseRenderer[DeepseekV4Tokenizer]):
def __init__(
self,
config: VllmConfig,
tokenizer: DeepseekV4Tokenizer | None,
) -> None:
super().__init__(config, tokenizer)
self._apply_chat_template_async = make_async(
self._apply_chat_template, executor=self._executor
)
def _apply_chat_template(self, *args, **kwargs):
return self.get_tokenizer().apply_chat_template(*args, **kwargs)
def render_messages(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
conversation, mm_data, mm_uuids = parse_chat_messages(
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
)
prompt_raw = self._apply_chat_template(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
async def render_messages_async(
self,
messages: list[ChatCompletionMessageParam],
params: ChatParams,
) -> tuple[list[ConversationMessage], DictPrompt]:
conversation, mm_data, mm_uuids = await parse_chat_messages_async(
messages,
self.model_config,
content_format="string",
media_io_kwargs=params.media_io_kwargs,
mm_processor_kwargs=params.mm_processor_kwargs,
)
prompt_raw = await self._apply_chat_template_async(
conversation=conversation,
messages=messages,
**params.get_apply_chat_template_kwargs(),
)
prompt = parse_dec_only_prompt(prompt_raw)
if mm_data is not None:
prompt["multi_modal_data"] = mm_data
if mm_uuids is not None:
prompt["multi_modal_uuids"] = mm_uuids
return conversation, prompt
......@@ -21,6 +21,7 @@ logger = init_logger(__name__)
_VLLM_RENDERERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Renderer"),
"deepseek_v4": ("deepseek_v4", "DeepseekV4Renderer"),
"hf": ("hf", "HfRenderer"),
"grok2": ("grok2", "Grok2Renderer"),
"kimi_audio": ("hf", "HfRenderer"),
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
import copy
from typing import Any
from transformers import PreTrainedTokenizerFast
from vllm.entrypoints.chat_utils import ChatCompletionMessageParam
from .deepseek_v4_encoding import encode_messages
from .hf import HfTokenizer, get_cached_tokenizer
from .protocol import TokenizerLike
def get_deepseek_v4_tokenizer(tokenizer: HfTokenizer) -> HfTokenizer:
"""
Wraps a tokenizer to use the custom DeepSeek V4 chat template encoding.
"""
dsv4_tokenizer = copy.copy(tokenizer)
added_vocab = tokenizer.get_added_vocab()
added_vocab_size = len(added_vocab)
tokenizer_vocab_size = tokenizer.vocab_size
class _DeepseekV4Tokenizer(tokenizer.__class__): # type: ignore
def apply_chat_template(
self,
messages: list["ChatCompletionMessageParam"],
tools: list[dict[str, Any]] | None = None,
**kwargs,
) -> str | list[int]:
thinking = kwargs.get("thinking", False)
enable_thinking = kwargs.get("enable_thinking", False)
thinking = thinking or enable_thinking
thinking_mode = "thinking" if thinking else "chat"
conversation = kwargs.get("conversation", messages)
messages = conversation.copy()
if tools is not None and len(tools) > 0:
messages.insert(0, {"role": "system"})
messages[0]["tools"] = tools # type: ignore[typeddict-unknown-key]
# The V4 reference currently accepts only "max", "high", or None.
reasoning_effort = kwargs.get("reasoning_effort")
if reasoning_effort not in ("max", "high"):
reasoning_effort = None
encode_config = dict(
thinking_mode=thinking_mode,
drop_thinking=kwargs.get("drop_thinking", True),
reasoning_effort=reasoning_effort,
)
prompt_str = encode_messages(messages, **encode_config) # type: ignore
if kwargs.get("tokenize", True):
tokenizer_kwargs = {
k: kwargs[k] for k in ("truncation", "max_length") if k in kwargs
}
return self.encode(
prompt_str,
add_special_tokens=False,
**tokenizer_kwargs,
)
return prompt_str
def num_special_tokens_to_add(self) -> int:
return len(self.encode(""))
def __len__(self) -> int:
return tokenizer_vocab_size + added_vocab_size
def get_added_vocab(self) -> dict[str, int]:
return added_vocab.copy()
def __reduce__(self):
return get_deepseek_v4_tokenizer, (tokenizer,)
_DeepseekV4Tokenizer.__name__ = f"DSV4{tokenizer.__class__.__name__}"
dsv4_tokenizer.__class__ = _DeepseekV4Tokenizer
return dsv4_tokenizer
class DeepseekV4Tokenizer(TokenizerLike):
@classmethod
def from_pretrained(cls, *args, **kwargs) -> HfTokenizer:
tokenizer = PreTrainedTokenizerFast.from_pretrained(*args, **kwargs)
return get_cached_tokenizer(get_deepseek_v4_tokenizer(tokenizer))
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
# ruff: noqa
# fmt: off
"""
DeepSeek-V4 Encoding
A self-contained implementation for encoding/decoding DeepSeek-V4 chat messages
with tool calling, thinking mode, and quick instruction task support.
"""
from typing import Any, Dict, List, Union, Optional, Tuple
import copy
import json
import regex as re
# ============================================================
# Special Tokens
# ============================================================
bos_token: str = "<|begin▁of▁sentence|>"
eos_token: str = "<|end▁of▁sentence|>"
thinking_start_token: str = "<think>"
thinking_end_token: str = "</think>"
dsml_token: str = "|DSML|"
USER_SP_TOKEN = "<|User|>"
ASSISTANT_SP_TOKEN = "<|Assistant|>"
LATEST_REMINDER_SP_TOKEN = "<|latest_reminder|>"
# Task special tokens for internal classification tasks
DS_TASK_SP_TOKENS = {
"action": "<|action|>",
"query": "<|query|>",
"authority": "<|authority|>",
"domain": "<|domain|>",
"title": "<|title|>",
"read_url": "<|read_url|>",
}
VALID_TASKS = set(DS_TASK_SP_TOKENS.keys())
# ============================================================
# Templates
# ============================================================
system_msg_template: str = "{content}"
user_msg_template: str = "{content}"
latest_reminder_msg_template: str = "{content}"
assistant_msg_template: str = "{reasoning}{content}{tool_calls}" + eos_token
assistant_msg_wo_eos_template: str = "{reasoning}{content}{tool_calls}"
thinking_template: str = "{reasoning}"
response_format_template: str = (
"## Response Format:\n\nYou MUST strictly adhere to the following schema to reply:\n{schema}"
)
tool_call_template: str = (
"<{dsml_token}invoke name=\"{name}\">\n{arguments}\n</{dsml_token}invoke>"
)
tool_calls_template = (
"<{dsml_token}{tc_block_name}>\n{tool_calls}\n</{dsml_token}{tc_block_name}>"
)
tool_calls_block_name: str = "tool_calls"
tool_output_template: str = (
"<tool_result>{content}</tool_result>"
)
REASONING_EFFORT_MAX = (
"Reasoning Effort: Absolute maximum with no shortcuts permitted.\n"
"You MUST be very thorough in your thinking and comprehensively decompose the problem to resolve the root cause, rigorously stress-testing your logic against all potential paths, edge cases, and adversarial scenarios.\n"
"Explicitly write out your entire deliberation process, documenting every intermediate step, considered alternative, and rejected hypothesis to ensure absolutely no assumption is left unchecked.\n\n"
)
TOOLS_TEMPLATE = """## Tools
You have access to a set of tools to help answer the user's question. You can invoke tools by writing a "<{dsml_token}tool_calls>" block like the following:
<{dsml_token}tool_calls>
<{dsml_token}invoke name="$TOOL_NAME">
<{dsml_token}parameter name="$PARAMETER_NAME" string="true|false">$PARAMETER_VALUE</{dsml_token}parameter>
...
</{dsml_token}invoke>
<{dsml_token}invoke name="$TOOL_NAME2">
...
</{dsml_token}invoke>
</{dsml_token}tool_calls>
String parameters should be specified as is and set `string="true"`. For all other types (numbers, booleans, arrays, objects), pass the value in JSON format and set `string="false"`.
If thinking_mode is enabled (triggered by {thinking_start_token}), you MUST output your complete reasoning inside {thinking_start_token}...{thinking_end_token} BEFORE any tool calls or final response.
Otherwise, output directly after {thinking_end_token} with tool calls or final response.
### Available Tool Schemas
{tool_schemas}
You MUST strictly follow the above defined tool name and parameter schemas to invoke tool calls.
"""
# ============================================================
# Utility Functions
# ============================================================
def to_json(value: Any) -> str:
"""Serialize a value to JSON string."""
try:
return json.dumps(value, ensure_ascii=False)
except Exception:
return json.dumps(value, ensure_ascii=True)
def tools_from_openai_format(tools):
"""Extract function definitions from OpenAI-format tool list."""
return [tool["function"] for tool in tools]
def tool_calls_from_openai_format(tool_calls):
"""Convert OpenAI-format tool calls to internal format."""
return [
{
"name": tool_call["function"]["name"],
"arguments": tool_call["function"]["arguments"],
}
for tool_call in tool_calls
]
def tool_calls_to_openai_format(tool_calls):
"""Convert internal tool calls to OpenAI format."""
return [
{
"type": "function",
"function": {
"name": tool_call["name"],
"arguments": tool_call["arguments"],
}
}
for tool_call in tool_calls
]
def encode_arguments_to_dsml(tool_call: Dict[str, Any]) -> str:
"""
Encode tool call arguments into DSML parameter format.
Args:
tool_call: Dict with "name" and "arguments" keys.
Returns:
DSML-formatted parameter string.
"""
p_dsml_template = '<{dsml_token}parameter name="{key}" string="{is_str}">{value}</{dsml_token}parameter>'
P_dsml_strs = []
if isinstance(tool_call["arguments"], str):
arguments = json.loads(tool_call["arguments"])
else:
arguments = tool_call["arguments"]
for k, v in arguments.items():
p_dsml_str = p_dsml_template.format(
dsml_token=dsml_token,
key=k,
is_str="true" if isinstance(v, str) else "false",
value=v if isinstance(v, str) else to_json(v),
)
P_dsml_strs.append(p_dsml_str)
return "\n".join(P_dsml_strs)
def decode_dsml_to_arguments(tool_name: str, tool_args: Dict[str, Tuple[str, str]]) -> Dict[str, str]:
"""
Decode DSML parameters back to a tool call dict.
Args:
tool_name: Name of the tool.
tool_args: Dict mapping param_name -> (value, is_string_flag).
Returns:
Dict with "name" and "arguments" (JSON string) keys.
"""
def _decode_value(key: str, value: str, string: str):
if string == "true":
value = to_json(value)
return f"{to_json(key)}: {value}"
tool_args_json = "{" + ", ".join([_decode_value(k, v, string=is_str) for k, (v, is_str) in tool_args.items()]) + "}"
return dict(name=tool_name, arguments=tool_args_json)
def render_tools(tools: List[Dict[str, Union[str, Dict[str, Any]]]]) -> str:
"""
Render tool schemas into the system prompt format.
Args:
tools: List of tool schema dicts (each with name, description, parameters).
Returns:
Formatted tools section string.
"""
tools_json = [to_json(t) for t in tools]
return TOOLS_TEMPLATE.format(
tool_schemas="\n".join(tools_json),
dsml_token=dsml_token,
thinking_start_token=thinking_start_token,
thinking_end_token=thinking_end_token,
)
def find_last_user_index(messages: List[Dict[str, Any]]) -> int:
"""Find the index of the last user/developer message."""
last_user_index = -1
for idx in range(len(messages) - 1, -1, -1):
if messages[idx].get("role") in ["user", "developer"]:
last_user_index = idx
break
return last_user_index
# ============================================================
# Message Rendering
# ============================================================
def render_message(index: int, messages: List[Dict[str, Any]], thinking_mode: str, drop_thinking: bool = True, reasoning_effort: Optional[str] = None) -> str:
"""
Render a single message at the given index into its encoded string form.
This is the core function that converts each message in the conversation
into the DeepSeek-V4 format.
Args:
index: Index of the message to render.
messages: Full list of messages in the conversation.
thinking_mode: Either "chat" or "thinking".
drop_thinking: Whether to drop reasoning content from earlier turns.
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
Returns:
Encoded string for this message.
"""
assert 0 <= index < len(messages)
assert thinking_mode in ["chat", "thinking"], f"Invalid thinking_mode `{thinking_mode}`"
prompt = ""
msg = messages[index]
last_user_idx = find_last_user_index(messages)
role = msg.get("role")
content = msg.get("content")
tools = msg.get("tools")
response_format = msg.get("response_format")
tool_calls = msg.get("tool_calls")
reasoning = msg.get("reasoning")
wo_eos = msg.get("wo_eos", False)
if tools:
tools = tools_from_openai_format(tools)
if tool_calls:
tool_calls = tool_calls_from_openai_format(tool_calls)
# Reasoning effort prefix (only at index 0 in thinking mode with max effort)
assert reasoning_effort in ['max', None, 'high'], f"Invalid reasoning effort: {reasoning_effort}"
if index == 0 and thinking_mode == "thinking" and reasoning_effort == 'max':
prompt += REASONING_EFFORT_MAX
if role == "system":
prompt += system_msg_template.format(content=content or "")
if tools:
prompt += "\n\n" + render_tools(tools)
if response_format:
prompt += "\n\n" + response_format_template.format(schema=to_json(response_format))
elif role == "developer":
assert content, f"Invalid message for role `{role}`: {msg}"
content_developer = USER_SP_TOKEN
content_developer += content
if tools:
content_developer += "\n\n" + render_tools(tools)
if response_format:
content_developer += "\n\n" + response_format_template.format(schema=to_json(response_format))
prompt += user_msg_template.format(content=content_developer)
elif role == "user":
prompt += USER_SP_TOKEN
# Handle content blocks (tool results mixed with text)
content_blocks = msg.get("content_blocks")
if content_blocks:
parts = []
for block in content_blocks:
block_type = block.get("type")
if block_type == "text":
parts.append(block.get("text", ""))
elif block_type == "tool_result":
tool_content = block.get("content", "")
if isinstance(tool_content, list):
text_parts = []
for b in tool_content:
if b.get("type") == "text":
text_parts.append(b.get("text", ""))
else:
text_parts.append(f"[Unsupported {b.get('type')}]")
tool_content = "\n\n".join(text_parts)
parts.append(tool_output_template.format(content=tool_content))
else:
parts.append(f"[Unsupported {block_type}]")
prompt += "\n\n".join(parts)
else:
prompt += content or ""
elif role == "latest_reminder":
prompt += LATEST_REMINDER_SP_TOKEN + latest_reminder_msg_template.format(content=content)
elif role == "tool":
raise NotImplementedError("deepseek_v4 merges tool messages into user; please preprocess with merge_tool_messages()")
elif role == "assistant":
thinking_part = ""
tc_content = ""
if tool_calls:
tc_list = [
tool_call_template.format(
dsml_token=dsml_token,
name=tc.get("name"),
arguments=encode_arguments_to_dsml(tc)
)
for tc in tool_calls
]
tc_content += '\n\n' + tool_calls_template.format(
dsml_token=dsml_token,
tool_calls="\n".join(tc_list),
tc_block_name=tool_calls_block_name,
)
summary_content = content or ""
reasoning = reasoning or ""
# Check if previous message has a task - if so, this is a task output (no thinking)
prev_has_task = index - 1 >= 0 and messages[index - 1].get("task") is not None
if thinking_mode == "thinking" and not prev_has_task:
if not drop_thinking or index > last_user_idx:
thinking_part = thinking_template.format(reasoning=reasoning) + thinking_end_token
else:
thinking_part = ""
if wo_eos:
prompt += assistant_msg_wo_eos_template.format(
reasoning=thinking_part,
content=summary_content,
tool_calls=tc_content,
)
else:
prompt += assistant_msg_template.format(
reasoning=thinking_part,
content=summary_content,
tool_calls=tc_content,
)
else:
raise NotImplementedError(f"Unknown role: {role}")
# Append transition tokens based on what follows
if index + 1 < len(messages) and messages[index + 1].get("role") not in ["assistant", "latest_reminder"]:
return prompt
task = messages[index].get("task")
if task is not None:
# Task special token for internal classification tasks
assert task in VALID_TASKS, f"Invalid task: '{task}'. Valid tasks are: {list(VALID_TASKS)}"
task_sp_token = DS_TASK_SP_TOKENS[task]
if task != "action":
# Non-action tasks: append task sp token directly after the message
prompt += task_sp_token
else:
# Action task: append Assistant + thinking token + action sp token
prompt += ASSISTANT_SP_TOKEN
prompt += thinking_end_token if thinking_mode != "thinking" else thinking_start_token
prompt += task_sp_token
elif messages[index].get("role") in ["user", "developer"]:
# Normal generation: append Assistant + thinking token
prompt += ASSISTANT_SP_TOKEN
if not drop_thinking and thinking_mode == "thinking":
prompt += thinking_start_token
elif drop_thinking and thinking_mode == "thinking" and index >= last_user_idx:
prompt += thinking_start_token
else:
prompt += thinking_end_token
return prompt
# ============================================================
# Preprocessing
# ============================================================
def merge_tool_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Merge tool messages into the preceding user message using content_blocks format.
DeepSeek-V4 does not have a standalone "tool" role; instead, tool results
are encoded as <tool_result> blocks within user messages.
This function converts a standard OpenAI-format conversation (with separate
"tool" role messages) into V4 format where tool results are merged into
user messages.
Args:
messages: List of message dicts in OpenAI format.
Returns:
Processed message list with tool messages merged into user messages.
"""
merged: List[Dict[str, Any]] = []
for msg in messages:
msg = copy.deepcopy(msg)
role = msg.get("role")
if role == "tool":
# Convert tool message to a user message with tool_result block
tool_block = {
"type": "tool_result",
"tool_use_id": msg.get("tool_call_id", ""),
"content": msg.get("content", ""),
}
# Merge into previous message if it's already a user (merged tool)
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1]:
merged[-1]["content_blocks"].append(tool_block)
else:
merged.append({
"role": "user",
"content_blocks": [tool_block],
})
elif role == "user":
text_block = {"type": "text", "text": msg.get("content", "")}
if merged and merged[-1].get("role") == "user" and "content_blocks" in merged[-1] and merged[-1].get("task") is None:
merged[-1]["content_blocks"].append(text_block)
else:
new_msg = {
"role": "user",
"content": msg.get("content", ""),
"content_blocks": [text_block],
}
# Preserve extra fields (task, wo_eos, mask, etc.)
for key in ("task", "wo_eos", "mask"):
if key in msg:
new_msg[key] = msg[key]
merged.append(new_msg)
else:
merged.append(msg)
return merged
def sort_tool_results_by_call_order(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Sort tool_result blocks within user messages by the order of tool_calls
in the preceding assistant message.
Args:
messages: Preprocessed message list (after merge_tool_messages).
Returns:
Message list with sorted tool result blocks.
"""
last_tool_call_order: Dict[str, int] = {}
for msg in messages:
role = msg.get("role")
if role == "assistant" and msg.get("tool_calls"):
last_tool_call_order = {}
for idx, tc in enumerate(msg["tool_calls"]):
tc_id = tc.get("id") or tc.get("function", {}).get("id", "")
if tc_id:
last_tool_call_order[tc_id] = idx
elif role == "user" and msg.get("content_blocks"):
tool_blocks = [b for b in msg["content_blocks"] if b.get("type") == "tool_result"]
if len(tool_blocks) > 1 and last_tool_call_order:
sorted_blocks = sorted(
tool_blocks,
key=lambda b: last_tool_call_order.get(b.get("tool_use_id", ""), 0)
)
sorted_idx = 0
new_blocks = []
for block in msg["content_blocks"]:
if block.get("type") == "tool_result":
new_blocks.append(sorted_blocks[sorted_idx])
sorted_idx += 1
else:
new_blocks.append(block)
msg["content_blocks"] = new_blocks
return messages
# ============================================================
# Main Encoding Function
# ============================================================
def encode_messages(
messages: List[Dict[str, Any]],
thinking_mode: str,
context: Optional[List[Dict[str, Any]]] = None,
drop_thinking: bool = True,
add_default_bos_token: bool = True,
reasoning_effort: Optional[str] = None,
) -> str:
"""
Encode a list of messages into the DeepSeek-V4 prompt format.
This is the main entry point for encoding conversations. It handles:
- BOS token insertion
- Thinking mode with optional reasoning content dropping
- Tool message merging into user messages
- Multi-turn conversation context
Args:
messages: List of message dicts to encode.
thinking_mode: Either "chat" or "thinking".
context: Optional preceding context messages (already encoded prefix).
drop_thinking: If True, drop reasoning from earlier assistant turns
(only keep reasoning for messages after the last user message).
add_default_bos_token: Whether to prepend BOS token at conversation start.
reasoning_effort: Optional reasoning effort level ("max", "high", or None).
Returns:
The encoded prompt string.
"""
context = context if context else []
# Preprocess: merge tool messages and sort tool results
messages = merge_tool_messages(messages)
messages = sort_tool_results_by_call_order(context + messages)[len(context):]
if context:
context = merge_tool_messages(context)
context = sort_tool_results_by_call_order(context)
full_messages = context + messages
prompt = bos_token if add_default_bos_token and len(context) == 0 else ""
# Resolve drop_thinking: if any message has tools defined, don't drop thinking
effective_drop_thinking = drop_thinking
if any(m.get("tools") for m in full_messages):
effective_drop_thinking = False
if thinking_mode == "thinking" and effective_drop_thinking:
full_messages = _drop_thinking_messages(full_messages)
# After dropping, recalculate how many messages to render
# (context may have shrunk too)
num_to_render = len(full_messages) - len(_drop_thinking_messages(context))
context_len = len(full_messages) - num_to_render
else:
num_to_render = len(messages)
context_len = len(context)
for idx in range(num_to_render):
prompt += render_message(
idx + context_len,
full_messages,
thinking_mode=thinking_mode,
drop_thinking=effective_drop_thinking,
reasoning_effort=reasoning_effort,
)
return prompt
def _drop_thinking_messages(messages: List[Dict[str, Any]]) -> List[Dict[str, Any]]:
"""
Drop reasoning and non-essential messages before the last user message.
Behavior:
- Messages with role in ["user", "system", "tool", "latest_reminder"] are always kept.
- Messages at or after the last user index are always kept.
- Assistant messages before the last user get reasoning removed.
- Developer messages before the last user are dropped entirely.
"""
last_user_idx = find_last_user_index(messages)
result = []
keep_roles = {"user", "system", "tool", "latest_reminder", "direct_search_results"}
for idx, msg in enumerate(messages):
role = msg.get("role")
if role in keep_roles or idx >= last_user_idx:
result.append(msg)
elif role == "assistant":
msg = copy.copy(msg)
msg.pop("reasoning", None)
result.append(msg)
# developer and other roles before last_user_idx are dropped
return result
# ============================================================
# Parsing (Decoding model output)
# ============================================================
def _read_until_stop(index: int, text: str, stop: List[str]) -> Tuple[int, str, Optional[str]]:
"""
Read text from index until one of the stop strings is found.
Returns:
Tuple of (new_index, content_before_stop, matched_stop_string_or_None).
"""
min_pos = len(text)
matched_stop = None
for s in stop:
pos = text.find(s, index)
if pos != -1 and pos < min_pos:
min_pos = pos
matched_stop = s
if matched_stop:
content = text[index:min_pos]
return min_pos + len(matched_stop), content, matched_stop
else:
content = text[index:]
return len(text), content, None
def parse_tool_calls(index: int, text: str) -> Tuple[int, Optional[str], List[Dict[str, str]]]:
"""
Parse DSML tool calls from text starting at the given index.
Args:
index: Starting position in text.
text: The full text to parse.
Returns:
Tuple of (new_index, last_stop_token, list_of_tool_call_dicts).
Each tool call dict has "name" and "arguments" keys.
"""
tool_calls: List[Dict[str, Any]] = []
stop_token = None
tool_calls_end_token = f"</{dsml_token}{tool_calls_block_name}>"
while index < len(text):
index, content_before, stop_token = _read_until_stop(index, text, [f"<{dsml_token}invoke", tool_calls_end_token])
if content_before != ">\n":
raise ValueError(f"Tool call format error: expected '>\\n' but got '{content_before}'")
if stop_token == tool_calls_end_token:
break
if stop_token is None:
raise ValueError("Missing special token in tool calls")
index, tool_name_content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
p_tool_name = re.findall(r'^\s*name="(.*?)">\n$', tool_name_content, flags=re.DOTALL)
if len(p_tool_name) != 1:
raise ValueError(f"Tool name format error: '{tool_name_content}'")
tool_name = p_tool_name[0]
tool_args: Dict[str, Tuple[str, str]] = {}
while stop_token == f"<{dsml_token}parameter":
index, param_content, stop_token = _read_until_stop(index, text, [f"/{dsml_token}parameter"])
param_kv = re.findall(r'^ name="(.*?)" string="(true|false)">(.*?)<$', param_content, flags=re.DOTALL)
if len(param_kv) != 1:
raise ValueError(f"Parameter format error: '{param_content}'")
param_name, string, param_value = param_kv[0]
if param_name in tool_args:
raise ValueError(f"Duplicate parameter name: '{param_name}'")
tool_args[param_name] = (param_value, string)
index, content, stop_token = _read_until_stop(index, text, [f"<{dsml_token}parameter", f"</{dsml_token}invoke"])
if content != ">\n":
raise ValueError(f"Parameter format error: expected '>\\n' but got '{content}'")
tool_call = decode_dsml_to_arguments(tool_name=tool_name, tool_args=tool_args)
tool_calls.append(tool_call)
return index, stop_token, tool_calls
def parse_message_from_completion_text(text: str, thinking_mode: str) -> Dict[str, Any]:
"""
Parse a model completion text into a structured assistant message.
This function takes the raw text output from the model (a single assistant turn)
and extracts:
- reasoning (thinking block)
- content (summary/response)
- tool_calls (if any)
NOTE: This function is designed to parse only correctly formatted strings and
will raise ValueError for malformed output.
Args:
text: The raw completion text (including EOS token).
thinking_mode: Either "chat" or "thinking".
Returns:
Dict with keys: "role", "content", "reasoning", "tool_calls".
tool_calls are in OpenAI format.
"""
summary_content, reasoning = "", ""
tool_calls: List[Dict[str, str]] = []
index, stop_token = 0, None
tool_calls_start_token = f"\n\n<{dsml_token}{tool_calls_block_name}"
is_thinking = thinking_mode == "thinking"
is_tool_calling = False
if is_thinking:
index, content_delta, stop_token = _read_until_stop(index, text, [thinking_end_token, tool_calls_start_token])
reasoning = content_delta
if stop_token != thinking_end_token:
raise ValueError("Invalid thinking format: missing </think>")
index, content_delta, stop_token = _read_until_stop(index, text, [eos_token, tool_calls_start_token])
summary_content = content_delta
if stop_token == tool_calls_start_token:
is_tool_calling = True
else:
if stop_token != eos_token:
raise ValueError("Invalid format: missing EOS token")
if is_tool_calling:
index, stop_token, tool_calls = parse_tool_calls(index, text)
index, tool_ends_text, stop_token = _read_until_stop(index, text, [eos_token])
if tool_ends_text:
raise ValueError("Unexpected content after tool calls")
if len(text) != index or stop_token not in [eos_token, None]:
raise ValueError("Unexpected content at end")
for sp_token in [bos_token, eos_token, thinking_start_token, thinking_end_token, dsml_token]:
if sp_token in summary_content or sp_token in reasoning:
raise ValueError(f"Unexpected special token '{sp_token}' in content")
return {
"role": "assistant",
"content": summary_content,
"reasoning": reasoning,
"tool_calls": tool_calls_to_openai_format(tool_calls)
}
# fmt: on
......@@ -42,6 +42,7 @@ _MODEL_TYPES_WITH_INCORRECT_TOKENIZER_CLASS: set[str] = {"step3_vl"}
_VLLM_TOKENIZERS = {
"deepseek_v32": ("deepseek_v32", "DeepseekV32Tokenizer"),
"deepseek_v4": ("deepseek_v4", "DeepseekV4Tokenizer"),
"grok2": ("grok2", "Grok2Tokenizer"),
"hf": ("hf", "CachedHfTokenizer"),
"kimi_audio": ("kimi_audio", "KimiAudioTokenizer"),
......
......@@ -34,6 +34,10 @@ _TOOL_PARSERS_TO_REGISTER = {
"deepseekv32_tool_parser",
"DeepSeekV32ToolParser",
),
"deepseek_v4": (
"deepseekv4_tool_parser",
"DeepSeekV4ToolParser",
),
"ernie45": (
"ernie45_tool_parser",
"Ernie45ToolParser",
......
......@@ -46,21 +46,24 @@ class DeepSeekV32ToolParser(ToolParser):
</|DSML|function_calls>
"""
tool_call_start_token: str = "<|DSML|function_calls>"
tool_call_end_token: str = "</|DSML|function_calls>"
def __init__(self, tokenizer: TokenizerLike, tools: list[Tool] | None = None):
super().__init__(tokenizer, tools)
self.prev_tool_call_arr: list[dict] = []
# Sentinel token
self.tool_call_start_token: str = "<|DSML|function_calls>"
# Streaming state
self.current_tool_index: int = 0
self._sent_content_idx: int = 0
# Regex patterns for complete parsing
self.tool_call_complete_regex = re.compile(
r"<|DSML|function_calls>(.*?)</|DSML|function_calls>", re.DOTALL
re.escape(self.tool_call_start_token)
+ r"(.*?)"
+ re.escape(self.tool_call_end_token),
re.DOTALL,
)
self.invoke_complete_regex = re.compile(
r'<|DSML|invoke\s+name="([^"]+)"\s*>(.*?)</|DSML|invoke>', re.DOTALL
......@@ -86,7 +89,7 @@ class DeepSeekV32ToolParser(ToolParser):
request = super().adjust_request(request)
if request.tools and request.tool_choice != "none":
# Ensure tool call tokens
# (<|DSML|function_calls>, </|DSML|function_calls>)
# (e.g. <|DSML|function_calls>, </|DSML|function_calls>)
# are not skippedduring decoding.
# Even though they are not marked as special tokens,
# setting skip_special_tokens=False ensures proper handling in
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from vllm.tool_parsers.deepseekv32_tool_parser import DeepSeekV32ToolParser
class DeepSeekV4ToolParser(DeepSeekV32ToolParser):
"""
DeepSeek V4 DSML tool parser.
V4 keeps the V3.2 DSML invoke/parameter grammar, but wraps tool calls in
``<|DSML|tool_calls>`` instead of ``<|DSML|function_calls>``.
"""
tool_call_start_token: str = "<|DSML|tool_calls>"
tool_call_end_token: str = "</|DSML|tool_calls>"
......@@ -89,6 +89,7 @@ _CONFIG_REGISTRY: dict[str, type[PretrainedConfig]] = LazyConfigDict(
qwen3_vl_nemotron_embed="Qwen3VLNemotronEmbedConfig",
deepseek_vl_v2="DeepseekVLV2Config",
deepseek_v32="DeepseekV3Config",
deepseek_v4="DeepseekV4Config",
flex_olmo="FlexOlmoConfig",
fireredlid="FireRedLIDConfig",
funaudiochat="FunAudioChatConfig",
......
......@@ -26,6 +26,7 @@ _CLASS_TO_MODULE: dict[str, str] = {
"OpsColQwen3Config": "vllm.transformers_utils.configs.colqwen3",
"Qwen3VLNemotronEmbedConfig": "vllm.transformers_utils.configs.colqwen3",
"DeepseekVLV2Config": "vllm.transformers_utils.configs.deepseek_vl2",
"DeepseekV4Config": "vllm.transformers_utils.configs.deepseek_v4",
"DotsOCRConfig": "vllm.transformers_utils.configs.dotsocr",
"EAGLEConfig": "vllm.transformers_utils.configs.eagle",
"FireRedLIDConfig": "vllm.transformers_utils.configs.fireredlid",
......@@ -88,6 +89,7 @@ __all__ = [
"Qwen3VLNemotronEmbedConfig",
"DeepseekVLV2Config",
"DeepseekV3Config",
"DeepseekV4Config",
"DotsOCRConfig",
"EAGLEConfig",
"FlexOlmoConfig",
......
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
from typing import Any
from transformers import PretrainedConfig
class DeepseekV4Config(PretrainedConfig):
model_type = "deepseek_v4"
def __init__(
self,
max_position_embeddings: int = 1048576,
rope_scaling: dict[str, Any] | None = None,
rope_parameters: dict[str, Any] | None = None,
rope_theta: float = 10000.0,
**kwargs,
):
self.max_position_embeddings = max_position_embeddings
self.rope_scaling = rope_scaling
self.rope_theta = rope_theta
self.rope_parameters = rope_scaling or rope_parameters
super().__init__(**kwargs)
......@@ -47,6 +47,9 @@ class ModelArchConfigConvertorBase:
def get_head_size(self) -> int:
if self.is_deepseek_mla():
# special case for deepseek_v4
if hasattr(self.hf_text_config, "compress_ratios"):
return self.hf_text_config.head_dim
qk_rope_head_dim = getattr(self.hf_text_config, "qk_rope_head_dim", 0)
if not envs.VLLM_MLA_DISABLE:
return self.hf_text_config.kv_lora_rank + qk_rope_head_dim
......@@ -222,6 +225,7 @@ class ModelArchConfigConvertorBase:
"deepseek_v2",
"deepseek_v3",
"deepseek_v32",
"deepseek_v4",
"deepseek_mtp",
"glm_moe_dsa",
"glm4_moe_lite",
......@@ -233,7 +237,11 @@ class ModelArchConfigConvertorBase:
"pangu_ultra_moe_mtp",
"bailing_hybrid",
):
return getattr(self.hf_text_config, "kv_lora_rank", None) is not None
# check is deepseek_v4 model
if hasattr(self.hf_text_config, "compress_ratios"):
return getattr(self.hf_text_config, "head_dim", None) is not None
else:
return getattr(self.hf_text_config, "kv_lora_rank", None) is not None
elif self.hf_text_config.model_type == "eagle":
# if the model is an EAGLE module, check for the
# underlying architecture
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment