from __future__ import annotations import importlib import sys from types import MappingProxyType from typing import ( TYPE_CHECKING, Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast, ) import torch from torch.nn.parameter import Parameter from sglang.srt.distributed import ( get_tensor_model_parallel_rank, get_tensor_model_parallel_world_size, ) from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading from sglang.srt.layers.parameter import ( ChannelQuantScaleParameter, ModelWeightParameter, PerTensorScaleParameter, ) from sglang.srt.layers.quantization.base_config import ( FusedMoEMethodBase, LinearMethodBase, QuantizationConfig, QuantizeMethodBase, ) from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.utils import ( apply_module_patch, cpu_has_amx_support, is_cpu, is_cuda, is_npu, set_weight_attrs, use_intel_amx_backend, ) if TYPE_CHECKING: from sglang.srt.layers.moe.topk import TopKOutput _is_cuda = is_cuda() _is_cpu_amx_available = cpu_has_amx_support() _is_cpu = is_cpu() if _is_cuda: from sgl_kernel import int8_scaled_mm _is_npu = is_npu() if _is_npu: import torch_npu try: from mindie_turbo import _ops as ops from mindie_turbo.quantize.quant_utils import quant_per_tensor except ImportError: useMindIETurbo = False else: useMindIETurbo = True # func refers to RMSNorm.__init__ def npu_wrapper_rmsnorm_init(func): def init(self, hidden_size: int, **extra_args) -> None: func(self, hidden_size, **extra_args) self.ignore_anti = True # The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False) return init # func refers to RMSNorm.forward_oot def npu_wrapper_rmsnorm_forward(func): def _rmsnorm_forward_oot( self, x: torch.Tensor, residual: Optional[torch.Tensor] = None, ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: if not x.is_contiguous(): x = x.contiguous() if residual is not None: out, _, residual_out = torch_npu.npu_add_rms_norm( residual, x, self.weight.data, self.variance_epsilon ) out = out + self.bias return out.to(x.dtype), residual_out out = torch_npu.npu_rms_norm(x, self.weight.data, self.variance_epsilon)[0] out = out + self.bias return out.to(x.dtype) return _rmsnorm_forward_oot def npu_fused_experts( hidden_states: torch.Tensor, w13: torch.Tensor, w13_scale: torch.Tensor, w2: torch.Tensor, w2_scale: torch.Tensor, topk_weights: torch.Tensor, topk_ids: torch.Tensor, top_k: int, ): original_shape = hidden_states.shape original_dtype = hidden_states.dtype scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32 if len(original_shape) == 3: hidden_states = hidden_states.view(-1, hidden_states.shape[-1]) num_tokens = hidden_states.shape[0] num_experts = w13.shape[0] row_idx_len = num_tokens * top_k row_idx = ( torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device) .view(top_k, -1) .permute(1, 0) .contiguous() ) hidden_states, expanded_row_idx, expanded_expert_idx = ( torch_npu.npu_moe_init_routing( hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens ) ) expert_tokens = torch_npu.npu_moe_compute_expert_tokens( expanded_expert_idx, num_experts ) expert_tokens = expert_tokens.to(torch.int64) # gmm1: gate_up_proj hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w13], scale=[w13_scale.to(scale_dtype)], per_token_scale=[pertoken_scale], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, output_dtype=original_dtype, )[0] # act_fn: swiglu hidden_states = torch_npu.npu_swiglu(hidden_states) hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states) # gmm2: down_proj hidden_states = torch_npu.npu_grouped_matmul( x=[hidden_states], weight=[w2], scale=[w2_scale.to(scale_dtype)], per_token_scale=[pertoken_scale], split_item=2, group_list_type=0, group_type=0, group_list=expert_tokens, output_dtype=original_dtype, )[0] final_hidden_states = torch_npu.npu_moe_finalize_routing( hidden_states, skip1=None, skip2=None, bias=None, scales=topk_weights, expanded_src_to_dst_row=expanded_row_idx, export_for_source_row=topk_ids, ) if len(original_shape) == 3: final_hidden_states = final_hidden_states.view(original_shape) return final_hidden_states class W8A8Int8Config(QuantizationConfig): """Config class for W8A8 Int8 Quantization. - Weight: static, per-channel, symmetric - Activation: dynamic, per-token, symmetric """ def __init__(self, quant_config: Dict[str, Any] = {}): super().__init__() self.quant_description = quant_config self.is_dynamic = quant_config.get("is_dynamic", False) ignore = cast(List[str], quant_config.get("ignore", [])) self.ignore = ignore if ignore is not None else [] packed_modules_mapping = quant_config.get("packed_modules_mapping", {}) self.packed_modules_mapping = ( packed_modules_mapping if packed_modules_mapping is not None else {} ) if _is_npu: # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models for name in self.quant_description.keys(): if "norm.bias" in name: apply_module_patch( "sglang.srt.layers.layernorm.RMSNorm", "__init__", [npu_wrapper_rmsnorm_init], ) apply_module_patch( "sglang.srt.layers.layernorm.RMSNorm", "forward_npu", [npu_wrapper_rmsnorm_forward], ) @classmethod def get_supported_act_dtypes(cls) -> List[torch.dtype]: return ( [torch.float16, torch.bfloat16] if not _is_npu else [torch.int8, torch.float16, torch.bfloat16] ) @classmethod def get_min_capability(cls) -> int: if _is_npu: raise NotImplementedError( 'NPU hardware does not support "get_min_capability" feature.' ) else: return 75 @classmethod def get_name(self) -> str: return "w8a8_int8" @classmethod def get_config_filenames(cls) -> List[str]: filenames = [] if _is_npu: filenames.append("quant_model_description.json") return filenames @classmethod def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config: return cls(config) def get_quant_method( self, layer: torch.nn.Module, prefix: str, ) -> Optional[QuantizeMethodBase]: from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.moe.fused_moe_triton import FusedMoE if _is_npu: if isinstance(layer, LinearBase): key = "model" if "vision_model" in prefix: key = "vision_model" elif "visual" in prefix: key = "visual" packed_modules_mapping_subset = self.packed_modules_mapping.get(key, {}) prefix_in_quant_config = prefix proj_name = prefix.split(".")[-1] if proj_name in packed_modules_mapping_subset: prefix_in_quant_config = prefix.replace( proj_name, packed_modules_mapping_subset[proj_name][0] ) self.is_dynamic = ( self.quant_description[prefix_in_quant_config + ".weight"] == "W8A8_DYNAMIC" ) if self.is_layer_skipped(prefix, packed_modules_mapping_subset): return UnquantizedLinearMethod() return ( NPU_W8A8DynamicLinearMethod(self) if self.is_dynamic else NPU_W8A8LinearMethod(self) ) elif isinstance(layer, FusedMoE): return NPU_W8A8MoEMethod(self) return None if should_ignore_layer( prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping ): return UnquantizedLinearMethod() if isinstance(layer, LinearBase): return W8A8Int8LinearMethod(self) elif isinstance(layer, FusedMoE): return W8A8Int8MoEMethod(self) return None def is_layer_skipped( self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) ): # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped proj_name = prefix.split(".")[-1] if proj_name in fused_mapping: shard_prefixes = [ prefix.replace(proj_name, shard_proj_name) for shard_proj_name in fused_mapping[proj_name] ] is_skipped = None for shard_prefix in shard_prefixes: is_shard_skipped = ( self.quant_description[shard_prefix + ".weight"] == "FLOAT" ) if is_skipped is None: is_skipped = is_shard_skipped elif is_shard_skipped != is_skipped: raise ValueError( f"Detected some but not all shards of {prefix} " "are quantized. All shards of fused layers " "to have the same precision." ) else: is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT" assert is_skipped is not None return is_skipped def get_scaled_act_names(self) -> List[str]: return [] class W8A8Int8LinearMethod(LinearMethodBase): def __init__(self, quantization_config: W8A8Int8Config): self.quantization_config = quantization_config def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _is_cpu: assert ( _is_cpu_amx_available ), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support" _amx_process_weight_after_loading(layer, ["weight"]) return layer.weight = Parameter(layer.weight.t(), requires_grad=False) layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): weight_loader = extra_weight_attrs.get("weight_loader") self.logical_widths = output_partition_sizes weight = ModelWeightParameter( data=torch.empty( sum(output_partition_sizes), input_size_per_partition, dtype=torch.int8 ), input_dim=1, output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight", weight) weight_scale = ChannelQuantScaleParameter( data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32), output_dim=0, weight_loader=weight_loader, ) layer.register_parameter("weight_scale", weight_scale) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ): if use_intel_amx_backend(layer): return torch.ops.sgl_kernel.int8_scaled_mm_with_quant( x, layer.weight, layer.weight_scale, bias, x.dtype, True, # is_vnni ) x_q, x_scale = per_token_quant_int8(x) return int8_scaled_mm( x_q, layer.weight, x_scale, layer.weight_scale, out_dtype=x.dtype, bias=bias ) class W8A8Int8MoEMethod(FusedMoEMethodBase): """MoE method for INT8. Supports loading INT8 checkpoints with static weight scale and dynamic/static activation scale. Also supports loading quantized FP16/BF16 model checkpoints with dynamic activation scaling. The weight scaling factor will be initialized after the model weights are loaded. Args: quant_config: The quantization config. """ def __init__(self, quant_config: W8A8Int8Config): self.quant_config = quant_config def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ): from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported tp_size = get_tensor_model_parallel_world_size() # WEIGHTS w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) w13_weight_scale = torch.nn.Parameter( torch.ones(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), requires_grad=False, ) w2_weight_scale = torch.nn.Parameter( torch.ones(num_experts, hidden_size, 1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) layer.register_parameter("w2_weight_scale", w2_weight_scale) extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) set_weight_attrs(w13_weight_scale, extra_weight_attrs) set_weight_attrs(w2_weight_scale, extra_weight_attrs) w13_input_scale = None layer.register_parameter("w13_input_scale", w13_input_scale) w2_input_scale = None layer.register_parameter("w2_input_scale", w2_input_scale) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if _is_cpu: assert ( _is_cpu_amx_available ), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support" _amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"]) return layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False) layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False) layer.w13_weight_scale = Parameter( layer.w13_weight_scale.data, requires_grad=False ) layer.w2_weight_scale = Parameter( layer.w2_weight_scale.data, requires_grad=False ) def apply( self, layer: torch.nn.Module, x: torch.Tensor, topk_output: TopKOutput, *, activation: str = "silu", apply_router_weight_on_input: bool = False, inplace: bool = True, no_combine: bool = False, routed_scaling_factor: Optional[float] = None, ) -> torch.Tensor: from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts if use_intel_amx_backend(layer): from sglang.srt.layers.moe.topk import apply_topk_weights_cpu topk_weights, topk_ids, _ = topk_output x, topk_weights = apply_topk_weights_cpu( apply_router_weight_on_input, topk_weights, x ) return torch.ops.sgl_kernel.fused_experts_cpu( x, layer.w13_weight, layer.w2_weight, topk_weights, topk_ids, False, # inplace See [Note] inplace should be False in fused_experts. True, # use_int8_w8a8 False, # use_fp8_w8a16 layer.w13_weight_scale, # w1_scale layer.w2_weight_scale, # w2_scale None, # block_size layer.w13_input_scale, # a1_scale layer.w2_input_scale, # a2_scale True, # is_vnni ) return fused_experts( x, layer.w13_weight, layer.w2_weight, topk_output=topk_output, inplace=inplace, activation=activation, apply_router_weight_on_input=apply_router_weight_on_input, use_int8_w8a8=True, per_channel_quant=True, w1_scale=(layer.w13_weight_scale), w2_scale=(layer.w2_weight_scale), a1_scale=layer.w13_input_scale, a2_scale=layer.w2_input_scale, no_combine=no_combine, routed_scaling_factor=routed_scaling_factor, ) class NPU_W8A8LinearMethodImpl: """Linear method for NPU W8A8.""" def __init__(self) -> None: # aclnn quant matmul requires to transpose matrix B, set to true by default. self.transpose_weight = True @staticmethod def get_weight( input_size: int, output_size: int, params_dtype: torch.dtype = torch.bfloat16, ) -> Dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict @staticmethod def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {} params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) return params_dict @staticmethod def get_perchannel_param( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: params_dict = {} params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) if params_dtype == torch.bfloat16: params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32) elif params_dtype == torch.float16: params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64) params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) return params_dict @staticmethod def apply( layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # To prevent import loops from sglang.srt.layers.linear import RowParallelLinear original_dtype = x.dtype if original_dtype != torch.int8: x = torch_npu.npu_quantize( x, layer.aclnn_input_scale, layer.aclnn_input_offset, torch.qint8, -1, True, ) # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in Attention TP>1 case) if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: quant_bias = None else: quant_bias = layer.quant_bias return torch_npu.npu_quant_matmul( x, layer.weight, layer.deq_scale, bias=quant_bias, output_dtype=original_dtype, ) def process_weights_after_loading(self, layer): expanding_factor = layer.weight.data.shape[1] layer.aclnn_input_scale = torch.nn.Parameter( layer.input_scale.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, ) layer.aclnn_input_offset = torch.nn.Parameter( layer.input_offset.data.repeat(expanding_factor).to(device="npu"), requires_grad=False, ) if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight_scale.data = torch.flatten(layer.weight_scale.data) layer.weight_offset.data = torch.flatten(layer.weight_offset.data) class NPU_W8A8LinearMethodMTImpl: """Linear method for NPU W8A8.""" def __init__(self) -> None: self.transpose_weight = True @staticmethod def get_weight( input_size: int, output_size: int, params_dtype: torch.dtype = torch.bfloat16, ) -> Dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict @staticmethod def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: params_dict = {} params_dict["input_scale"] = torch.empty(1, dtype=params_dtype) params_dict["input_offset"] = torch.empty(1, dtype=torch.int8) return params_dict @staticmethod def get_perchannel_param( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: params_dict = {} params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32) if params_dtype == torch.bfloat16: params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32) elif params_dtype == torch.float16: params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64) params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) return params_dict @staticmethod def apply( layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: # To prevent import loops from sglang.srt.layers.linear import RowParallelLinear original_dtype = x.dtype if original_dtype != torch.int8: x = quant_per_tensor(x, layer.input_scale, layer.input_offset) # Only fuse bias add into GEMM for rank 0 (this ensures that # bias will not get added more than once in Attention TP>1 case) if isinstance(layer, RowParallelLinear) and layer.tp_rank > 0: quant_bias = None else: quant_bias = layer.quant_bias return ops.quant_matmul( x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias ) def process_weights_after_loading(self, layer): layer.aclnn_deq_scale = torch.nn.Parameter( torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"), requires_grad=False, ) class NPU_W8A8LinearMethod(LinearMethodBase): """Linear method for NPU quantization. This class search for specific quantization implementation supported on NPU hardware for linear methods. Args: quant_config: The NPU quantization config. """ def __init__(self, quantization_config: W8A8Int8Config) -> None: self.quantization_config = quantization_config self.quant_method = ( NPU_W8A8LinearMethodMTImpl() if useMindIETurbo else NPU_W8A8LinearMethodImpl() ) def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") weight_dict = self.quant_method.get_weight( input_size_per_partition, output_size_per_partition, params_dtype ) for weight_name, weight_param in weight_dict.items(): param = torch.nn.Parameter(weight_param, requires_grad=False) set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) layer.register_parameter(weight_name, param) set_weight_attrs(param, extra_weight_attrs) pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): param = PerTensorScaleParameter( data=pertensor_param, weight_loader=weight_loader ) # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param) perchannel_dict = self.quant_method.get_perchannel_param( output_size_per_partition, params_dtype ) for perchannel_name, perchannel_param in perchannel_dict.items(): param = torch.nn.Parameter(perchannel_param, requires_grad=False) set_weight_attrs(param, {"output_dim": 0}) layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.quant_method.apply(layer, x, bias) class NPU_W8A8DynamicLinearMethodImpl: """Linear method for NPU W8A8_DYNAMIC.""" def __init__(self): self.transpose_weight = True @staticmethod def get_weight( input_size: int, output_size: int, params_dtype: torch.dtype ) -> Dict[str, Any]: params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)} return params_dict @staticmethod def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]: return {} @staticmethod def get_perchannel_param( output_size: int, params_dtype: torch.dtype, ) -> Dict[str, Any]: params_dict = {} params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype) params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype) return params_dict @staticmethod def apply( layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, tp_rank: Optional[int] = 0, ) -> torch.Tensor: original_dtype = x.dtype quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x) return torch_npu.npu_quant_matmul( quant_out, layer.weight, layer.weight_scale, pertoken_scale=dynamic_scale, bias=bias, output_dtype=original_dtype, ) def process_weights_after_loading(self, layer): if self.transpose_weight: layer.weight.data = layer.weight.data.transpose(0, 1).contiguous() layer.weight_scale.data = layer.weight_scale.data.flatten() layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32) layer.weight_offset.data = layer.weight_offset.data.flatten() class NPU_W8A8DynamicLinearMethod(LinearMethodBase): """Linear method for NPU quantization. This class search for specific quantization implementations supported on NPU hardware for linear methods. Args: quant_config: The NPU quantization config. """ def __init__(self, quantization_config: W8A8Int8Config) -> None: self.quantization_config = quantization_config self.quant_method = NPU_W8A8DynamicLinearMethodImpl() def create_weights( self, layer: torch.nn.Module, input_size_per_partition: int, output_partition_sizes: List[int], input_size: int, output_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: output_size_per_partition = sum(output_partition_sizes) weight_loader = extra_weight_attrs.get("weight_loader") weight_dict = self.quant_method.get_weight( input_size_per_partition, output_size_per_partition, params_dtype ) for weight_name, weight_param in weight_dict.items(): param = torch.nn.Parameter(weight_param, requires_grad=False) set_weight_attrs(param, {"input_dim": 1, "output_dim": 0}) layer.register_parameter(weight_name, param) set_weight_attrs(param, extra_weight_attrs) pertensor_dict = self.quant_method.get_pertensor_param(params_dtype) for pertensor_name, pertensor_param in pertensor_dict.items(): param = PerTensorScaleParameter( data=pertensor_param, weight_loader=weight_loader ) # disable warning param.ignore_warning = True layer.register_parameter(pertensor_name, param) perchannel_dict = self.quant_method.get_perchannel_param( output_size_per_partition, params_dtype ) for perchannel_name, perchannel_param in perchannel_dict.items(): param = torch.nn.Parameter(perchannel_param, requires_grad=False) set_weight_attrs(param, {"output_dim": 0}) layer.register_parameter(perchannel_name, param) set_weight_attrs(param, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: if hasattr(self.quant_method, "process_weights_after_loading"): self.quant_method.process_weights_after_loading(layer) def apply( self, layer: torch.nn.Module, x: torch.Tensor, bias: Optional[torch.Tensor] = None, ) -> torch.Tensor: return self.quant_method.apply(layer, x, bias) class NPU_W8A8MoEMethod(FusedMoEMethodBase): """MoE method for NPU quantization. This class search for specific quantization implementations supported on NPU hardware for moe methods. Args: quant_config: The NPU quantization config. """ def __init__(self, quantization_config: W8A8Int8Config) -> None: self.quantization_config = quantization_config self.quant_method = self def create_weights( self, layer: torch.nn.Module, num_experts: int, hidden_size: int, intermediate_size: int, params_dtype: torch.dtype, **extra_weight_attrs, ) -> None: from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported self.num_experts = num_experts extra_weight_attrs.update( {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value} ) # weight w13_weight = torch.nn.Parameter( torch.empty( num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8 ), requires_grad=False, ) layer.register_parameter("w13_weight", w13_weight) set_weight_attrs(w13_weight, extra_weight_attrs) w2_weight = torch.nn.Parameter( torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8), requires_grad=False, ) layer.register_parameter("w2_weight", w2_weight) set_weight_attrs(w2_weight, extra_weight_attrs) # scale w13_weight_scale = torch.nn.Parameter( torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w13_weight_scale", w13_weight_scale) set_weight_attrs(w13_weight_scale, extra_weight_attrs) w2_weight_scale = torch.nn.Parameter( torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w2_weight_scale", w2_weight_scale) set_weight_attrs(w2_weight_scale, extra_weight_attrs) # offset w13_weight_offset = torch.nn.Parameter( torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w13_weight_offset", w13_weight_offset) set_weight_attrs(w13_weight_offset, extra_weight_attrs) w2_weight_offset = torch.nn.Parameter( torch.empty(num_experts, hidden_size, 1, dtype=torch.float32), requires_grad=False, ) layer.register_parameter("w2_weight_offset", w2_weight_offset) set_weight_attrs(w2_weight_offset, extra_weight_attrs) def process_weights_after_loading(self, layer: torch.nn.Module) -> None: layer.w13_weight = Parameter( layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False ) layer.w2_weight = Parameter( layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False ) layer.w13_weight_scale = Parameter( layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False ) layer.w2_weight_scale = Parameter( layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False ) layer.w13_weight_offset = Parameter( layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False ) layer.w2_weight_offset = Parameter( layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False ) def apply( self, layer, x, topk_output: TopKOutput, **kwargs, ) -> torch.Tensor: topk_weights, topk_ids, _ = topk_output topk_ids = topk_ids.to(torch.int32) topk_weights = topk_weights.to(x.dtype) return npu_fused_experts( hidden_states=x, w13=layer.w13_weight, w13_scale=layer.w13_weight_scale, w2=layer.w2_weight, w2_scale=layer.w2_weight_scale, topk_weights=topk_weights, topk_ids=topk_ids, top_k=topk_ids.shape[1], )