Commit 8376cc41 authored by gaoqiong's avatar gaoqiong
Browse files

修改awq workspace 申请

parent 3b2b3046
...@@ -143,6 +143,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, ...@@ -143,6 +143,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
# quantization ops # quantization ops
# awq # awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
......
...@@ -20,12 +20,8 @@ class AWQShareWorkSpace: ...@@ -20,12 +20,8 @@ class AWQShareWorkSpace:
return cls._instance return cls._instance
def _initialize(self): def _initialize(self):
self.awqworkshapcesize = 2 << 29 self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
self.awqworkshapce = torch.zeros(self.awqworkshapcesize // 2 + 1, dtype=torch.float16).cuda() self.awqworkshapce = ops.GetAWQShareWorkspace()
#print("AWQShareWorkSpace _initialize\n")
#print("self.awqworkshapce.device:",self.awqworkshapce.device)
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
"""Config class for AWQ. """Config class for AWQ.
...@@ -202,6 +198,7 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -202,6 +198,7 @@ class AWQLinearMethod(LinearMethodBase):
else: else:
padding_group=0 padding_group=0
if m<4096:
out = ops.awq_gemm(reshaped_x, out = ops.awq_gemm(reshaped_x,
qweight, qweight,
zeros_and_scales, zeros_and_scales,
...@@ -212,14 +209,15 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -212,14 +209,15 @@ class AWQLinearMethod(LinearMethodBase):
padding_group, padding_group,
self.awqsingleton.awqworkshapce, self.awqsingleton.awqworkshapce,
self.awqsingleton.awqworkshapcesize) self.awqsingleton.awqworkshapcesize)
else:
#下面是采用rocblas的做法 #下面是采用rocblas的做法
# deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k] deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight, qweight,
# zeros_and_scales, zeros_and_scales,
# k, k,
# n, n,
# self.quant_config.group_size) self.quant_config.group_size)
# output=F.linear(reshaped_x, deqweight) output=F.linear(reshaped_x, deqweight[:,0:k])
if bias is not None: if bias is not None:
out.add_(bias) out.add_(bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment