Unverified Commit 17de02f9 authored by bjmsong's avatar bjmsong Committed by GitHub
Browse files

Integration of TurboMind AWQ (#2828)


Co-authored-by: default avatarroot <bjmsong@126.com>
parent 51ab3ccf
......@@ -28,7 +28,7 @@ runtime_common = [
srt = [
"sglang[runtime_common]", "cuda-python",
"sgl-kernel>=0.0.2.post11", "torch", "vllm>=0.6.3.post1,<=0.6.4.post1",
"flashinfer==0.1.6"
"flashinfer==0.1.6", "turbomind"
]
# HIP (Heterogeneous-computing Interface for Portability) for AMD
......
......@@ -14,6 +14,7 @@
import json
import logging
import sys
from enum import IntEnum, auto
from typing import List, Optional, Set, Union
......@@ -230,7 +231,7 @@ class ModelConfig:
# Parse quantization method from the HF model config, if available.
quant_cfg = self._parse_quant_hf_config()
if quant_cfg is not None:
if quant_cfg is not None and not quantization_in_turbomind(self.quantization):
quant_method = quant_cfg.get("quant_method", "").lower()
# Detect which checkpoint is it
......@@ -401,3 +402,10 @@ def is_multimodal_model(model_architectures: List[str]):
def is_encoder_decoder_model(model_architectures: List[str]):
return "MllamaForConditionalGeneration" in model_architectures
def quantization_in_turbomind(quantization: str) -> bool:
if quantization in ["awq_turbomind"]:
return True
else:
return False
......@@ -48,6 +48,7 @@ WEIGHT_LOADER_V2_SUPPORTED = [
"GPTQLinearMethod",
"FBGEMMFp8LinearMethod",
"ModelOptFp8LinearMethod",
"AWQTurbomindLinearMethod",
"IPEXAWQLinearMethod",
]
......
......@@ -20,6 +20,7 @@ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
from vllm.model_executor.layers.quantization.qqq import QQQConfig
from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
from sglang.srt.layers.quantization.awq_turbomind import AWQTurbomindConfig
from sglang.srt.layers.quantization.base_config import QuantizationConfig
from sglang.srt.layers.quantization.fp8 import Fp8Config
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
......@@ -37,6 +38,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
"gptq_marlin_24": GPTQMarlin24Config,
"gptq_marlin": GPTQMarlinConfig,
"awq_marlin": AWQMarlinConfig,
"awq_turbomind": AWQTurbomindConfig,
"gptq": GPTQConfig,
"compressed-tensors": CompressedTensorsConfig,
"bitsandbytes": BitsAndBytesConfig,
......
import logging
import os
import sys
from typing import Any, Dict, List, Optional
import torch
import turbomind
from torch.nn import Parameter
turbomind_dir = os.path.split(turbomind.__file__)[0]
sys.path.append(os.path.join(turbomind_dir, "lib"))
import _turbomind_ext
from vllm.model_executor.layers.linear import LinearBase
from sglang.srt.layers.linear import LinearMethodBase, UnquantizedLinearMethod
from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
from sglang.srt.layers.quantization.base_config import (
QuantizationConfig,
QuantizeMethodBase,
)
from sglang.srt.layers.quantization.turbomind_utils import (
get_u4_slices,
is_layer_skipped_awq,
pack_u4_row,
unpack_awq_gemm,
verify_turbomind_supported,
)
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
from sglang.srt.utils import is_cuda, set_weight_attrs
logger = logging.getLogger(__name__)
class AWQTurbomindConfig(QuantizationConfig):
"""Config class for AWQ Turbomind"""
def __init__(
self,
weight_bits: int,
group_size: int,
zero_point: bool,
lm_head_quantized: bool,
modules_to_not_convert: Optional[List[str]] = None,
) -> None:
self.pack_factor = 32 // weight_bits # packed into int32
self.group_size = group_size
self.zero_point = zero_point
self.lm_head_quantized = lm_head_quantized
self.weight_bits = weight_bits
self.modules_to_not_convert = modules_to_not_convert or []
verify_turbomind_supported(self.weight_bits, self.group_size)
def __repr__(self) -> str:
return (
f"AWQTurbomindConfig(weight_bits={self.weight_bits}, "
f"group_size={self.group_size}, "
f"zero_point={self.zero_point}, "
f"lm_head_quantized={self.lm_head_quantized}, "
f"modules_to_not_convert={self.modules_to_not_convert})"
)
@classmethod
def get_name(cls) -> str:
return "awq_turbomind"
@classmethod
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
return [torch.half, torch.bfloat16]
@classmethod
def get_min_capability(cls) -> int:
return 70
@classmethod
def get_config_filenames(cls) -> List[str]:
return ["quantize_config.json"]
@classmethod
def from_config(cls, config: Dict[str, Any]) -> "AWQTurbomindConfig":
weight_bits = cls.get_from_keys(config, ["bits"])
group_size = cls.get_from_keys(config, ["group_size"])
zero_point = cls.get_from_keys(config, ["zero_point"])
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
modules_to_not_convert = cls.get_from_keys_or(
config, ["modules_to_not_convert"], None
)
return cls(
weight_bits,
group_size,
zero_point,
lm_head_quantized,
modules_to_not_convert,
)
@classmethod
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
can_convert = cls.is_awq_turbomind_compatible(hf_quant_cfg)
is_valid_user_quant = user_quant is None or user_quant == "awq_turbomind"
if can_convert and is_valid_user_quant:
msg = f"The model is convertible to {cls.get_name()} during runtime. Using {cls.get_name()} kernel."
logger.info(msg)
return cls.get_name()
if can_convert and user_quant == "awq":
logger.info(
"Detected that the model can run with awq_turbomind"
", however you specified quantization=awq explicitly,"
" so forcing awq. Use quantization=awq_turbomind for"
" faster inference"
)
return None
def get_quant_method(
self, layer: torch.nn.Module, prefix: str
) -> Optional["QuantizeMethodBase"]:
if isinstance(layer, LinearBase) or (
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
):
if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
return UnquantizedLinearMethod()
return AWQTurbomindLinearMethod(self)
return None
@classmethod
def is_awq_turbomind_compatible(cls, quant_config: Dict[str, Any]):
if not is_cuda():
return False
# Extract data from quant config.
quant_method = quant_config.get("quant_method", "").lower()
num_bits = quant_config.get("bits")
group_size = quant_config.get("group_size")
zero_point = quant_config.get("zero_point")
if quant_method != "awq":
return False
# If we cannot find the info needed in the config, cannot convert.
if num_bits is None or group_size is None or zero_point is None:
return False
return verify_turbomind_supported(quant_bit=num_bits, group_size=group_size)
def get_scaled_act_names(self) -> List[str]:
return []
class AWQTurbomindLinearMethod(LinearMethodBase):
"""Linear method for AWQ Turbomind.
Args:
quant_config: The AWQ Turbomind quantization config.
"""
def __init__(self, quant_config: AWQTurbomindConfig) -> None:
self.quant_config = quant_config
def create_weights(
self,
layer: torch.nn.Module,
input_size_per_partition: int,
output_partition_sizes: List[int],
input_size: int,
output_size: int,
params_dtype: torch.dtype,
**extra_weight_attrs,
) -> None:
output_size_per_partition = sum(output_partition_sizes)
weight_loader = extra_weight_attrs.get("weight_loader")
# Normalize group_size
if self.quant_config.group_size != -1:
group_size = self.quant_config.group_size
else:
group_size = input_size
qweight = PackedvLLMParameter(
data=torch.empty(
input_size_per_partition,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
num_groups = input_size_per_partition // group_size
qzeros = PackedvLLMParameter(
data=torch.empty(
num_groups,
output_size_per_partition // self.quant_config.pack_factor,
dtype=torch.int32,
),
input_dim=0,
output_dim=1,
packed_dim=1,
packed_factor=self.quant_config.pack_factor,
weight_loader=weight_loader,
)
scales = GroupQuantScaleParameter(
data=torch.empty(
num_groups,
output_size_per_partition,
dtype=params_dtype,
),
input_dim=0,
output_dim=1,
weight_loader=weight_loader,
)
layer.register_parameter("qweight", qweight)
layer.register_parameter("qzeros", qzeros)
layer.register_parameter("scales", scales)
layer.input_size_per_partition = input_size_per_partition
layer.output_size_per_partition = output_size_per_partition
layer.num_groups = num_groups
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
qweight_turbomind = unpack_awq_gemm(layer.qweight.data)
qzeros_turbomind = unpack_awq_gemm(layer.qzeros.data)
scales_turbomind = layer.scales.data
qweight_turbomind = pack_u4_row(qweight_turbomind)
qzeros_turbomind = qzeros_turbomind.to(torch.half)
device_id = layer.qweight.device.index
properties = torch.cuda.get_device_properties(device_id)
def is_16xx_series(name):
import re
pattern = r"GTX 16\d\d"
return bool(re.search(pattern, name))
simt = is_16xx_series(properties.name)
qweight_turbomind = qweight_turbomind.contiguous()
scales_turbomind = scales_turbomind.contiguous()
qzeros_turbomind = qzeros_turbomind.contiguous()
self.linear = _turbomind_ext.Linear(
layer.input_size_per_partition,
layer.output_size_per_partition,
self.quant_config.weight_bits,
self.quant_config.group_size,
)
self.linear.post_init(
qweight_turbomind, scales_turbomind, qzeros_turbomind, simt
)
layer.qweight = Parameter(qweight_turbomind, requires_grad=False)
layer.scales = Parameter(scales_turbomind, requires_grad=False)
layer.qzeros = Parameter(qzeros_turbomind, requires_grad=False)
def apply(
self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None,
) -> torch.Tensor:
x = x.view(-1, x.shape[-1])
out_shape = x.shape[:-1] + (layer.output_size_per_partition,)
out = torch.empty(
(x.shape[0], layer.output_size_per_partition),
dtype=torch.float16,
device=x.device,
)
stream = torch.cuda.current_stream()
self.linear.forward(x, out, stream.cuda_stream)
out = torch.from_dlpack(out)
if bias is not None:
out.add_(bias)
return out.view(out_shape)
# Copyright (c) OpenMMLab. All rights reserved.
from typing import List
import torch
from sglang.srt.utils import get_device_capability
def get_u4_slices(x: torch.Tensor, dtype: torch.dtype) -> List[torch.Tensor]:
assert x.dtype == torch.int32
xs = []
for _ in range(8):
xs.append((x & 15).to(dtype))
x = x >> 4
return xs
def unpack_awq_gemm(x: torch.Tensor) -> torch.Tensor:
"""
The int4 weights are packed into int32:
bit: 31-28 27-24 23-20 19-16 15-12 11-8 7-4 3-0
weight: int4_1 int4_2 int4_3 int4_4 int4_5 int4_6 int4_7 int4_8
"""
xs = get_u4_slices(x, torch.uint8)
order = [0, 4, 1, 5, 2, 6, 3, 7]
ys = [xs[i] for i in order]
return torch.stack(ys, dim=-1).view(*x.shape[:-1], -1)
def pack_u4_row(x: torch.Tensor) -> torch.Tensor:
assert x.dtype == torch.uint8
xs = x.view(*x.shape[:-1], -1, 8).split(1, dim=-1)
a = torch.zeros(xs[0].shape, dtype=torch.int32, device=x.device)
for t in reversed(xs):
a = (a << 4) | t
return a.squeeze(dim=-1)
def verify_turbomind_supported(quant_bit: int, group_size: int) -> bool:
if quant_bit not in [4]:
raise NotImplementedError(
f"[Tubomind] Only 4-bit is supported for now, but got {quant_bit} bit"
)
if group_size != 128:
raise NotImplementedError(
f"[Tubomind] Only group_size 128 is supported for now, "
f"but got group_size {group_size}"
)
major, minor = get_device_capability()
capability = major * 10 + minor
if capability < 70:
raise NotImplementedError(
f"[Tubomind] Only capability >= 70 is supported for now, but got {capability}"
)
return True
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
return any(module_name in prefix for module_name in modules_to_not_convert)
......@@ -375,6 +375,7 @@ class ServerArgs:
"marlin",
"gptq_marlin",
"awq_marlin",
"awq_turbomind",
"bitsandbytes",
"gguf",
"modelopt",
......
import unittest
from types import SimpleNamespace
from sglang.srt.utils import kill_process_tree
from sglang.test.run_eval import run_eval
from sglang.test.test_utils import (
DEFAULT_MODEL_NAME_FOR_TEST,
DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
DEFAULT_URL_FOR_TEST,
popen_launch_server,
)
class TestMLA(unittest.TestCase):
@classmethod
def setUpClass(cls):
cls.model = DEFAULT_MODEL_NAME_FOR_TEST
cls.base_url = DEFAULT_URL_FOR_TEST
cls.process = popen_launch_server(
cls.model,
cls.base_url,
timeout=DEFAULT_TIMEOUT_FOR_SERVER_LAUNCH,
other_args=[
"--quantization",
"awq_turbomind",
],
)
@classmethod
def tearDownClass(cls):
kill_process_tree(cls.process.pid)
def test_mgsm_en(self):
args = SimpleNamespace(
base_url=self.base_url,
model=self.model,
eval_name="mgsm_en",
num_examples=None,
num_threads=1024,
)
metrics = run_eval(args)
assert metrics["score"] >= 0.5
if __name__ == "__main__":
unittest.main()
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment