# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project # Copyright 2024 The Qwen team. # Copyright 2023 The vLLM team. # Copyright 2022 EleutherAI and the HuggingFace Inc. team. All rights reserved. # # This code is based on EleutherAI's GPT-NeoX library and the GPT-NeoX # and OPT implementations in this library. It has been modified from its # original forms to accommodate minor architectural differences compared # to GPT-NeoX and OPT used by the Meta AI team that trained the model. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Inference-only Qwen3MoE model compatible with HuggingFace weights.""" import typing from collections.abc import Callable, Iterable from itertools import islice from typing import Any import os import re import torch import torch.nn.functional as F from torch import nn from vllm.attention.layer import Attention from vllm.compilation.decorators import support_torch_compile from vllm.config import CacheConfig, VllmConfig, get_current_vllm_config from vllm.distributed import ( get_ep_group, get_pp_group, get_tensor_model_parallel_world_size, tensor_model_parallel_all_gather, ) from vllm.logger import init_logger from vllm.model_executor.layers.activation import SiluAndMul from vllm.model_executor.layers.fused_moe import SharedFusedMoE from vllm.model_executor.layers.layernorm import RMSNorm from vllm.model_executor.layers.linear import ( MergedColumnParallelLinear, QKVParallelLinear, ReplicatedLinear, RowParallelLinear, ) from vllm.model_executor.layers.logits_processor import LogitsProcessor from vllm.model_executor.layers.quantization import QuantizationConfig from vllm.model_executor.layers.rotary_embedding import get_rope from vllm.model_executor.layers.vocab_parallel_embedding import ( ParallelLMHead, VocabParallelEmbedding, ) from vllm.model_executor.model_loader.weight_utils import ( default_weight_loader, maybe_remap_kv_scale_name, ) from vllm.model_executor.models.utils import sequence_parallel_chunk from vllm.sequence import IntermediateTensors from .interfaces import MixtureOfExperts, SupportsEagle3, SupportsLoRA, SupportsPP from .utils import ( AutoWeightsLoader, PPMissingLayer, extract_layer_index, is_pp_missing_parameter, make_empty_intermediate_tensors_factory, make_layers, maybe_prefix, ) from vllm import envs from vllm import _custom_ops as ops from vllm.model_executor.utils import pad_weight, gemm_bank_conf from vllm.utils import W8a8GetCacheJSON from vllm.utils.torch_utils import direct_register_custom_op logger = init_logger(__name__) class Qwen3MoeMLP(nn.Module): def __init__( self, hidden_size: int, intermediate_size: int, hidden_act: str, quant_config: QuantizationConfig | None = None, reduce_results: bool = True, expert_gate: torch.nn.Linear | None = None, prefix: str = "", ) -> None: super().__init__() self.gate_up_proj = MergedColumnParallelLinear( hidden_size, [intermediate_size] * 2, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate_up_proj", ) self.down_proj = RowParallelLinear( intermediate_size, hidden_size, bias=False, quant_config=quant_config, reduce_results=reduce_results, prefix=f"{prefix}.down_proj", ) if hidden_act != "silu": raise ValueError( f"Unsupported activation: {hidden_act}. Only silu is supported for now." ) self.act_fn = SiluAndMul() self.expert_gate = expert_gate def forward(self, x): gate_up, _ = self.gate_up_proj(x) out = self.act_fn(gate_up) out, _ = self.down_proj(out) if self.expert_gate is not None: out = F.sigmoid(self.expert_gate(x)[0]) * out return out class Qwen3MoeSparseMoeBlock(nn.Module): def __init__( self, vllm_config: VllmConfig, prefix: str = "", ): super().__init__() config = vllm_config.model_config.hf_text_config parallel_config = vllm_config.parallel_config quant_config = vllm_config.quant_config self.tp_size = get_tensor_model_parallel_world_size() self.ep_group = get_ep_group().device_group self.ep_rank = get_ep_group().rank_in_group self.ep_size = self.ep_group.size() self.n_routed_experts = config.num_experts self.is_sequence_parallel = parallel_config.use_sequence_parallel_moe if self.tp_size > config.num_experts: raise ValueError( f"Tensor parallel size {self.tp_size} is greater than " f"the number of experts {config.num_experts}." ) # Load balancing settings. vllm_config = get_current_vllm_config() eplb_config = vllm_config.parallel_config.eplb_config self.enable_eplb = parallel_config.enable_eplb self.n_logical_experts = self.n_routed_experts self.n_redundant_experts = eplb_config.num_redundant_experts self.n_physical_experts = self.n_logical_experts + self.n_redundant_experts self.n_local_physical_experts = self.n_physical_experts // self.ep_size self.physical_expert_start = self.ep_rank * self.n_local_physical_experts self.physical_expert_end = ( self.physical_expert_start + self.n_local_physical_experts ) self.gate = ReplicatedLinear( config.hidden_size, config.num_experts, bias=False, quant_config=quant_config, prefix=f"{prefix}.gate", ) shared_expert_intermediate_size = getattr( config, "shared_expert_intermediate_size", 0 ) if shared_expert_intermediate_size > 0: self.shared_expert_gate = ReplicatedLinear( config.hidden_size, 1, bias=False, quant_config=None, prefix=f"{prefix}.shared_expert_gate", ) self.shared_expert = Qwen3MoeMLP( hidden_size=config.hidden_size, intermediate_size=shared_expert_intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, reduce_results=False, expert_gate=self.shared_expert_gate, prefix=f"{prefix}.shared_expert", ) else: self.shared_expert_gate = None self.shared_expert = None self.experts = SharedFusedMoE( shared_experts=self.shared_expert, gate=self.gate, num_experts=self.n_routed_experts, top_k=config.num_experts_per_tok, hidden_size=config.hidden_size, intermediate_size=config.moe_intermediate_size, reduce_results=False, renormalize=config.norm_topk_prob, quant_config=quant_config, prefix=f"{prefix}.experts", enable_eplb=self.enable_eplb, num_redundant_experts=self.n_redundant_experts, is_sequence_parallel=self.is_sequence_parallel, ) def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: assert hidden_states.dim() <= 2, ( "Qwen3MoeSparseMoeBlock only supports 1D or 2D inputs" ) is_input_1d = hidden_states.dim() == 1 num_tokens, hidden_dim = hidden_states.shape hidden_states = hidden_states.view(-1, hidden_dim) if self.is_sequence_parallel: hidden_states = sequence_parallel_chunk(hidden_states) # router_logits: (num_tokens, n_experts) router_logits, _ = self.gate(hidden_states) shared_out, fused_out = self.experts( hidden_states=hidden_states, router_logits=router_logits ) final_hidden_states = ( shared_out + fused_out if shared_out is not None else fused_out ) if self.is_sequence_parallel: final_hidden_states = tensor_model_parallel_all_gather( final_hidden_states, 0 ) final_hidden_states = final_hidden_states[:num_tokens] elif self.tp_size > 1: final_hidden_states = self.experts.maybe_all_reduce_tensor_model_parallel( # noqa E501 final_hidden_states ) # return to 1d if input is 1d return final_hidden_states.squeeze(0) if is_input_1d else final_hidden_states class Qwen3MoeAttention(nn.Module): def __init__( self, hidden_size: int, num_heads: int, num_kv_heads: int, rope_parameters: dict[str, Any], max_position_embeddings: int = 8192, head_dim: int | None = None, rms_norm_eps: float = 1e-06, qkv_bias: bool = False, cache_config: CacheConfig | None = None, quant_config: QuantizationConfig | None = None, prefix: str = "", dual_chunk_attention_config: dict[str, Any] | None = None, ) -> None: super().__init__() self.hidden_size = hidden_size tp_size = get_tensor_model_parallel_world_size() self.total_num_heads = num_heads assert self.total_num_heads % tp_size == 0 self.num_heads = self.total_num_heads // tp_size self.total_num_kv_heads = num_kv_heads if self.total_num_kv_heads >= tp_size: # Number of KV heads is greater than TP size, so we partition # the KV heads across multiple tensor parallel GPUs. assert self.total_num_kv_heads % tp_size == 0 else: # Number of KV heads is less than TP size, so we replicate # the KV heads across multiple tensor parallel GPUs. assert tp_size % self.total_num_kv_heads == 0 self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size) self.head_dim = head_dim or (hidden_size // self.total_num_heads) self.q_size = self.num_heads * self.head_dim self.kv_size = self.num_kv_heads * self.head_dim self.scaling = self.head_dim**-0.5 self.max_position_embeddings = max_position_embeddings self.dual_chunk_attention_config = dual_chunk_attention_config self.qkv_proj = QKVParallelLinear( hidden_size, self.head_dim, self.total_num_heads, self.total_num_kv_heads, bias=qkv_bias, quant_config=quant_config, prefix=f"{prefix}.qkv_proj", ) self.o_proj = RowParallelLinear( self.total_num_heads * self.head_dim, hidden_size, bias=False, quant_config=quant_config, prefix=f"{prefix}.o_proj", ) self.rotary_emb = get_rope( self.head_dim, max_position=max_position_embeddings, rope_parameters=rope_parameters, dual_chunk_attention_config=dual_chunk_attention_config, ) self.attn = Attention( self.num_heads, self.head_dim, self.scaling, num_kv_heads=self.num_kv_heads, cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.attn", **{ "layer_idx": extract_layer_index(prefix), "dual_chunk_attention_config": dual_chunk_attention_config, } if dual_chunk_attention_config else {}, ) self.q_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) self.k_norm = RMSNorm(self.head_dim, eps=rms_norm_eps) def rms_rotary_embedding_fuse( positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None, head_size: int, cos_sin_cache: torch.Tensor, is_neox_style: bool, q_weight: torch.Tensor, k_weight: torch.Tensor, epsilon: float, q_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None, ) -> None: from lightop import rms_rotary_embedding_fuse as fused_kernel fused_kernel( positions, query, key, head_size, cos_sin_cache, is_neox_style, q_weight, k_weight, q_bias, k_bias, epsilon, ) def rms_rotary_embedding_fuse_fake( # q_out:torch.Tensor, # k_out:torch.Tensor, positions: torch.Tensor, query: torch.Tensor, key: torch.Tensor | None, head_size: int, cos_sin_cache: torch.Tensor, is_neox_style: bool, q_weight: torch.Tensor, k_weight: torch.Tensor, epsilon: float, q_bias: torch.Tensor | None = None, k_bias: torch.Tensor | None = None, ) -> None: # Fake impl intentionally left as no-op for graph tracing modes. pass direct_register_custom_op( op_name="rms_rotary_embedding_fuse", op_func=rms_rotary_embedding_fuse, mutates_args=["query", "key"], fake_impl=rms_rotary_embedding_fuse_fake, ) def rms_mrope_fuse( query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, head_size: int, rotary_dim: int, mrope_section_t: int, mrope_section_h: int, mrope_section_w: int, mrope_interleaved: bool, q_weight: torch.Tensor, k_weight: torch.Tensor, epsilon: float, q_residual: torch.Tensor | None = None, k_residual: torch.Tensor | None = None, ) -> None: from lightop import op as lightop_ops lightop_ops.fuse_rms_mrope_cuda( query, key, cos, sin, [mrope_section_t, mrope_section_h, mrope_section_w], head_size, rotary_dim, mrope_interleaved, q_weight, k_weight, epsilon, q_residual, k_residual, ) def rms_mrope_fuse_fake( query: torch.Tensor, key: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, head_size: int, rotary_dim: int, mrope_section_t: int, mrope_section_h: int, mrope_section_w: int, mrope_interleaved: bool, q_weight: torch.Tensor, k_weight: torch.Tensor, epsilon: float, q_residual: torch.Tensor | None = None, k_residual: torch.Tensor | None = None, ) -> None: # Fake impl intentionally left as no-op for graph tracing modes. pass direct_register_custom_op( op_name="rms_mrope_fuse", op_func=rms_mrope_fuse, mutates_args=["query", "key"], fake_impl=rms_mrope_fuse_fake, ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, ) -> torch.Tensor: qkv, _ = self.qkv_proj(hidden_states) q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1) # Add qk-norm if envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 1: # Fused RMSNorm + RoPE path through custom op. cos_sin_cache = self.rotary_emb.cos_sin_cache if (cos_sin_cache.device != q.device or cos_sin_cache.dtype != q.dtype): cos_sin_cache = cos_sin_cache.to(q.device, dtype=q.dtype, non_blocking=True) # Persist the converted cache so we don't re-copy/re-allocate # on every forward when the original buffer starts on CPU. self.rotary_emb.cos_sin_cache = cos_sin_cache # # q, k 使用 continuous q = q.contiguous() k = k.contiguous() torch.ops.vllm.rms_rotary_embedding_fuse( positions, q, k, self.head_dim, cos_sin_cache, self.rotary_emb.is_neox_style, self.q_norm.weight, self.k_norm.weight, self.q_norm.variance_epsilon, None, None, ) elif envs.VLLM_USE_FUSED_RMS_ROPE and positions.ndim == 2 and getattr( self.rotary_emb, "mrope_section", None) is not None: # Fused RMSNorm + M-RoPE path through custom op. cos_sin_cache = self.rotary_emb.cos_sin_cache if (cos_sin_cache.device != q.device or cos_sin_cache.dtype != q.dtype): cos_sin_cache = cos_sin_cache.to(q.device, dtype=q.dtype, non_blocking=True) self.rotary_emb.cos_sin_cache = cos_sin_cache cos_sin = cos_sin_cache[positions] cos, sin = cos_sin.chunk(2, dim=-1) q = q.contiguous() k = k.contiguous() cos = cos.contiguous() sin = sin.contiguous() mrope_section = self.rotary_emb.mrope_section assert mrope_section is not None and len(mrope_section) == 3 torch.ops.vllm.rms_mrope_fuse( q, k, cos, sin, self.head_dim, self.rotary_emb.rotary_dim, mrope_section[0], mrope_section[1], mrope_section[2], self.rotary_emb.mrope_interleaved, self.q_norm.weight, self.k_norm.weight, self.q_norm.variance_epsilon, None, None, ) else: q_by_head = q.view(*q.shape[:-1], q.shape[-1] // self.head_dim, self.head_dim) if envs.VLLM_USE_APEX_RN: q_by_head = self.q_norm.forward_apex(q_by_head) else: q_by_head = self.q_norm.forward_cuda(q_by_head) q = q_by_head.view(q.shape) k_by_head = k.view(*k.shape[:-1], k.shape[-1] // self.head_dim, self.head_dim) if envs.VLLM_USE_APEX_RN: k_by_head = self.k_norm.forward_apex(k_by_head) else: k_by_head = self.k_norm.forward_cuda(k_by_head) k = k_by_head.view(k.shape) q, k = self.rotary_emb(positions, q, k) attn_output = self.attn(q, k, v) output, _ = self.o_proj(attn_output) return output class Qwen3MoeDecoderLayer(nn.Module): def __init__(self, vllm_config: VllmConfig, prefix: str = "") -> None: super().__init__() config = vllm_config.model_config.hf_text_config cache_config = vllm_config.cache_config quant_config = vllm_config.quant_config self.hidden_size = config.hidden_size max_position_embeddings = getattr(config, "max_position_embeddings", 8192) dual_chunk_attention_config = getattr( config, "dual_chunk_attention_config", None ) self.self_attn = Qwen3MoeAttention( hidden_size=self.hidden_size, num_heads=config.num_attention_heads, num_kv_heads=config.num_key_value_heads, rope_parameters=config.rope_parameters, max_position_embeddings=max_position_embeddings, rms_norm_eps=config.rms_norm_eps, qkv_bias=getattr(config, "attention_bias", False), head_dim=getattr(config, "head_dim", None), cache_config=cache_config, quant_config=quant_config, prefix=f"{prefix}.self_attn", dual_chunk_attention_config=dual_chunk_attention_config, ) # `mlp_only_layers` in the config. layer_idx = extract_layer_index(prefix) mlp_only_layers = ( [] if not hasattr(config, "mlp_only_layers") else config.mlp_only_layers ) if (layer_idx not in mlp_only_layers) and ( config.num_experts > 0 and (layer_idx + 1) % config.decoder_sparse_step == 0 ): self.mlp = Qwen3MoeSparseMoeBlock( vllm_config=vllm_config, prefix=f"{prefix}.mlp" ) else: self.mlp = Qwen3MoeMLP( hidden_size=config.hidden_size, intermediate_size=config.intermediate_size, hidden_act=config.hidden_act, quant_config=quant_config, prefix=f"{prefix}.mlp", ) self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.post_attention_layernorm = RMSNorm( config.hidden_size, eps=config.rms_norm_eps ) def forward( self, positions: torch.Tensor, hidden_states: torch.Tensor, residual: torch.Tensor | None, ) -> tuple[torch.Tensor, torch.Tensor]: # Self Attention if residual is None: residual = hidden_states hidden_states = self.input_layernorm(hidden_states) else: hidden_states, residual = self.input_layernorm(hidden_states, residual) hidden_states = self.self_attn( positions=positions, hidden_states=hidden_states, ) # Fully Connected hidden_states, residual = self.post_attention_layernorm(hidden_states, residual) hidden_states = self.mlp(hidden_states) return hidden_states, residual @support_torch_compile( dynamic_arg_dims={ "input_ids": 0, # positions is of shape (3, seq_len) if mrope is enabled, # otherwise (seq_len, ). "positions": -1, "intermediate_tensors": 0, "inputs_embeds": 0, }) class Qwen3MoeModel(nn.Module): def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_text_config quant_config = vllm_config.quant_config parallel_config = vllm_config.parallel_config eplb_config = parallel_config.eplb_config self.num_redundant_experts = eplb_config.num_redundant_experts self.padding_idx = config.pad_token_id self.vocab_size = config.vocab_size self.config = config self.quant_config = quant_config self.quant_method = None if quant_config is not None: self.quant_method=quant_config.get_name() self.quant_config=quant_config # if self.config.quantization_config["bits"] == 4: os.environ['LLAMA_NN'] = '0' os.environ['LM_NN'] = '0' self.embed_tokens = VocabParallelEmbedding( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=f"{prefix}.embed_tokens", ) self.start_layer, self.end_layer, self.layers = make_layers( config.num_hidden_layers, lambda prefix: Qwen3MoeDecoderLayer(vllm_config=vllm_config, prefix=prefix), prefix=f"{prefix}.layers", ) self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps) self.make_empty_intermediate_tensors = make_empty_intermediate_tensors_factory( ["hidden_states", "residual"], config.hidden_size ) # Track layers for auxiliary hidden state outputs (EAGLE3) self.aux_hidden_state_layers: tuple[int, ...] = () self.tritonsingleton= W8a8GetCacheJSON() self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_awq_pad = os.environ.get('AWQ_PAD') == '1' self.w8a8_strategy = envs.VLLM_W8A8_BACKEND def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.embed_tokens(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors | tuple[torch.Tensor, list[torch.Tensor]]: if get_pp_group().is_first_rank: if inputs_embeds is not None: hidden_states = inputs_embeds else: hidden_states = self.embed_input_ids(input_ids) residual = None else: assert intermediate_tensors is not None hidden_states = intermediate_tensors["hidden_states"] residual = intermediate_tensors["residual"] aux_hidden_states = [] for layer_idx, layer in enumerate( islice(self.layers, self.start_layer, self.end_layer), start=self.start_layer, ): # Collect auxiliary hidden states if specified if layer_idx in self.aux_hidden_state_layers: aux_hidden_state = ( hidden_states + residual if residual is not None else hidden_states ) aux_hidden_states.append(aux_hidden_state) hidden_states, residual = layer(positions, hidden_states, residual) if not get_pp_group().is_last_rank: return IntermediateTensors( {"hidden_states": hidden_states, "residual": residual} ) hidden_states, _ = self.norm(hidden_states, residual) # Return auxiliary hidden states if collected if len(aux_hidden_states) > 0: return hidden_states, aux_hidden_states return hidden_states def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: # Params for weights, fp8 weight scales, fp8 activation scales # (param_name, weight_name, expert_id, shard_id) return SharedFusedMoE.make_expert_params_mapping( self, ckpt_gate_proj_name="gate_proj", ckpt_down_proj_name="down_proj", ckpt_up_proj_name="up_proj", num_experts=self.config.num_experts, num_redundant_experts=self.num_redundant_experts, ) def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: stacked_params_mapping = [ # (param_name, shard_name, shard_id) ("qkv_proj", "q_proj", "q"), ("qkv_proj", "k_proj", "k"), ("qkv_proj", "v_proj", "v"), ("gate_up_proj", "gate_proj", 0), ("gate_up_proj", "up_proj", 1), ] # Skip loading extra parameters for GPTQ/modelopt models. ignore_suffixes = ( ".bias", "_bias", ".k_scale", "_k_scale", ".v_scale", "_v_scale", ".weight_scale", "_weight_scale", ".input_scale", "_input_scale", ) params_dict = dict(self.named_parameters()) loaded_params: set[str] = set() expert_params_mapping = self.get_expert_mapping() for name, loaded_weight in weights: if self.use_llama_nn: current_count = loaded_weight.current_count total_count = loaded_weight.total_count if self.quant_config is not None and ( scale_name := self.quant_config.get_cache_scale(name) ): # Loading kv cache quantization scales param = params_dict[scale_name] weight_loader = getattr(param, "weight_loader", default_weight_loader) assert loaded_weight.numel() == 1, ( f"KV scale numel {loaded_weight.numel()} != 1" ) loaded_weight = loaded_weight.squeeze() weight_loader(param, loaded_weight) loaded_params.add(scale_name) continue for param_name, weight_name, shard_id in stacked_params_mapping: # Skip non-stacked layers and experts (experts handled below). if weight_name not in name: continue # We have mlp.experts[0].gate_proj in the checkpoint. # Since we handle the experts below in expert_params_mapping, # we need to skip here BEFORE we update the name, otherwise # name will be updated to mlp.experts[0].gate_up_proj, which # will then be updated below in expert_params_mapping # for mlp.experts[0].gate_gate_up_proj, which breaks load. if "mlp.experts" in name: continue name = name.replace(weight_name, param_name) # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue if name.endswith("scale"): # Remapping the name of FP8 kv-scale. name = maybe_remap_kv_scale_name(name, params_dict) if name is None: continue if name not in params_dict: continue param = params_dict[name] weight_loader = getattr(param, "weight_loader", default_weight_loader) if weight_loader == default_weight_loader: weight_loader(param, loaded_weight) else: weight_loader(param, loaded_weight, shard_id) break else: is_expert_weight = False for mapping in expert_params_mapping: param_name, weight_name, expert_id, shard_id = mapping if weight_name not in name: continue # Anyway, this is an expert weight and should not be # attempted to load as other weights later is_expert_weight = True # Do not modify `name` since the loop may continue here # Instead, create a new variable name_mapped = name.replace(weight_name, param_name) if is_pp_missing_parameter(name_mapped, self): continue # Skip loading extra parameters for GPTQ/modelopt models. if ( name_mapped.endswith(ignore_suffixes) and name_mapped not in params_dict ): continue param = params_dict[name_mapped] # We should ask the weight loader to return success or not # here since otherwise we may skip experts with other # available replicas. weight_loader = typing.cast( Callable[..., bool], param.weight_loader ) success = weight_loader( param, loaded_weight, name_mapped, shard_id=shard_id, expert_id=expert_id, return_success=True, ) if success: name = name_mapped break else: if is_expert_weight: # We've checked that this is an expert weight # However it's not mapped locally to this rank # So we simply skip it continue # Skip loading extra parameters for GPTQ/modelopt models. if name.endswith(ignore_suffixes) and name not in params_dict: continue # Skip layers on other devices. if is_pp_missing_parameter(name, self): continue # Remapping the name of FP8 kv-scale. if name.endswith("kv_scale"): remapped_kv_scale_name = name.replace( ".kv_scale", ".attn.kv_scale" ) if remapped_kv_scale_name not in params_dict: logger.warning_once( "Found kv scale in the checkpoint (e.g. %s), but not found the expected name in the model (e.g. %s). kv-scale is not loaded.", # noqa: E501 name, remapped_kv_scale_name, ) continue else: name = remapped_kv_scale_name param = params_dict[name] weight_loader = getattr( param, "weight_loader", default_weight_loader ) weight_loader(param, loaded_weight) loaded_params.add(name) if self.use_llama_nn and self.quant_method is None and current_count==total_count: lay_key_words = [ "gate_up_proj.weight", "down_proj.weight", "mlp.gate.weight", "self_attn.qkv_proj.weight", "self_attn.o_proj.weight", "lm_head.weight", ] combined_words = "|".join(lay_key_words) # lay_qkv_words = ["self_attn.qkv_proj.weight"] # qkv_words = "|".join(lay_qkv_words) # lay_qkv_bias_words = ["self_attn.qkv_proj.bias"] # qkv_bias_words = "|".join(lay_qkv_bias_words) # for layername in loaded_params: for layername in params_dict.keys(): weight = params_dict[layername] os.environ['LM_NN'] = '0' # if self.use_fa_pad and (re.findall(qkv_bias_words, layername)): # weight.data = pad_weight(weight.data, 32) matches = re.findall(combined_words, layername) if matches: # if self.use_gemm_pad and gemm_bank_conf(weight.data.shape[0]): # weight.data = pad_weight(weight.data, 32) # if self.use_fa_pad and (re.findall(qkv_words, layername)): # if not gemm_bank_conf(weight.data.shape[0]): # weight.data = pad_weight(weight.data, 32) _weight = torch.zeros_like(weight.data) ori_shape =_weight.shape ops.trans_w16_gemm(_weight, weight.data, _weight.shape[0], _weight.shape[1]) weight.data.copy_(_weight) weight.data=weight.data.reshape(ori_shape[1],-1) return loaded_params class Qwen3MoeForCausalLM( nn.Module, SupportsPP, SupportsLoRA, SupportsEagle3, MixtureOfExperts ): packed_modules_mapping = { "qkv_proj": [ "q_proj", "k_proj", "v_proj", ] } fall_back_to_pt_during_load = False def __init__(self, *, vllm_config: VllmConfig, prefix: str = ""): super().__init__() config = vllm_config.model_config.hf_text_config quant_config = vllm_config.quant_config self.config = config self.quant_config = quant_config # Only perform the following mapping when Qwen3MoeMLP exists if getattr(config, "mlp_only_layers", []): self.packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"] self.model = Qwen3MoeModel( vllm_config=vllm_config, prefix=maybe_prefix(prefix, "model") ) self.lm_head = ParallelLMHead( config.vocab_size, config.hidden_size, quant_config=quant_config, prefix=maybe_prefix(prefix, "lm_head"), ) if self.config.tie_word_embeddings: self.lm_head.weight = self.model.embed_tokens.weight self.logits_processor = LogitsProcessor(config.vocab_size) self.make_empty_intermediate_tensors = ( self.model.make_empty_intermediate_tensors ) # Set MoE hyperparameters self.expert_weights = [] self.moe_layers = [] example_layer = None for layer in self.model.layers: if isinstance(layer, PPMissingLayer): continue assert isinstance(layer, Qwen3MoeDecoderLayer) if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): example_layer = layer.mlp self.moe_layers.append(layer.mlp.experts) if example_layer is None: raise RuntimeError("No Qwen3MoE layer found in the model.layers.") self.num_moe_layers = len(self.moe_layers) self.num_expert_groups = 1 self.num_shared_experts = 0 self.num_logical_experts = example_layer.n_logical_experts self.num_physical_experts = example_layer.n_physical_experts self.num_local_physical_experts = example_layer.n_local_physical_experts self.num_routed_experts = example_layer.n_routed_experts self.num_redundant_experts = example_layer.n_redundant_experts def update_physical_experts_metadata( self, num_physical_experts: int, num_local_physical_experts: int, ) -> None: assert self.num_local_physical_experts == num_local_physical_experts self.num_physical_experts = num_physical_experts self.num_local_physical_experts = num_local_physical_experts self.num_redundant_experts = num_physical_experts - self.num_logical_experts for layer in self.model.layers: if isinstance(layer.mlp, Qwen3MoeSparseMoeBlock): moe = layer.mlp moe.n_local_physical_experts = num_local_physical_experts moe.n_physical_experts = num_physical_experts moe.n_redundant_experts = self.num_redundant_experts moe.experts.update_expert_map() def set_aux_hidden_state_layers(self, layers: tuple[int, ...]) -> None: self.model.aux_hidden_state_layers = layers def get_eagle3_aux_hidden_state_layers(self) -> tuple[int, ...]: num_layers = len(self.model.layers) return (2, num_layers // 2, num_layers - 3) def embed_input_ids(self, input_ids: torch.Tensor) -> torch.Tensor: return self.model.embed_input_ids(input_ids) def forward( self, input_ids: torch.Tensor, positions: torch.Tensor, intermediate_tensors: IntermediateTensors | None = None, inputs_embeds: torch.Tensor | None = None, ) -> torch.Tensor | IntermediateTensors: hidden_states = self.model( input_ids, positions, intermediate_tensors, inputs_embeds ) return hidden_states def compute_logits( self, hidden_states: torch.Tensor, ) -> torch.Tensor | None: logits = self.logits_processor(self.lm_head, hidden_states) return logits def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> set[str]: loader = AutoWeightsLoader(self) return loader.load_weights(weights) def get_expert_mapping(self) -> list[tuple[str, str, int, str]]: return self.model.get_expert_mapping()