Unverified Commit 149236e4 authored by Casper's avatar Casper Committed by GitHub
Browse files

Up to 60% faster context processing (#316)

parent c6c7b065
......@@ -153,6 +153,20 @@ class WQLinear_GEMM(nn.Module):
x = x.half()
if AWQ_INSTALLED:
FP16_MATMUL_HEURISTIC_CONDITION = x.shape[0]*x.shape[1] >= 1024
if FP16_MATMUL_HEURISTIC_CONDITION:
out = awq_ext.dequantize_weights_cuda(
self.qweight,
self.scales,
self.qzeros,
0,
0,
0,
False
)
out = torch.matmul(x, out)
else:
out = awq_ext.gemm_forward_cuda(
x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8
)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment