Commit f3a71d1d authored by Casper Hansen's avatar Casper Hansen
Browse files

Use GEMM v2 kernel for context processing

parent 2fa3a5d1
......@@ -194,7 +194,13 @@ class WQLinear_GEMV(nn.Module):
@torch.no_grad()
def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, )
out = awq_inference_engine.gemv_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, self.group_size)
inputs = x.reshape(-1, x.shape[-1])
if inputs.shape[0] > 8:
out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
else:
out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment