"vscode:/vscode.git/clone" did not exist on "987db374fdd134dd6acf62ab05acbf08adc1c37d"
Unverified Commit 7891bac1 authored by jianan-gu's avatar jianan-gu Committed by GitHub
Browse files

[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with...

[Quantization][w8a8_int8] Fix weight loading issue for w8a8_int8 path with "ignore" layer list in quantization config (#7820)
parent 48c1fa7b
......@@ -347,8 +347,8 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
if use_intel_amx_backend(layer):
from sglang.srt.layers.moe.topk import (
select_experts,
apply_topk_weights_cpu,
select_experts,
)
topk_weights, topk_ids = select_experts(
......
......@@ -3,7 +3,7 @@ from __future__ import annotations
import importlib
import sys
from types import MappingProxyType
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union, cast
import torch
from torch.nn.parameter import Parameter
......@@ -24,6 +24,7 @@ from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.compressed_tensors.utils import should_ignore_layer
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
from sglang.srt.utils import (
......@@ -178,17 +179,18 @@ class W8A8Int8Config(QuantizationConfig):
- Activation: dynamic, per-token, symmetric
"""
def __init__(self, quant_config: Dict[str, Any]):
def __init__(self, quant_config: Dict[str, Any] = {}):
super().__init__()
self.quant_description = quant_config
self.is_dynamic = quant_config.get("is_dynamic", False)
if _is_npu:
if (
"packed_modules_mapping" in quant_config
and quant_config["packed_modules_mapping"] is not None
):
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
ignore = cast(List[str], quant_config.get("ignore", []))
self.ignore = ignore if ignore is not None else []
packed_modules_mapping = quant_config.get("packed_modules_mapping", {})
self.packed_modules_mapping = (
packed_modules_mapping if packed_modules_mapping is not None else {}
)
if _is_npu:
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
for name in self.quant_description.keys():
if "norm.bias" in name:
......@@ -237,7 +239,7 @@ class W8A8Int8Config(QuantizationConfig):
layer: torch.nn.Module,
prefix: str,
) -> Optional[QuantizeMethodBase]:
from sglang.srt.layers.linear import LinearBase
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
if _is_npu:
......@@ -262,12 +264,16 @@ class W8A8Int8Config(QuantizationConfig):
elif isinstance(layer, FusedMoE):
return NPU_W8A8MoEMethod(self)
return None
else:
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
if should_ignore_layer(
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
):
return UnquantizedLinearMethod()
if isinstance(layer, LinearBase):
return W8A8Int8LinearMethod(self)
elif isinstance(layer, FusedMoE):
return W8A8Int8MoEMethod(self)
return None
def is_layer_skipped(
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment