Unverified Commit d5cb0be2 authored by pppppM's avatar pppppM Committed by GitHub
Browse files

[Fix] Fix llama2 70b & qwen quantization error (#273)

* fix llama2 70b

* fix qwen quantization

* remove pdb

* add faq
parent e5bfd387
...@@ -210,7 +210,7 @@ python3 -m lmdeploy.lite.apis.calibrate \ ...@@ -210,7 +210,7 @@ python3 -m lmdeploy.lite.apis.calibrate \
#### Weight INT4 Quantization #### Weight INT4 Quantization
LMDeploy uses AWQ algorithm for model weight quantization LMDeploy uses [AWQ](https://arxiv.org/abs/2306.00978) algorithm for model weight quantization
> Requires input from the $WORK_DIR of step 1, and the quantized weights will also be stored in this folder. > Requires input from the $WORK_DIR of step 1, and the quantized weights will also be stored in this folder.
......
...@@ -48,3 +48,9 @@ export LD_LIBRARY_PATH={Location}/nvidia/nccl/lib:$LD_LIBRARY_PATH ...@@ -48,3 +48,9 @@ export LD_LIBRARY_PATH={Location}/nvidia/nccl/lib:$LD_LIBRARY_PATH
## 服务 ## 服务
## 量化 ## 量化
### RuntimeError: \[enforce fail at inline_container.cc:337\] . unexpected pos 4566829760 vs 4566829656
请检查你的硬盘空间。
这个错误是因为保存权重时硬盘空间不足导致的,在量化 70B 模型时可能会遇到
...@@ -4,8 +4,10 @@ from pathlib import Path ...@@ -4,8 +4,10 @@ from pathlib import Path
import fire import fire
import torch import torch
from accelerate import (infer_auto_device_map, init_empty_weights,
load_checkpoint_in_model)
from torch import nn from torch import nn
from transformers import AutoModelForCausalLM, AutoTokenizer from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP, from lmdeploy.lite.quantization.awq import (FC_FCS_MAP, NORM_FCS_MAP,
quant_weights, smooth_layers) quant_weights, smooth_layers)
...@@ -26,24 +28,42 @@ NORM_TYPE_MAP = { ...@@ -26,24 +28,42 @@ NORM_TYPE_MAP = {
def auto_awq(model: str, def auto_awq(model: str,
work_dir: str,
w_bits: int = 4, w_bits: int = 4,
w_sym: bool = False, w_sym: bool = False,
w_group_size: int = 128, w_group_size: int = 128,
work_dir: str = './work_dir',
device: str = 'cuda'): device: str = 'cuda'):
# Load tokenizer and configuration
tokenizer = AutoTokenizer.from_pretrained(model, tokenizer = AutoTokenizer.from_pretrained(model,
use_fast=False, use_fast=False,
trust_remote_code=True) trust_remote_code=True)
hf_config = AutoConfig.from_pretrained(model, trust_remote_code=True)
checkpoint = hf_config._name_or_path
with init_empty_weights():
# Load model
model = AutoModelForCausalLM.from_pretrained(model, model = AutoModelForCausalLM.from_pretrained(model,
torch_dtype=torch.float16, torch_dtype=torch.float16,
trust_remote_code=True) trust_remote_code=True)
model.config.use_cache = False
layer_type = LAYER_TYPE_MAP[type(model).__name__] layer_type = LAYER_TYPE_MAP[type(model).__name__]
fc2fcs = FC_FCS_MAP[layer_type] fc2fcs = FC_FCS_MAP[layer_type]
norm2fcs = NORM_FCS_MAP[layer_type] norm2fcs = NORM_FCS_MAP[layer_type]
decoder_layers = collect_target_modules(model, layer_type)
# Infer device map
device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type])
for name in device_map.keys():
if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu'
else:
device_map[name] = 0
load_checkpoint_in_model(model, checkpoint, device_map)
work_dir = Path(work_dir) work_dir = Path(work_dir)
act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmean'] act_scales = torch.load(work_dir / 'inputs_stats.pth')['absmean']
......
...@@ -74,7 +74,7 @@ def calibrate(model: str, ...@@ -74,7 +74,7 @@ def calibrate(model: str,
device_map = infer_auto_device_map(model, device_map = infer_auto_device_map(model,
no_split_module_classes=[layer_type]) no_split_module_classes=[layer_type])
for name in device_map.keys(): for name in device_map.keys():
if name in decoder_layers: if name in decoder_layers or 'lm_head' in name:
device_map[name] = 'cpu' device_map[name] = 'cpu'
else: else:
device_map[name] = 0 device_map[name] = 0
......
...@@ -37,9 +37,15 @@ class KVCacheObserver(GlobalAvailMixin): ...@@ -37,9 +37,15 @@ class KVCacheObserver(GlobalAvailMixin):
x : Input tensor x : Input tensor
""" """
assert len(x.shape) == 4 assert len(x.shape) == 4
if x.size(2) == self.num_head and x.size(3) == self.head_dim:
# layout: (bs, seqlen, heads, dims)
x = x
elif x.size(1) == self.num_head and x.size(3) == self.head_dim:
# layout: (bs, heads, seqlen, dims)
x = x.transpose(1, 2) x = x.transpose(1, 2)
assert x.size(2) == self.num_head else:
assert x.size(3) == self.head_dim raise RuntimeError
cur_max = x.flatten(0, 1).max(0)[0].cpu() cur_max = x.flatten(0, 1).max(0)[0].cpu()
cur_min = x.flatten(0, 1).min(0)[0].cpu() cur_min = x.flatten(0, 1).min(0)[0].cpu()
......
...@@ -14,6 +14,10 @@ NORM_FCS_MAP = { ...@@ -14,6 +14,10 @@ NORM_FCS_MAP = {
'input_layernorm': 'input_layernorm':
['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'], ['self_attn.k_proj', 'self_attn.q_proj', 'self_attn.v_proj'],
'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj'] 'post_attention_layernorm': ['mlp.gate_proj', 'mlp.up_proj']
},
'QWenBlock': {
'ln_1': ['attn.c_attn'],
'ln_2': ['mlp.w1', 'mlp.w2']
} }
} }
...@@ -25,6 +29,10 @@ FC_FCS_MAP = { ...@@ -25,6 +29,10 @@ FC_FCS_MAP = {
'InternLMDecoderLayer': { 'InternLMDecoderLayer': {
'self_attn.v_proj': ['self_attn.o_proj'], 'self_attn.v_proj': ['self_attn.o_proj'],
'mlp.up_proj': ['mlp.down_proj'] 'mlp.up_proj': ['mlp.down_proj']
},
'QWenBlock': {
'attn.c_attn': ['attn.c_proj'],
'mlp.w1': ['mlp.c_proj']
} }
} }
...@@ -94,6 +102,14 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, ...@@ -94,6 +102,14 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
:return: Scales :return: Scales
""" """
device, dtype = pre_fc.weight.device, pre_fc.weight.dtype device, dtype = pre_fc.weight.device, pre_fc.weight.dtype
size_a = act_scales.size(0)
size_pre_fc = pre_fc.weight.size(0)
# (for llama2) use group query attention, pre_fc is v_proj, fc is o_proj
if size_pre_fc < size_a and size_a % size_pre_fc == 0:
return
act_scales = act_scales.to(device=device, dtype=dtype) act_scales = act_scales.to(device=device, dtype=dtype)
concat_w = torch.cat([fc.weight for fc in fcs], dim=0) concat_w = torch.cat([fc.weight for fc in fcs], dim=0)
...@@ -103,6 +119,15 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module, ...@@ -103,6 +119,15 @@ def smooth_fc_fcs(pre_fc: torch.nn.Module,
w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype) w_scales.pow(1 - alpha)).clamp(min=1e-4).to(device).to(dtype)
scales = scales / (scales.max() * scales.min()).sqrt() scales = scales / (scales.max() * scales.min()).sqrt()
# (for qwen) pre_fc is packed QKV, only V needs to scale
if size_pre_fc > size_a and size_pre_fc % size_a == 0 \
and size_pre_fc // size_a == 3:
pre_fc.weight[-size_a:].div_(scales.view(-1, 1))
if getattr(pre_fc, 'bias', None) is not None:
pre_fc.bias[-size_a:].div_(scales)
else:
pre_fc.weight.div_(scales.view(-1, 1)) pre_fc.weight.div_(scales.view(-1, 1))
if getattr(pre_fc, 'bias', None) is not None: if getattr(pre_fc, 'bias', None) is not None:
...@@ -186,6 +211,7 @@ def smooth_layers(layers, ...@@ -186,6 +211,7 @@ def smooth_layers(layers,
fc = layer.get_submodule(f_name) fc = layer.get_submodule(f_name)
fcs = [layer.get_submodule(n) for n in fc_names] fcs = [layer.get_submodule(n) for n in fc_names]
smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size) smooth_fc_fcs(fc, fcs, a_scales[a_name], group_size)
layer.to('cpu') layer.to('cpu')
......
...@@ -49,9 +49,12 @@ class CalibrationContext(): ...@@ -49,9 +49,12 @@ class CalibrationContext():
self.layer_type = layer_type self.layer_type = layer_type
self.norm_type = norm_type self.norm_type = norm_type
self.num_head = self._guess_num_heads(model) num_kv_heads, num_attn_heads = self._guess_num_heads(model)
self.head_dim = model.config.hidden_size // self.num_head self.num_kv_heads = num_kv_heads
self.head_dim = model.config.hidden_size // num_attn_heads
self.model = model self.model = model
del self.model.lm_head
self.tokenizer = tokenizer self.tokenizer = tokenizer
# Collect modules to observe # Collect modules to observe
...@@ -74,12 +77,15 @@ class CalibrationContext(): ...@@ -74,12 +77,15 @@ class CalibrationContext():
self.device = device self.device = device
def _guess_num_heads(self, model): def _guess_num_heads(self, model):
if hasattr(model.config, 'num_attention_heads'):
return model.config.num_attention_heads if hasattr(model.config, 'num_key_value_heads'):
elif hasattr(model.config, 'num_key_value_heads'): num_kv_heads = model.config.num_key_value_heads
return model.config.num_key_value_heads
else: else:
raise KeyError num_kv_heads = model.config.num_attention_heads
num_attn_heads = model.config.num_attention_heads
return num_kv_heads, num_attn_heads
def _init_input_observers(self, name2mod): def _init_input_observers(self, name2mod):
"""Initialize input observers for given modules.""" """Initialize input observers for given modules."""
...@@ -96,8 +102,8 @@ class CalibrationContext(): ...@@ -96,8 +102,8 @@ class CalibrationContext():
def _init_kv_observers(self, name2mod): def _init_kv_observers(self, name2mod):
"""Initialize KV observers for given modules.""" """Initialize KV observers for given modules."""
for name in name2mod.keys(): for name in name2mod.keys():
k_obs = KVCacheObserver(self.num_head, self.head_dim) k_obs = KVCacheObserver(self.num_kv_heads, self.head_dim)
v_obs = KVCacheObserver(self.num_head, self.head_dim) v_obs = KVCacheObserver(self.num_kv_heads, self.head_dim)
k_obs.global_available(name, group=self.key_obs_group) k_obs.global_available(name, group=self.key_obs_group)
v_obs.global_available(name, group=self.value_obs_group) v_obs.global_available(name, group=self.value_obs_group)
...@@ -270,8 +276,13 @@ class CalibrationContext(): ...@@ -270,8 +276,13 @@ class CalibrationContext():
def calibrate(self, data): def calibrate(self, data):
"""Forward pass through the model in inference mode with given data.""" """Forward pass through the model in inference mode with given data."""
if type(self.model).__name__ == 'QWenLMHeadModel':
model = self.model.transformer
else:
model = self.model.model
with torch.inference_mode(): with torch.inference_mode():
_ = self.model.model(data.to(self.device)) _ = model(data.to(self.device))
def __enter__(self): def __enter__(self):
"""Prepares the Calibration object for a 'with' statement by """Prepares the Calibration object for a 'with' statement by
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment