Commit 8cd246bd authored by zhuwenwen's avatar zhuwenwen
Browse files

修改awq workspace 申请

parent 17928589
...@@ -183,6 +183,12 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int, ...@@ -183,6 +183,12 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
# quantization ops # quantization ops
# awq # awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor, def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int, zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor: thy: int) -> torch.Tensor:
......
...@@ -20,8 +20,8 @@ class AWQShareWorkSpace: ...@@ -20,8 +20,8 @@ class AWQShareWorkSpace:
return cls._instance return cls._instance
def _initialize(self): def _initialize(self):
self.awqworkshapcesize = 2 << 29 self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
self.awqworkshapce = torch.zeros(self.awqworkshapcesize // 2 + 1, dtype=torch.float16).cuda() self.awqworkshapce = ops.GetAWQShareWorkspace()
class AWQConfig(QuantizationConfig): class AWQConfig(QuantizationConfig):
...@@ -200,7 +200,8 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -200,7 +200,8 @@ class AWQLinearMethod(LinearMethodBase):
else: else:
padding_group=0 padding_group=0
out = ops.awq_gemm(reshaped_x, if m<4096:
out = ops.awq_gemm(reshaped_x,
qweight, qweight,
zeros_and_scales, zeros_and_scales,
m, m,
...@@ -210,15 +211,17 @@ class AWQLinearMethod(LinearMethodBase): ...@@ -210,15 +211,17 @@ class AWQLinearMethod(LinearMethodBase):
padding_group, padding_group,
self.awqsingleton.awqworkshapce, self.awqsingleton.awqworkshapce,
self.awqsingleton.awqworkshapcesize) self.awqsingleton.awqworkshapcesize)
#下面是采用rocblas的做法 else:
# deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k] #下面是采用rocblas的做法
# qweight, deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# zeros_and_scales, qweight,
# k, zeros_and_scales,
# n, k,
# self.quant_config.group_size) n,
# output=F.linear(reshaped_x, deqweight) self.quant_config.group_size)
out=F.linear(reshaped_x, deqweight[:,0:k])
if bias is not None: if bias is not None:
out.add_(bias) out.add_(bias)
return out.reshape(out_shape) return out.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment