Unverified Commit 5cc6bddb authored by Xiangyu Li's avatar Xiangyu Li Committed by GitHub
Browse files

[Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by...

[Kernel] Add GPTQv2 format support for low-bit or asymmetric quantization, by adapting gptq_gemm (#26092)
parent 1f9460c4
......@@ -307,7 +307,7 @@ void dynamic_scaled_int8_quant(torch::Tensor& out, torch::Tensor const& input,
torch::Tensor gptq_gemm(torch::Tensor a, torch::Tensor b_q_weight,
torch::Tensor b_gptq_qzeros,
torch::Tensor b_gptq_scales, torch::Tensor b_g_idx,
bool use_exllama, int64_t bit);
bool use_exllama, bool use_v2_format, int64_t bit);
void gptq_shuffle(torch::Tensor q_weight, torch::Tensor q_perm, int64_t bit);
......
This diff is collapsed.
......@@ -557,7 +557,8 @@ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
// to prevent the meta function registry.
ops.def(
"gptq_gemm(Tensor a, Tensor b_q_weight, Tensor b_gptq_qzeros, "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, int bit) "
"Tensor b_gptq_scales, Tensor b_g_idx, bool use_exllama, bool "
"use_v2_format, int bit) "
"-> Tensor",
{stride_tag});
ops.impl("gptq_gemm", torch::kCUDA, &gptq_gemm);
......
......@@ -26,4 +26,10 @@ def test_gptq_gemm_opcheck():
idx = torch.empty((0,), device="cuda", dtype=torch.int32)
use_exllama = True
bit = 4
opcheck(torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, bit))
# Test both GPTQv1 and GPTQv2 format
opcheck(
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, True, bit)
)
opcheck(
torch.ops._C.gptq_gemm, (a, weight, zeros, scales, idx, use_exllama, False, bit)
)
# SPDX-License-Identifier: Apache-2.0
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
"""Tests whether vllm correctly load and run gptq_v2 format checkpoints.
Run `pytest tests/quantization/test_gptq_v2.py --forked`.
"""
import pytest
import torch
from transformers import AutoTokenizer
from vllm import SamplingParams
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
# A dummy small model quantized by GPTQModel, stored in GPTQ v2 format
MODELS = ["XXXXyu/Qwen3-1.7B-w2g64-gptq_v2"]
# Generate multiple sequences for testing, because an 1.7B 2-bit model
# cannot always generate normal texts.
N_SEQ = 5
@pytest.mark.parametrize("model_id", MODELS)
def test_model_load(vllm_runner, model_id, monkeypatch):
# `LLM.apply_model` requires pickling a function.
monkeypatch.setenv("VLLM_ALLOW_INSECURE_SERIALIZATION", "1")
# Only check the default GPTQ linear method (used for 2/3-bit models).
# 4/8-bit linear methods like Marlin already support gptq_v2.
linear_method_cls = GPTQLinearMethod
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
def check_model(model_id):
for name, submodule in model_id.named_modules():
# Could check more modules if necessary
if name == "model_id.layers.0.self_attn.qkv_proj":
assert isinstance(submodule.quant_method, linear_method_cls)
config = submodule.quant_method.quant_config
assert config.checkpoint_format == "gptq_v2"
assert submodule.quant_method.use_v2_format
# Just break since currently we only check 1 module
break
# Check if gptq_v2 format is correctly loaded
llm.apply_model(check_model)
@pytest.mark.parametrize("model_id", MODELS)
def test_model_inference(vllm_runner, model_id):
# Prepare prompt to test the model's generation result.
prompt = "What is the meaning of life?"
messages = [
{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": prompt},
]
tokenizer = AutoTokenizer.from_pretrained(model_id)
text = tokenizer.apply_chat_template(
messages,
tokenize=False,
add_generation_prompt=True,
enable_thinking=False, # If thinking model, set it to false
)
sampling_params = SamplingParams(
n=N_SEQ,
max_tokens=128,
temperature=0.7,
top_p=0.8,
top_k=20,
min_p=0,
presence_penalty=2,
)
with vllm_runner(model_id, dtype=torch.float16, max_model_len=512) as llm:
# Generate a response to verify inference correctness
output = llm.generate(text, sampling_params)
# Make sure the output exists
assert output
assert output[0][1]
assert len(output[0][1]) == N_SEQ
def has_normal_char_distribution(texts, min_len):
for text in texts:
# Response too short
if len(text) < min_len:
return False
# Basic ratio checks
letters = sum(c.isalpha() for c in text)
spaces = sum(c.isspace() for c in text)
total = len(text)
letter_ratio = letters / total
space_ratio = spaces / total
# At least 1 normal text should exist within output sequences
# Normal text should be mostly letters with reasonable spacing
# Some magic numbers, could be adjusted
if 0.5 <= letter_ratio <= 0.9 and 0.01 <= space_ratio <= 0.3:
return True
# No sequence contains normal text, output might be broken
return False
# Apply some simple checks for giberish output
# Print the output sequences if failed
assert has_normal_char_distribution(output[0][1], 5), output[0][1]
......@@ -451,10 +451,18 @@ def gptq_gemm(
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_exllama: bool,
use_v2_format: bool,
bit: int,
) -> torch.Tensor:
return torch.ops._C.gptq_gemm(
a, b_q_weight, b_gptq_qzeros, b_gptq_scales, b_g_idx, use_exllama, bit
a,
b_q_weight,
b_gptq_qzeros,
b_gptq_scales,
b_g_idx,
use_exllama,
use_v2_format,
bit,
)
......@@ -468,6 +476,7 @@ if hasattr(torch.ops._C, "gptq_gemm"):
b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor,
use_exllama: bool,
use_v2_format: bool,
bit: int,
) -> torch.Tensor:
return torch.empty(
......
......@@ -11,6 +11,7 @@ from safetensors.torch import _TYPES as _SAFETENSORS_TO_TORCH_DTYPE
from torch.nn.parameter import Parameter
from vllm import _custom_ops as ops
from vllm.logger import init_logger
from vllm.model_executor.layers.fused_moe.layer import FusedMoE
from vllm.model_executor.layers.linear import LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
......@@ -36,6 +37,8 @@ if TYPE_CHECKING:
else:
QuantizationMethods = str
logger = init_logger(__name__)
class GPTQConfig(QuantizationConfig):
"""Config class for GPTQ.
......@@ -52,6 +55,7 @@ class GPTQConfig(QuantizationConfig):
dynamic: dict[str, dict[str, int | bool]],
autoround_version: str = "",
modules_in_block_to_quantize: list[str] | None = None,
checkpoint_format: str = "",
) -> None:
# GPTQModel use `dynamic` config property to allow per module
# quantization config so each module can be individually optimized.
......@@ -89,12 +93,24 @@ class GPTQConfig(QuantizationConfig):
"Currently, only 2/3/4/8-bit weight quantization is "
f"supported for GPTQ, but got {self.weight_bits} bits."
)
# Somehow gptq_gemm 4-bit is buggy, maybe fix it in the future.
# For now, show a warning, since gptq_marlin will be used by default.
if self.weight_bits == 4:
logger.warning_once(
"Currently, the 4-bit gptq_gemm kernel for GPTQ is buggy. "
"Please switch to gptq_marlin or gptq_bitblas."
)
self.modules_in_block_to_quantize = modules_in_block_to_quantize or []
# used to identify GPTQ model quantized by autoround
self.autoround_version = autoround_version
# GPTQ v1 and v2 format deals with zero points differently.
# Currently GPTQModel stores v1 format checkpoints by default,
# but provides the option to set `format="gptq_v2"` in `QuantizeConfig`.
self.checkpoint_format = checkpoint_format
def __repr__(self) -> str:
return (
f"GPTQConfig(weight_bits={self.weight_bits}, "
......@@ -102,7 +118,8 @@ class GPTQConfig(QuantizationConfig):
f"desc_act={self.desc_act}), "
f"lm_head_quantized={self.lm_head_quantized}, "
f"dynamic={self.dynamic}, "
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize})"
f"modules_in_block_to_quantize={self.modules_in_block_to_quantize}), "
f"checkpoint_format={self.checkpoint_format})"
)
@classmethod
......@@ -137,6 +154,9 @@ class GPTQConfig(QuantizationConfig):
modules_in_block_to_quantize = cls.get_from_keys_or(
config, ["modules_in_block_to_quantize"], default=None
)
checkpoint_format = cls.get_from_keys_or(
config, ["checkpoint_format"], default=""
)
return cls(
weight_bits,
group_size,
......@@ -145,6 +165,7 @@ class GPTQConfig(QuantizationConfig):
dynamic,
autoround_version,
modules_in_block_to_quantize,
checkpoint_format,
)
def get_quant_method(
......@@ -154,6 +175,7 @@ class GPTQConfig(QuantizationConfig):
# GPTQ MoE support: fall back to MoeWNA16 for broad compatibility
from .moe_wna16 import MoeWNA16Config
# TODO: maybe update this for GPTQv2 format checkpoints
config = {
"quant_method": "gptq",
"bits": self.weight_bits,
......@@ -210,6 +232,9 @@ class GPTQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: GPTQConfig):
self.quant_config = quant_config
# GPTQ v1 and v2 format deals with zero points differently
self.use_v2_format = quant_config.checkpoint_format == "gptq_v2"
def create_weights(
self,
layer: torch.nn.Module,
......@@ -351,6 +376,8 @@ class GPTQLinearMethod(LinearMethodBase):
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
reshaped_x = x.reshape(-1, x.shape[-1])
# GPTQ v1 and v2 format checkpoints deals with zero points differently,
# and require different gemm kernels.
output = ops.gptq_gemm(
reshaped_x,
layer.qweight,
......@@ -358,6 +385,7 @@ class GPTQLinearMethod(LinearMethodBase):
layer.scales,
layer.g_idx,
layer.exllama_state == ExllamaState.READY,
self.use_v2_format,
self.quant_config.weight_bits,
)
if bias is not None:
......
......@@ -145,10 +145,15 @@ class ExllamaLinearKernel(MPLinearKernel):
w_q, w_s, w_zp, w_g_idx = self._get_weight_params(layer)
# gptq_gemm supports GPTQv2 format by passing use_v2_format=True.
# However, the MPLinearLayerConfig doesn't contain format info.
# So hardcode GPTQv1 format here, to keep its behavior unchanged.
use_v2_format = False
assert w_zp is not None, "Zero points are required by Exllama"
assert w_g_idx is not None, "Group index is required by Exllama"
output = ops.gptq_gemm(
x_2d, w_q, w_zp, w_s, w_g_idx, True, c.weight_type.size_bits
x_2d, w_q, w_zp, w_s, w_g_idx, True, use_v2_format, c.weight_type.size_bits
)
if bias is not None:
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment