Commit cc4b902f authored by zhuwenwen's avatar zhuwenwen
Browse files

增加AWQ相关环境变量控制

parent 35a8304d
from typing import Any, Dict, List, Optional
import torch
import os
import torch.nn.functional as F
from vllm import _custom_ops as ops
......@@ -98,6 +99,8 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config
self.awqsingleton= AWQShareWorkSpace()
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.AWQ_CK_GEMMBS =int(os.getenv('AWQ_CK_GEMMBS', '20000'))
def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int,
......@@ -190,12 +193,15 @@ class AWQLinearMethod(LinearMethodBase):
k = reshaped_x.shape[-1]
n = qweight.shape[0]
if k % 4096==0:
padding_group=2
if self.use_awq_pad:
if k % 4096 == 0:
padding_group=2
else:
padding_group=0
else:
padding_group=0
if m<20000:
if m <= self.AWQ_CK_GEMMBS:
out = ops.awq_gemm(reshaped_x,
qweight,
zeros_and_scales,
......@@ -208,7 +214,7 @@ class AWQLinearMethod(LinearMethodBase):
self.awqsingleton.awqworkshapcesize)
else:
#下面是采用rocblas的做法
deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
deqweight=ops.dequant_w4_gemm_colmajor( # shape[n, k/8] ---> [n,k]
qweight,
zeros_and_scales,
k,
......
......@@ -30,10 +30,22 @@ def get_model_architecture(
os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0'
try:
if os.getenv('AWQ_PAD') == '0' or ((torch.cuda.isCurrentDeviceEco(torch.cuda.current_device())) and os.getenv('AWQ_PAD') == None):
os.environ['AWQ_PAD'] = '0'
else:
os.environ['AWQ_PAD'] = '1'
except Exception as e:
print("Info: this version torch cannot get eco device info.\n")
if os.getenv('AWQ_PAD') != '0':
os.environ['AWQ_PAD'] = '1'
else:
os.environ['AWQ_PAD'] = '0'
else:
os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0'
os.environ['AWQ_PAD'] = '0'
# Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack.
......
......@@ -457,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward(
self,
......@@ -633,7 +634,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0:
if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
......
......@@ -903,6 +903,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def _get_image_input_type(
self,
......@@ -1085,7 +1086,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0:
if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
......
......@@ -378,6 +378,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward(
self,
......@@ -537,7 +538,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0:
if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment