"docs/vscode:/vscode.git/clone" did not exist on "610a3efcafed501ea90ee87e114f22ffe9348cd9"
Commit 8376cc41 authored by gaoqiong's avatar gaoqiong
Browse files

修改awq workspace 申请

parent 3b2b3046
...@@ -143,6 +143,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor, ...@@ -143,6 +143,12 @@ def fused_add_rms_norm(input: torch.Tensor, residual: torch.Tensor,
# quantization ops # quantization ops
# awq # awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
......
...@@ -20,12 +20,8 @@ class AWQShareWorkSpace: ...@@ -20,12 +20,8 @@ class AWQShareWorkSpace:
return cls._instance return cls._instance
def _initialize(self): def _initialize(self):
self.awqworkshapcesize = 2 << 29 self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
self.awqworkshapce = torch.zeros(self.awqworkshapcesize // 2 + 1, dtype=torch.float16).cuda() self.awqworkshapce = ops.GetAWQShareWorkspace()
#print("AWQShareWorkSpace _initialize\n")
#print("self.awqworkshapce.device:",self.awqworkshapce.device)
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
"""Config class for AWQ. """Config class for AWQ.
...@@ -201,8 +197,9 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -201,8 +197,9 @@ class AWQLinearMethod(LinearMethodBase):
padding_group=2 padding_group=2
else: else:
padding_group=0 padding_group=0
out = ops.awq_gemm(reshaped_x, if m<4096:
out = ops.awq_gemm(reshaped_x,
qweight, qweight,
zeros_and_scales, zeros_and_scales,
m, m,
...@@ -212,14 +209,15 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -212,14 +209,15 @@ class AWQLinearMethod(LinearMethodBase):
padding_group, padding_group,
self.awqsingleton.awqworkshapce, self.awqsingleton.awqworkshapce,
self.awqsingleton.awqworkshapcesize) self.awqsingleton.awqworkshapcesize)
#下面是采用rocblas的做法 else:
# deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k] #下面是采用rocblas的做法
# qweight, deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# zeros_and_scales, qweight,
# k, zeros_and_scales,
# n, k,
# self.quant_config.group_size) n,
# output=F.linear(reshaped_x, deqweight) self.quant_config.group_size)
output=F.linear(reshaped_x, deqweight[:,0:k])
if bias is not None: if bias is not None:
out.add_(bias) out.add_(bias)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment