Commit dcaabcf7 authored by gaoqiong's avatar gaoqiong
Browse files

增添lmslim gptq支持

parent b8c88ed3
......@@ -3,7 +3,7 @@
vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention高效管理kv内存,Continuous batching传入请求,支持很多Hugging Face模型,如LLaMA & LLaMA-2、Qwen、Chatglm2 & Chatglm3等。
## 暂不支持的官方功能
- **量化推理**:目前支持fp16的推理和gptq推理,awq-int4mralin的权重量化、kv-cache fp8推理方案暂不支持
- **量化推理**:目前支持fp16的推理和gptq,awq-int4推理,mralin的权重量化、kv-cache fp8推理方案暂不支持
- **模块支持**:目前不支持Sliding window attention、 moe kernel和lora模块
......@@ -62,14 +62,11 @@ pip install -r requirements-rocm.txt
```
1. 编译whl包并安装
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
python csrc/quantization/gptq/setup.py bdist_wheel
cd dist
pip install vllm*
pip install gptq_kernel
2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
python csrc/quantization/gptq/setup.py install
```
#### 运行基础环境准备
......@@ -79,7 +76,7 @@ python csrc/quantization/gptq/setup.py install
- triton:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton/)
- xformers:[https://cancon.hpccube.com:65024/4/main/xformers](https://cancon.hpccube.com:65024/4/main/xformers)
- flash_attn: [https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn)
- lmslim: [https://cancon.hpccube.com:65024/4/main/lmslim](https://cancon.hpccube.com:65024/4/main/lmslim)
#### 注意事项
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
......
......@@ -18,7 +18,7 @@ from typing import Optional, Union
import subprocess
from pathlib import Path
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True
......
import contextlib
from typing import List, Optional, Tuple, Type
import torch
try:
import gptq_kernels
from lmslim import quant_ops
except Exception:
print("INFO: Please install gptq_kernels if you want to infer gptq model.\n")
print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
try:
import vllm._C
......@@ -150,17 +150,47 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
# def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
# scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
# return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return quant_ops.awq_gemm(input,
weight,
zeros_and_scales,
m,
n,
k,
group_size,
padding_group,
splikspace,
splikspacesize)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size)
def sz_permute(sz:torch.Tensor)-> torch.Tensor:
return quant_ops.sz_permute(sz)
def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
zeros_and_scale:torch.Tensor,
k:int,
n:int,
group_size:int
)->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor:
return gptq_kernels.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit)
......@@ -168,7 +198,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None:
gptq_kernels.gptq_shuffle(q_weight, q_perm, bit)
quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16
......
......@@ -8,10 +8,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs
try:
from lmslim import quant_ops as _ops
except Exception:
print("INFO: You need install lmslim if you want infer awq model.\n")
class AWQShareWorkSpace():
awqworkshapcesize=2<<29 #
......@@ -192,7 +189,7 @@ class AWQLinearMethod(LinearMethodBase):
else:
padding_group=0
out = _ops.awq_gemm(reshaped_x,
out = ops.awq_gemm(reshaped_x,
qweight,
zeros_and_scales,
m,
......@@ -203,7 +200,7 @@ class AWQLinearMethod(LinearMethodBase):
AWQShareWorkSpace.awqworkshapce,
AWQShareWorkSpace.awqworkshapcesize)
#下面是采用rocblas的做法
# deqweight=_ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight,
# zeros_and_scales,
# k,
......
......@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None)
if quant_method is not None and quant_method!="awq":
if quant_method is not None and quant_method!="awq" and quant_method!="gptq":
quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated
# to use quant_method.
......
......@@ -373,12 +373,6 @@ class LlamaForCausalLM(nn.Module):
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......@@ -490,7 +484,6 @@ class LlamaForCausalLM(nn.Module):
weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [
"self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight",
......@@ -514,9 +507,9 @@ class LlamaForCausalLM(nn.Module):
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
_qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
......
......@@ -251,11 +251,6 @@ class QWenLMHeadModel(nn.Module):
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......@@ -351,7 +346,6 @@ class QWenLMHeadModel(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [
"attn.c_attn.qweight",
"attn.c_proj.qweight",
......@@ -375,9 +369,9 @@ class QWenLMHeadModel(nn.Module):
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
_qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
......
......@@ -331,13 +331,7 @@ class Qwen2ForCausalLM(nn.Module):
if quant_config is not None:
self.quant_method=quant_config.get_name()
self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
......@@ -439,8 +433,6 @@ class Qwen2ForCausalLM(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [
"self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight",
......@@ -464,9 +456,9 @@ class Qwen2ForCausalLM(nn.Module):
dim_k = qweight.data.shape[0]
pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size))
_qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n)
sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment