"vscode:/vscode.git/clone" did not exist on "c29fb540ff90da720490daae58bb4bfe31a91125"
Commit cc4b902f authored by zhuwenwen's avatar zhuwenwen
Browse files

增加AWQ相关环境变量控制

parent 35a8304d
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
import torch import torch
import os
import torch.nn.functional as F import torch.nn.functional as F
from vllm import _custom_ops as ops from vllm import _custom_ops as ops
...@@ -98,6 +99,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -98,6 +99,8 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config self.quant_config = quant_config
self.awqsingleton= AWQShareWorkSpace() self.awqsingleton= AWQShareWorkSpace()
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
self.AWQ_CK_GEMMBS =int(os.getenv('AWQ_CK_GEMMBS', '20000'))
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -190,12 +193,15 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -190,12 +193,15 @@ class AWQLinearMethod(LinearMethodBase):
k = reshaped_x.shape[-1] k = reshaped_x.shape[-1]
n = qweight.shape[0] n = qweight.shape[0]
if k % 4096==0: if self.use_awq_pad:
if k % 4096 == 0:
padding_group=2 padding_group=2
else: else:
padding_group=0 padding_group=0
else:
padding_group=0
if m<20000: if m <= self.AWQ_CK_GEMMBS:
out = ops.awq_gemm(reshaped_x, out = ops.awq_gemm(reshaped_x,
qweight, qweight,
zeros_and_scales, zeros_and_scales,
...@@ -208,7 +214,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -208,7 +214,7 @@ class AWQLinearMethod(LinearMethodBase):
self.awqsingleton.awqworkshapcesize) self.awqsingleton.awqworkshapcesize)
else: else:
#下面是采用rocblas的做法 #下面是采用rocblas的做法
deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k] deqweight=ops.dequant_w4_gemm_colmajor( # shape[n, k/8] ---> [n,k]
qweight, qweight,
zeros_and_scales, zeros_and_scales,
k, k,
......
...@@ -30,10 +30,22 @@ def get_model_architecture( ...@@ -30,10 +30,22 @@ def get_model_architecture(
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
if os.getenv('FA_PAD') != '1': if os.getenv('FA_PAD') != '1':
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
try:
if os.getenv('AWQ_PAD') == '0' or ((torch.cuda.isCurrentDeviceEco(torch.cuda.current_device())) and os.getenv('AWQ_PAD') == None):
os.environ['AWQ_PAD'] = '0'
else:
os.environ['AWQ_PAD'] = '1'
except Exception as e:
print("Info: this version torch cannot get eco device info.\n")
if os.getenv('AWQ_PAD') != '0':
os.environ['AWQ_PAD'] = '1'
else:
os.environ['AWQ_PAD'] = '0'
else: else:
os.environ['LLAMA_NN'] = '0' os.environ['LLAMA_NN'] = '0'
os.environ['GEMM_PAD'] = '0' os.environ['GEMM_PAD'] = '0'
os.environ['FA_PAD'] = '0' os.environ['FA_PAD'] = '0'
os.environ['AWQ_PAD'] = '0'
# Special handling for quantized Mixtral. # Special handling for quantized Mixtral.
# FIXME(woosuk): This is a temporary hack. # FIXME(woosuk): This is a temporary hack.
......
...@@ -457,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -457,6 +457,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward( def forward(
self, self,
...@@ -633,7 +634,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA): ...@@ -633,7 +634,7 @@ class LlamaForCausalLM(nn.Module, SupportsLoRA):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size] zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8] qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0: if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
......
...@@ -903,6 +903,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -903,6 +903,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def _get_image_input_type( def _get_image_input_type(
self, self,
...@@ -1085,7 +1086,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal): ...@@ -1085,7 +1086,7 @@ class QWenLMHeadModel(nn.Module, SupportsMultiModal):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size] zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8] qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0: if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
......
...@@ -378,6 +378,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -378,6 +378,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
self.use_llama_nn = os.environ.get('LLAMA_NN') == '1' self.use_llama_nn = os.environ.get('LLAMA_NN') == '1'
self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1' self.use_gemm_pad = os.environ.get('GEMM_PAD') == '1'
self.use_fa_pad = os.environ.get('FA_PAD') == '1' self.use_fa_pad = os.environ.get('FA_PAD') == '1'
self.use_awq_pad = os.environ.get('AWQ_PAD') == '1'
def forward( def forward(
self, self,
...@@ -537,7 +538,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA): ...@@ -537,7 +538,7 @@ class Qwen2ForCausalLM(nn.Module, SupportsLoRA):
zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size] zeros_and_scalse.data=zeros_and_scalse.reshape(dim_n,-1) #[k/greop_size,n]------>[n,k/group_size]
qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8] qweight.data=qweight.data.reshape(dim_n,-1) #[k,n/8]---->[n,k/8]
if dim_k % 4096==0: if dim_k % 4096==0 and self.use_awq_pad:
zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda() zeros_and_scalse_pad= torch.zeros(dim_n,pad_group,dtype=torch.int32).cuda()
zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous() zeros_and_scalse.data=torch.cat((zeros_and_scalse.data,zeros_and_scalse_pad),dim=1).contiguous()
qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda() qweight_pad= torch.zeros(dim_n,int(group_size//4),dtype=torch.int32).cuda()
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment