Commit 8cd246bd authored by zhuwenwen's avatar zhuwenwen
Browse files

修改awq workspace 申请

parent 17928589
......@@ -183,6 +183,12 @@ def advance_step(num_seqs: int, num_queries: int, block_size: int,
# quantization ops
# awq
def GetAWQShareWorkspaceSize()->int:
return quant_ops.GetAWQShareWorkspaceSize()
def GetAWQShareWorkspace()->torch.Tensor:
return quant_ops.GetAWQShareWorkspace()
def awq_dequantize(qweight: torch.Tensor, scales: torch.Tensor,
zeros: torch.Tensor, split_k_iters: int, thx: int,
thy: int) -> torch.Tensor:
......
......@@ -20,8 +20,8 @@ class AWQShareWorkSpace:
return cls._instance
def _initialize(self):
self.awqworkshapcesize = 2 << 29
self.awqworkshapce = torch.zeros(self.awqworkshapcesize // 2 + 1, dtype=torch.float16).cuda()
self.awqworkshapcesize = ops.GetAWQShareWorkspaceSize()
self.awqworkshapce = ops.GetAWQShareWorkspace()
class AWQConfig(QuantizationConfig):
......@@ -200,7 +200,8 @@ class AWQLinearMethod(LinearMethodBase):
else:
padding_group=0
out = ops.awq_gemm(reshaped_x,
if m<4096:
out = ops.awq_gemm(reshaped_x,
qweight,
zeros_and_scales,
m,
......@@ -210,15 +211,17 @@ class AWQLinearMethod(LinearMethodBase):
padding_group,
self.awqsingleton.awqworkshapce,
self.awqsingleton.awqworkshapcesize)
#下面是采用rocblas的做法
# deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
# qweight,
# zeros_and_scales,
# k,
# n,
# self.quant_config.group_size)
# output=F.linear(reshaped_x, deqweight)
else:
#下面是采用rocblas的做法
deqweight=ops.dequant_w4_gemm_colmajor( #shape[n,k/8]--->[n,k]
qweight,
zeros_and_scales,
k,
n,
self.quant_config.group_size)
out=F.linear(reshaped_x, deqweight[:,0:k])
if bias is not None:
out.add_(bias)
return out.reshape(out_shape)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment