Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
f713b888
Unverified
Commit
f713b888
authored
Mar 02, 2024
by
Oscar Savolainen
Committed by
GitHub
Mar 02, 2024
Browse files
x_max -> x_mean and w_max -> w_mean name changes and some comments (#378)
parent
d9dc8e56
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
17 additions
and
11 deletions
+17
-11
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+17
-11
No files found.
awq/quantize/quantizer.py
View file @
f713b888
...
...
@@ -244,17 +244,23 @@ class AwqQuantizer:
# Put x on the right device
inp
=
inp
.
to
(
next
(
module2inspect
.
parameters
()).
device
)
# [STEP 1]: Compute maximum of weight
# [STEP 1]: Compute per-channel mean of normalised weights
# All layer weights are concatted together
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
layers
],
dim
=
0
)
org_shape
=
weight
.
shape
# The weights are reshaped to be organised by quantization group
weight
=
weight
.
view
(
-
1
,
self
.
group_size
)
# Calculates the relative magnitude of the weights within each of the quantization groups,
# and rescales each group individually so that each group has weights on a 0-1 scale.
w_scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
# Resizes the rescaled weight matrix back up to its original dimensions
w_scale
=
w_scale
.
view
(
org_shape
)
w_max
=
w_scale
.
mean
(
0
)
# Gets the average rescaled magnitude for each output channel
w_mean
=
w_scale
.
mean
(
0
)
clear_memory
(
weight
)
# [STEP 2]: Compute
maximum of x
x_m
ax
=
inp
.
abs
().
view
(
-
1
,
inp
.
shape
[
-
1
]).
mean
(
0
)
# [STEP 2]: Compute
per-channel mean of the input activation
x_m
ean
=
inp
.
abs
().
view
(
-
1
,
inp
.
shape
[
-
1
]).
mean
(
0
)
# [STEP 3]: Compute output of module
with
torch
.
no_grad
():
...
...
@@ -266,7 +272,7 @@ class AwqQuantizer:
# [STEP 4]: Compute loss
best_scales
=
self
.
_compute_best_scale
(
inp
,
w_m
ax
,
x_m
ax
,
module2inspect
,
layers
,
fp16_output
,
module_kwargs
inp
,
w_m
ean
,
x_m
ean
,
module2inspect
,
layers
,
fp16_output
,
module_kwargs
)
return
(
...
...
@@ -278,8 +284,8 @@ class AwqQuantizer:
def
_compute_best_scale
(
self
,
x
,
w_m
ax
,
x_m
ax
,
w_m
ean
,
x_m
ean
,
module2inspect
,
linears2scale
:
List
[
nn
.
Linear
],
fp16_output
,
...
...
@@ -303,8 +309,8 @@ class AwqQuantizer:
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
module2inspect
.
state_dict
().
items
()}
device
=
x
.
device
x_m
ax
=
x_m
ax
.
view
(
-
1
).
to
(
device
)
w_m
ax
=
w_m
ax
.
view
(
-
1
).
to
(
device
)
x_m
ean
=
x_m
ean
.
view
(
-
1
).
to
(
device
)
w_m
ean
=
w_m
ean
.
view
(
-
1
).
to
(
device
)
for
ratio
in
range
(
n_grid
):
# create new scales
...
...
@@ -312,9 +318,9 @@ class AwqQuantizer:
# NOTE: s^-1 * x is fused here, according to paper
if
self
.
duo_scaling
:
scales
=
(
x_m
ax
.
pow
(
ratio
)
/
w_m
ax
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
)
scales
=
(
x_m
ean
.
pow
(
ratio
)
/
w_m
ean
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
)
else
:
scales
=
x_m
ax
.
pow
(
ratio
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
x_m
ean
.
pow
(
ratio
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales_view
=
scales
.
view
(
1
,
-
1
).
to
(
device
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment