Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
c9d01ac3
Commit
c9d01ac3
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Better comments. Implement cosine similarity
parent
e205548d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
58 additions
and
44 deletions
+58
-44
awq/models/base.py
awq/models/base.py
+2
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+56
-42
No files found.
awq/models/base.py
View file @
c9d01ac3
...
@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -38,13 +38,13 @@ class BaseAWQForCausalLM(nn.Module):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
split
=
"train"
,
text_column
=
"text"
,
loss_objective
=
'mse'
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
quantizer
=
AwqQuantizer
(
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
quant_config
[
"w_bit"
],
quant_config
[
"q_group_size"
],
self
,
self
.
model
,
tokenizer
,
quant_config
[
"w_bit"
],
quant_config
[
"q_group_size"
],
quant_config
[
"version"
],
calib_data
,
split
,
text_column
quant_config
[
"version"
],
calib_data
,
split
,
text_column
,
loss_objective
)
)
quantizer
.
quantize
()
quantizer
.
quantize
()
self
.
is_quantized
=
True
self
.
is_quantized
=
True
...
...
awq/quantize/quantizer.py
View file @
c9d01ac3
...
@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
...
@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class
AwqQuantizer
:
class
AwqQuantizer
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
)
->
None
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
,
loss_objective
=
'mse'
)
->
None
:
self
.
awq_model
=
awq_model
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
tokenizer
=
tokenizer
...
@@ -23,6 +24,7 @@ class AwqQuantizer:
...
@@ -23,6 +24,7 @@ class AwqQuantizer:
self
.
split
=
split
self
.
split
=
split
self
.
text_column
=
text_column
self
.
text_column
=
text_column
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
self
.
loss_objective
=
loss_objective
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
org_w_shape
=
w
.
shape
org_w_shape
=
w
.
shape
...
@@ -74,37 +76,39 @@ class AwqQuantizer:
...
@@ -74,37 +76,39 @@ class AwqQuantizer:
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 4]: Quantize weights
# [STEP 4]: Quantize weights
for
name
,
linear_layer
in
named_linears
.
items
():
self
.
_apply_quant
(
self
.
modules
[
i
],
named_linears
)
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
clear_memory
()
linear_layer
=
linear_layer
.
cuda
().
half
()
def
_apply_quant
(
self
,
module
,
named_linears
:
dict
[
str
,
nn
.
Linear
]):
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
for
name
,
linear_layer
in
named_linears
.
items
():
linear_layer
.
weight
.
data
,
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
get_scale_zp
=
True
linear_layer
=
linear_layer
.
cuda
().
half
()
)
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
if
self
.
version
==
'GEMM'
:
linear_layer
.
weight
.
data
,
scales
=
scales
.
t
().
contiguous
()
get_scale_zp
=
True
zeros
=
zeros
.
t
().
contiguous
()
)
q_linear_module
=
WQLinear_GEMM
if
self
.
version
==
'GEMM'
:
elif
self
.
version
==
'GEMV'
:
scales
=
scales
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMV
zeros
=
zeros
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMM
q_linear
=
q_linear_module
.
from_linear
(
linear
=
linear_layer
,
elif
self
.
version
==
'GEMV'
:
w_bit
=
self
.
w_bit
,
q_linear_module
=
WQLinear_GEMV
group_size
=
self
.
group_size
,
init_only
=
False
,
scales
=
scales
,
zeros
=
zeros
)
linear_layer
.
cpu
()
q_linear
.
to
(
next
(
self
.
modules
[
i
].
parameters
()).
device
)
set_op_by_name
(
self
.
modules
[
i
],
name
,
q_linear
)
clear_memory
()
q_linear
=
q_linear_module
.
from_linear
(
linear
=
linear_layer
,
w_bit
=
self
.
w_bit
,
group_size
=
self
.
group_size
,
init_only
=
False
,
scales
=
scales
,
zeros
=
zeros
)
linear_layer
.
cpu
()
q_linear
.
to
(
next
(
module
.
parameters
()).
device
)
set_op_by_name
(
module
,
name
,
q_linear
)
clear_memory
()
clear_memory
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
...
@@ -133,19 +137,20 @@ class AwqQuantizer:
...
@@ -133,19 +137,20 @@ class AwqQuantizer:
# [STEP 3]: Compute output of module
# [STEP 3]: Compute output of module
with
torch
.
no_grad
():
with
torch
.
no_grad
():
org_o
ut
=
module2inspect
(
inp
,
**
kwargs
)
fp16_outp
ut
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
org_o
ut
,
tuple
):
if
isinstance
(
fp16_outp
ut
,
tuple
):
org_out
=
org_o
ut
[
0
]
fp16_output
=
fp16_outp
ut
[
0
]
# [STEP 4]: Compute loss
# [STEP 4]: Compute loss
best_scales
=
self
.
_compute_best_scale
(
best_scales
=
self
.
_compute_best_scale
(
inp
,
w_max
,
x_max
,
module2inspect
,
inp
,
w_max
,
x_max
,
module2inspect
,
layers
,
org_o
ut
,
kwargs
layers
,
fp16_outp
ut
,
kwargs
)
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
best_scales
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
best_scales
)
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
module2inspect
,
linears2scale
:
list
[
nn
.
Linear
],
org_out
,
kwargs
=
{}):
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
module2inspect
,
linears2scale
:
list
[
nn
.
Linear
],
fp16_output
,
kwargs
=
{}):
"""
"""
Compute loss and select best scales
Compute loss and select best scales
...
@@ -170,20 +175,29 @@ class AwqQuantizer:
...
@@ -170,20 +175,29 @@ class AwqQuantizer:
for
ratio
in
range
(
n_grid
):
for
ratio
in
range
(
n_grid
):
# create new scales
# create new scales
ratio
=
ratio
/
n_grid
ratio
=
ratio
/
n_grid
# s^-1
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
)
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales_view
=
scales
.
view
(
1
,
-
1
).
to
(
device
)
scales_view
=
scales
.
view
(
1
,
-
1
).
to
(
device
)
# NOTE: s^-1 * x is fused here, according to paper
for
fc
in
linears2scale
:
for
fc
in
linears2scale
:
# Q(W * s)
fc
.
weight
.
mul_
(
scales_view
)
fc
.
weight
.
mul_
(
scales_view
)
fc
.
weight
.
data
=
self
.
pseudo_quantize_tensor
(
fc
.
weight
.
data
)
/
scales_view
fc
.
weight
.
data
=
self
.
pseudo_quantize_tensor
(
fc
.
weight
.
data
)
/
scales_view
out
=
module2inspect
(
x
,
**
kwargs
)
# W * X
if
isinstance
(
out
,
tuple
):
int_w_output
=
module2inspect
(
x
,
**
kwargs
)
out
=
out
[
0
]
if
isinstance
(
int_w_output
,
tuple
):
int_w_output
=
int_w_output
[
0
]
if
self
.
loss_objective
==
'mse'
:
# (L2 norm)
loss
=
(
fp16_output
-
int_w_output
).
float
().
pow
(
2
).
mean
().
item
()
# NOTE: float prevents overflow
elif
self
.
loss_objective
==
'cosine'
:
loss
=
-
nn
.
functional
.
cosine_similarity
(
fp16_output
,
int_w_output
).
mean
().
item
()
# measure loss and check if better than best
loss
=
(
org_out
-
out
).
float
().
pow
(
2
).
mean
().
item
()
# NOTE: float prevents overflow
history
.
append
(
loss
)
history
.
append
(
loss
)
if
loss
<
best_error
:
if
loss
<
best_error
:
best_error
=
loss
best_error
=
loss
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment