Commit dcaabcf7 authored by gaoqiong's avatar gaoqiong
Browse files

增添lmslim gptq支持

parent b8c88ed3
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention高效管理kv内存,Continuous batching传入请求,支持很多Hugging Face模型,如LLaMA & LLaMA-2、Qwen、Chatglm2 & Chatglm3等。 vLLM是一个快速且易于使用的LLM推理和服务库,使用PageAttention高效管理kv内存,Continuous batching传入请求,支持很多Hugging Face模型,如LLaMA & LLaMA-2、Qwen、Chatglm2 & Chatglm3等。
## 暂不支持的官方功能 ## 暂不支持的官方功能
- **量化推理**:目前支持fp16的推理和gptq推理,awq-int4mralin的权重量化、kv-cache fp8推理方案暂不支持 - **量化推理**:目前支持fp16的推理和gptq,awq-int4推理,mralin的权重量化、kv-cache fp8推理方案暂不支持
- **模块支持**:目前不支持Sliding window attention、 moe kernel和lora模块 - **模块支持**:目前不支持Sliding window attention、 moe kernel和lora模块
...@@ -62,14 +62,11 @@ pip install -r requirements-rocm.txt ...@@ -62,14 +62,11 @@ pip install -r requirements-rocm.txt
``` ```
1. 编译whl包并安装 1. 编译whl包并安装
VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel VLLM_INSTALL_PUNICA_KERNELS=1 python setup.py bdist_wheel
python csrc/quantization/gptq/setup.py bdist_wheel
cd dist cd dist
pip install vllm* pip install vllm*
pip install gptq_kernel
2. 源码编译安装 2. 源码编译安装
VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install VLLM_INSTALL_PUNICA_KERNELS=1 python3 setup.py install
python csrc/quantization/gptq/setup.py install
``` ```
#### 运行基础环境准备 #### 运行基础环境准备
...@@ -79,7 +76,7 @@ python csrc/quantization/gptq/setup.py install ...@@ -79,7 +76,7 @@ python csrc/quantization/gptq/setup.py install
- triton:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton/) - triton:[https://cancon.hpccube.com:65024/4/main/triton](https://cancon.hpccube.com:65024/4/main/triton/)
- xformers:[https://cancon.hpccube.com:65024/4/main/xformers](https://cancon.hpccube.com:65024/4/main/xformers) - xformers:[https://cancon.hpccube.com:65024/4/main/xformers](https://cancon.hpccube.com:65024/4/main/xformers)
- flash_attn: [https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn) - flash_attn: [https://cancon.hpccube.com:65024/4/main/flash_attn](https://cancon.hpccube.com:65024/4/main/flash_attn)
- lmslim: [https://cancon.hpccube.com:65024/4/main/lmslim](https://cancon.hpccube.com:65024/4/main/lmslim)
#### 注意事项 #### 注意事项
+ 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/ + 若使用 pip install 下载安装过慢,可添加源:-i https://pypi.tuna.tsinghua.edu.cn/simple/
......
...@@ -18,7 +18,7 @@ from typing import Optional, Union ...@@ -18,7 +18,7 @@ from typing import Optional, Union
import subprocess import subprocess
from pathlib import Path from pathlib import Path
add_git_version = False
if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1: if int(os.environ.get('ADD_GIT_VERSION', '0')) == 1:
add_git_version = True add_git_version = True
......
import contextlib import contextlib
from typing import List, Optional, Tuple, Type from typing import List, Optional, Tuple, Type
import torch import torch
try: try:
import gptq_kernels from lmslim import quant_ops
except Exception: except Exception:
print("INFO: Please install gptq_kernels if you want to infer gptq model.\n") print("INFO: Please install lmslim if you want to infer gptq or awq model.\n")
try: try:
import vllm._C import vllm._C
...@@ -150,17 +150,47 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, ...@@ -150,17 +150,47 @@ def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
thx, thy) thx, thy)
def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor, # def awq_gemm(input: torch.Tensor, qweight: torch.Tensor, qzeros: torch.Tensor,
scales: torch.Tensor, split_k_iters: int) -> torch.Tensor: # scales: torch.Tensor, split_k_iters: int) -> torch.Tensor:
return torch.ops._C.awq_gemm(input, qweight, qzeros, scales, split_k_iters) # return quant_ops.awq_gemm(input, qweight, qzeros, scales, split_k_iters)
def awq_gemm(input: torch.Tensor, weight: torch.Tensor,
zeros_and_scales:torch.Tensor,
m:int,n:int,k:int,
group_size:int,padding_group:int,splikspace:torch.Tensor,
splikspacesize:int) -> torch.Tensor:
return quant_ops.awq_gemm(input,
weight,
zeros_and_scales,
m,
n,
k,
group_size,
padding_group,
splikspace,
splikspacesize)
def convert_s4(qw: torch.Tensor, qz: torch.Tensor, s: torch.Tensor,
group_size: int):
return quant_ops.convert_s4(qw,qz,s,group_size)
def sz_permute(sz:torch.Tensor)-> torch.Tensor:
return quant_ops.sz_permute(sz)
def dequant_w4_gemm_colmajor(qweight:torch.Tensor,
zeros_and_scale:torch.Tensor,
k:int,
n:int,
group_size:int
)->torch.Tensor:
return quant_ops.dequant_w4_gemm_colmajor(qweight,zeros_and_scale,k,n,group_size)
# gptq # gptq
def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor, b_gptq_qzeros: torch.Tensor, b_gptq_scales: torch.Tensor,
b_g_idx: torch.Tensor, use_exllama: bool, b_g_idx: torch.Tensor, use_exllama: bool,
bit: int) -> torch.Tensor: bit: int) -> torch.Tensor:
return gptq_kernels.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, return quant_ops.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
b_g_idx, use_exllama, bit) b_g_idx, use_exllama, bit)
# return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales, # return torch.ops._C.gptq_gemm(a, b_q_weight, b_gptq_qzeros, b_gptq_scales,
# b_g_idx, use_exllama, bit) # b_g_idx, use_exllama, bit)
...@@ -168,7 +198,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor, ...@@ -168,7 +198,7 @@ def gptq_gemm(a: torch.Tensor, b_q_weight: torch.Tensor,
def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor, def gptq_shuffle(q_weight: torch.Tensor, q_perm: torch.Tensor,
bit: int) -> None: bit: int) -> None:
gptq_kernels.gptq_shuffle(q_weight, q_perm, bit) quant_ops.gptq_shuffle(q_weight, q_perm, bit)
# torch.ops._C.gptq_shuffle(q_weight, q_perm, bit) # torch.ops._C.gptq_shuffle(q_weight, q_perm, bit)
# trans_w16 # trans_w16
......
...@@ -8,10 +8,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase ...@@ -8,10 +8,7 @@ from vllm.model_executor.layers.linear import LinearBase, LinearMethodBase
from vllm.model_executor.layers.quantization.base_config import ( from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
try:
from lmslim import quant_ops as _ops
except Exception:
print("INFO: You need install lmslim if you want infer awq model.\n")
class AWQShareWorkSpace(): class AWQShareWorkSpace():
awqworkshapcesize=2<<29 # awqworkshapcesize=2<<29 #
...@@ -192,7 +189,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -192,7 +189,7 @@ class AWQLinearMethod(LinearMethodBase):
else: else:
padding_group=0 padding_group=0
out = _ops.awq_gemm(reshaped_x, out = ops.awq_gemm(reshaped_x,
qweight, qweight,
zeros_and_scales, zeros_and_scales,
m, m,
...@@ -203,7 +200,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -203,7 +200,7 @@ class AWQLinearMethod(LinearMethodBase):
AWQShareWorkSpace.awqworkshapce, AWQShareWorkSpace.awqworkshapce,
AWQShareWorkSpace.awqworkshapcesize) AWQShareWorkSpace.awqworkshapcesize)
#下面是采用rocblas的做法 #下面是采用rocblas的做法
# deqweight=_ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k] # deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight, # qweight,
# zeros_and_scales, # zeros_and_scales,
# k, # k,
......
...@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader): ...@@ -259,7 +259,7 @@ class DefaultModelLoader(BaseModelLoader):
for _, module in model.named_modules(): for _, module in model.named_modules():
quant_method = getattr(module, "quant_method", None) quant_method = getattr(module, "quant_method", None)
if quant_method is not None and quant_method!="awq": if quant_method is not None and quant_method!="awq" and quant_method!="gptq":
quant_method.process_weights_after_loading(module) quant_method.process_weights_after_loading(module)
# FIXME: Remove this after Mixtral is updated # FIXME: Remove this after Mixtral is updated
# to use quant_method. # to use quant_method.
......
...@@ -373,12 +373,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -373,12 +373,6 @@ class LlamaForCausalLM(nn.Module):
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
...@@ -490,7 +484,6 @@ class LlamaForCausalLM(nn.Module): ...@@ -490,7 +484,6 @@ class LlamaForCausalLM(nn.Module):
weight.data=weight.data.reshape(ori_shape[1], -1) weight.data=weight.data.reshape(ori_shape[1], -1)
if self.quant_method == "awq": if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
...@@ -514,9 +507,9 @@ class LlamaForCausalLM(nn.Module): ...@@ -514,9 +507,9 @@ class LlamaForCausalLM(nn.Module):
dim_k = qweight.data.shape[0] dim_k = qweight.data.shape[0]
pad_group=2 pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size)) _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n) sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz) zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw) qweight.data.copy_(_qw)
......
...@@ -251,11 +251,6 @@ class QWenLMHeadModel(nn.Module): ...@@ -251,11 +251,6 @@ class QWenLMHeadModel(nn.Module):
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
...@@ -351,7 +346,6 @@ class QWenLMHeadModel(nn.Module): ...@@ -351,7 +346,6 @@ class QWenLMHeadModel(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq": if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [ lay_key_words = [
"attn.c_attn.qweight", "attn.c_attn.qweight",
"attn.c_proj.qweight", "attn.c_proj.qweight",
...@@ -375,9 +369,9 @@ class QWenLMHeadModel(nn.Module): ...@@ -375,9 +369,9 @@ class QWenLMHeadModel(nn.Module):
dim_k = qweight.data.shape[0] dim_k = qweight.data.shape[0]
pad_group=2 pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size)) _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n) sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz) zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw) qweight.data.copy_(_qw)
......
...@@ -332,12 +332,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -332,12 +332,6 @@ class Qwen2ForCausalLM(nn.Module):
self.quant_method=quant_config.get_name() self.quant_method=quant_config.get_name()
self.quant_config=quant_config self.quant_config=quant_config
if self.quant_method=="awq":
try:
import lmslim
except ValueError as e:
raise RuntimeError("please install lmslim first for awq\n") from e
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
...@@ -439,8 +433,6 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -439,8 +433,6 @@ class Qwen2ForCausalLM(nn.Module):
weight.data=weight.data.reshape(ori_shape[1],-1) weight.data=weight.data.reshape(ori_shape[1],-1)
if self.quant_method == "awq": if self.quant_method == "awq":
from lmslim import quant_ops as _ops
lay_key_words = [ lay_key_words = [
"self_attn.qkv_proj.qweight", "self_attn.qkv_proj.qweight",
"self_attn.o_proj.qweight", "self_attn.o_proj.qweight",
...@@ -464,9 +456,9 @@ class Qwen2ForCausalLM(nn.Module): ...@@ -464,9 +456,9 @@ class Qwen2ForCausalLM(nn.Module):
dim_k = qweight.data.shape[0] dim_k = qweight.data.shape[0]
pad_group=2 pad_group=2
_qw, _sz=_ops.convert_s4(qweight,qzeros,scales,int(group_size)) _qw, _sz=ops.convert_s4(qweight,qzeros,scales,int(group_size))
sz = _ops.sz_permute(_sz).reshape(-1,dim_n) sz = ops.sz_permute(_sz).reshape(-1,dim_n)
zeros_and_scalse.data.copy_(sz) zeros_and_scalse.data.copy_(sz)
qweight.data.copy_(_qw) qweight.data.copy_(_qw)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment