Commit 982c1545 authored by gaoqiong's avatar gaoqiong
Browse files

add qwen awq support

parent ec98d390
...@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None and quant_method!="awq":: if quant_method is not None and quant_method!="awq":
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated # FIXME: Remove this after Mixtral is updated
# to use quant_method. # to use quant_method.
......
...@@ -245,6 +245,17 @@ class QWenLMHeadModel(nn.Module): ...@@ -245,6 +245,17 @@ class QWenLMHeadModel(nn.Module):
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size) self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
...@@ -339,4 +350,45 @@ class QWenLMHeadModel(nn.Module): ...@@ -339,4 +350,45 @@ class QWenLMHeadModel(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [
"attn.c_attn.qweight",
"attn.c_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.c_proj.qweight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
group_size= self.quant_config.group_size
dim_n = scales.data.shape[1]
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
#reshape
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
...@@ -326,8 +326,8 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -326,8 +326,8 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None self.quant_method = None
if quant_config is not None: if quant_config is not None:
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
...@@ -440,7 +440,7 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -440,7 +440,7 @@ class Qwen2ForCausalLM(nn.Module):
if self.quant_method == "awq": if self.quant_method == "awq":
from lmslim import quant_ops as _ops from lmslim import quant_ops as _ops
# 对weight进行处理转置处理
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
...@@ -453,7 +453,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -453,7 +453,6 @@ class Qwen2ForCausalLM(nn.Module):
matches = re.findall(combined_words, layername) matches = re.findall(combined_words, layername)
if matches: if matches:
#只对.qweight做了匹配,但是对对应的scale和qzeros都做了处理
qweight =params_dict[layername] qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")] qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")] scales=params_dict[layername.replace("qweight", "scales")]
...@@ -465,18 +464,10 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -465,18 +464,10 @@ class Qwen2ForCausalLM(nn.Module):
dim_k = qweight.data.shape[0] dim_k = qweight.data.shape[0]
pad_group=2 pad_group=2
#对qweight和qzeros以及scales进行pad
#qweight[k,n/8]--->[k+group_size*2,n/8]
#qzeros [k/group_size+2,n/8]
#scales [k/group_size+2,n]
#给weight进行转置和zeros_and_scales打包
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size)) _qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
#给sz转置(转置之后但是暂时保留原来的shape信息)
sz = _ops.sz_permute(_sz).reshape(-1,dim_n) sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
#数据拷贝
zeros_and_scalse.data.copy_(sz) zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw) qweight.data.copy_(_qw)
...@@ -484,7 +475,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -484,7 +475,6 @@ class Qwen2ForCausalLM(nn.Module):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size] zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8] qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
#对qweight 与zeros_and_scalse 进行pad
if dim_k % 4096==0: if dim_k % 4096==0:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment