Commit 5f5ddc3d authored by gaoqiong's avatar gaoqiong
Browse files

add llama model awq support

parent bdac8f06
......@@ -172,7 +172,7 @@ class ModelConfig:
def _verify_quantization(self) -> None:
supported_quantization = [*QUANTIZATION_METHODS]
rocm_supported_quantization = ["gptq", "squeezellm"]
rocm_supported_quantization = ["gptq", "squeezellm","awq"]
if self.quantization is not None:
self.quantization = self.quantization.lower()
......
......@@ -8,6 +8,14 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
try:
from lmslim import quant_ops as _ops
except Exception:
print("INFO:you need install lmslim if you want infer awq model.\n")
class AWQShareWorkSpace():
awqworkshapcesize=2<<29 #
awqworkshapce=torch.zeros(awqworkshapcesize//2+1,dtype=torch.float16).cuda()
class AWQConfig(QuantizationConfig):
......@@ -142,6 +150,19 @@ class AWQLinearMethod(LinearMethodBase):
"input_dim": 0,
"output_dim": 1,
})
zeros_and_scales=Parameter(
torch.empty(
(input_size_per_partition // self.quant_config.group_size),
output_size_per_partition,
dtype=torch.int32,
),
requires_grad=False,
)
set_weight_attrs(zeros_and_scales, {
"input_dim": 0,
"output_dim": 1,
})
layer.register_parameter("qweight", qweight)
set_weight_attrs(qweight, extra_weight_attrs)
......@@ -149,27 +170,47 @@ class AWQLinearMethod(LinearMethodBase):
set_weight_attrs(qzeros, extra_weight_attrs)
layer.register_parameter("scales", scales)
set_weight_attrs(scales, extra_weight_attrs)
layer.register_parameter("zeros_and_scales", zeros_and_scales)
set_weight_attrs(zeros_and_scales, extra_weight_attrs)
def apply(self,
layer: torch.nn.Module,
x: torch.Tensor,
bias: Optional[torch.Tensor] = None) -> torch.Tensor:
qweight = layer.qweight
scales = layer.scales
qzeros = layer.qzeros
pack_factor = self.quant_config.pack_factor
out_shape = (x.shape[:-1] + (qweight.shape[-1] * pack_factor, ))
zeros_and_scales = layer.zeros_and_scales
out_shape = (x.shape[:-1] + (qweight.shape[0] * 1, ))
reshaped_x = x.reshape(-1, x.shape[-1])
# num_tokens >= threshold
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[:-1].numel() >= 256
if FP16_MATMUL_HEURISTIC_CONDITION:
out = ops.awq_dequantize(qweight, scales, qzeros, 0, 0, 0)
out = torch.matmul(reshaped_x, out)
m = reshaped_x.shape[0]
k = reshaped_x.shape[-1]
n = qweight.shape[0]
if k % 4096==0:
padding_group=2
else:
out = ops.awq_gemm(reshaped_x, qweight, scales, qzeros,
pack_factor)
padding_group=0
out = _ops.awq_gemm(reshaped_x,
qweight,
zeros_and_scales,
m,
n,
k,
self.quant_config.group_size,
padding_group,
AWQShareWorkSpace.awqworkshapce,
AWQShareWorkSpace.awqworkshapcesize)
#下面是采用rocblas的做法
# deqweight=_ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight,
# zeros_and_scales,
# k,
# n,
# self.quant_config.group_size)
# output=F.linear(reshaped_x, deqweight)
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
......@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None:
if quant_method is not None and quant_method!="awq"::
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
......
......@@ -367,6 +367,18 @@ class LlamaForCausalLM(nn.Module):
self.logits_processor = LogitsProcessor(self.unpadded_vocab_size,
config.vocab_size, logit_scale)
self.sampler = Sampler()
self.quant_method = None
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......@@ -476,7 +488,49 @@ class LlamaForCausalLM(nn.Module):
weight.data.copy_(_weight)
weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [
"self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight",
"mlp.gate_up_proj.qweight",
"mlp.down_proj.qweight"
]
combined_words = "|".join(lay_key_words)
for layername, weight in params_dict.items():
matches = re.findall(combined_words, layername)
if matches:
qweight =params_dict[layername]
qzeros=params_dict[layername.replace("qweight", "qzeros")]
scales=params_dict[layername.replace("qweight", "scales")]
zeros_and_scalse =params_dict[layername.replace("qweight", "zeros_and_scales")]
group_size= self.quant_config.group_size
dim_n = scales.data.shape[1]
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
#reshape
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
qweight.data=torch.cat((qweight.data,qweight_pad),dim=1).contiguous()
# If this function is called, it should always initialize KV cache scale
# factors (or else raise an exception). Thus, handled exceptions should
# make sure to leave KV cache scale factors in a known good (dummy) state
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment