# SPDX-License-Identifier: Apache-2.0 # SPDX-FileCopyrightText: Copyright contributors to the vLLM project import torch from vllm.config import get_current_vllm_config from vllm.logger import init_logger from vllm.model_executor.layers.attention import Attention from vllm.model_executor.layers.fused_moe import ( FusedMoE, FusedMoEConfig, FusedMoEMethodBase, ) from vllm.model_executor.layers.fused_moe import modular_kernel as mk from vllm.model_executor.layers.fused_moe.config import ( FusedMoEParallelConfig, FusedMoEQuantConfig, ) from vllm.model_executor.layers.fused_moe.oracle.mxfp4 import ( TRITON_BACKENDS, Mxfp4MoeBackend, convert_to_mxfp4_moe_kernel_format, make_mxfp4_moe_kernel, make_mxfp4_moe_quant_config, mxfp4_round_up_hidden_size_and_intermediate_size, select_mxfp4_moe_backend, ) from vllm.model_executor.layers.linear import LinearBase, UnquantizedLinearMethod from vllm.model_executor.layers.quantization import QuantizationMethods from vllm.model_executor.layers.quantization.base_config import ( QuantizationConfig, QuantizeMethodBase, ) from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped from vllm.model_executor.utils import replace_parameter, set_weight_attrs logger = init_logger(__name__) class Mxfp4Config(QuantizationConfig): def __init__(self, ignored_layers: list[str] | None = None): super().__init__() self.ignored_layers = ignored_layers @classmethod def from_config(cls, config): return cls() @classmethod def get_min_capability(cls) -> int: return 80 @classmethod def get_name(cls) -> QuantizationMethods: return "mxfp4" @classmethod def get_supported_act_dtypes(cls) -> list[torch.dtype]: return [torch.bfloat16] @classmethod def get_config_filenames(cls) -> list[str]: return [] def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> "QuantizeMethodBase | None": if isinstance(layer, LinearBase): if self.ignored_layers and is_layer_skipped( prefix=prefix, ignored_layers=self.ignored_layers, fused_mapping=self.packed_modules_mapping, ): return UnquantizedLinearMethod() logger.debug_once( "MXFP4 linear layer is not implemented - falling back to " "UnquantizedLinearMethod.", scope="local", ) return UnquantizedLinearMethod() elif isinstance(layer, FusedMoE): return Mxfp4MoEMethod(layer.moe_config) elif isinstance(layer, Attention): logger.debug_once( "MXFP4 attention layer is not implemented. " "Skipping quantization for this layer.", scope="local", ) return None def is_mxfp4_quant(self, prefix: str, layer: torch.nn.Module) -> bool: """MXFP4 config always uses MXFP4 quantization.""" return True class Mxfp4MoEMethod(FusedMoEMethodBase): """MXFP4 MoE quantization method.""" def __init__(self, moe: FusedMoEConfig): super().__init__(moe) self.weight_dtype = "mxfp4" self.mxfp4_backend, self.experts_cls = select_mxfp4_moe_backend(moe) self.max_capture_size = ( get_current_vllm_config().compilation_config.max_cudagraph_capture_size ) self._cache_permute_indices: dict[torch.Size, torch.Tensor] = {} self.moe_kernel: mk.FusedMoEKernel | None = None # Used for triton kernel precision configs self.w13_precision_config = None self.w2_precision_config = None @property def skip_forward_padding(self) -> bool: # SM100_FI_MXFP4_MXFP8_TRTLLM supports padding with mxfp8 quant # so can skip the padding in the forward before applying the moe method return self.mxfp4_backend == Mxfp4MoeBackend.FLASHINFER_TRTLLM_MXFP4_MXFP8 def maybe_roundup_sizes( self, hidden_size: int, intermediate_size_per_partition: int, act_dtype: torch.dtype, moe_parallel_config: FusedMoEParallelConfig, ) -> tuple[int, int]: hidden_size, intermediate_size_per_partition = super().maybe_roundup_sizes( hidden_size=hidden_size, intermediate_size_per_partition=intermediate_size_per_partition, act_dtype=act_dtype, moe_parallel_config=moe_parallel_config, ) return mxfp4_round_up_hidden_size_and_intermediate_size( self.mxfp4_backend, hidden_size, intermediate_size_per_partition ) def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): self.num_experts = num_experts weight_dtype = torch.uint8 scale_dtype = torch.uint8 mxfp4_block = 32 layer.params_dtype = params_dtype layer.num_experts = num_experts self.intermediate_size = intermediate_size_per_partition self.hidden_size = hidden_size # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size // 2, dtype=weight_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size // mxfp4_block, dtype=scale_dtype, ), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, intermediate_size_per_partition // 2, dtype=weight_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, intermediate_size_per_partition // mxfp4_block, dtype=scale_dtype, ), requires_grad=False, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) if self.moe.has_bias: w13_bias = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, dtype=torch.bfloat16, ), requires_grad=False, ) layer.register_parameter("w13_bias", w13_bias) set_weight_attrs(w13_bias, extra_weight_attrs) w2_bias = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, dtype=torch.bfloat16, ), requires_grad=False, ) layer.register_parameter("w2_bias", w2_bias) set_weight_attrs(w2_bias, extra_weight_attrs) def _setup_kernel( self, layer: FusedMoE, w13: torch.Tensor, w2: torch.Tensor, w13_scale: torch.Tensor, w2_scale: torch.Tensor, w13_bias: torch.Tensor | None = None, w2_bias: torch.Tensor | None = None, ) -> None: num_experts = self.num_experts intermediate_size = self.intermediate_size hidden_size = self.hidden_size sf_block_size = 32 # Shape assertions assert ( w13.dim() == 3 and w13.shape[0] == num_experts and w13.shape[1] == intermediate_size * 2 and w13.shape[2] == hidden_size // 2 ) assert ( w13_scale.dim() == 3 and w13_scale.shape[0] == num_experts and w13_scale.shape[1] == intermediate_size * 2 and w13_scale.shape[2] == hidden_size // sf_block_size ) assert ( w2.dim() == 3 and w2.shape[0] == num_experts and w2.shape[1] == hidden_size and w2.shape[2] == intermediate_size // 2 ) assert ( w2_scale.dim() == 3 and w2_scale.shape[1] == hidden_size and w2_scale.shape[2] == intermediate_size // sf_block_size ) if w13_bias is not None: assert ( w13_bias.dim() == 2 and w13_bias.shape[0] == num_experts and w13_bias.shape[1] == intermediate_size * 2 ) if w2_bias is not None: assert ( w2_bias.dim() == 2 and w2_bias.shape[0] == num_experts and w2_bias.shape[1] == hidden_size ) # Convert weights to kernel format w13, w2, w13_scale, w2_scale, w13_bias, w2_bias = ( convert_to_mxfp4_moe_kernel_format( mxfp4_backend=self.mxfp4_backend, layer=layer, w13_weight=w13, w2_weight=w2, w13_weight_scale=w13_scale, w2_weight_scale=w2_scale, w13_bias=w13_bias, w2_bias=w2_bias, _cache_permute_indices=self._cache_permute_indices, ) ) # For TRITON backends, weights are wrapped tensors from triton_kernels # that don't support .detach(). Manually assign parameters. if self.mxfp4_backend not in TRITON_BACKENDS: replace_parameter(layer, "w13_weight", w13) replace_parameter(layer, "w2_weight", w2) replace_parameter(layer, "w13_weight_scale", w13_scale) replace_parameter(layer, "w2_weight_scale", w2_scale) else: layer.w13_weight = w13 layer.w2_weight = w2 self.w13_precision_config = w13_scale self.w2_precision_config = w2_scale if w13_bias is not None and w2_bias is not None: replace_parameter(layer, "w13_bias", w13_bias) replace_parameter(layer, "w2_bias", w2_bias) # Build quant config self.moe_quant_config = self.get_fused_moe_quant_config(layer) # Build kernel (modular or monolithic) if self.moe_quant_config is not None and self.experts_cls is not None: self.moe_kernel = make_mxfp4_moe_kernel( moe_quant_config=self.moe_quant_config, moe_config=self.moe, mxfp4_backend=self.mxfp4_backend, experts_cls=self.experts_cls, routing_tables=layer._maybe_init_expert_routing_tables(), shared_experts=layer.shared_experts, ) def process_weights_after_loading(self, layer): w13 = layer.w13_weight w2 = layer.w2_weight w13_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale w13_bias = getattr(layer, "w13_bias", None) w2_bias = getattr(layer, "w2_bias", None) if self.mxfp4_backend == Mxfp4MoeBackend.NONE: return self._setup_kernel(layer, w13, w2, w13_scale, w2_scale, w13_bias, w2_bias) def get_fused_moe_quant_config( self, layer: torch.nn.Module ) -> FusedMoEQuantConfig | None: w1_scale = layer.w13_weight_scale w2_scale = layer.w2_weight_scale w1_bias = getattr(layer, "w13_bias", None) w2_bias = getattr(layer, "w2_bias", None) if self.mxfp4_backend in TRITON_BACKENDS: assert self.w13_precision_config is not None assert self.w2_precision_config is not None w1_scale = self.w13_precision_config w2_scale = self.w2_precision_config return make_mxfp4_moe_quant_config( mxfp4_backend=self.mxfp4_backend, w1_scale=w1_scale, w2_scale=w2_scale, w1_bias=w1_bias, w2_bias=w2_bias, ) def select_gemm_impl( self, prepare_finalize: mk.FusedMoEPrepareAndFinalize, layer: torch.nn.Module, ) -> mk.FusedMoEExpertsModular: raise ValueError( f"{self.__class__.__name__} uses the new modular kernel " "initialization logic. This function should not be called." ) def apply( self, layer: FusedMoE, x: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, shared_experts_input: torch.Tensor | None, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert not self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, topk_weights=topk_weights, topk_ids=topk_ids, activation=layer.activation, global_num_experts=layer.global_num_experts, apply_router_weight_on_input=layer.apply_router_weight_on_input, expert_map=layer.expert_map, shared_experts_input=shared_experts_input, ) def apply_monolithic( self, layer: FusedMoE, x: torch.Tensor, router_logits: torch.Tensor, ) -> torch.Tensor | tuple[torch.Tensor, torch.Tensor]: assert self.is_monolithic assert self.moe_kernel is not None return self.moe_kernel.apply_monolithic( hidden_states=x, w1=layer.w13_weight, w2=layer.w2_weight, router_logits=router_logits, activation=layer.activation, global_num_experts=layer.global_num_experts, expert_map=layer.expert_map, apply_router_weight_on_input=layer.apply_router_weight_on_input, )