Unverified Commit 7891bac1 authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with...

[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820)
parent 48c1fa7b
...@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp): ...@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if use_intel_amx_backend(layer): if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import ( from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu, apply_topk_weights_cpu,
select_experts,
) )
topk_weights, topk_ids = select_experts( topk_weights, topk_ids = select_experts(
......
...@@ -3,7 +3,7 @@ from __future__ import annotations ...@@ -3,7 +3,7 @@ from __future__ import annotations
import importlib import importlib
import sys import sys
from types import MappingProxyType from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch import torch
from torch.nn.parameter import Parameter from torch.nn.parameter import Parameter
...@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import ( ...@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig, QuantizationConfig,
QuantizeMethodBase, QuantizeMethodBase,
) )
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8 from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import ( from sglang.srt.utils import (
...@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig):
- Activation: dynamic, per-token, symmetric - Activation: dynamic, per-token, symmetric
""" """
def __init__(self, quant_config: Dict[str, Any]): def __init__(self, quant_config: Dict[str, Any] = {}):
super().__init__() super().__init__()
self.quant_description = quant_config self.quant_description = quant_config
self.is_dynamic = quant_config.get("is_dynamic", False) self.is_dynamic = quant_config.get("is_dynamic", False)
if _is_npu: ignore = cast(List[str], quant_config.get("ignore", []))
if ( self.ignore = ignore if ignore is not None else []
"packed_modules_mapping" in quant_config packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
and quant_config["packed_modules_mapping"] is not None self.packed_modules_mapping = (
): packed_modules_mapping if packed_modules_mapping is not None else {}
self.packed_modules_mapping = quant_config["packed_modules_mapping"] )
if _is_npu:
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
for name in self.quant_description.keys(): for name in self.quant_description.keys():
if "norm.bias" in name: if "norm.bias" in name:
...@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module, layer: torch.nn.Module,
prefix: str, prefix: str,
) -> Optional[QuantizeMethodBase]: ) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu: if _is_npu:
...@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig): ...@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig):
elif isinstance(layer, FusedMoE): elif isinstance(layer, FusedMoE):
return NPU_W8A8MoEMethod(self) return NPU_W8A8MoEMethod(self)
return None return None
else:
if isinstance(layer, LinearBase): if should_ignore_layer(
return W8A8Int8LinearMethod(self) prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
elif isinstance(layer, FusedMoE): ):
return W8A8Int8MoEMethod(self) return UnquantizedLinearMethod()
return None if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def is_layer_skipped( def is_layer_skipped(
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({}) self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment