Unverified Commit c15fbf47 authored by pppppM's avatar pppppM Committed by GitHub
Browse files

[Fix] Qwen's quantization results are abnormal & Baichuan cannot be quantized (#605)

* fix awq

* adapt new qwen code

* adapt qwen 14b and baichuan2 7b

* add docstring

* add runtime error for qwen
parent 15d1cc2e
...@@ -15,13 +15,15 @@ from lmdeploy.lite.utils import collect_target_modules ...@@ -15,13 +15,15 @@ from lmdeploy.lite.utils import collect_target_modules
LAYER_TYPE_MAP = { LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer', 'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock', 'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer', 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer', 'LlamaForCausalLM': 'LlamaDecoderLayer',
} }
NORM_TYPE_MAP = { NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm', 'InternLMForCausalLM': 'InternLMRMSNorm',
'QWenLMHeadModel': 'RMSNorm', 'QWenLMHeadModel': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', 'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaRMSNorm', 'LlamaForCausalLM': 'LlamaRMSNorm',
} }
...@@ -40,6 +42,9 @@ def auto_awq(model: str, ...@@ -40,6 +42,9 @@ def auto_awq(model: str,
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
checkpoint = hf_config._name_or_path checkpoint = hf_config._name_or_path
# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True
with init_empty_weights(): with init_empty_weights():
# Load model # Load model
model = AutoModelForCausalLM.from_pretrained(model, model = AutoModelForCausalLM.from_pretrained(model,
...@@ -61,11 +66,14 @@ def auto_awq(model: str, ...@@ -61,11 +66,14 @@ def auto_awq(model: str,
device_map[name] = 'cpu' device_map[name] = 'cpu'
else: else:
device_map[name] = 0 device_map[name] = 0
load_checkpoint_in_model(model, checkpoint, device_map) load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)
work_dir = Path(work_dir) work_dir = Path(work_dir)
act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmean'] act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmax']
layers = collect_target_modules(model, layer_type) layers = collect_target_modules(model, layer_type)
fcs = {} fcs = {}
for l_name, layer in layers.items(): for l_name, layer in layers.items():
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from pathlib import Path from pathlib import Path
from typing import Union
import torch import torch
from accelerate import (infer_auto_device_map, init_empty_weights, from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model) load_checkpoint_in_model)
from torch import nn
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from lmdeploy.lite.quantization import CalibrationContext from lmdeploy.lite.quantization import CalibrationContext
...@@ -13,17 +15,90 @@ from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders ...@@ -13,17 +15,90 @@ from lmdeploy.lite.utils import collect_target_modules, get_calib_loaders
LAYER_TYPE_MAP = { LAYER_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMDecoderLayer', 'InternLMForCausalLM': 'InternLMDecoderLayer',
'QWenLMHeadModel': 'QWenBlock', 'QWenLMHeadModel': 'QWenBlock',
'BaiChuanForCausalLM': 'DecoderLayer', 'BaiChuanForCausalLM': 'DecoderLayer', # Baichuan 7B
'BaichuanForCausalLM': 'DecoderLayer', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaDecoderLayer', 'LlamaForCausalLM': 'LlamaDecoderLayer',
} }
NORM_TYPE_MAP = { NORM_TYPE_MAP = {
'InternLMForCausalLM': 'InternLMRMSNorm', 'InternLMForCausalLM': 'InternLMRMSNorm',
'QWenLMHeadModel': 'RMSNorm', 'QWenLMHeadModel': 'RMSNorm',
'BaiChuanForCausalLM': 'RMSNorm', 'BaiChuanForCausalLM': 'RMSNorm', # Baichuan 7B
'BaichuanForCausalLM': 'RMSNorm', # Baichuan2 7B
'LlamaForCausalLM': 'LlamaRMSNorm', 'LlamaForCausalLM': 'LlamaRMSNorm',
} }
def _prepare_for_calibrate(model: nn.Module,
layer_type: Union[str, type],
head_name: str = 'lm_head',
device: str = 'cuda',
prefix: str = '') -> None:
"""Prepare the model for calibration by moving specific modules to CPU.
This function goes through each child of a given model and checks whether
it is an instance of a certain layer type or has the name equal to
`head_name`.
If yes, it moves the module to CPU, otherwise to the specified device
(default is CUDA).
If the child contains the target layer type in its sub-modules, the
function performs the same operation recursively.
Parameters
----------
model : nn.Module
The PyTorch model to prepare for calibration.
layer_type : Union[str, Type]
The type of the layer to be moved to CPU. Can be either a string of
class name or the class type itself.
head_name : str, optional
The name of the module to be moved to CPU. Default is 'lm_head'.
device : str, optional
The device to which modules not matching the `layer_type` or
`head_name` will be moved. Default is 'cuda'.
prefix : str, optional
The prefix used when printing the names of the moved modules.
Default is ''.
Raises
------
TypeError
If `layer_type` is neither a string nor a type.
"""
for name, child in model.named_children():
# Check if the child is an instance of the given layer type
if isinstance(layer_type, str):
is_layer = type(child).__name__ == layer_type
elif isinstance(layer_type, type):
is_layer = isinstance(child, layer_type)
else:
raise TypeError(
'layer_type should be a string (class name) or a type')
# Check if the child contains the target module type
contain_layer = len(
collect_target_modules(child, layer_type, [head_name]).keys()) > 0
# Check if the child matches the head name
is_head = name == head_name
mod_name = f'{prefix}.{name}' if prefix else name
# If the child is either an instance of the layer type or has the
# head name, move it to CPU, otherwise move it to the specified device
if is_layer or is_head:
child.to('cpu')
print(f'Move {mod_name} to CPU.')
elif contain_layer:
_prepare_for_calibrate(child, layer_type, head_name, device,
mod_name)
else:
child.to(device)
print(f'Move {mod_name} to GPU.')
def calibrate(model: str, def calibrate(model: str,
calib_dataset: str = 'c4', calib_dataset: str = 'c4',
calib_samples: int = 128, calib_samples: int = 128,
...@@ -54,16 +129,38 @@ def calibrate(model: str, ...@@ -54,16 +129,38 @@ def calibrate(model: str,
tokenizer = AutoTokenizer.from_pretrained(model, tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False, use_fast=False,
trust_remote_code=True) trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True) hf_config = AutoConfig.from_pretrained(model,
torch_dtype=torch.float16,
trust_remote_code=True)
checkpoint = hf_config._name_or_path checkpoint = hf_config._name_or_path
# hard code for qwen, other configs do not have the `fp16` attribute.
hf_config.fp16 = True
with init_empty_weights(): with init_empty_weights():
# Load model # Load model
model = AutoModelForCausalLM.from_pretrained(model, model = AutoModelForCausalLM.from_pretrained(model,
config=hf_config,
torch_dtype=torch.float16, torch_dtype=torch.float16,
trust_remote_code=True) trust_remote_code=True)
model.config.use_cache = False model.config.use_cache = False
model_type = type(model).__name__
if model_type not in LAYER_TYPE_MAP or model_type not in NORM_TYPE_MAP:
raise RuntimeError(
f'Currently, quantification and calibration of {model_type} are '
f'not supported. The supported model types are '
f"{', '.join(LAYER_TYPE_MAP.keys())}.")
if model_type == 'QWenLMHeadModel':
try:
import flash_attn # noqa: F401
except ImportError:
raise RuntimeError(
'When using Qwen, you need to `pip install flash-attn` first, '
'otherwise calibration and quantification will not work '
'properly.')
layer_type = LAYER_TYPE_MAP[type(model).__name__] layer_type = LAYER_TYPE_MAP[type(model).__name__]
norm_type = NORM_TYPE_MAP[type(model).__name__] norm_type = NORM_TYPE_MAP[type(model).__name__]
...@@ -77,7 +174,12 @@ def calibrate(model: str, ...@@ -77,7 +174,12 @@ def calibrate(model: str,
device_map[name] = 'cpu' device_map[name] = 'cpu'
else: else:
device_map[name] = 0 device_map[name] = 0
load_checkpoint_in_model(model, checkpoint, device_map) load_checkpoint_in_model(model,
checkpoint,
device_map,
dtype=torch.float16)
_prepare_for_calibrate(model, layer_type, 'lm_head', device)
print('Loading calibrate dataset ...') print('Loading calibrate dataset ...')
calib_loader, _ = get_calib_loaders(calib_dataset, calib_loader, _ = get_calib_loaders(calib_dataset,
......
...@@ -18,6 +18,10 @@ NORM_FCS_MAP = { ...@@ -18,6 +18,10 @@ NORM_FCS_MAP = {
'QWenBlock': { 'QWenBlock': {
'ln_1': ['attn.c_attn'], 'ln_1': ['attn.c_attn'],
'ln_2': ['mlp.w1', 'mlp.w2'] 'ln_2': ['mlp.w1', 'mlp.w2']
},
'DecoderLayer': {
'input_layernorm': ['self_attn.W_pack'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
} }
} }
...@@ -33,6 +37,10 @@ FC_FCS_MAP = { ...@@ -33,6 +37,10 @@ FC_FCS_MAP = {
'QWenBlock': { 'QWenBlock': {
'attn.c_attn': ['attn.c_proj'], 'attn.c_attn': ['attn.c_proj'],
'mlp.w1': ['mlp.c_proj'] 'mlp.w1': ['mlp.c_proj']
},
'DecoderLayer': {
'self_attn.W_pack': ['self_attn.o_proj'],
'mlp.up_proj': ['mlp.down_proj']
} }
} }
...@@ -69,7 +77,7 @@ def smooth_ln_fcs(ln: torch.nn.Module, ...@@ -69,7 +77,7 @@ def smooth_ln_fcs(ln: torch.nn.Module,
w_scales = get_weight_scale(concat_w, group_size) w_scales = get_weight_scale(concat_w, group_size)
scales = (act_scales.pow(alpha) / scales = (act_scales.pow(alpha) /
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) w_scales.pow(1 - alpha)).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
ln.weight.div_(scales) ln.weight.div_(scales)
...@@ -116,10 +124,10 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, ...@@ -116,10 +124,10 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
w_scales = get_weight_scale(concat_w, group_size) w_scales = get_weight_scale(concat_w, group_size)
scales = (act_scales.pow(alpha) / scales = (act_scales.pow(alpha) /
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) w_scales.pow(1 - alpha)).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
# (for qwen) pre_fc is packed QKV, only V needs to scale # (for qwen&baichuan) pre_fc is packed QKV, only V needs to scale
if size_pre_fc > size_a and size_pre_fc % size_a == 0 \ if size_pre_fc > size_a and size_pre_fc % size_a == 0 \
and size_pre_fc // size_a == 3: and size_pre_fc // size_a == 3:
......
...@@ -8,7 +8,7 @@ from lmdeploy.lite.utils import (QParams, cal_qparams_per_channel_absmax, ...@@ -8,7 +8,7 @@ from lmdeploy.lite.utils import (QParams, cal_qparams_per_channel_absmax,
cal_qparams_per_group_absmax, cal_qparams_per_group_absmax,
cal_qparams_per_group_minmax, cal_qparams_per_group_minmax,
cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_absmax,
cal_qparams_per_tensor_minmax) cal_qparams_per_tensor_minmax, precise_round)
from lmdeploy.lite.utils.global_avail import GlobalAvailMixin from lmdeploy.lite.utils.global_avail import GlobalAvailMixin
...@@ -119,8 +119,10 @@ class WeightQuantizer(GlobalAvailMixin): ...@@ -119,8 +119,10 @@ class WeightQuantizer(GlobalAvailMixin):
torch.Tensor: The fake quantized weight tensor. torch.Tensor: The fake quantized weight tensor.
""" """
float_w = weight.float()
if qparams is None: if qparams is None:
qparams = self.calculate_qparams(weight) qparams = self.calculate_qparams(float_w)
scales = qparams.scales scales = qparams.scales
zero_points = qparams.zero_points zero_points = qparams.zero_points
...@@ -133,17 +135,18 @@ class WeightQuantizer(GlobalAvailMixin): ...@@ -133,17 +135,18 @@ class WeightQuantizer(GlobalAvailMixin):
# per group scales shape: [out_c, in_c//group_size, 1] # per group scales shape: [out_c, in_c//group_size, 1]
if len(scales.shape) > 2: if len(scales.shape) > 2:
# scales shape: [out_c, in_c//group_size, 1] # scales shape: [out_c, in_c//group_size, 1]
weight = weight.reshape(out_c, scales.shape[1], -1) float_w = float_w.reshape(out_c, scales.shape[1], -1)
if zero_points is None: if zero_points is None:
assert self.symmetry assert self.symmetry
real_qweight = (weight / scales).round() real_qweight = (float_w / scales).round()
fake_qweight = real_qweight * scales fake_qweight = real_qweight * scales
else: else:
assert not self.symmetry assert not self.symmetry
real_qweight = (weight / scales).round() + zero_points real_qweight = precise_round(
(float_w - float_w.min(-1, keepdim=True)[0]) / scales)
fake_qweight = (real_qweight - zero_points) * scales fake_qweight = (real_qweight - zero_points) * scales
if len(scales.shape) > 2: if len(scales.shape) > 2:
...@@ -153,4 +156,4 @@ class WeightQuantizer(GlobalAvailMixin): ...@@ -153,4 +156,4 @@ class WeightQuantizer(GlobalAvailMixin):
if real: if real:
return real_qweight.to(torch.int32) return real_qweight.to(torch.int32)
else: else:
return fake_qweight return fake_qweight.to(weight.dtype)
...@@ -6,7 +6,7 @@ from .cal_qparams import (QParams, cal_qparams_per_channel_absmax, ...@@ -6,7 +6,7 @@ from .cal_qparams import (QParams, cal_qparams_per_channel_absmax,
cal_qparams_per_group_absmax, cal_qparams_per_group_absmax,
cal_qparams_per_group_minmax, cal_qparams_per_group_minmax,
cal_qparams_per_tensor_absmax, cal_qparams_per_tensor_absmax,
cal_qparams_per_tensor_minmax) cal_qparams_per_tensor_minmax, precise_round)
from .calib_dataloader import get_calib_loaders from .calib_dataloader import get_calib_loaders
from .collect import (bimap_name_mod, collect_target_modules, from .collect import (bimap_name_mod, collect_target_modules,
collect_target_weights) collect_target_weights)
...@@ -16,7 +16,7 @@ __all__ = [ ...@@ -16,7 +16,7 @@ __all__ = [
'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax', 'cal_qparams_per_channel_absmax', 'cal_qparams_per_channel_minmax',
'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax', 'cal_qparams_per_group_absmax', 'cal_qparams_per_group_minmax',
'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax', 'cal_qparams_per_tensor_absmax', 'cal_qparams_per_tensor_minmax',
'QParams', 'get_calib_loaders', 'collect_target_modules', 'QParams', 'get_calib_loaders', 'collect_target_modules', 'precise_round',
'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs', 'collect_target_weights', 'GlobalAvailMixin', 'split_decoder_layer_inputs',
'bimap_name_mod', 'concat_decoder_layer_outputs' 'bimap_name_mod', 'concat_decoder_layer_outputs'
] ]
...@@ -11,16 +11,22 @@ class QParams(NamedTuple): ...@@ -11,16 +11,22 @@ class QParams(NamedTuple):
zero_points: Optional[torch.Tensor] zero_points: Optional[torch.Tensor]
@torch.no_grad()
def precise_round(x):
return x.sign() * (x.abs() + 0.5).floor()
@torch.no_grad() @torch.no_grad()
def cal_qparams_per_channel_absmax(w: torch.Tensor, def cal_qparams_per_channel_absmax(w: torch.Tensor,
n_bits: int, n_bits: int,
return_stats: bool = False) -> QParams: return_stats: bool = False) -> QParams:
"""Calculate quantization parameters for each channel using absolute max """Calculate quantization parameters for each channel using absolute max
value.""" value."""
float_w = w.float()
absmax = w.abs().max(dim=-1, keepdim=True)[0] absmax = float_w.abs().max(dim=-1, keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1 q_max = 2**(n_bits - 1) - 1
scales = absmax.clamp(min=1e-5).div(q_max) scales = absmax.div(q_max)
if return_stats: if return_stats:
return QParams(scales=scales, zero_points=None), absmax return QParams(scales=scales, zero_points=None), absmax
...@@ -35,14 +41,16 @@ def cal_qparams_per_channel_minmax(w: torch.Tensor, ...@@ -35,14 +41,16 @@ def cal_qparams_per_channel_minmax(w: torch.Tensor,
"""Calculate quantization parameters for each channel using min and max """Calculate quantization parameters for each channel using min and max
values.""" values."""
w_min = w.min(dim=-1, keepdim=True)[0] float_w = w.float()
w_max = w.max(dim=-1, keepdim=True)[0]
w_min = float_w.min(dim=-1, keepdim=True)[0]
w_max = float_w.max(dim=-1, keepdim=True)[0]
q_max = 2**n_bits - 1 q_max = 2**n_bits - 1
scales = (w_max - w_min) scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max) scales = scales.div_(q_max)
zero_points = (-w_min / scales).round() zero_points = precise_round(-w_min / scales)
if return_stats: if return_stats:
return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
...@@ -63,9 +71,12 @@ def cal_qparams_per_group_absmax(w: torch.Tensor, ...@@ -63,9 +71,12 @@ def cal_qparams_per_group_absmax(w: torch.Tensor,
'Input channels should be greater than or equal to group_size.' 'Input channels should be greater than or equal to group_size.'
assert inc % group_size == 0, \ assert inc % group_size == 0, \
'Input channels should be divisible by group_size.' 'Input channels should be divisible by group_size.'
absmax = w.abs().reshape(outc, -1, group_size).max(dim=-1, keepdim=True)[0]
float_w = w.float()
absmax = float_w.abs().reshape(outc, -1, group_size).max(dim=-1,
keepdim=True)[0]
q_max = 2**(n_bits - 1) - 1 q_max = 2**(n_bits - 1) - 1
scales = absmax.clamp(min=1e-5).div(q_max) scales = absmax.div(q_max)
if return_stats: if return_stats:
return QParams(scales=scales, zero_points=None), absmax return QParams(scales=scales, zero_points=None), absmax
else: else:
...@@ -85,14 +96,16 @@ def cal_qparams_per_group_minmax(w: torch.Tensor, ...@@ -85,14 +96,16 @@ def cal_qparams_per_group_minmax(w: torch.Tensor,
'Input channels should be greater than or equal to group_size.' 'Input channels should be greater than or equal to group_size.'
assert inc % group_size == 0, \ assert inc % group_size == 0, \
'Input channels should be divisible by group_size.' 'Input channels should be divisible by group_size.'
w_group_wise = w.reshape(outc, -1, group_size)
float_w = w.float()
w_group_wise = float_w.reshape(outc, -1, group_size)
w_min = w_group_wise.min(dim=-1, keepdim=True)[0] w_min = w_group_wise.min(dim=-1, keepdim=True)[0]
w_max = w_group_wise.max(dim=-1, keepdim=True)[0] w_max = w_group_wise.max(dim=-1, keepdim=True)[0]
q_max = 2**n_bits - 1 q_max = 2**n_bits - 1
scales = (w_max - w_min) scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max) scales = scales.div_(q_max)
zero_points = (-w_min / scales).round() zero_points = precise_round(-w_min / scales)
if return_stats: if return_stats:
return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
else: else:
...@@ -106,13 +119,15 @@ def cal_qparams_per_tensor_minmax(w: torch.Tensor, ...@@ -106,13 +119,15 @@ def cal_qparams_per_tensor_minmax(w: torch.Tensor,
"""Calculate quantization parameters for the entire tensor using min and """Calculate quantization parameters for the entire tensor using min and
max values.""" max values."""
w_min = w.min() float_w = w.float()
w_max = w.max()
w_min = float_w.min()
w_max = float_w.max()
q_max = 2**n_bits - 1 q_max = 2**n_bits - 1
scales = (w_max - w_min) scales = (w_max - w_min)
scales = scales.clamp_(min=1e-5).div_(q_max) scales = scales.clamp_(min=1e-5).div_(q_max)
zero_points = (-w_min / scales).round() zero_points = precise_round(-w_min / scales)
if return_stats: if return_stats:
return QParams(scales=scales, zero_points=zero_points), (w_min, w_max) return QParams(scales=scales, zero_points=zero_points), (w_min, w_max)
else: else:
...@@ -125,9 +140,10 @@ def cal_qparams_per_tensor_absmax(w: torch.Tensor, ...@@ -125,9 +140,10 @@ def cal_qparams_per_tensor_absmax(w: torch.Tensor,
return_stats: bool = False) -> QParams: return_stats: bool = False) -> QParams:
"""Calculate quantization parameters for the entire tensor using absolute """Calculate quantization parameters for the entire tensor using absolute
max value.""" max value."""
absmax = w.abs().max() float_w = w.float()
absmax = float_w.abs().max()
q_max = 2**(n_bits - 1) - 1 q_max = 2**(n_bits - 1) - 1
scales = absmax.clamp(min=1e-5).div(q_max) scales = absmax.div(q_max)
if return_stats: if return_stats:
return QParams(scales=scales, zero_points=None), absmax return QParams(scales=scales, zero_points=None), absmax
......
# Copyright (c) OpenMMLab. All rights reserved. # Copyright (c) OpenMMLab. All rights reserved.
from typing import Dict, List, Tuple, Union from typing import Dict, List, Tuple, Union
from mmengine.config.lazy import LazyAttr
from torch import nn from torch import nn
...@@ -22,9 +21,6 @@ def collect_target_modules(model: nn.Module, ...@@ -22,9 +21,6 @@ def collect_target_modules(model: nn.Module,
A dictionary mapping from module names to module instances. A dictionary mapping from module names to module instances.
""" """
if isinstance(target, LazyAttr):
target = target.build()
if not isinstance(target, (type, str)): if not isinstance(target, (type, str)):
raise TypeError('Target must be a string (name of the module) ' raise TypeError('Target must be a string (name of the module) '
'or a type (class of the module)') 'or a type (class of the module)')
......
...@@ -4,6 +4,11 @@ from typing import Optional, Type, TypeVar ...@@ -4,6 +4,11 @@ from typing import Optional, Type, TypeVar
import torch import torch
from torch import nn from torch import nn
try:
import awq_inference_engine
except ModuleNotFoundError:
awq_inference_engine = None
class WeightOnlyQLinear(nn.Module): class WeightOnlyQLinear(nn.Module):
"""This class implements weight only quantization linear. """This class implements weight only quantization linear.
...@@ -18,13 +23,15 @@ class WeightOnlyQLinear(nn.Module): ...@@ -18,13 +23,15 @@ class WeightOnlyQLinear(nn.Module):
bias (Tensor, optional): Defaults to None. bias (Tensor, optional): Defaults to None.
""" """
def __init__(self, def __init__(
w_bit: int, self,
symmetry: bool, in_features: int,
group_size: int, out_features: int,
in_features: int, bias: Optional[torch.Tensor] = True,
out_features: int, w_bit: int = 4,
bias: Optional[torch.Tensor] = None) -> None: symmetry: bool = False,
group_size: int = 128,
) -> None:
super().__init__() super().__init__()
if w_bit not in [2, 4, 8]: if w_bit not in [2, 4, 8]:
...@@ -92,8 +99,8 @@ class WeightOnlyQLinear(nn.Module): ...@@ -92,8 +99,8 @@ class WeightOnlyQLinear(nn.Module):
out_features = linear.out_features out_features = linear.out_features
bias = False if linear.bias is None else True bias = False if linear.bias is None else True
qlinear = cls(w_bit, symmetry, group_size, in_features, out_features, qlinear = cls(in_features, out_features, bias, w_bit, symmetry,
bias) group_size)
qlinear.bias = linear.bias qlinear.bias = linear.bias
qparams = quantizer.calculate_qparams(linear.weight) qparams = quantizer.calculate_qparams(linear.weight)
...@@ -124,3 +131,24 @@ class WeightOnlyQLinear(nn.Module): ...@@ -124,3 +131,24 @@ class WeightOnlyQLinear(nn.Module):
qlinear.to('cpu') qlinear.to('cpu')
return qlinear return qlinear
@torch.no_grad()
def forward(self, x):
if awq_inference_engine is None:
raise RuntimeError(
'Run the following command to install '
'the kernel for 4bit inference\n\n'
'git clone https://github.com/mit-han-lab/llm-awq.git\n'
'cd awq/kernels\n'
'python setup.py install\n')
out_shape = x.shape[:-1] + (self.out_features, )
inputs = x.reshape(-1, x.shape[-1])
out = awq_inference_engine.gemm_forward_cuda(inputs.half(),
self.qweight,
self.scales.half(),
self.qzeros,
self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment