Unverified Commit c15fbf47 authored by pppppM's avatar pppppM Committed by GitHub
Browse files

[Fix] Qwen's quantization results are abnormal & Baichuan cannot be quantized (#605)

* fix awq

* adapt new qwen code

* adapt qwen 14b and baichuan2 7b

* add docstring

* add runtime error for qwen
parent 15d1cc2e
......@@ -15,13 +15,15 @@ from lmdeploy.lite.utils import collect_target_modules
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
}
NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm',
'QWenLMHeadModel': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaRMSNorm',
}
......@@ -40,6 +42,9 @@ def auto_awq(model: str,
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
checkpoint = hf_config._name_or_path
# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True
with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
......@@ -61,11 +66,14 @@ def auto_awq(model: str,
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model, checkpoint, device_map)
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)
work_dir = Path(work_dir)
act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmean']
act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmax']
layers = collect_target_modules(model, layer_type)
fcs = {}
for l_name, layer in layers.items():
......
# Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path
from typing import Union
import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from lmdeploy.lite.quantization import CalibrationContext
......@@ -13,17 +15,90 @@ from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders
LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer',
'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer',
}
NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm',
'QWenLMHeadModel': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaRMSNorm',
}
def _prepare_for_calibrate(model: nn.Module,
layer_type: Union[str, type],
head_name: str = 'lm_head',
device: str = 'cuda',
prefix: str = '') -> None:
"""Prepare the model for calibration by moving specific modules to CPU.
This function goes through each child of a given model and checks whether
it is an instance of a certain layer type or has the name equal to
`head_name`.
If yes, it moves the module to CPU, otherwise to the specified device
(default is CUDA).
If the child contains the target layer type in its sub-modules, the
function performs the same operation recursively.
Parameters
----------
model : nn.Module
The PyTorch model to prepare for calibration.
layer_type : Union[str, Type]
The type of the layer to be moved to CPU. Can be either a string of
class name or the class type itself.
head_name : str, optional
The name of the module to be moved to CPU. Default is 'lm_head'.
device : str, optional
The device to which modules not matching the `layer_type` or
`head_name` will be moved. Default is 'cuda'.
prefix : str, optional
The prefix used when printing the names of the moved modules.
Default is ''.
Raises
------
TypeError
If `layer_type` is neither a string nor a type.
"""
for name, child in model.named_children():
# Check if the child is an instance of the given layer type
if isinstance(layer_type, str):
is_layer = type(child).__name__ == layer_type
elif isinstance(layer_type, type):
is_layer = isinstance(child, layer_type)
else:
raise TypeError(
'layer_type should be a string (class name) or a type')
# Check if the child contains the target module type
contain_layer = len(
collect_target_modules(child, layer_type, [head_name]).keys()) > 0
# Check if the child matches the head name
is_head = name == head_name
mod_name = f'{prefix}.{name}' if prefix else name
# If the child is either an instance of the layer type or has the
# head name, move it to CPU, otherwise move it to the specified device
if is_layer or is_head:
child.to('cpu')
print(f'Move {mod_name} to CPU.')
elif contain_layer:
_prepare_for_calibrate(child, layer_type, head_name, device,
mod_name)
else:
child.to(device)
print(f'Move {mod_name} to GPU.')
def calibrate(model: str,
calib_dataset: str = 'c4',
calib_samples: int = 128,
......@@ -54,16 +129,38 @@ def calibrate(model: str,
tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False,
trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
checkpoint = hf_config._name_or_path
# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True
with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model,
config=hf_config,
torch_dtype=torch.float16,
trust_remote_code=True)
model.config.use_cache = False
model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(
f'Currently, quantification and calibration of {model_type} are '
f'not supported. The supported model types are '
f"{', '.join(LAYER_TYPE_MAP.keys())}.")
if model_type == 'QWenLMHeadModel':
try:
import flash_attn # noqa: F401
except ImportError:
raise RuntimeError(
'When using Qwen, you need to `pip install flash-attn` first, '
'otherwise calibration and quantification will not work '
'properly.')
layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__]
......@@ -77,7 +174,12 @@ def calibrate(model: str,
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model, checkpoint, device_map)
load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)
_prepare_for_calibrate(model, layer_type, 'lm_head', device)
print('Loading calibrate dataset ...')
calib_loader, _ = get_calib_loaders(calib_dataset,
......
......@@ -18,6 +18,10 @@ NORM_FCS_MAP = {
'QWenBlock': {
'ln_1': ['attn.c_attn'],
'ln_2': ['mlp.w1', 'mlp.w2']
},
'DecoderLayer': {
'input_layernorm': ['self_attn.W_pack'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
}
}
......@@ -33,6 +37,10 @@ FC_FCS_MAP = {
'QWenBlock': {
'attn.c_attn': ['attn.c_proj'],
'mlp.w1': ['mlp.c_proj']
},
'DecoderLayer': {
'self_attn.W_pack': ['self_attn.o_proj'],
'mlp.up_proj': ['mlp.down_proj']
}
}
......@@ -69,7 +77,7 @@ def smooth_ln_fcs(ln: torch.nn.Module,
w_scales = get_weight_scale(concat_w, group_size)
scales = (act_scales.pow(alpha) /
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
w_scales.pow(1 - alpha)).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt()
ln.weight.div_(scales)
......@@ -116,10 +124,10 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
w_scales = get_weight_scale(concat_w, group_size)
scales = (act_scales.pow(alpha) /
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
w_scales.pow(1 - alpha)).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt()
# (for qwen) pre_fc is packed QKV, only V needs to scale
# (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale
if size_pre_fc > size_a and size_pre_fc % size_a == 0 \
and size_pre_fc // size_a == 3:
......
......@@ -8,7 +8,7 @@ from lmdeploy.lite.utils import (QParams, cal_qparams_per_channel_absmax,
cal_qparams_per_group_absmax,
cal_qparams_per_group_minmax,
cal_qparams_per_tensor_absmax,
cal_qparams_per_tensor_minmax)
cal_qparams_per_tensor_minmax, precise_round)
from lmdeploy.lite.utils.global_avail import GlobalAvailMixin
......@@ -119,8 +119,10 @@ class WeightQuantizer(GlobalAvailMixin):
torch.Tensor: The fake quantized weight tensor.
"""
float_w = weight.float()
if qparams is None:
qparams = self.calculate_qparams(weight)
qparams = self.calculate_qparams(float_w)
scales = qparams.scales
zero_points = qparams.zero_points
......@@ -133,17 +135,18 @@ class WeightQuantizer(GlobalAvailMixin):
# per group scales shape: [out_c, in_c//group_size, 1]
if len(scales.shape) > 2:
# scales shape: [out_c, in_c//group_size, 1]
weight = weight.reshape(out_c, scales.shape[1], -1)
float_w = float_w.reshape(out_c, scales.shape[1], -1)
if zero_points is None:
assert self.symmetry
real_qweight = (weight / scales).round()
real_qweight = (float_w / scales).round()
fake_qweight = real_qweight * scales
else:
assert not self.symmetry
real_qweight = (weight / scales).round() + zero_points
real_qweight = precise_round(
(float_w - float_w.min(-1, keepdim=True)[0]) / scales)
fake_qweight = (real_qweight - zero_points) * scales
if len(scales.shape) > 2:
......@@ -153,4 +156,4 @@ class WeightQuantizer(GlobalAvailMixin):
if real:
return real_qweight.to(torch.int32)
else:
return fake_qweight
return fake_qweight.to(weight.dtype)
......@@ -6,7 +6,7 @@ from .cal_qparams import (QParams, cal_qparams_per_channel_absmax,
cal_qparams_per_group_absmax,
cal_qparams_per_group_minmax,
cal_qparams_per_tensor_absmax,
cal_qparams_per_tensor_minmax)
cal_qparams_per_tensor_minmax, precise_round)
from .calib_dataloader import get_calib_loaders
from .collect import (bimap_name_mod, collect_target_modules,
collect_target_weights)
......@@ -16,7 +16,7 @@ __all__ = [
'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax',
'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax',
'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax',
'QParams', 'get_calib_loaders', 'collect_target_modules',
'QParams', 'get_calib_loaders', 'collect_target_modules', 'precise_round',
'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs',
'bimap_name_mod', 'concat_decoder_layer_outputs'
]
......@@ -11,16 +11,22 @@ class QParams(NamedTuple):
zero_points: Optional[torch.Tensor]
@torch.no_grad()
def precise_round(x):
return x.sign() * (x.abs() + 0.5).floor()
@torch.no_grad()
def cal_qparams_per_channel_absmax(w: torch.Tensor,
n_bits: int,
return_stats: bool = False) -> QParams:
"""Calculate quantization parameters for each channel using absolute max
value."""
float_w = w.float()
absmax = w.abs().max(dim=-1, keepdim=True)[0]
absmax = float_w.abs().max(dim=-1, keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1
scales = absmax.clamp(min=1e-5).div(q_max)
scales = absmax.div(q_max)
if return_stats:
return QParams(scales=scales, zero_points=None), absmax
......@@ -35,14 +41,16 @@ def cal_qparams_per_channel_minmax(w: torch.Tensor,
"""Calculate quantization parameters for each channel using min and max
values."""
w_min = w.min(dim=-1, keepdim=True)[0]
w_max = w.max(dim=-1, keepdim=True)[0]
float_w = w.float()
w_min = float_w.min(dim=-1, keepdim=True)[0]
w_max = float_w.max(dim=-1, keepdim=True)[0]
q_max = 2**n_bits - 1
scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max)
scales = scales.div_(q_max)
zero_points = (-w_min / scales).round()
zero_points = precise_round(-w_min / scales)
if return_stats:
return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
......@@ -63,9 +71,12 @@ def cal_qparams_per_group_absmax(w: torch.Tensor,
'Input channels should be greater than or equal to group_size.'
assert inc % group_size == 0, \
'Input channels should be divisible by group_size.'
absmax = w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0]
float_w = w.float()
absmax = float_w.abs().reshape(outc, -1, group_size).max(dim=-1,
keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1
scales = absmax.clamp(min=1e-5).div(q_max)
scales = absmax.div(q_max)
if return_stats:
return QParams(scales=scales, zero_points=None), absmax
else:
......@@ -85,14 +96,16 @@ def cal_qparams_per_group_minmax(w: torch.Tensor,
'Input channels should be greater than or equal to group_size.'
assert inc % group_size == 0, \
'Input channels should be divisible by group_size.'
w_group_wise = w.reshape(outc, -1, group_size)
float_w = w.float()
w_group_wise = float_w.reshape(outc, -1, group_size)
w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
q_max = 2**n_bits - 1
scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max)
zero_points = (-w_min / scales).round()
scales = scales.div_(q_max)
zero_points = precise_round(-w_min / scales)
if return_stats:
return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
else:
......@@ -106,13 +119,15 @@ def cal_qparams_per_tensor_minmax(w: torch.Tensor,
"""Calculate quantization parameters for the entire tensor using min and
max values."""
w_min = w.min()
w_max = w.max()
float_w = w.float()
w_min = float_w.min()
w_max = float_w.max()
q_max = 2**n_bits - 1
scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max)
zero_points = (-w_min / scales).round()
zero_points = precise_round(-w_min / scales)
if return_stats:
return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
else:
......@@ -125,9 +140,10 @@ def cal_qparams_per_tensor_absmax(w: torch.Tensor,
return_stats: bool = False) -> QParams:
"""Calculate quantization parameters for the entire tensor using absolute
max value."""
absmax = w.abs().max()
float_w = w.float()
absmax = float_w.abs().max()
q_max = 2**(n_bits - 1) - 1
scales = absmax.clamp(min=1e-5).div(q_max)
scales = absmax.div(q_max)
if return_stats:
return QParams(scales=scales, zero_points=None), absmax
......
# Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple, Union
from mmengine.config.lazy import LazyAttr
from torch import nn
......@@ -22,9 +21,6 @@ def collect_target_modules(model: nn.Module,
A dictionary mapping from module names to module instances.
"""
if isinstance(target, LazyAttr):
target = target.build()
if not isinstance(target, (type, str)):
raise TypeError('Target must be a string (name of the module) '
'or a type (class of the module)')
......
......@@ -4,6 +4,11 @@ from typing import Optional, Type, TypeVar
import torch
from torch import nn
try:
import awq_inference_engine
except ModuleNotFoundError:
awq_inference_engine = None
class WeightOnlyQLinear(nn.Module):
"""This class implements weight only quantization linear.
......@@ -18,13 +23,15 @@ class WeightOnlyQLinear(nn.Module):
bias (Tensor, optional): Defaults to None.
"""
def __init__(self,
w_bit: int,
symmetry: bool,
group_size: int,
def __init__(
self,
in_features: int,
out_features: int,
bias: Optional[torch.Tensor] = None) -> None:
bias: Optional[torch.Tensor] = True,
w_bit: int = 4,
symmetry: bool = False,
group_size: int = 128,
) -> None:
super().__init__()
if w_bit not in [2, 4, 8]:
......@@ -92,8 +99,8 @@ class WeightOnlyQLinear(nn.Module):
out_features = linear.out_features
bias = False if linear.bias is None else True
qlinear = cls(w_bit, symmetry, group_size, in_features, out_features,
bias)
qlinear = cls(in_features, out_features, bias, w_bit, symmetry,
group_size)
qlinear.bias = linear.bias
qparams = quantizer.calculate_qparams(linear.weight)
......@@ -124,3 +131,24 @@ class WeightOnlyQLinear(nn.Module):
qlinear.to('cpu')
return qlinear
@torch.no_grad()
def forward(self, x):
if awq_inference_engine is None:
raise RuntimeError(
'Run the following command to install '
'the kernel for 4bit inference\n\n'
'git clone https://github.com/mit-han-lab/llm-awq.git\n'
'cd awq/kernels\n'
'python setup.py install\n')
out_shape = x.shape[:-1] + (self.out_features, )
inputs = x.reshape(-1, x.shape[-1])
out = awq_inference_engine.gemm_forward_cuda(inputs.half(),
self.qweight,
self.scales.half(),
self.qzeros,
self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment