Commit ec98d390 authored by gaoqiong's avatar gaoqiong
Browse files

add qwen2 awq support

parent 5f5ddc3d
......@@ -326,6 +326,18 @@ class Qwen2ForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......@@ -426,4 +438,56 @@ class Qwen2ForCausalLM(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1)
\ No newline at end of file
if self.quant_method == "awq":
from lmslim import quant_ops as _ops
# 对weight进行处理转置处理
lay_key_words = [
"self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
#只对.qweight做了匹配,但是对对应的scale和qzeros都做了处理
qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
group_size= self.quant_config.group_size
dim_n = scales.data.shape[1]
dim_k = qweight.data.shape[0]
pad_group=2
#对qweight和qzeros以及scales进行pad
#qweight[k,n/8]--->[k+group_size*2,n/8]
#qzeros [k/group_size+2,n/8]
#scales [k/group_size+2,n]
#给weight进行转置和zeros_and_scales打包
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
#给sz转置(转置之后但是暂时保留原来的shape信息)
sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
#数据拷贝
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
#reshape
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
#对qweight 与zeros_and_scalse 进行pad
if dim_k % 4096==0:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment