Unverified Commit 18712d00 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Support fp32 / bf16 inference (#121)

parent 26e94d2e
...@@ -97,7 +97,16 @@ class WQLinear_GEMM(nn.Module): ...@@ -97,7 +97,16 @@ class WQLinear_GEMM(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features, )
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
...@@ -196,11 +205,18 @@ class WQLinear_GEMV(nn.Module): ...@@ -196,11 +205,18 @@ class WQLinear_GEMV(nn.Module):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features, )
inputs = x.reshape(-1, x.shape[-1]) inputs = x.reshape(-1, x.shape[-1])
input_dtype = inputs.dtype
if input_dtype != torch.float16:
inputs = inputs.half()
if inputs.shape[0] > 8: if inputs.shape[0] > 8:
out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters) out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
else: else:
out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size) out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment