"...git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "76b7d86a9a5c0c2186efa09c4a67b5f5666ac9e3"
Unverified Commit caa4819b authored by Weiwei's avatar Weiwei Committed by GitHub
Browse files

Add support for AutoRound quantized models (#10153)

parent a88b006e
......@@ -40,6 +40,81 @@ python3 -m sglang.launch_server \
### Examples of Offline Model Quantization
#### Using [auto-round](https://github.com/intel/auto-round)
```bash
# Install
pip install auto-round
```
- LLM quantization
```py
# for LLM
from auto_round import AutoRound
model_id = "meta-llama/Llama-3.2-1B-Instruct"
quant_path = "Llama-3.2-1B-Instruct-autoround-4bit"
# Scheme examples: "W2A16", "W3A16", "W4A16", "W8A16", "NVFP4", "MXFP4" (no real kernels), "GGUF:Q4_K_M", etc.
scheme = "W4A16"
format = "auto_round"
autoround = AutoRound(model_id, scheme=scheme)
autoround.quantize_and_save(quant_path, format=format) # quantize and save
```
- VLM quantization
```py
# for VLMs
from auto_round import AutoRoundMLLM
model_name = "Qwen/Qwen2-VL-2B-Instruct"
quant_path = "Qwen2-VL-2B-Instruct-autoround-4bit"
scheme = "W4A16"
format = "auto_round"
autoround = AutoRoundMLLM(model_name, scheme)
autoround.quantize_and_save(quant_path, format=format) # quantize and save
```
- Command Line Usage (Gaudi/CPU/Intel GPU/CUDA)
```bash
auto-round \
--model meta-llama/Llama-3.2-1B-Instruct \
--bits 4 \
--group_size 128 \
--format "auto_round" \
--output_dir ./tmp_autoround
```
- known issues
Several limitations currently affect offline quantized model loading in sglang, These issues might be resolved in future updates of sglang. If you experience any problems, consider using Hugging Face Transformers as an alternative.
1. Mixed-bit Quantization Limitations
Mixed-bit quantization is not fully supported. Due to vLLM's layer fusion (e.g., QKV fusion), applying different bit-widths to components within the same fused layer can lead to compatibility issues.
2. Limited Support for Quantized MoE Models
Quantized MoE models may encounter inference issues due to kernel limitations (e.g., lack of support for mlp.gate layer quantization). please try to skip quantizing these layers to avoid such errors.
3. Limited Support for Quantized VLMs
<details>
<summary>VLM failure cases</summary>
Qwen2.5-VL-7B
auto_round:auto_gptq format: Accuracy is close to zero.
GPTQ format: Fails with:
```
The output size is not aligned with the quantized weight shape
```
auto_round:auto_awq and AWQ format: These work as expected.
</details>
#### Using [GPTQModel](https://github.com/ModelCloud/GPTQModel)
```bash
......@@ -302,3 +377,4 @@ python3 -m sglang.launch_server \
- [NVIDIA Model Optimizer (ModelOpt)](https://github.com/NVIDIA/TensorRT-Model-Optimizer)
- [Torchao: PyTorch Architecture Optimization](https://github.com/pytorch/ao)
- [vLLM Quantization](https://docs.vllm.ai/en/latest/quantization/)
- [auto-round](https://github.com/intel/auto-round)
......@@ -613,6 +613,7 @@ class ModelConfig:
"petit_nvfp4",
"quark",
"mxfp4",
"auto-round",
]
optimized_quantization_methods = [
"fp8",
......
......@@ -33,7 +33,7 @@ except ImportError as e:
ExpertsInt8Config
) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
......@@ -82,6 +82,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"w4afp8": W4AFp8Config,
"petit_nvfp4": PetitNvFp4Config,
"fbgemm_fp8": FBGEMMFp8Config,
"auto-round": AutoRoundConfig,
}
......
# SPDX-License-Identifier: Apache-2.0
import logging
import re
from fractions import Fraction
from typing import Any, Optional, Union
import torch
logger = logging.getLogger(__name__)
from sglang.srt.layers.quantization.utils import get_scalar_types
ScalarType, scalar_types = get_scalar_types()
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
from sglang.srt.layers.quantization.base_config import QuantizationConfig
class AutoRoundConfig(QuantizationConfig):
"""Config class for AutoRound.
Reference: https://arxiv.org/pdf/2309.05516
"""
SUPPORTED_BITS = {2, 3, 4, 8}
SUPPORTED_DTYPES = {"int"}
SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
SUPPORTED_BACKENDS = {"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin"}
def __init__(
self,
weight_bits: int,
group_size: int,
sym: bool = True,
packing_format: str = "auto_round:auto_gptq",
block_name_to_quantize: Optional[Union[str, list[str]]] = None,
extra_config: Optional[dict[str, Any]] = None,
data_type: str = "int",
backend: str = "auto",
) -> None:
super().__init__()
if weight_bits not in self.SUPPORTED_BITS:
raise ValueError(
f"Unsupported weight_bits: {weight_bits}, "
f"currently only support {self.SUPPORTED_BITS}"
)
if data_type not in self.SUPPORTED_DTYPES:
raise ValueError(
f"Unsupported data_type: {data_type},"
f" currently only support {self.SUPPORTED_DTYPES}"
)
if packing_format not in self.SUPPORTED_FORMATS:
raise ValueError(
f"Unsupported packing_format: {packing_format}, "
f"currently only support {self.SUPPORTED_FORMATS}"
)
if backend not in self.SUPPORTED_BACKENDS:
raise ValueError(
f"Unsupported backend: {backend}, "
f"currently only support {self.SUPPORTED_BACKENDS}"
)
self.weight_bits = weight_bits
self.group_size = group_size
self.sym = sym
self.packing_format = packing_format
self.block_name_to_quantize = (
block_name_to_quantize.split(",")
if isinstance(block_name_to_quantize, str)
else block_name_to_quantize
)
self.extra_config = extra_config
self.data_type = data_type
self.backend = backend
self.pack_factor = Fraction(32, weight_bits)
def __repr__(self) -> str:
return (
f"AutoRoundConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, sym={self.sym})"
)
@classmethod
def get_name(cls):
return "auto-round"
@classmethod
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 60
@classmethod
def get_config_filenames(cls) -> list[str]:
return ["quantization_config.json"]
@classmethod
def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
return cls(
weight_bits=cls.get_from_keys(config, ["bits"]),
group_size=cls.get_from_keys(config, ["group_size"]),
sym=cls.get_from_keys(config, ["sym"]),
packing_format=cls.get_from_keys_or(
config,
["packing_format"],
"auto_round:auto_gptq",
),
block_name_to_quantize=cls.get_from_keys_or(
config, ["block_name_to_quantize", "to_quant_block_names"], None
),
extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
backend=cls.get_from_keys_or(
config, ["backend", "vllm_backend", "sglang_backend"], "auto"
),
)
def get_scaled_act_names(self) -> list[str]:
"""Returns the activation function names that should be post-scaled.
For now, this is only used by AWQ.
"""
raise NotImplementedError
def get_layer_config(self, layer, layer_name: str):
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
def get_config(name: str, quantized: bool = True):
if not self.extra_config:
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)
# Exact match first
if name in self.extra_config:
cfg = self.extra_config[name]
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
for pattern, cfg in self.extra_config.items():
if not isinstance(pattern, str) or not any(
c in REGEX_SPECIAL_CHARS for c in pattern
):
continue
try:
if re.fullmatch(pattern, name):
return (
cfg.get("bits", self.weight_bits if quantized else 16),
cfg.get("group_size", self.group_size if quantized else -1),
cfg.get("sym", self.sym if quantized else True),
)
except re.error:
# Invalid regex, ignore.
continue
return (
self.weight_bits if quantized else 16,
self.group_size if quantized else -1,
self.sym if quantized else True,
)
# 1. Exact match from config
if self.extra_config and layer_name in self.extra_config:
return get_config(layer_name)
# 2. Determine whether layer should be quantized
quantized = not isinstance(layer, ParallelLMHead)
if self.block_name_to_quantize:
quantized = any(
layer_name.startswith(name) for name in self.block_name_to_quantize
)
# 3. Handle fused MoE
if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
moe_configs = [
get_config(name, quantized)
for name in self.extra_config
if name.startswith(layer_name)
]
if moe_configs:
if len(set(moe_configs)) == 1:
return moe_configs[0]
raise ValueError(
f"Fused MoE layer '{layer_name}' requires "
f"consistent quant config for all sub-layers"
)
# 4. Handle fused QKV or other patterns
if self.extra_config:
for fusion_key, sub_keys in self.packed_modules_mapping.items():
if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
sub_names = [
layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
]
sub_configs = [get_config(name, quantized) for name in sub_names]
if len(set(sub_configs)) == 1:
return sub_configs[0]
raise ValueError(
f"Fused module '{layer_name}' requires "
f"consistent quant config for {sub_names}"
)
# 5. Fallback or try a regular expression match
return get_config(layer_name, quantized)
def check_quantized(self, weight_bits: int) -> bool:
return weight_bits < 16
def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
AWQ_TYPE_MAP = {
4: scalar_types.uint4,
8: scalar_types.uint8,
}
use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
AWQ_TYPE_MAP[weight_bits], group_size, not sym
)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
from sglang.srt.layers.quantization.awq import (
AWQMarlinConfig,
AWQMarlinLinearMethod,
AWQMoEMethod,
)
quant_args_marlin = AWQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
lm_head_quantized=False,
full_config={},
modules_to_not_convert=[],
)
else:
from sglang.srt.layers.quantization.awq import AWQConfig, AWQLinearMethod
quant_args = AWQConfig(
weight_bits=weight_bits,
group_size=group_size,
zero_point=not sym,
)
if isinstance(layer, FusedMoE):
if use_marlin:
return AWQMoEMethod(quant_args_marlin)
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
config = {
"quant_method": "awq",
"bits": weight_bits,
"group_size": group_size,
"zero_point": not sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return AWQMarlinLinearMethod(quant_args_marlin)
else:
return AWQLinearMethod(quant_args)
return None
def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
from sglang.srt.layers.quantization.marlin_utils import (
check_marlin_supported,
check_moe_marlin_supports_layer,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
if not self.check_quantized(weight_bits):
if isinstance(layer, (LinearBase, ParallelLMHead)):
return UnquantizedLinearMethod()
else:
return None
logger.debug(
"[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
prefix,
layer.__class__.__name__,
weight_bits,
group_size,
sym,
)
if backend == "auto" or "marlin" in backend:
GPTQ_TYPE_MAP = {
(4, True): scalar_types.uint4b8,
(8, True): scalar_types.uint8b128,
}
use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
)
if isinstance(layer, FusedMoE):
use_marlin = use_marlin and check_moe_marlin_supports_layer(
layer, group_size
)
else:
use_marlin = False
if use_marlin:
from sglang.srt.layers.quantization.gptq import (
GPTQMarlinConfig,
GPTQMarlinLinearMethod,
GPTQMarlinMoEMethod,
)
quant_args_marlin = GPTQMarlinConfig(
weight_bits=weight_bits,
group_size=group_size,
is_sym=sym,
lm_head_quantized=False,
desc_act=False,
dynamic={},
full_config={},
)
else:
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQLinearMethod
quant_args = GPTQConfig(
weight_bits=weight_bits,
group_size=group_size,
lm_head_quantized=False,
desc_act=False,
dynamic={},
)
if isinstance(layer, FusedMoE):
if use_marlin:
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
config = {
"quant_method": "gptq",
"bits": weight_bits,
"group_size": group_size,
"sym": sym,
"lm_head": False,
}
return MoeWNA16Config.from_config(config).get_quant_method(
layer, prefix
)
return GPTQMarlinMoEMethod(quant_args_marlin)
if isinstance(layer, (LinearBase, ParallelLMHead)):
if use_marlin:
return GPTQMarlinLinearMethod(quant_args_marlin)
else:
return GPTQLinearMethod(quant_args)
return None
def get_quant_method(self, layer: torch.nn.Module, prefix: str):
# TODO enable CPU quant method later
if "gptq" in self.packing_format or "gptq" in self.backend:
return self.apply_gptq_quant_layer(layer, prefix)
if "awq" in self.packing_format or "awq" in self.backend:
return self.apply_awq_quant_layer(layer, prefix)
......@@ -98,6 +98,7 @@ QUANTIZATION_CHOICES = [
"qoq",
"w4afp8",
"mxfp4",
"auto-round",
"compressed-tensors", # for Ktransformers
]
......
......@@ -92,6 +92,10 @@ DEFAULT_STANDALONE_SPECULATIVE_DRAFT_MODEL_FOR_TEST = "meta-llama/Llama-3.2-1B-I
DEFAULT_NGRAM_SPECULATIVE_TARGET_MODEL_FOR_TEST = "Qwen/Qwen2.5-Coder-7B-Instruct"
# Other use cases
DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST = (
"OPEA/Qwen2.5-0.5B-Instruct-int4-sym-inc", # auto_round:auto_gptq
"Intel/Qwen2-0.5B-Instruct-int4-sym-AutoRound", # auto_round:auto_awq
)
DEFAULT_MODEL_NAME_FOR_TEST_LOCAL_ATTENTION = (
"meta-llama/Llama-4-Scout-17B-16E-Instruct"
)
......
"""
Usage:
python3 -m unittest test_autoround.TestAutoRound.test_mmlu
"""
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
CustomTestCase,
popen_launch_server,
)
class TestAutoRound(CustomTestCase):
@classmethod
def setUpClass(cls):
cls.base_url = DEFAULT_URL_FOR_TEST
@classmethod
def tearDownClass(cls):
pass
def test_mmlu(self):
device = "auto"
for model in DEFAULT_AUTOROUND_MODEL_NAME_FOR_TEST:
with self.subTest(model=model):
print(f"\n[INFO] Launching server for model: {model}")
process = popen_launch_server(
model,
self.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=["--trust-remote-code", "--quantization", "auto-round"],
device=device,
)
try:
args = SimpleNamespace(
base_url=self.base_url,
model=model,
eval_name="mmlu",
num_examples=32,
num_threads=32,
device=device,
)
metrics = run_eval(args)
if "Llama" in model:
self.assertGreaterEqual(metrics["score"], 0.6)
else:
self.assertGreaterEqual(metrics["score"], 0.26)
finally:
kill_process_tree(process.pid)
print(f"[INFO] Server for {model} stopped.")
if __name__ == "__main__":
unittest.main()
......@@ -67,6 +67,7 @@ suites = {
TestFile("quant/test_int8_kernel.py", 8),
TestFile("quant/test_triton_scaled_mm.py", 8),
TestFile("quant/test_w8a8_quantization.py", 46),
TestFile("quant/test_autoround.py", 60),
TestFile("rl/test_fp32_lm_head.py", 30),
TestFile("rl/test_update_weights_from_disk.py", 114),
TestFile("rl/test_update_weights_from_tensor.py", 48),
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment