# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import typing from collections.abc import Callable, Iterable from itertools import islice import regex as re import torch import torch.nn as nn import torch.nn.functional as F from vllm.compilation.decorators import support_torch_compile from vllm.config import VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from vllm.forward_context import get_forward_context from vllm.model_executor.layers.activation import SiluAndMul, SiluAndMulWithClamp from vllm.model_executor.layers.deepseek_v4_attention import ( DeepseekV4Indexer, DeepseekV4MLAModules, DeepseekV4MultiHeadLatentAttentionWrapper, ) from vllm.model_executor.layers.fused_moe import FusedMoE, GateLinear from vllm.model_executor.layers.fused_moe.layer import UnquantizedFusedMoEMethod from vllm.model_executor.layers.fused_moe.router.fused_topk_bias_router import ( fused_topk_bias, ) from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( ColumnParallelLinear, MergedColumnParallelLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import ( QuantizationConfig, QuantizationMethods, ) from vllm.model_executor.layers.quantization.fp8 import Fp8Config from vllm.model_executor.layers.quantization.mxfp4 import Mxfp4MoEMethod from vllm.model_executor.layers.quantization.utils.quant_utils import ( is_layer_skipped, ) from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import default_weight_loader from vllm.model_executor.utils import set_weight_attrs from vllm.platforms import current_platform from vllm.sequence import IntermediateTensors from vllm.triton_utils import tl, triton from vllm.utils.multi_stream_utils import AuxStreamType from vllm.utils.torch_utils import direct_register_custom_op from .utils import ( AutoWeightsLoader, WeightsMapper, extract_layer_index, make_layers, maybe_prefix, ) _DEEPSEEK_V4_EXPERT_DTYPES = ("fp4", "fp8") class DeepseekV4MLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, swiglu_limit: float | None = None, quant_config: QuantizationConfig | None = None, reduce_results: bool = True, is_sequence_parallel: bool = False, prefix: str = "", ) -> None: super().__init__() # If is_sequence_parallel, the input and output tensors are sharded # across the ranks within the tp_group. In this case the weights are # replicated and no collective ops are needed. # Otherwise we use standard TP with an allreduce at the end. self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, disable_tp=is_sequence_parallel, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, disable_tp=is_sequence_parallel, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) if swiglu_limit is not None: self.act_fn = SiluAndMulWithClamp(swiglu_limit) else: self.act_fn = SiluAndMul() def forward(self, x): gate_up, _ = self.gate_up_proj(x) x = self.act_fn(gate_up) x, _ = self.down_proj(x) return x class DeepseekV4FP8Config(Fp8Config): """FP8 config for DeepSeek V4 with expert-dtype-aware MoE dispatch. DeepSeek V4 checkpoints always use FP8 block quantization for linear/attention layers. The MoE expert weights vary by checkpoint: - ``expert_dtype="fp4"`` (e.g. DeepSeek-V4-Flash): MXFP4 experts with ue8m0 (e8m0fnu) FP8 linear scales. - ``expert_dtype="fp8"`` (e.g. DeepSeek-V4-Flash-Base): FP8 block experts with float32 FP8 linear scales. The dispatch and the linear scale dtype are both keyed off ``expert_dtype`` from the model's hf_config; missing values default to ``"fp4"`` so existing FP4 checkpoints stay unchanged. NOTE: ``expert_dtype`` is resolved lazily because this config is constructed during VllmConfig setup, before ``set_current_vllm_config`` is active. Reading hf_config eagerly in ``__init__`` would always see the default ``"fp4"`` and silently misroute Flash-Base checkpoints. """ def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self._resolved_expert_dtype: str | None = None # ``is_scale_e8m0`` is a property that resolves on first read, # by which time the current vllm_config has been set. @property def expert_dtype(self) -> str: if self._resolved_expert_dtype is None: try: hf_config = get_current_vllm_config().model_config.hf_config except Exception: # vllm_config not yet set; defer the decision until a # later call lands inside set_current_vllm_config. return "fp4" expert_dtype = getattr(hf_config, "expert_dtype", "fp4") if expert_dtype not in _DEEPSEEK_V4_EXPERT_DTYPES: raise ValueError( f"Unsupported DeepSeek V4 expert_dtype={expert_dtype!r}; " f"expected one of {_DEEPSEEK_V4_EXPERT_DTYPES}." ) self._resolved_expert_dtype = expert_dtype from vllm.logger import init_logger init_logger(__name__).info_once( "DeepSeek V4 expert_dtype resolved to %r", expert_dtype ) return self._resolved_expert_dtype @property def is_scale_e8m0(self) -> bool: # FP4 checkpoints store FP8 linear scales as e8m0fnu; FP8 expert # checkpoints (Flash-Base) store them as float32. return self.expert_dtype == "fp4" @classmethod def get_name(cls) -> QuantizationMethods: return "deepseek_v4_fp8" @classmethod def override_quantization_method( cls, hf_quant_cfg, user_quant, hf_config=None ) -> QuantizationMethods | None: if not ( isinstance(hf_quant_cfg, dict) and hf_quant_cfg.get("quant_method") in ("fp8", "deepseek_v4_fp8") ): return None model_type = getattr(hf_config, "model_type", None) if model_type == "deepseek_v4" or user_quant == "deepseek_v4_fp8": return "deepseek_v4_fp8" return None def get_quant_method(self, layer, prefix): if isinstance(layer, FusedMoE): if is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): return UnquantizedFusedMoEMethod(layer.moe_config) if self.expert_dtype == "fp4": return Mxfp4MoEMethod(layer.moe_config) # expert_dtype == "fp8": fall through to Fp8Config which # returns Fp8MoEMethod with block-wise float32 scales. return super().get_quant_method(layer, prefix) def is_mxfp4_quant(self, prefix, layer): return isinstance(layer, FusedMoE) and self.expert_dtype == "fp4" @triton.jit def _deepseek_v4_stage_mega_moe_inputs_kernel( hidden_states, x_fp8, x_sf, topk_ids, topk_weights, topk_idx_out, topk_weights_out, hidden_stride_m: tl.constexpr, hidden_stride_k: tl.constexpr, x_stride_m: tl.constexpr, x_stride_k: tl.constexpr, x_sf_stride_m: tl.constexpr, x_sf_stride_k: tl.constexpr, topk_ids_stride_m: tl.constexpr, topk_ids_stride_k: tl.constexpr, topk_weights_stride_m: tl.constexpr, topk_weights_stride_k: tl.constexpr, topk_idx_stride_m: tl.constexpr, topk_idx_stride_k: tl.constexpr, topk_weights_out_stride_m: tl.constexpr, topk_weights_out_stride_k: tl.constexpr, hidden_size: tl.constexpr, top_k: tl.constexpr, BLOCK_K: tl.constexpr, GROUP_K: tl.constexpr, BLOCK_TOPK: tl.constexpr, ) -> None: token_id = tl.program_id(0) k_block_id = tl.program_id(1) k_offsets = k_block_id * BLOCK_K + tl.arange(0, BLOCK_K) k_mask = k_offsets < hidden_size hidden = tl.load( hidden_states + token_id * hidden_stride_m + k_offsets * hidden_stride_k, mask=k_mask, other=0.0, ).to(tl.float32) num_groups: tl.constexpr = BLOCK_K // GROUP_K hidden_groups = tl.reshape(tl.abs(hidden), [num_groups, GROUP_K]) amax = tl.max(hidden_groups, axis=1) amax = tl.maximum(amax, 1.0e-4) scale = amax / 448.0 scale_bits = scale.to(tl.uint32, bitcast=True) scale_exp = ((scale_bits >> 23) & 0xFF) + ((scale_bits & 0x7FFFFF) != 0).to( tl.uint32 ) scale_exp = tl.minimum(tl.maximum(scale_exp, 1), 254) rounded_scale = (scale_exp << 23).to(tl.float32, bitcast=True) hidden_groups = tl.reshape(hidden, [num_groups, GROUP_K]) scaled = hidden_groups * (1.0 / rounded_scale)[:, None] scaled = tl.reshape(scaled, [BLOCK_K]) fp8 = scaled.to(tl.float8e4nv) tl.store( x_fp8 + token_id * x_stride_m + k_offsets * x_stride_k, fp8, mask=k_mask, ) scale_offsets = tl.arange(0, num_groups) packed_scale = tl.sum(scale_exp << (scale_offsets * 8), axis=0).to(tl.int32) tl.store( x_sf + token_id * x_sf_stride_m + k_block_id * x_sf_stride_k, packed_scale, ) if k_block_id == 0: topk_offsets = tl.arange(0, BLOCK_TOPK) topk_mask = topk_offsets < top_k ids = tl.load( topk_ids + token_id * topk_ids_stride_m + topk_offsets * topk_ids_stride_k, mask=topk_mask, other=0, ).to(tl.int64) tl.store( topk_idx_out + token_id * topk_idx_stride_m + topk_offsets * topk_idx_stride_k, ids, mask=topk_mask, ) weights = tl.load( topk_weights + token_id * topk_weights_stride_m + topk_offsets * topk_weights_stride_k, mask=topk_mask, other=0.0, ) tl.store( topk_weights_out + token_id * topk_weights_out_stride_m + topk_offsets * topk_weights_out_stride_k, weights, mask=topk_mask, ) def _stage_deepseek_v4_mega_moe_inputs( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, x_fp8: torch.Tensor, x_sf: torch.Tensor, topk_idx_out: torch.Tensor, topk_weights_out: torch.Tensor, ) -> None: num_tokens, hidden_size = hidden_states.shape if num_tokens == 0: return if hidden_size % 128 != 0: raise ValueError( "DeepSeek V4 MegaMoE input staging requires hidden_size to be " "a multiple of 128." ) top_k = topk_ids.shape[1] if topk_weights.shape != topk_ids.shape: raise ValueError( "DeepSeek V4 MegaMoE input staging requires topk_weights and " "topk_ids to have the same shape." ) block_k = 128 grid = (num_tokens, triton.cdiv(hidden_size, block_k)) block_topk = triton.next_power_of_2(top_k) _deepseek_v4_stage_mega_moe_inputs_kernel[grid]( hidden_states, x_fp8, x_sf, topk_ids, topk_weights, topk_idx_out, topk_weights_out, hidden_states.stride(0), hidden_states.stride(1), x_fp8.stride(0), x_fp8.stride(1), x_sf.stride(0), x_sf.stride(1), topk_ids.stride(0), topk_ids.stride(1), topk_weights.stride(0), topk_weights.stride(1), topk_idx_out.stride(0), topk_idx_out.stride(1), topk_weights_out.stride(0), topk_weights_out.stride(1), hidden_size, top_k, BLOCK_K=block_k, GROUP_K=32, BLOCK_TOPK=block_topk, num_warps=4, ) def make_deepseek_v4_expert_params_mapping( num_experts: int, ) -> list[tuple[str, str, int, str]]: return [ ( "experts.w13_" if shard_id in ("w1", "w3") else "experts.w2_", f"experts.{expert_id}.{weight_name}.", expert_id, shard_id, ) for expert_id in range(num_experts) for shard_id, weight_name in [ ("w1", "w1"), ("w2", "w2"), ("w3", "w3"), ] ] class DeepseekV4MegaMoEExperts(nn.Module): _symm_buffer_cache: dict[tuple[int, int, int, int, int, int, int], object] = {} def __init__( self, vllm_config: VllmConfig, *, num_experts: int, num_local_experts: int, experts_start_idx: int, top_k: int, hidden_size: int, intermediate_size: int, prefix: str = "", ): super().__init__() self.prefix = prefix self.num_experts = num_experts self.num_local_experts = num_local_experts self.experts_start_idx = experts_start_idx self.experts_end_idx = experts_start_idx + num_local_experts self.top_k = top_k self.hidden_size = hidden_size self.intermediate_size = intermediate_size self.max_num_tokens = vllm_config.scheduler_config.max_num_batched_tokens weight_attrs = {"weight_loader": self.weight_loader} self.w13_weight = nn.Parameter( torch.zeros( num_local_experts, 2 * intermediate_size, hidden_size // 2, dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w13_weight, weight_attrs) self.w13_weight_scale = nn.Parameter( torch.zeros( num_local_experts, 2 * intermediate_size, hidden_size // 32, dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w13_weight_scale, weight_attrs) self.w13_weight_scale.quant_method = "block" self.w2_weight = nn.Parameter( torch.zeros( num_local_experts, hidden_size, intermediate_size // 2, dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w2_weight, weight_attrs) self.w2_weight_scale = nn.Parameter( torch.zeros( num_local_experts, hidden_size, intermediate_size // 32, dtype=torch.uint8, ), requires_grad=False, ) set_weight_attrs(self.w2_weight_scale, weight_attrs) self.w2_weight_scale.quant_method = "block" self._transformed_l1_weights: tuple[torch.Tensor, torch.Tensor] | None = None self._transformed_l2_weights: tuple[torch.Tensor, torch.Tensor] | None = None # Register in the static forward context so the custom-op wrapper # can look up this module by name from within a torch.compile graph. compilation_config = vllm_config.compilation_config if prefix in compilation_config.static_forward_context: raise ValueError(f"Duplicate layer name: {prefix}") compilation_config.static_forward_context[prefix] = self def _map_global_expert_id(self, expert_id: int) -> int: if expert_id < self.experts_start_idx or expert_id >= self.experts_end_idx: return -1 return expert_id - self.experts_start_idx def weight_loader( self, param: nn.Parameter, loaded_weight: torch.Tensor, weight_name: str, shard_id: str, expert_id: int, return_success: bool = False, ) -> bool | None: local_expert_id = self._map_global_expert_id(expert_id) if local_expert_id == -1: return False if return_success else None expert_data = param.data[local_expert_id] if shard_id in ("w1", "w3"): if "w13_" not in weight_name: return False if return_success else None shard_offset = 0 if shard_id == "w1" else self.intermediate_size expert_data = expert_data.narrow(0, shard_offset, self.intermediate_size) elif shard_id == "w2": if "w2_" not in weight_name: return False if return_success else None else: raise ValueError(f"Unsupported expert shard id: {shard_id}") if expert_data.shape != loaded_weight.shape: raise ValueError( f"DeepSeek V4 MegaMoE expert weight shape mismatch for " f"{weight_name}: parameter shard {tuple(expert_data.shape)} " f"vs checkpoint {tuple(loaded_weight.shape)}" ) expert_data.copy_(loaded_weight) return True if return_success else None @staticmethod def _ue8m0_uint8_to_float(sf: torch.Tensor) -> torch.Tensor: return (sf.to(torch.int32) << 23).view(torch.float32) def _check_runtime_supported(self) -> None: if not torch.cuda.is_available(): raise NotImplementedError("DeepSeek V4 MegaMoE requires CUDA.") device = self.w13_weight.device if device.type != "cuda": raise NotImplementedError( "DeepSeek V4 MegaMoE expert weights must be loaded on CUDA." ) if torch.cuda.get_device_capability(device)[0] != 10: raise NotImplementedError("DeepGEMM MegaMoE requires SM100 GPUs.") if self.hidden_size % 128 != 0 or self.intermediate_size % 128 != 0: raise ValueError( "DeepGEMM MegaMoE requires hidden and intermediate sizes " "to be multiples of 128." ) def finalize_weights(self) -> None: if self._transformed_l1_weights is not None: return self._check_runtime_supported() import vllm.third_party.deep_gemm as deep_gemm w13_scale = deep_gemm.transform_sf_into_required_layout( self._ue8m0_uint8_to_float(self.w13_weight_scale.data).contiguous(), 2 * self.intermediate_size, self.hidden_size, (1, 32), self.num_local_experts, ) w2_scale = deep_gemm.transform_sf_into_required_layout( self._ue8m0_uint8_to_float(self.w2_weight_scale.data).contiguous(), self.hidden_size, self.intermediate_size, (1, 32), self.num_local_experts, ) self._transformed_l1_weights, self._transformed_l2_weights = ( deep_gemm.transform_weights_for_mega_moe( (self.w13_weight.data.view(torch.int8).contiguous(), w13_scale), (self.w2_weight.data.view(torch.int8).contiguous(), w2_scale), ) ) # Drop the original loader-side parameters: the MegaMoE kernels only # consume the transformed views above. transform_weights_for_mega_moe # allocates a fresh tensor for the L1 weight (see _interleave_l1_weights) # and fresh SF tensors for L1/L2; the L2 weight is the only tensor that # aliases the original storage, and _transformed_l2_weights still holds # it, so the storage stays live after we drop the Parameter. self.w13_weight = None self.w13_weight_scale = None self.w2_weight = None self.w2_weight_scale = None def get_symm_buffer(self): import vllm.third_party.deep_gemm as deep_gemm group = get_ep_group().device_group device = torch.accelerator.current_device_index() key = ( id(group), device, self.num_experts, self.max_num_tokens, self.top_k, self.hidden_size, self.intermediate_size, ) symm_buffer = self._symm_buffer_cache.get(key) if symm_buffer is None: symm_buffer = deep_gemm.get_symm_buffer_for_mega_moe( group, self.num_experts, self.max_num_tokens, self.top_k, self.hidden_size, self.intermediate_size, ) self._symm_buffer_cache[key] = symm_buffer return symm_buffer def forward( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, *, activation_clamp: float | None, fast_math: bool = True, ) -> torch.Tensor: if hidden_states.shape[0] > self.max_num_tokens: raise ValueError( f"DeepSeek V4 MegaMoE got {hidden_states.shape[0]} tokens, " f"but the symmetric buffer was sized for {self.max_num_tokens}." ) y = torch.empty_like(hidden_states, dtype=torch.bfloat16) torch.ops.vllm.deepseek_v4_mega_moe_experts( hidden_states, topk_weights, topk_ids, y, self.prefix, activation_clamp, fast_math, ) return y def _run_mega_moe( self, hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, y: torch.Tensor, activation_clamp: float | None, fast_math: bool, ) -> None: import vllm.third_party.deep_gemm as deep_gemm symm_buffer = self.get_symm_buffer() num_tokens = hidden_states.shape[0] _stage_deepseek_v4_mega_moe_inputs( hidden_states, topk_weights, topk_ids, symm_buffer.x[:num_tokens], symm_buffer.x_sf[:num_tokens], symm_buffer.topk_idx[:num_tokens], symm_buffer.topk_weights[:num_tokens], ) # This method must have been already called during the weight loading phase. # We call it again here to cover the dummy weight loading case. self.finalize_weights() assert self._transformed_l1_weights is not None assert self._transformed_l2_weights is not None deep_gemm.fp8_fp4_mega_moe( y, self._transformed_l1_weights, self._transformed_l2_weights, symm_buffer, activation_clamp=activation_clamp, fast_math=fast_math, ) DeepseekV4MegaMoEExperts.weight_loader.supports_moe_loading = True # type: ignore[attr-defined] def _deepseek_v4_mega_moe_experts_op( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, out: torch.Tensor, layer_name: str, activation_clamp: float | None, fast_math: bool, ) -> None: self = get_forward_context().no_compile_layers[layer_name] self._run_mega_moe( hidden_states, topk_weights, topk_ids, out, activation_clamp, fast_math, ) def _deepseek_v4_mega_moe_experts_op_fake( hidden_states: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, out: torch.Tensor, layer_name: str, activation_clamp: float | None, fast_math: bool, ) -> None: return None direct_register_custom_op( op_name="deepseek_v4_mega_moe_experts", op_func=_deepseek_v4_mega_moe_experts_op, mutates_args=["out"], fake_impl=_deepseek_v4_mega_moe_experts_op_fake, ) class DeepseekV4MoE(nn.Module): def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() self.tp_size = get_tensor_model_parallel_world_size() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.prefix = prefix if vllm_config.parallel_config.enable_expert_parallel: self.use_mega_moe = ( vllm_config.kernel_config.moe_backend == "deep_gemm_mega_moe" ) else: self.use_mega_moe = False self.routed_scaling_factor = getattr(config, "routed_scaling_factor", 1.0) self.hidden_size = config.hidden_size self.n_routed_experts = config.n_routed_experts self.n_activated_experts = config.num_experts_per_tok self.moe_intermediate_size = config.moe_intermediate_size self.swiglu_limit = config.swiglu_limit self.renormalize = config.norm_topk_prob self.scoring_func = getattr(config, "scoring_func", "sqrtsoftplus") if self.use_mega_moe and self.scoring_func != "sqrtsoftplus": raise NotImplementedError( "DeepSeek V4 MegaMoE currently supports sqrtsoftplus routing only." ) if self.use_mega_moe and getattr(config, "expert_dtype", "fp4") != "fp4": raise NotImplementedError( "DeepSeek V4 MegaMoE only supports fp4 experts; got expert_dtype=" f"{config.expert_dtype!r}. Drop --kernel-config moe_backend=" "deep_gemm_mega_moe for this checkpoint." ) self.gate = GateLinear( config.hidden_size, config.n_routed_experts, out_dtype=torch.float32, bias=False, prefix=f"{prefix}.gate", ) self.gate.e_score_correction_bias = None self.gate.tid2eid = None is_hash_moe = extract_layer_index(prefix) < config.num_hash_layers self.hash_indices_dtype = torch.int64 if self.use_mega_moe else torch.int32 if is_hash_moe: # hash MoE doesn't use e_score_correction_bias # Use randint instead of empty to avoid garbage values causing # invalid memory access in dummy mode (--load-format="dummy") self.gate.tid2eid = nn.Parameter( torch.randint( 0, config.n_routed_experts, (config.vocab_size, config.num_experts_per_tok), dtype=self.hash_indices_dtype, ), requires_grad=False, ) elif getattr(config, "topk_method", None) == "noaux_tc": self.gate.e_score_correction_bias = nn.Parameter( torch.empty(config.n_routed_experts, dtype=torch.float32), requires_grad=False, ) if config.n_shared_experts is None: self.shared_experts = None else: intermediate_size = config.moe_intermediate_size * config.n_shared_experts self.shared_experts = DeepseekV4MLP( hidden_size=config.hidden_size, intermediate_size=intermediate_size, hidden_act=config.hidden_act, swiglu_limit=self.swiglu_limit, quant_config=quant_config, reduce_results=self.use_mega_moe, prefix=f"{prefix}.shared_experts", ) if self.use_mega_moe: self._init_mega_moe_experts(vllm_config, config, prefix) else: self._init_fused_moe_experts(config, quant_config, prefix) def _init_mega_moe_experts( self, vllm_config: VllmConfig, config, prefix: str, ) -> None: self.ep_group = get_ep_group() self.ep_size = self.ep_group.world_size self.ep_rank = self.ep_group.rank_in_group assert config.n_routed_experts % self.ep_size == 0 self.n_local_experts = config.n_routed_experts // self.ep_size self.experts_start_idx = self.ep_rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.experts = DeepseekV4MegaMoEExperts( vllm_config, num_experts=config.n_routed_experts, num_local_experts=self.n_local_experts, experts_start_idx=self.experts_start_idx, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, prefix=f"{prefix}.experts", ) def _init_fused_moe_experts( self, config, quant_config, prefix: str, ) -> None: self.tp_rank = get_tensor_model_parallel_rank() assert config.n_routed_experts % self.tp_size == 0 self.n_local_experts = config.n_routed_experts // self.tp_size self.experts_start_idx = self.tp_rank * self.n_local_experts self.experts_end_idx = self.experts_start_idx + self.n_local_experts self.experts = FusedMoE( shared_experts=self.shared_experts, gate=self.gate, num_experts=config.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", scoring_func=self.scoring_func, routed_scaling_factor=self.routed_scaling_factor, e_score_correction_bias=self.gate.e_score_correction_bias, hash_indices_table=self.gate.tid2eid, swiglu_limit=self.swiglu_limit, router_logits_dtype=torch.float32, ) def forward( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: if self.gate.tid2eid is not None: if input_ids is None: raise ValueError("DeepSeek V4 hash MoE routing requires input_ids.") input_ids = input_ids.to(dtype=self.hash_indices_dtype) if not self.use_mega_moe: return self._forward_fused_moe(hidden_states, input_ids) org_shape = hidden_states.shape router_logits, _ = self.gate(hidden_states) topk_weights, topk_ids = fused_topk_bias( hidden_states=hidden_states, gating_output=router_logits, scoring_func=self.scoring_func, e_score_correction_bias=self.gate.e_score_correction_bias.data if self.gate.e_score_correction_bias is not None else None, topk=self.n_activated_experts, renormalize=self.renormalize, indices_type=self.hash_indices_dtype, input_tokens=input_ids, hash_indices_table=self.gate.tid2eid, routed_scaling_factor=self.routed_scaling_factor, ) activation_clamp = ( float(self.swiglu_limit) if self.swiglu_limit is not None else None ) final_hidden_states = self.experts( hidden_states, topk_weights, topk_ids, activation_clamp=activation_clamp, ) if self.shared_experts is not None: shared_output = self.shared_experts(hidden_states) final_hidden_states += shared_output return final_hidden_states.view(org_shape) def _forward_fused_moe( self, hidden_states: torch.Tensor, input_ids: torch.Tensor | None = None ) -> torch.Tensor: org_shape = hidden_states.shape if self.experts.is_internal_router: # In this case, the gate/router runs inside the FusedMoE class final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=hidden_states, input_ids=input_ids, ) else: router_logits, _ = self.gate(hidden_states) final_hidden_states = self.experts( hidden_states=hidden_states, router_logits=router_logits, input_ids=input_ids, ) return final_hidden_states.view(org_shape) def finalize_mega_moe_weights(self) -> None: if self.use_mega_moe: self.experts.finalize_weights() class DeepseekV4Attention(nn.Module): def __init__( self, vllm_config: VllmConfig, prefix: str, topk_indices_buffer: torch.Tensor | None = None, aux_stream: torch.cuda.Stream | None = None, ): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config layer_id = extract_layer_index(prefix) self.layer_id = layer_id self.hidden_size = config.hidden_size self.n_heads = config.num_attention_heads tp_size = get_tensor_model_parallel_world_size() assert self.n_heads % tp_size == 0 self.n_local_heads = self.n_heads // tp_size self.q_lora_rank = config.q_lora_rank self.o_lora_rank = config.o_lora_rank self.head_dim = config.head_dim self.rope_head_dim = config.qk_rope_head_dim self.nope_head_dim = self.head_dim - self.rope_head_dim self.n_groups = config.o_groups self.n_local_groups = self.n_groups // tp_size self.window_size = config.sliding_window # NOTE(zyongye) Compress ratio can't be 0 # we do this for because MTP layer is not included # in the compress ratio list if layer_id < config.num_hidden_layers: self.compress_ratio = max(1, config.compress_ratios[layer_id]) else: self.compress_ratio = 1 self.eps = config.rms_norm_eps self.max_position_embeddings = config.max_position_embeddings # Padded to min 64 heads for FlashMLA, initialized to -inf # (no sink effect). Weight loading fills the first n_local_heads slots. padded_heads = max(self.n_local_heads, 64) self.attn_sink = nn.Parameter( torch.full((padded_heads,), -float("inf"), dtype=torch.float32), requires_grad=False, ) self.fused_wqa_wkv = MergedColumnParallelLinear( self.hidden_size, [self.q_lora_rank, self.head_dim], bias=False, quant_config=quant_config, prefix=f"{prefix}.fused_wqa_wkv", disable_tp=True, # fused ReplicatedLinear ) self.q_norm = RMSNorm(self.q_lora_rank, self.eps) self.wq_b = ColumnParallelLinear( self.q_lora_rank, self.n_heads * self.head_dim, bias=False, quant_config=quant_config, return_bias=False, prefix=f"{prefix}.wq_b", ) self.kv_norm = RMSNorm(self.head_dim, self.eps) self.wo_a = ColumnParallelLinear( self.n_heads * self.head_dim // self.n_groups, self.n_groups * self.o_lora_rank, bias=False, quant_config=quant_config, return_bias=False, prefix=f"{prefix}.wo_a", ) self.wo_a.is_bmm = True self.wo_a.bmm_batch_size = self.n_local_groups self.wo_b = RowParallelLinear( self.n_groups * self.o_lora_rank, self.hidden_size, bias=False, quant_config=quant_config, return_bias=False, prefix=f"{prefix}.wo_b", ) self.softmax_scale = self.head_dim**-0.5 self.scale_fmt = config.quantization_config["scale_fmt"] self.rope_parameters = config.rope_scaling # Initialize rotary embedding BEFORE DeepseekV4MLAModules (which needs it) rope_parameters = config.rope_parameters rope_parameters["rope_theta"] = ( config.compress_rope_theta if self.compress_ratio > 1 else config.rope_theta ) if config.rope_parameters["rope_type"] != "default": config.rope_parameters["rope_type"] = ( "deepseek_yarn" if config.rope_parameters.get("apply_yarn_scaling", True) else "deepseek_llama_scaling" ) rope_parameters["mscale"] = 0 # Disable mscale rope_parameters["mscale_all_dim"] = 0 # Disable mscale rope_parameters["is_deepseek_v4"] = True rope_parameters["rope_dim"] = self.rope_head_dim self.rotary_emb = get_rope( self.head_dim, max_position=self.max_position_embeddings, rope_parameters=rope_parameters, is_neox_style=False, dtype=config.torch_dtype, ) self.indexer = None if self.compress_ratio == 4: # Only C4A uses sparse attention and hence has indexer. self.indexer = DeepseekV4Indexer( vllm_config, config=config, hidden_size=self.hidden_size, q_lora_rank=self.q_lora_rank, quant_config=quant_config, cache_config=vllm_config.cache_config, topk_indices_buffer=topk_indices_buffer, compress_ratio=self.compress_ratio, prefix=f"{prefix}.indexer", ) mla_modules = DeepseekV4MLAModules( vllm_config=vllm_config, fused_wqa_wkv=self.fused_wqa_wkv, q_norm=self.q_norm, wq_b=self.wq_b, kv_norm=self.kv_norm, wo_a=self.wo_a, wo_b=self.wo_b, attn_sink=self.attn_sink, rotary_emb=self.rotary_emb, indexer=self.indexer, indexer_rotary_emb=self.rotary_emb, topk_indices_buffer=topk_indices_buffer, aux_stream=aux_stream, ) self.mla_attn = DeepseekV4MultiHeadLatentAttentionWrapper( hidden_size=self.hidden_size, num_heads=self.n_local_heads, head_dim=self.head_dim, scale=self.softmax_scale, qk_nope_head_dim=self.nope_head_dim, qk_rope_head_dim=self.rope_head_dim, v_head_dim=self.head_dim, q_lora_rank=self.q_lora_rank, kv_lora_rank=self.head_dim, o_lora_rank=self.o_lora_rank, mla_modules=mla_modules, window_size=self.window_size, compress_ratio=self.compress_ratio, cache_config=vllm_config.cache_config, quant_config=quant_config, prefix=prefix, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, llama_4_scaling: torch.Tensor | None, ): return self.mla_attn(positions, hidden_states, llama_4_scaling) class DeepseekV4DecoderLayer(nn.Module): def __init__( self, vllm_config, prefix, topk_indices_buffer: torch.Tensor | None = None, aux_stream_dict: dict[AuxStreamType, torch.cuda.Stream] | None = None, ): super().__init__() config = vllm_config.model_config.hf_config self.hidden_size = config.hidden_size self.rms_norm_eps = config.rms_norm_eps self.attn = DeepseekV4Attention( vllm_config, prefix=f"{prefix}.attn", topk_indices_buffer=topk_indices_buffer, aux_stream=aux_stream_dict.get(AuxStreamType.Attention) if aux_stream_dict is not None else None, ) self.ffn = DeepseekV4MoE(vllm_config, prefix=f"{prefix}.ffn") self.attn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) self.ffn_norm = RMSNorm(self.hidden_size, self.rms_norm_eps) self.hc_mult = config.hc_mult self.hc_sinkhorn_iters = config.hc_sinkhorn_iters self.hc_eps = config.hc_eps self.hc_post_alpha = 2.0 mix_hc = (2 + self.hc_mult) * self.hc_mult hc_dim = self.hc_mult * self.hidden_size self.hc_attn_fn = nn.Parameter( torch.empty( (mix_hc, hc_dim), dtype=torch.float32, ), requires_grad=False, ) self.hc_ffn_fn = nn.Parameter( torch.empty( (mix_hc, hc_dim), dtype=torch.float32, ), requires_grad=False, ) self.hc_attn_base = nn.Parameter( torch.empty( mix_hc, dtype=torch.float32, ), requires_grad=False, ) self.hc_ffn_base = nn.Parameter( torch.empty( mix_hc, dtype=torch.float32, ), requires_grad=False, ) self.hc_attn_scale = nn.Parameter( torch.empty( 3, dtype=torch.float32, ), requires_grad=False, ) self.hc_ffn_scale = nn.Parameter( torch.empty( 3, dtype=torch.float32, ), requires_grad=False, ) def hc_pre( self, x: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, ): # Lazy import to avoid top-level tilelang dependency. # Registers both torch.ops.vllm.mhc_pre and mhc_post, # so hc_post() doesn't need its own import. import vllm.model_executor.layers.mhc # noqa: F401 post_mix, res_mix, layer_input = torch.ops.vllm.mhc_pre( residual=x, fn=hc_fn, hc_scale=hc_scale, hc_base=hc_base, rms_eps=self.rms_norm_eps, hc_pre_eps=self.hc_eps, hc_sinkhorn_eps=self.hc_eps, hc_post_mult_value=self.hc_post_alpha, sinkhorn_repeat=self.hc_sinkhorn_iters, ) return layer_input, post_mix, res_mix def hc_post( self, x: torch.Tensor, residual: torch.Tensor, post: torch.Tensor, comb: torch.Tensor, ): return torch.ops.vllm.mhc_post(x, residual, post, comb) def forward( self, x: torch.Tensor, positions: torch.Tensor, input_ids: torch.Tensor | None, ) -> torch.Tensor: residual = x x, post, comb = self.hc_pre( x, self.hc_attn_fn, self.hc_attn_scale, self.hc_attn_base ) x = self.attn_norm(x) x = self.attn(positions, x, None) x = self.hc_post(x, residual, post, comb) residual = x x, post, comb = self.hc_pre( x, self.hc_ffn_fn, self.hc_ffn_scale, self.hc_ffn_base ) x = self.ffn_norm(x) x = self.ffn(x, input_ids) x = self.hc_post(x, residual, post, comb) return x @support_torch_compile class DeepseekV4Model(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config quant_config = vllm_config.quant_config self.config = config self.vocab_size = config.vocab_size self.hc_eps = config.hc_eps self.hc_mult = config.hc_mult self.hc_dim = self.hc_mult * config.hidden_size self.rms_norm_eps = config.rms_norm_eps aux_stream_list = [torch.cuda.Stream() for _ in range(1)] self.aux_stream_dict = { AuxStreamType.Attention: aux_stream_list[0], } self.device = current_platform.device_type # Reserved topk indices buffer for all Indexer layers to reuse. self.topk_indices_buffer = torch.empty( vllm_config.scheduler_config.max_num_batched_tokens, config.index_topk, dtype=torch.int32, device=self.device, ) self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: DeepseekV4DecoderLayer( vllm_config, prefix=prefix, topk_indices_buffer=self.topk_indices_buffer, aux_stream_dict=self.aux_stream_dict, ), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, self.rms_norm_eps) self.hc_head_fn = nn.Parameter( torch.empty( self.hc_mult, self.hc_dim, dtype=torch.float32, ), requires_grad=False, ) self.hc_head_base = nn.Parameter( torch.empty( self.hc_mult, dtype=torch.float32, ), requires_grad=False, ) self.hc_head_scale = nn.Parameter( torch.empty(1, dtype=torch.float32), requires_grad=False, ) # Pre-hc_head residual stream buffer for the MTP draft. Stable # address (outside the cudagraph pool) so the copy_ in forward() # refreshes it correctly across captured shapes. self._mtp_hidden_buffer = torch.empty( vllm_config.scheduler_config.max_num_batched_tokens, self.hc_dim, dtype=vllm_config.model_config.dtype, device=self.device, ) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: hidden_states = self.embed_input_ids(input_ids) hidden_states = hidden_states.unsqueeze(-2).repeat(1, self.hc_mult, 1) for layer in islice(self.layers, self.start_layer, self.end_layer): hidden_states = layer( hidden_states, positions, input_ids, ) # Stash pre-hc_head residual for the MTP draft (captured copy_). num_tokens = hidden_states.shape[0] self._mtp_hidden_buffer[:num_tokens].copy_(hidden_states.flatten(1)) hidden_states = hc_head( hidden_states, self.hc_head_fn, self.hc_head_scale, self.hc_head_base, self.rms_norm_eps, self.hc_eps, ) hidden_states = self.norm(hidden_states) return hidden_states def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("gate_up_proj", "w1", 0), ("gate_up_proj", "w3", 1), ("attn.fused_wqa_wkv", "attn.wq_a", 0), ("attn.fused_wqa_wkv", "attn.wkv", 1), ("compressor.fused_wkv_wgate", "compressor.wkv", 0), ("compressor.fused_wkv_wgate", "compressor.wgate", 1), ] params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() # TP for attention tp_size = get_tensor_model_parallel_world_size() tp_rank = get_tensor_model_parallel_rank() n_head = self.config.num_attention_heads n_local_head = n_head // tp_size head_rank_start = n_local_head * tp_rank head_rank_end = n_local_head * (tp_rank + 1) # Pre-compute expert mapping ONCE. expert_mapping = self.get_expert_mapping() for name, loaded_weight in weights: for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if ".experts." in name: continue if weight_name not in name: continue name = name.replace(weight_name, param_name) param = params_dict[name] weight_loader = param.weight_loader weight_loader(param, loaded_weight, shard_id) loaded_params.add(name) break else: if ".experts." in name: # E8M0 scales are stored as float8_e8m0fnu in # checkpoints but the MoE param is uint8. copy_() # would do a numeric conversion (e.g. 2^-7 → 0), # destroying the raw exponent bytes. if ( "weight_scale" in name and loaded_weight.dtype == torch.float8_e8m0fnu ): loaded_weight = loaded_weight.view(torch.uint8) for mapping in expert_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue name_mapped = name.replace(weight_name, param_name) param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. weight_loader = typing.cast( Callable[..., bool], param.weight_loader ) success = weight_loader( param, loaded_weight, name_mapped, shard_id=shard_id, expert_id=expert_id, return_success=True, ) if success: name = name_mapped break loaded_params.add(name_mapped) continue elif "attn_sink" in name: narrow_weight = loaded_weight[head_rank_start:head_rank_end] n = narrow_weight.shape[0] params_dict[name][:n].copy_(narrow_weight) loaded_params.add(name) continue else: param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) continue return loaded_params def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: first_layer = next(iter(islice(self.layers, self.start_layer, self.end_layer))) if first_layer.ffn.use_mega_moe: return make_deepseek_v4_expert_params_mapping(self.config.n_routed_experts) # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return FusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="w1", ckpt_down_proj_name="w2", ckpt_up_proj_name="w3", num_experts=self.config.n_routed_experts, ) def finalize_mega_moe_weights(self) -> None: for layer in islice(self.layers, self.start_layer, self.end_layer): layer.ffn.finalize_mega_moe_weights() @torch.compile(backend=current_platform.simple_compile_backend) def hc_head( hidden_states: torch.Tensor, hc_fn: torch.Tensor, hc_scale: torch.Tensor, hc_base: torch.Tensor, rms_norm_eps: float, hc_eps: float, ) -> torch.Tensor: x = hidden_states shape, dtype = x.size(), x.dtype x = x.flatten(1).float() rsqrt = torch.rsqrt(x.square().mean(-1, keepdim=True) + rms_norm_eps) mixes = F.linear(x, hc_fn) * rsqrt pre = torch.sigmoid(mixes * hc_scale + hc_base) + hc_eps y = torch.sum(pre.unsqueeze(-1) * x.view(shape), dim=1) return y.to(dtype) def _make_deepseek_v4_weights_mapper(expert_dtype: str) -> WeightsMapper: if expert_dtype == "fp4": # MXFP4 experts use Mxfp4MoEMethod, which registers scales as # ``w{1,2,3}_weight_scale`` (no _inv suffix). FP8 linear and # shared experts use Fp8LinearMethod's block scales, which # register as ``weight_scale_inv``. scale_regex = { re.compile(r"(\.experts\.\d+\.w[123])\.scale$"): r"\1.weight_scale", re.compile(r"\.scale$"): ".weight_scale_inv", } else: # FP8 experts use Fp8MoEMethod (block_quant=True), which registers # scales as ``w{13,2}_weight_scale_inv``. Map all ``.scale`` keys # there. scale_regex = { re.compile(r"\.scale$"): ".weight_scale_inv", } return WeightsMapper( orig_to_new_prefix={ "layers.": "model.layers.", "embed.": "model.embed.", "norm.": "model.norm.", "hc_head": "model.hc_head", "mtp.": "model.mtp.", }, orig_to_new_regex=scale_regex, orig_to_new_suffix={ "head.weight": "lm_head.weight", "embed.weight": "embed_tokens.weight", ".ffn.gate.bias": ".ffn.gate.e_score_correction_bias", }, orig_to_new_substr={ ".attn.compressor.": ".attn.mla_attn.compressor.", ".shared_experts.w2": ".shared_experts.down_proj", }, ) class DeepseekV4ForCausalLM(nn.Module): model_cls = DeepseekV4Model # Default mapper assumes the original FP4-expert checkpoint layout. # Overridden per-instance in __init__ when expert_dtype != "fp4". hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper("fp4") def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_config self.config = config expert_dtype = getattr(config, "expert_dtype", "fp4") if expert_dtype != "fp4": self.hf_to_vllm_mapper = _make_deepseek_v4_weights_mapper(expert_dtype) self.model = self.model_cls( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, prefix=maybe_prefix(prefix, "lm_head"), ) self.logits_processor = LogitsProcessor(config.vocab_size) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states def get_mtp_target_hidden_states(self) -> torch.Tensor | None: """Pre-hc_head residual stream buffer (max_num_batched_tokens, hc_mult * hidden_size) for the MTP draft model. Populated by forward(); valid after each target step.""" return getattr(self.model, "_mtp_hidden_buffer", None) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self, skip_substrs=["mtp."]) loaded_params = loader.load_weights(weights, mapper=self.hf_to_vllm_mapper) self.model.finalize_mega_moe_weights() return loaded_params def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping()