Unverified Commit 18712d00 authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

[`core`] Support fp32 / bf16 inference (#121)

parent 26e94d2e
...@@ -97,7 +97,16 @@ class WQLinear_GEMM(nn.Module): ...@@ -97,7 +97,16 @@ class WQLinear_GEMM(nn.Module):
@torch.no_grad() @torch.no_grad()
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features, )
input_dtype = x.dtype
if input_dtype != torch.float16:
x = x.half()
out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8) out = awq_inference_engine.gemm_forward_cuda(x.reshape(-1, x.shape[-1]), self.qweight, self.scales, self.qzeros, 8)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
...@@ -195,11 +204,18 @@ class WQLinear_GEMV(nn.Module): ...@@ -195,11 +204,18 @@ class WQLinear_GEMV(nn.Module):
def forward(self, x): def forward(self, x):
out_shape = x.shape[:-1] + (self.out_features, ) out_shape = x.shape[:-1] + (self.out_features, )
inputs = x.reshape(-1, x.shape[-1]) inputs = x.reshape(-1, x.shape[-1])
input_dtype = inputs.dtype
if input_dtype != torch.float16:
inputs = inputs.half()
if inputs.shape[0] > 8: if inputs.shape[0] > 8:
out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters) out = awq_inference_engine.gemmv2_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size, self.split_k_iters)
else: else:
out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size) out = awq_inference_engine.gemv_forward_cuda(inputs, self.qweight, self.scales, self.qzeros, self.group_size)
if input_dtype != torch.float16:
out = out.to(dtype=input_dtype)
out = out + self.bias if self.bias is not None else out out = out + self.bias if self.bias is not None else out
return out.reshape(out_shape) return out.reshape(out_shape)
...@@ -207,4 +223,4 @@ class WQLinear_GEMV(nn.Module): ...@@ -207,4 +223,4 @@ class WQLinear_GEMV(nn.Module):
def extra_repr(self) -> str: def extra_repr(self) -> str:
return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format( return 'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'.format(
self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size self.in_features, self.out_features, self.bias is not None, self.w_bit, self.group_size
) )
\ No newline at end of file
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment