Unverified Commit 766392c6 authored by ronnie_zheng's avatar ronnie_zheng Committed by GitHub
Browse files

[feature]Ascend quantization support (#7791)


Co-authored-by: default avatarichernob <ichernobnn@gmail.com>
Co-authored-by: default avatarliupeng <liupeng374@huawei.com>
parent 4a0d1919
......@@ -413,7 +413,9 @@ class ModelConfig:
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
quant_method = quant_cfg.get("quant_method", "").lower()
quant_method = quant_cfg.get(
"quant_method", "" if not self.quantization else self.quantization
).lower()
# Detect which checkpoint is it
for _, method in QUANTIZATION_METHODS.items():
......
......@@ -34,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import (
from sglang.srt.utils import (
cpu_has_amx_support,
is_cpu,
is_npu,
set_weight_attrs,
use_intel_amx_backend,
)
......@@ -60,6 +61,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
_is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
_is_npu = is_npu()
def adjust_marlin_shard(param, shard_size, shard_offset):
......@@ -297,6 +299,14 @@ class ReplicatedLinear(LinearBase):
if len(loaded_weight.shape) == 0:
loaded_weight = loaded_weight.reshape(1)
# The per-tensor quant-scale must be 1 dimension
if _is_npu:
if param.size() != loaded_weight.size() and param.size(0) == 1:
if torch.allclose(loaded_weight, loaded_weight[0]):
loaded_weight = loaded_weight[:1]
else:
raise ValueError(f"{loaded_weight} are not all equal")
assert param.size() == loaded_weight.size()
param.data.copy_(loaded_weight)
......
......@@ -12,7 +12,6 @@ from sglang.srt.distributed import (
)
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
from sglang.srt.layers.moe.ep_moe.kernels import (
ep_gather,
ep_scatter,
......@@ -65,6 +64,8 @@ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
if not _is_npu:
from sgl_kernel import silu_and_mul
from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
if _is_hip:
from vllm._custom_ops import scaled_fp8_quant
......
......@@ -850,7 +850,7 @@ class FusedMoE(torch.nn.Module):
return
# Case weight scales and zero_points
if "scale" in weight_name or "zero" in weight_name:
if "scale" in weight_name or "zero" in weight_name or "offset" in weight_name:
# load the weight scales and zp based on the quantization scheme
# supported weight scales/zp can be found in
# FusedMoeWeightScaleSupported
......
......@@ -308,7 +308,7 @@ def biased_grouped_topk_gpu(
renormalize: bool,
num_expert_group: int = 0,
topk_group: int = 0,
compiled: bool = True,
compiled: bool = not _is_npu,
num_fused_shared_experts: int = 0,
routed_scaling_factor: Optional[float] = None,
num_token_non_padded: Optional[torch.Tensor] = None,
......
......@@ -116,8 +116,7 @@ class MoeWNA16Config(QuantizationConfig):
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
if can_convert and user_quant == "moe_wna16":
if user_quant == "moe_wna16" and cls.is_moe_wna16_compatible(hf_quant_cfg):
return cls.get_name()
return None
......
from typing import Any, Callable, Dict, List, Optional
import importlib
import sys
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
import torch
from torch.nn.parameter import Parameter
from sglang.srt.distributed import get_tensor_model_parallel_world_size
from sglang.srt.distributed import (
get_tensor_model_parallel_rank,
get_tensor_model_parallel_world_size,
)
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
from sglang.srt.layers.linear import LinearMethodBase
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
from sglang.srt.layers.linear import (
LinearMethodBase,
RowParallelLinear,
UnquantizedLinearMethod,
)
from sglang.srt.layers.parameter import (
ChannelQuantScaleParameter,
ModelWeightParameter,
PerTensorScaleParameter,
)
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.utils import (
apply_module_patch,
cpu_has_amx_support,
is_cpu,
is_cuda,
is_npu,
set_weight_attrs,
use_intel_amx_backend,
)
......@@ -25,6 +41,134 @@ _is_cpu_amx_available = cpu_has_amx_support()
_is_cpu = is_cpu()
if _is_cuda:
from sgl_kernel import int8_scaled_mm
_is_npu = is_npu()
if _is_npu:
import torch_npu
try:
from mindie_turbo import _ops as ops
from mindie_turbo.quantize.quant_utils import quant_per_tensor
except ImportError:
useMindIETurbo = False
else:
useMindIETurbo = True
# func refers to RMSNorm.__init__
def npu_wrapper_rmsnorm_init(func):
def init(self, hidden_size: int, **extra_args) -> None:
func(self, hidden_size, **extra_args)
self.ignore_anti = True
# The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
return init
# func refers to RMSNorm.forward_oot
def npu_wrapper_rmsnorm_forward(func):
def _rmsnorm_forward_oot(
self,
x: torch.Tensor,
residual: Optional[torch.Tensor] = None,
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
if not x.is_contiguous():
x = x.contiguous()
original_dtype = x.dtype
x = x.to(torch.float32)
if residual is not None:
x = x + residual.to(torch.float32)
residual = x.to(original_dtype)
x = (
torch_npu.npu_rms_norm(
x, self.weight.to(torch.float32), self.variance_epsilon
)[0]
+ self.bias
)
if residual is None:
return x.to(original_dtype)
return x.to(original_dtype), residual
return _rmsnorm_forward_oot
def npu_fused_experts(
hidden_states: torch.Tensor,
w13: torch.Tensor,
w13_scale: torch.Tensor,
w2: torch.Tensor,
w2_scale: torch.Tensor,
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
):
original_shape = hidden_states.shape
original_dtype = hidden_states.dtype
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
if len(original_shape) == 3:
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
num_tokens = hidden_states.shape[0]
num_experts = w13.shape[0]
row_idx_len = num_tokens * top_k
row_idx = (
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
.view(top_k, -1)
.permute(1, 0)
.contiguous()
)
hidden_states, expanded_row_idx, expanded_expert_idx = (
torch_npu.npu_moe_init_routing(
hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
)
)
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
expanded_expert_idx, num_experts
)
expert_tokens = expert_tokens.to(torch.int64)
# gmm1: gate_up_proj
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
scale=[w13_scale.to(scale_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=original_dtype,
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale.to(scale_dtype)],
per_token_scale=[pertoken_scale],
split_item=2,
group_list_type=0,
group_type=0,
group_list=expert_tokens,
output_dtype=original_dtype,
)[0]
final_hidden_states = torch_npu.npu_moe_finalize_routing(
hidden_states,
skip1=None,
skip2=None,
bias=None,
scales=topk_weights,
expanded_src_to_dst_row=expanded_row_idx,
export_for_source_row=topk_ids,
)
if len(original_shape) == 3:
final_hidden_states = final_hidden_states.view(original_shape)
return final_hidden_states
class W8A8Int8Config(QuantizationConfig):
......@@ -34,16 +178,47 @@ class W8A8Int8Config(QuantizationConfig):
- Activation: dynamic, per-token, symmetric
"""
def __init__(self):
pass
def __init__(self, quant_config: Dict[str, Any]):
super().__init__()
self.quant_description = quant_config
self.is_dynamic = quant_config.get("is_dynamic", False)
if _is_npu:
if (
"packed_modules_mapping" in quant_config
and quant_config["packed_modules_mapping"] is not None
):
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
for name in self.quant_description.keys():
if "norm.bias" in name:
apply_module_patch(
"sglang.srt.layers.layernorm.RMSNorm",
"__init__",
[npu_wrapper_rmsnorm_init],
)
apply_module_patch(
"sglang.srt.layers.layernorm.RMSNorm",
"forward_npu",
[npu_wrapper_rmsnorm_forward],
)
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.float16, torch.bfloat16]
return (
[torch.float16, torch.bfloat16]
if not _is_npu
else [torch.int8, torch.float16, torch.bfloat16]
)
@classmethod
def get_min_capability(cls) -> int:
return 75
if _is_npu:
raise NotImplementedError(
'NPU hardware does not support "get_min_capability" feature.'
)
else:
return 75
@classmethod
def get_name(self) -> str:
......@@ -55,7 +230,7 @@ class W8A8Int8Config(QuantizationConfig):
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
return cls()
return cls(config)
def get_quant_method(
self,
......@@ -65,11 +240,65 @@ class W8A8Int8Config(QuantizationConfig):
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
if _is_npu:
if isinstance(layer, LinearBase):
prefix_in_quant_config = prefix
proj_name = prefix.split(".")[-1]
if proj_name in self.packed_modules_mapping:
prefix_in_quant_config = prefix.replace(
proj_name, self.packed_modules_mapping[proj_name][0]
)
self.is_dynamic = (
self.quant_description[prefix_in_quant_config + ".weight"]
== "W8A8_DYNAMIC"
)
if self.is_layer_skipped(prefix, self.packed_modules_mapping):
return UnquantizedLinearMethod()
return (
NPU_W8A8DynamicLinearMethod(self)
if self.is_dynamic
else NPU_W8A8LinearMethod(self)
)
elif isinstance(layer, FusedMoE):
return NPU_W8A8MoEMethod(self)
return None
else:
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def is_layer_skipped(
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
):
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
proj_name = prefix.split(".")[-1]
if proj_name in fused_mapping:
shard_prefixes = [
prefix.replace(proj_name, shard_proj_name)
for shard_proj_name in fused_mapping[proj_name]
]
is_skipped = None
for shard_prefix in shard_prefixes:
is_shard_skipped = (
self.quant_description[shard_prefix + ".weight"] == "FLOAT"
)
if is_skipped is None:
is_skipped = is_shard_skipped
elif is_shard_skipped != is_skipped:
raise ValueError(
f"Detected some but not all shards of {prefix} "
"are quantized. All shards of fused layers "
"to have the same precision."
)
else:
is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
assert is_skipped is not None
return is_skipped
def get_scaled_act_names(self) -> List[str]:
return []
......@@ -321,3 +550,498 @@ class W8A8Int8MoEMethod:
no_combine=no_combine,
routed_scaling_factor=routed_scaling_factor,
)
class NPU_W8A8LinearMethodImpl:
"""Linear method for NPU W8A8."""
def __init__(self) -> None:
# aclnn quant matmul requires to transpose matrix B, set to true by default.
self.transpose_weight = True
@staticmethod
def get_weight(
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
return params_dict
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
x = torch_npu.npu_quantize(
x,
layer.aclnn_input_scale,
layer.aclnn_input_offset,
torch.qint8,
-1,
True,
)
quant_bias = layer.quant_bias if tp_rank == 0 else None
return torch_npu.npu_quant_matmul(
x,
layer.weight,
layer.deq_scale,
bias=quant_bias,
output_dtype=original_dtype,
)
def process_weights_after_loading(self, layer):
expanding_factor = layer.weight.data.shape[1]
layer.aclnn_input_scale = torch.nn.Parameter(
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
layer.aclnn_input_offset = torch.nn.Parameter(
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
requires_grad=False,
)
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
class NPU_W8A8LinearMethodMTImpl:
"""Linear method for NPU W8A8."""
def __init__(self) -> None:
self.transpose_weight = True
@staticmethod
def get_weight(
input_size: int,
output_size: int,
params_dtype: torch.dtype = torch.bfloat16,
) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
params_dict = {}
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
return params_dict
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
if params_dtype == torch.bfloat16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
elif params_dtype == torch.float16:
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
if original_dtype != torch.int8:
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
quant_bias = layer.quant_bias if tp_rank == 0 else None
return ops.quant_matmul(
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
)
def process_weights_after_loading(self, layer):
layer.aclnn_deq_scale = torch.nn.Parameter(
torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"),
requires_grad=False,
)
class NPU_W8A8LinearMethod(LinearMethodBase):
"""Linear method for NPU quantization.
This class search for specific quantization
implementation supported on NPU hardware for linear methods.
Args:
quant_config: The NPU quantization config.
"""
def __init__(self, quantization_config: W8A8Int8Config) -> None:
self.quantization_config = quantization_config
self.quant_method = (
NPU_W8A8LinearMethodMTImpl()
if useMindIETurbo
else NPU_W8A8LinearMethodImpl()
)
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(
input_size_per_partition, output_size_per_partition, params_dtype
)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(
data=pertensor_param, weight_loader=weight_loader
)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype
)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
class NPU_W8A8DynamicLinearMethodImpl:
"""Linear method for NPU W8A8_DYNAMIC."""
def __init__(self):
self.transpose_weight = True
@staticmethod
def get_weight(
input_size: int, output_size: int, params_dtype: torch.dtype
) -> Dict[str, Any]:
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
return params_dict
@staticmethod
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
return {}
@staticmethod
def get_perchannel_param(
output_size: int,
params_dtype: torch.dtype,
) -> Dict[str, Any]:
params_dict = {}
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
return params_dict
@staticmethod
def apply(
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
tp_rank: Optional[int] = 0,
) -> torch.Tensor:
original_dtype = x.dtype
# use ATB quantize
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
return torch_npu.npu_quant_matmul(
quant_out,
layer.weight,
layer.weight_scale,
pertoken_scale=dynamic_scale,
bias=bias,
output_dtype=original_dtype,
)
def process_weights_after_loading(self, layer):
if self.transpose_weight:
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
layer.weight_scale.data = layer.weight_scale.data.flatten()
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
layer.weight_offset.data = layer.weight_offset.data.flatten()
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
"""Linear method for NPU quantization.
This class search for specific quantization
implementations supported on NPU hardware for linear methods.
Args:
quant_config: The NPU quantization config.
"""
def __init__(self, quantization_config: W8A8Int8Config) -> None:
self.quantization_config = quantization_config
self.quant_method = NPU_W8A8DynamicLinearMethodImpl()
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
weight_dict = self.quant_method.get_weight(
input_size_per_partition, output_size_per_partition, params_dtype
)
for weight_name, weight_param in weight_dict.items():
param = torch.nn.Parameter(weight_param, requires_grad=False)
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
layer.register_parameter(weight_name, param)
set_weight_attrs(param, extra_weight_attrs)
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
for pertensor_name, pertensor_param in pertensor_dict.items():
param = PerTensorScaleParameter(
data=pertensor_param, weight_loader=weight_loader
)
# disable warning
param.ignore_warning = True
layer.register_parameter(pertensor_name, param)
perchannel_dict = self.quant_method.get_perchannel_param(
output_size_per_partition, params_dtype
)
for perchannel_name, perchannel_param in perchannel_dict.items():
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
set_weight_attrs(param, {"output_dim": 0})
layer.register_parameter(perchannel_name, param)
set_weight_attrs(param, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
if hasattr(self.quant_method, "process_weights_after_loading"):
self.quant_method.process_weights_after_loading(layer)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
if isinstance(layer, RowParallelLinear):
tp_rank = get_tensor_model_parallel_rank()
return self.quant_method.apply(layer, x, bias, tp_rank)
return self.quant_method.apply(layer, x, bias)
class NPU_W8A8MoEMethod:
"""MoE method for NPU quantization.
This class search for specific quantization
implementations supported on NPU hardware for moe methods.
Args:
quant_config: The NPU quantization config.
"""
def __init__(self, quantization_config: W8A8Int8Config) -> None:
self.quantization_config = quantization_config
self.quant_method = self
def create_weights(
self,
layer: torch.nn.Module,
num_experts: int,
hidden_size: int,
intermediate_size: List[int],
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
self.num_experts = num_experts
extra_weight_attrs.update(
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
)
# weight
w13_weight = torch.nn.Parameter(
torch.empty(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
),
requires_grad=False,
)
layer.register_parameter("w13_weight", w13_weight)
set_weight_attrs(w13_weight, extra_weight_attrs)
w2_weight = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
requires_grad=False,
)
layer.register_parameter("w2_weight", w2_weight)
set_weight_attrs(w2_weight, extra_weight_attrs)
# scale
w13_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_scale", w13_weight_scale)
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
w2_weight_scale = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_scale", w2_weight_scale)
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
# offset
w13_weight_offset = torch.nn.Parameter(
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w13_weight_offset", w13_weight_offset)
set_weight_attrs(w13_weight_offset, extra_weight_attrs)
w2_weight_offset = torch.nn.Parameter(
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
requires_grad=False,
)
layer.register_parameter("w2_weight_offset", w2_weight_offset)
set_weight_attrs(w2_weight_offset, extra_weight_attrs)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.w13_weight = Parameter(
layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
)
layer.w2_weight = Parameter(
layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
)
layer.w13_weight_scale = Parameter(
layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
)
layer.w2_weight_scale = Parameter(
layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
)
layer.w13_weight_offset = Parameter(
layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
)
layer.w2_weight_offset = Parameter(
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
)
def apply(
self,
layer,
x,
router_logits,
top_k,
renormalize,
use_grouped_topk,
topk_group,
num_expert_group,
num_fused_shared_experts,
custom_routing_function,
correction_bias,
activation,
apply_router_weight_on_input,
routed_scaling_factor,
**kwargs,
) -> torch.Tensor:
from sglang.srt.layers.moe.topk import select_experts
global_num_experts = router_logits.shape[-1]
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
if global_num_experts == 256:
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
router_logits,
k=top_k,
bias=correction_bias,
k_group=topk_group,
group_count=num_expert_group,
group_select_mode=1,
renorm=0,
norm_type=1,
routed_scaling_factor=1,
eps=float(1e-20),
)
else:
topk_weights, topk_ids = select_experts(
hidden_states=x,
router_logits=router_logits,
use_grouped_topk=use_grouped_topk,
top_k=top_k,
renormalize=renormalize,
topk_group=topk_group,
num_expert_group=num_expert_group,
num_fused_shared_experts=num_fused_shared_experts,
custom_routing_function=custom_routing_function,
correction_bias=correction_bias,
torch_native=True,
routed_scaling_factor=routed_scaling_factor,
)
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
return npu_fused_experts(
hidden_states=x,
w13=layer.w13_weight,
w13_scale=layer.w13_weight_scale,
w2=layer.w2_weight,
w2_scale=layer.w2_weight_scale,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=top_k,
)
......@@ -34,16 +34,18 @@ import torch
import torch.distributed as dist
import triton
import triton.language as tl
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
from sglang.srt.constants import GPU_MEMORY_TYPE_KV_CACHE
from sglang.srt.layers.radix_attention import RadixAttention
from sglang.srt.utils import get_bool_env_var, is_cuda, next_power_of_2
from sglang.srt.utils import get_bool_env_var, is_cuda, is_npu, next_power_of_2
logger = logging.getLogger(__name__)
GB = 1024 * 1024 * 1024
_is_cuda = is_cuda()
_is_npu = is_npu()
if not _is_npu:
from sgl_kernel.kvcacheio import transfer_kv_per_layer, transfer_kv_per_layer_mla
class ReqToTokenPool:
......
......@@ -64,10 +64,13 @@ from sglang.srt.model_loader.weight_utils import (
from sglang.srt.utils import (
get_bool_env_var,
get_device_capability,
is_npu,
is_pin_memory_available,
set_weight_attrs,
)
_is_npu = is_npu()
@contextmanager
def device_loading_context(module: torch.nn.Module, target_device: torch.device):
......@@ -127,18 +130,19 @@ def _get_quantization_config(
# (yizhang2077) workaround for nvidia/Llama-4-Maverick-17B-128E-Eagle3
if quant_config is None:
return None
major, minor = get_device_capability()
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
if not _is_npu:
major, minor = get_device_capability()
if major is not None and minor is not None:
assert 0 <= minor < 10
capability = major * 10 + minor
if capability < quant_config.get_min_capability():
raise ValueError(
f"The quantization method {model_config.quantization} "
"is not supported for the current GPU. "
f"Minimum capability: {quant_config.get_min_capability()}. "
f"Current capability: {capability}."
)
supported_dtypes = quant_config.get_supported_act_dtypes()
if model_config.dtype not in supported_dtypes:
raise ValueError(
......@@ -157,6 +161,13 @@ def _initialize_model(
"""Initialize a model with the given configurations."""
model_class, _ = get_model_architecture(model_config)
packed_modules_mapping = getattr(model_class, "packed_modules_mapping", {})
if _is_npu:
packed_modules_mapping["fused_qkv_a_proj_with_mqa"] = [
"q_a_proj",
"kv_a_proj_with_mqa",
]
packed_modules_mapping["qkv_proj"] = ["q_proj", "k_proj", "v_proj"]
packed_modules_mapping["gate_up_proj"] = ["gate_proj", "up_proj"]
quant_config = _get_quantization_config(
model_config, load_config, packed_modules_mapping
)
......
......@@ -575,6 +575,8 @@ class LlamaForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......
......@@ -407,6 +407,8 @@ class QuantMixtralForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......@@ -418,6 +420,8 @@ class QuantMixtralForCausalLM(nn.Module):
# Skip experts that are not assigned to this worker.
if "block_sparse_moe.experts." in name and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = getattr(param, "weight_loader", default_weight_loader)
weight_loader(param, loaded_weight)
......
......@@ -538,6 +538,8 @@ class Qwen2ForCausalLM(nn.Module):
# Skip loading extra bias for GPTQ models.
if name.endswith(".bias") and name not in params_dict:
continue
if name not in params_dict:
continue
param = params_dict[name]
weight_loader = param.weight_loader
weight_loader(param, loaded_weight, shard_id)
......
......@@ -197,7 +197,7 @@ def get_int_env_var(name: str, default: int = 0) -> int:
def support_triton(backend: str) -> bool:
return backend not in ["torch_native", "intel_amx"]
return backend not in ["torch_native", "intel_amx", "ascend"]
try:
......@@ -2782,3 +2782,101 @@ def lru_cache_frozenset(maxsize=128):
return wrapper
return decorator
def apply_module_patch(target_module, target_function, wrappers):
original_module, original_function = parse_module_path(
target_module, target_function, False
)
original_function_id = id(original_function)
candidate = original_function
for wrapper in wrappers:
candidate = wrapper(candidate)
if target_function is not None:
setattr(original_module, target_function, candidate)
for key, value in sys.modules.copy().items():
if (
target_function is not None
and hasattr(value, target_function)
and id(getattr(value, target_function)) == original_function_id
):
setattr(value, target_function, candidate)
def parse_module_path(module_path, function_name, create_dummy):
from importlib.machinery import ModuleSpec
def create_dummy_module(full_path, parent=None):
"""Create and register a placeholder module"""
dummy = types.ModuleType(full_path)
dummy.__file__ = "vllm_ascend.dummy_module.py"
dummy.__spec__ = ModuleSpec(full_path, None)
sys.modules[full_path] = dummy
if parent:
setattr(parent, full_path.split(".")[-1], dummy)
return dummy
def create_placeholder_function(func_name):
"""Create dummy function that raises when called"""
def placeholder(*args, **kwargs):
raise NotImplementedError(f"Function {func_name} is a placeholder")
placeholder.__name__ = func_name
return placeholder
modules = module_path.split(".")
current_module = None
processed_path = []
for idx, part in enumerate(modules):
current_path = ".".join(modules[: idx + 1])
parent_path = ".".join(modules[:idx]) if idx > 0 else None
try:
current_module = importlib.import_module(current_path)
except ModuleNotFoundError:
# Handle missing module
parent = importlib.import_module(parent_path) if parent_path else None
if parent and hasattr(parent, part):
# Use existing attribute from parent
current_module = getattr(parent, part)
# Check for early function resolution
if function_name and hasattr(current_module, function_name):
return current_module, getattr(current_module, function_name)
if function_name and create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(current_module, function_name, ph_func)
return current_module, ph_func
if function_name:
raise AttributeError(
f"Function {function_name} missing in {current_path}"
)
else:
if not create_dummy:
raise
# Create and register dummy module
current_module = create_dummy_module(
current_path,
parent=(
importlib.import_module(parent_path) if parent_path else None
),
)
processed_path.append(part)
# Final function handling
final_module = sys.modules[module_path]
if function_name is not None:
if not hasattr(final_module, function_name):
if create_dummy:
ph_func = create_placeholder_function(function_name)
setattr(final_module, function_name, ph_func)
else:
setattr(final_module, function_name, None)
return final_module, getattr(final_module, function_name)
return final_module, None
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment