Unverified Commit 149236e4 authored by Casper's avatar Casper Committed by GitHub
Browse files

Up to 60% faster context processing (#316)

parent c6c7b065
...@@ -153,9 +153,23 @@ class WQLinear_GEMM(nn.Module): ...@@ -153,9 +153,23 @@ class WQLinear_GEMM(nn.Module):
x = x.half() x = x.half()
if AWQ_INSTALLED: if AWQ_INSTALLED:
out = awq_ext.gemm_forward_cuda( FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
) if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
self.qweight,
self.scales,
self.qzeros,
0,
0,
0,
False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
else: else:
out = dequantize_gemm( out = dequantize_gemm(
self.qweight, self.qweight,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment