"vscode:/vscode.git/clone" did not exist on "e0613702ade9ace874feabb7b6f080cdfd181f4b"
Commit 02b6f735 authored by zhuwenwen's avatar zhuwenwen
Browse files

Merge branch 'lmslim_awq' into 'v0.5.0-dtk24.04.1'

fix awq ckworkspace bug

See merge request dcutoolkit/deeplearing/vllm!9
parents 3e2c63a7 2d250236
...@@ -9,10 +9,22 @@ from vllm.model_executor.layers.quantization.base_config import ( ...@@ -9,10 +9,22 @@ from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig) QuantizationConfig)
from vllm.model_executor.utils import set_weight_attrs from vllm.model_executor.utils import set_weight_attrs
class AWQShareWorkSpace:
_instance = None
def __new__(cls, *args, **kwargs):
if cls._instance is None:
cls._instance = super(AWQShareWorkSpace, cls).__new__(cls, *args, **kwargs)
# 执行初始化
cls._instance._initialize()
return cls._instance
def _initialize(self):
self.awqworkshapcesize = 2 << 29
self.awqworkshapce = torch.zeros(self.awqworkshapcesize // 2 + 1, dtype=torch.float16).cuda()
#print("AWQShareWorkSpace _initialize\n")
#print("self.awqworkshapce.device:",self.awqworkshapce.device)
class AWQShareWorkSpace():
awqworkshapcesize=2<<29 #
awqworkshapce=torch.zeros(awqworkshapcesize//2+1,dtype=torch.float16).cuda()
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
...@@ -86,6 +98,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -86,6 +98,7 @@ class AWQLinearMethod(LinearMethodBase):
def __init__(self, quant_config: AWQConfig): def __init__(self, quant_config: AWQConfig):
self.quant_config = quant_config self.quant_config = quant_config
self.awqsingleton= AWQShareWorkSpace()
def create_weights(self, layer: torch.nn.Module, def create_weights(self, layer: torch.nn.Module,
input_size_per_partition: int, input_size_per_partition: int,
...@@ -197,8 +210,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -197,8 +210,8 @@ class AWQLinearMethod(LinearMethodBase):
k, k,
self.quant_config.group_size, self.quant_config.group_size,
padding_group, padding_group,
AWQShareWorkSpace.awqworkshapce, self.awqsingleton.awqworkshapce,
AWQShareWorkSpace.awqworkshapcesize) self.awqsingleton.awqworkshapcesize)
#下面是采用rocblas的做法 #下面是采用rocblas的做法
# deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k] # deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight, # qweight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment