from __future__ import annotations import logging from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional import torch from torch.nn import Module from torch.nn.parameter import Parameter from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.utils import is_layer_skipped from sglang.srt.utils import is_npu, set_weight_attrs if TYPE_CHECKING: from sglang.srt.layers.moe import MoeRunnerConfig from sglang.srt.layers.moe.token_dispatcher import ( CombineInput, StandardDispatchOutput, ) ACTIVATION_SCHEMES = ["static", "dynamic"] logger = logging.getLogger(__name__) class W4AFp8Config(QuantizationConfig): """Config class for MIXED_PRECISION W4AFp8.""" def __init__( self, is_checkpoint_fp8_serialized: bool = True, is_checkpoint_w4afp8_serialized: bool = True, linear_activation_scheme: str = "dynamic", moe_activation_scheme: str = "static", ignored_layers: Optional[List[str]] = None, weight_block_size: Optional[List[int]] = None, group_size: int = 128, ) -> None: super().__init__() self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized if is_checkpoint_w4afp8_serialized: logger.warning("Detected w4afp8 checkpoint. Please note that") if moe_activation_scheme not in ACTIVATION_SCHEMES: raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}") self.linear_activation_scheme = linear_activation_scheme self.moe_activation_scheme = moe_activation_scheme self.ignored_layers = ignored_layers or [] self.weight_block_size = [128, 128] self.group_size = group_size @classmethod def get_name(cls) -> str: return "w4afp8" @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return [torch.bfloat16, torch.float8_e4m3fn] @classmethod def get_min_capability(cls) -> int: return 90 @classmethod def get_config_filenames(cls) -> List[str]: return [] @classmethod def from_config(cls, config: Dict[str, Any]) -> W4AFp8Config: quant_method = cls.get_from_keys(config, ["quant_method"]) is_checkpoint_fp8_serialized = "fp8" in quant_method is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method linear_activation_scheme = "dynamic" moe_activation_scheme = "static" weight_block_size = [128, 128] return cls( is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized, is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized, linear_activation_scheme=linear_activation_scheme, moe_activation_scheme=moe_activation_scheme, weight_block_size=weight_block_size, ) def get_quant_method( self, layer: torch.nn.Module, prefix: str ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if isinstance(layer, LinearBase): if is_layer_skipped(prefix, self.ignored_layers): return UnquantizedLinearMethod() return Fp8LinearMethod(self) elif isinstance(layer, FusedMoE): return W4AFp8MoEMethod(self) return None def get_scaled_act_names(self) -> List[str]: return [] def interleave_scales(scales: torch.Tensor) -> torch.Tensor: """Interleave scales in groups of 4 similar to TRT-LLM implementation.""" s_shape = scales.shape # Reshape to separate groups of 4 alignment = 4 if s_shape[2] % 4 == 0 else 1 scales_interleaved = scales.reshape( s_shape[0], s_shape[1], (s_shape[2] // alignment), alignment ) # Permute dimensions to interleave scales_interleaved = scales_interleaved.permute(0, 2, 1, 3) # Reshape back to original dimensions but with interleaved values scales_interleaved = scales_interleaved.reshape( s_shape[0], s_shape[2] // alignment, s_shape[1] * alignment ) return scales_interleaved.contiguous() class W4AFp8MoEMethod(FusedMoEMethodBase): def __init__(self, quant_config: W4AFp8Config): self.quant_config = quant_config def create_weights( self, layer: Module, num_experts: int, hidden_size: int, intermediate_size_per_partition: int, params_dtype: torch.dtype, **extra_weight_attrs, ): from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported assert "weight_loader" in extra_weight_attrs # Fused gate_up_proj (column parallel) w13_weight = torch.nn.Parameter( torch.empty( num_experts, intermediate_size_per_partition * 2, hidden_size // 2, dtype=torch.int8, ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) # down_proj (row parallel) w2_weight = torch.nn.Parameter( torch.empty( num_experts, hidden_size, intermediate_size_per_partition // 2, dtype=torch.int8, ), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.GROUP.value} ) w13_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, 2 * intermediate_size_per_partition, hidden_size // self.quant_config.group_size, dtype=torch.float32, ), requires_grad=False, ) layer.register_parameter("w13_weight_scale_inv", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter( torch.zeros( num_experts, hidden_size, intermediate_size_per_partition // self.quant_config.group_size, dtype=torch.float32, ), requires_grad=False, ) layer.register_parameter("w2_weight_scale_inv", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # Input scales w13_input_scale = torch.nn.Parameter( torch.ones((num_experts, 2), dtype=torch.bfloat16), requires_grad=False, ) layer.register_parameter("w13_input_scale", w13_input_scale) set_weight_attrs(w13_input_scale, extra_weight_attrs) w2_input_scale = torch.nn.Parameter( torch.ones(num_experts, dtype=torch.bfloat16), requires_grad=False, ) layer.register_parameter("w2_input_scale", w2_input_scale) set_weight_attrs(w2_input_scale, extra_weight_attrs) # Pre-populate the strides device = layer.w13_weight.device self.a_strides1 = torch.full( (num_experts, 3), hidden_size, device=device, dtype=torch.int64, ) self.c_strides1 = torch.full( (num_experts, 3), 2 * intermediate_size_per_partition, device=device, dtype=torch.int64, ) self.a_strides2 = torch.full( (num_experts, 3), intermediate_size_per_partition, device=device, dtype=torch.int64, ) self.c_strides2 = torch.full( (num_experts, 3), hidden_size, device=device, dtype=torch.int64, ) self.b_strides1 = self.a_strides1 self.s_strides13 = self.c_strides1 self.b_strides2 = self.a_strides2 self.s_strides2 = self.c_strides2 self.expert_offsets = torch.empty( (num_experts + 1), dtype=torch.int32, device=device ) self.problem_sizes1 = torch.empty( (num_experts, 3), dtype=torch.int32, device=device ) self.problem_sizes2 = torch.empty( (num_experts, 3), dtype=torch.int32, device=device ) return def process_weights_after_loading(self, layer: Module) -> None: dtype = torch.bfloat16 device = layer.w2_weight.device # Interleave w13_weight_scale (gate_up_proj) w13_weight_scale = layer.w13_weight_scale_inv.to(dtype) w13_weight_scale = interleave_scales(w13_weight_scale) layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False) # Interleave w2_weight_scale (down_proj) w2_weight_scale = layer.w2_weight_scale_inv.to(dtype) w2_weight_scale = interleave_scales(w2_weight_scale) layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False) # Process input scales w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item() new_w13_input_scale = torch.tensor( [w13_input_scale_max], dtype=dtype, device=device, ) layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False) w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item() new_w2_input_scale = torch.tensor( [w2_input_scale_max], dtype=dtype, device=device ) layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False) def create_moe_runner( self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig ): self.moe_runner_config = moe_runner_config def apply( self, layer: Module, dispatch_output: StandardDispatchOutput, ) -> CombineInput: from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput x = dispatch_output.hidden_states topk_output = dispatch_output.topk_output topk_weights, topk_ids, _ = topk_output output = cutlass_w4a8_moe( x, layer.w13_weight, layer.w2_weight, layer.w13_weight_scale_inv, layer.w2_weight_scale_inv, topk_weights, topk_ids, self.a_strides1, self.b_strides1, self.c_strides1, self.a_strides2, self.b_strides2, self.c_strides2, self.s_strides13, self.s_strides2, self.expert_offsets, self.problem_sizes1, self.problem_sizes2, layer.w13_input_scale, layer.w2_input_scale, ) if self.moe_runner_config.routed_scaling_factor is not None: output *= self.moe_runner_config.routed_scaling_factor return StandardCombineInput(hidden_states=output)