Commit 39ae4102 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'v0.5.0-dtk24.04.1' into v0.5.3.post1-dtk24.04.1

parents 75011627 880b2e41
...@@ -331,6 +331,12 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -331,6 +331,12 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.logits_processor = LogitsProcessor(config.vocab_size) self.logits_processor = LogitsProcessor(config.vocab_size)
self.sampler = Sampler() self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
...@@ -438,4 +444,44 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -438,4 +444,44 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq":
\ No newline at end of file lay_key_words = [
"self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
group_size= self.quant_config.group_size
dim_n = scales.data.shape[1]
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
#reshape
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment