Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
2350a4d0
"sgl-kernel/python/vscode:/vscode.git/clone" did not exist on "d738ab52f86f65806fc6549270a704e34aae5a32"
Unverified
Commit
2350a4d0
authored
Dec 16, 2023
by
Younes Belkada
Committed by
GitHub
Dec 16, 2023
Browse files
Fix quantization issue with transformers >= 4.36.0 (#264)
parent
9c3dfa07
Changes
1
Show whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
32 additions
and
3 deletions
+32
-3
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+32
-3
No files found.
awq/quantize/quantizer.py
View file @
2350a4d0
import
torch
import
inspect
import
logging
import
functools
import
torch.nn
as
nn
...
...
@@ -170,14 +171,16 @@ class AwqQuantizer:
# [STEP 3]: Compute output of module
with
torch
.
no_grad
():
fp16_output
=
module2inspect
(
inp
,
**
kwargs
)
module_kwargs
=
self
.
_sanitize_kwargs
(
kwargs
,
module2inspect
)
fp16_output
=
module2inspect
(
inp
,
**
module_kwargs
)
if
isinstance
(
fp16_output
,
tuple
):
fp16_output
=
fp16_output
[
0
]
# [STEP 4]: Compute loss
best_scales
=
self
.
_compute_best_scale
(
inp
,
w_max
,
x_max
,
module2inspect
,
layers
,
fp16_output
,
kwargs
layers
,
fp16_output
,
module_
kwargs
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
best_scales
)
...
...
@@ -390,10 +393,36 @@ class AwqQuantizer:
feat_dict
=
input_feat
)))
self
.
inps
=
self
.
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
self
.
inps
=
layer
(
self
.
inps
,
**
self
.
module_kwargs
)[
0
]
# Sanitize the kwargs in case we use transformers version that contains
# kwargs that are not handled by the module.
# Useful for trust_remote_code models.
module_kwargs
=
self
.
_sanitize_kwargs
(
self
.
module_kwargs
,
layer
)
self
.
inps
=
layer
(
self
.
inps
,
**
module_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
return
input_feat
def
_sanitize_kwargs
(
self
,
inputs_kwargs
,
module
):
"""
Remove the arguments that are not supported in the module's
forward pass to avoid breaking behaviour between different versions
of transformers.
Args:
inputs_kwargs (`dict`):
The input dictionary to pass to the model layer
module (`torch.nn.Module`):
Target module to quantize.
"""
module_signature
=
inspect
.
signature
(
module
.
forward
).
parameters
sanitized_kwargs
=
{}
for
k
,
v
in
inputs_kwargs
.
items
():
if
k
in
module_signature
:
sanitized_kwargs
[
k
]
=
v
return
sanitized_kwargs
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment