Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
969b290e
Unverified
Commit
969b290e
authored
Feb 16, 2024
by
Younes Belkada
Committed by
GitHub
Feb 16, 2024
Browse files
ENH / FIX: Few enhancements and fix for mixed-precision training (#348)
parent
2de6092a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
8 additions
and
2 deletions
+8
-2
awq/modules/linear/gemm.py
awq/modules/linear/gemm.py
+8
-2
No files found.
awq/modules/linear/gemm.py
View file @
969b290e
...
@@ -63,9 +63,16 @@ class WQLinearMMFunction(Function):
...
@@ -63,9 +63,16 @@ class WQLinearMMFunction(Function):
def
backward
(
ctx
,
grad_output
):
def
backward
(
ctx
,
grad_output
):
input
,
qweight
,
qzeros
,
scales
,
bias
=
ctx
.
saved_tensors
input
,
qweight
,
qzeros
,
scales
,
bias
=
ctx
.
saved_tensors
if
awq_ext
is
None
:
raise
ValueError
(
"auto-awq kernels is needed to be installed to use `.backward()`. Make sure to install the auto-awq kernels"
" by following the installation guides in https://github.com/casper-hansen/AutoAWQ_kernels"
)
# Cast to correct dtype for mixed precision training
weights
=
awq_ext
.
dequantize_weights_cuda
(
weights
=
awq_ext
.
dequantize_weights_cuda
(
qweight
,
scales
,
qzeros
,
1
,
0
,
0
,
False
qweight
,
scales
,
qzeros
,
1
,
0
,
0
,
False
)
)
.
to
(
grad_output
.
dtype
)
if
ctx
.
needs_input_grad
[
0
]:
if
ctx
.
needs_input_grad
[
0
]:
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
# 3D matmul using torch.bmm: https://pytorch.org/docs/stable/generated/torch.bmm.html#torch.bmm
...
@@ -75,7 +82,6 @@ class WQLinearMMFunction(Function):
...
@@ -75,7 +82,6 @@ class WQLinearMMFunction(Function):
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
return
grad_input
,
None
,
None
,
None
,
None
,
None
,
None
,
None
class
WQLinear_GEMM
(
nn
.
Module
):
class
WQLinear_GEMM
(
nn
.
Module
):
def
__init__
(
def
__init__
(
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
,
training
=
False
self
,
w_bit
,
group_size
,
in_features
,
out_features
,
bias
,
dev
,
training
=
False
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment