Unverified Commit 39c237f0 authored by ErvinXie's avatar ErvinXie Committed by GitHub
Browse files

Add AWQ quantization support for NPU. (#10158)


Co-authored-by: default avatarAlisehen <814073252@qq.com>
Co-authored-by: default avatarYaochen Han <48639761+Alisehen@users.noreply.github.com>
Co-authored-by: default avatarZhengda Qin <zhengdqin@gmail.com>
parent 28b8a406
......@@ -51,6 +51,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"CompressedTensorsLinearMethod",
"AWQMarlinLinearMethod",
"AWQLinearMethod",
"AWQLinearAscendMethod",
"GPTQMarlinLinearMethod",
"Fp8LinearMethod",
"BlockInt8LinearMethod",
......
......@@ -31,6 +31,7 @@ from sglang.srt.layers.quantization.marlin_utils import (
)
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
from sglang.srt.layers.quantization.w8a8_int8 import npu_fused_experts
if TYPE_CHECKING:
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
......@@ -39,11 +40,16 @@ if TYPE_CHECKING:
CombineInput,
)
from sglang.srt.utils import is_cuda, is_hip, is_xpu
from sglang.srt.utils import is_cuda, is_hip, is_npu, is_xpu
_is_cuda = is_cuda()
_is_hip = is_hip()
_is_xpu = is_xpu()
_is_npu = is_npu()
if _is_npu:
import torch_npu
if _is_cuda:
from sgl_kernel import (
awq_dequantize,
......@@ -117,12 +123,17 @@ class AWQConfig(QuantizationConfig):
return "awq"
def get_supported_act_dtypes(self) -> List[torch.dtype]:
return [torch.half]
return [torch.float16] if not _is_npu else [torch.float16, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
# The AWQ kernel only supports Turing or newer GPUs.
return 75
if _is_npu:
raise NotImplementedError(
'NPU hardware does not support "get_min_capability" feature.'
)
else:
return 75
@staticmethod
def get_config_filenames() -> List[str]:
......@@ -146,6 +157,16 @@ class AWQConfig(QuantizationConfig):
self, layer: torch.nn.Module, prefix: str
) -> Optional[LinearMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu:
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQLinearAscendMethod(self)
elif isinstance(layer, FusedMoE):
return AWQMoEAscendMethod(self)
return None
if isinstance(layer, LinearBase):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
......@@ -575,6 +596,64 @@ class AWQMarlinLinearMethod(LinearMethodBase):
)
class AWQLinearAscendMethod(AWQLinearMethod):
"""Linear method for AWQ on Ascend.
Args:
quant_config: The AWQ quantization config.
"""
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
qweight_tmp = torch.zeros_like(layer.qweight.data)
qzeros_tmp = layer.qzeros.data
qzeros_list = []
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
for i in range(0, self.quant_config.pack_factor):
shift_num = shifts[i] * 4
qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
qweight_tmp.bitwise_or_(
((layer.qweight.data >> shift_num) * (2 ** (4 * i))) & (0xF << (4 * i))
)
qweight_tmp.bitwise_xor_(0x88888888)
qzeros_tmp = torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1)
qzeros_tmp = -(qzeros_tmp - 8)
qzeros_tmp = qzeros_tmp.to(layer.scales.data.dtype)
layer.qzeros = torch.nn.Parameter(qzeros_tmp, requires_grad=False)
layer.qweight = torch.nn.Parameter(qweight_tmp, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
reshaped_x = x.reshape(-1, x.shape[-1])
if bias is not None and bias.dtype == torch.bfloat16:
bias = bias.float()
out = torch_npu.npu_weight_quant_batchmatmul(
reshaped_x,
qweight,
antiquant_scale=scales,
antiquant_offset=qzeros,
antiquant_group_size=self.quant_config.group_size,
bias=bias,
)
return out.reshape(out_shape)
class AWQMoEMethod(FusedMoEMethodBase):
def __init__(self, quant_config: AWQMarlinConfig):
......@@ -677,7 +756,8 @@ class AWQMoEMethod(FusedMoEMethodBase):
set_weight_attrs(w2_qzeros, extra_weight_attrs)
device = layer.w13_qweight.device
layer.workspace = marlin_make_workspace(device, 4)
if not _is_npu:
layer.workspace = marlin_make_workspace(device, 4)
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
num_experts = layer.w13_qweight.shape[0]
......@@ -785,3 +865,95 @@ class AWQMoEMethod(FusedMoEMethodBase):
num_bits=self.quant_config.weight_bits,
).to(orig_dtype)
return StandardCombineInput(hidden_states=output)
class AWQMoEAscendMethod(AWQMoEMethod):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
w13_qweight_tmp = torch.zeros_like(layer.w13_qweight.data)
w2_qweight_tmp = torch.zeros_like(layer.w2_qweight.data)
w13_qzeros_list = []
w2_qzeros_list = []
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
for i in range(0, self.quant_config.pack_factor):
shift_num = shifts[i] * 4
w13_qzeros_list.append(
(layer.w13_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
)
w2_qzeros_list.append(
(layer.w2_qzeros.data.reshape(-1, 1) >> shift_num) & 0xF
)
w13_qweight_tmp.bitwise_or_(
((layer.w13_qweight.data >> shift_num) * (2 ** (4 * i)))
& (0xF << (4 * i))
)
w2_qweight_tmp.bitwise_or_(
((layer.w2_qweight.data >> shift_num) * (2 ** (4 * i)))
& (0xF << (4 * i))
)
w13_qweight_tmp.bitwise_xor_(0x88888888)
w2_qweight_tmp.bitwise_xor_(0x88888888)
w13_qzeros_tmp = torch.cat(w13_qzeros_list, dim=-1).reshape(
layer.w13_qzeros.shape[0], layer.w13_qzeros.shape[1], -1
)
w13_qzeros_tmp = -(w13_qzeros_tmp - 8)
w13_qzeros_tmp = w13_qzeros_tmp.to(layer.w13_scales.data.dtype)
w2_qzeros_tmp = torch.cat(w2_qzeros_list, dim=-1).reshape(
layer.w2_qzeros.shape[0], layer.w2_qzeros.shape[1], -1
)
w2_qzeros_tmp = -(w2_qzeros_tmp - 8)
w2_qzeros_tmp = w2_qzeros_tmp.to(layer.w2_scales.data.dtype)
layer.register_parameter(
"w13_qzeros", torch.nn.Parameter(w13_qzeros_tmp, requires_grad=False)
)
layer.register_parameter(
"w13_qweight", torch.nn.Parameter(w13_qweight_tmp, requires_grad=False)
)
layer.register_parameter(
"w2_qzeros", torch.nn.Parameter(w2_qzeros_tmp, requires_grad=False)
)
layer.register_parameter(
"w2_qweight", torch.nn.Parameter(w2_qweight_tmp, requires_grad=False)
)
def create_moe_runner(
self, layer: torch.nn.Module, moe_runner_config: MoeRunnerConfig
):
self.moe_runner_config = moe_runner_config
def apply(
self,
layer: torch.nn.Module,
dispatch_output: StandardDispatchOutput,
) -> torch.Tensor:
from sglang.srt.layers.moe.token_dispatcher import StandardCombineInput
assert (
self.moe_runner_config.activation == "silu"
), "Only SiLU activation is supported."
x = dispatch_output.hidden_states
topk_output = dispatch_output.topk_output
topk_weights, topk_ids, _ = topk_output
topk_ids = topk_ids.to(torch.int32)
topk_weights = topk_weights.to(x.dtype)
output = npu_fused_experts(
hidden_states=x,
w13=layer.w13_qweight,
w13_scale=layer.w13_scales,
w13_offset=layer.w13_qzeros,
w2=layer.w2_qweight,
w2_scale=layer.w2_scales,
w2_offset=layer.w2_qzeros,
topk_weights=topk_weights,
topk_ids=topk_ids,
top_k=topk_ids.shape[1],
use_wna16=True,
)
return StandardCombineInput(hidden_states=output)
......@@ -337,3 +337,32 @@ def awq_gemm_triton(
result = result.sum(0)
return result
def awq_dequantize_decomposition(
qweight: torch.Tensor,
scales: torch.Tensor,
zeros: torch.Tensor,
) -> torch.Tensor:
qweight_tmp = qweight
qzeros_tmp = zeros
qweight_list = []
qzeros_list = []
shifts = [0, 4, 1, 5, 2, 6, 3, 7]
for i in range(0, 8):
shift_num = shifts[i] * 4
qzeros_list.append((qzeros_tmp.reshape(-1, 1) >> shift_num) & 0xF)
qweight_list.append((qweight_tmp.reshape(-1, 1) >> shift_num) & 0xF)
qzeros_tmp = (
torch.cat(qzeros_list, dim=-1).reshape(qzeros_tmp.shape[0], -1).to(scales.dtype)
)
qweight_tmp = (
torch.cat(qweight_list, dim=-1)
.reshape(qweight_tmp.shape[0], -1)
.to(scales.dtype)
)
res = (
qweight_tmp.reshape(qzeros_tmp.shape[0], -1, qzeros_tmp.shape[1])
- qzeros_tmp.unsqueeze(1)
) * scales.unsqueeze(1)
return res.reshape(qweight_tmp.shape[0], -1)
......@@ -102,7 +102,12 @@ def npu_fused_experts(
topk_weights: torch.Tensor,
topk_ids: torch.Tensor,
top_k: int,
**kwargs,
):
w13_offset = kwargs.get("w13_offset", None)
w2_offset = kwargs.get("w2_offset", None)
use_wna16 = kwargs.get("use_wna16", False)
original_shape = hidden_states.shape
original_dtype = hidden_states.dtype
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
......@@ -127,12 +132,22 @@ def npu_fused_experts(
)
expert_tokens = expert_tokens.to(torch.int64)
# gmm1: gate_up_proj
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
if not use_wna16:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
scale_args13 = {
"scale": [w13_scale.to(scale_dtype)],
"per_token_scale": [pertoken_scale],
}
else:
scale_args13 = {
"antiquant_scale": [w13_scale],
"antiquant_offset": [w13_offset],
}
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w13],
scale=[w13_scale.to(scale_dtype)],
per_token_scale=[pertoken_scale],
**scale_args13,
split_item=2,
group_list_type=0,
group_type=0,
......@@ -141,13 +156,20 @@ def npu_fused_experts(
)[0]
# act_fn: swiglu
hidden_states = torch_npu.npu_swiglu(hidden_states)
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
if not use_wna16:
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
scale_args2 = {
"scale": [w2_scale.to(scale_dtype)],
"per_token_scale": [pertoken_scale],
}
else:
scale_args2 = {"antiquant_scale": [w2_scale], "antiquant_offset": [w2_offset]}
# gmm2: down_proj
hidden_states = torch_npu.npu_grouped_matmul(
x=[hidden_states],
weight=[w2],
scale=[w2_scale.to(scale_dtype)],
per_token_scale=[pertoken_scale],
**scale_args2,
split_item=2,
group_list_type=0,
group_type=0,
......
......@@ -612,6 +612,8 @@ class DefaultModelLoader(BaseModelLoader):
# parameters onto device for processing and back off after.
with device_loading_context(module, target_device):
quant_method.process_weights_after_loading(module)
if _is_npu:
torch.npu.empty_cache()
class LayeredModelLoader(DefaultModelLoader):
......
......@@ -189,6 +189,10 @@ elif _is_npu:
import custom_ops # noqa: F401
import sgl_kernel_npu # noqa: F401
import torch_npu # noqa: F401
from sglang.srt.layers.quantization.awq_triton import (
awq_dequantize_decomposition as awq_dequantize,
)
else:
pass
......@@ -2965,7 +2969,7 @@ class DeepseekV2ForCausalLM(nn.Module):
)
if hasattr(self_attn.kv_b_proj, "qweight"):
# AWQ compatible
if _is_cuda or _is_hip:
if _is_cuda or _is_hip or _is_npu:
w = awq_dequantize(
self_attn.kv_b_proj.qweight,
self_attn.kv_b_proj.scales,
......
......@@ -510,6 +510,8 @@ def get_available_gpu_memory(
f"WARNING: current device is not {gpu_id}, but {torch.npu.current_device()}, ",
"which may cause useless memory allocation for torch NPU context.",
)
if empty_cache:
torch.npu.empty_cache()
free_gpu_memory, total_gpu_memory = torch.npu.mem_get_info()
if distributed:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment