Unverified Commit 969b290e authored by Younes Belkada's avatar Younes Belkada Committed by GitHub
Browse files

ENH / FIX: Few enhancements and fix for mixed-precision training (#348)

parent 2de6092a
...@@ -63,9 +63,16 @@ class WQLinearMMFunction(Function): ...@@ -63,9 +63,16 @@ class WQLinearMMFunction(Function):
def backward(ctx, grad_output): def backward(ctx, grad_output):
input, qweight, qzeros, scales, bias = ctx.saved_tensors input, qweight, qzeros, scales, bias = ctx.saved_tensors
if awq_ext is None:
raise ValueError(
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)
# Cast to correct dtype for mixed precision training
weights = awq_ext.dequantize_weights_cuda( weights = awq_ext.dequantize_weights_cuda(
qweight, scales, qzeros, 1, 0, 0, False qweight, scales, qzeros, 1, 0, 0, False
) ).to(grad_output.dtype)
if ctx.needs_input_grad[0]: if ctx.needs_input_grad[0]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm # 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
...@@ -75,7 +82,6 @@ class WQLinearMMFunction(Function): ...@@ -75,7 +82,6 @@ class WQLinearMMFunction(Function):
return grad_input, None, None, None, None, None, None, None return grad_input, None, None, None, None, None, None, None
class WQLinear_GEMM(nn.Module): class WQLinear_GEMM(nn.Module):
def __init__( def __init__(
self, w_bit, group_size, in_features, out_features, bias, dev, training=False self, w_bit, group_size, in_features, out_features, bias, dev, training=False
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment