Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
45c22ee5
Commit
45c22ee5
authored
Sep 20, 2023
by
Casper
Browse files
Initial quantization refactoring
parent
a5e8b048
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
434 additions
and
40 deletions
+434
-40
awq/quantize/apply_quantized.py
awq/quantize/apply_quantized.py
+112
-0
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+315
-40
awq/utils/utils.py
awq/utils/utils.py
+7
-0
No files found.
awq/quantize/apply_quantized.py
0 → 100644
View file @
45c22ee5
import
torch
import
torch.nn
as
nn
from
typing
import
Tuple
from
awq.modules.act
import
ScaledActivation
from
transformers.activations
import
NewGELUActivation
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
]
allowed_act_fns
=
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
@
torch
.
no_grad
()
def
apply_clip
(
module
,
clip_list
:
Tuple
[
str
,
torch
.
Tensor
]):
for
name
,
max_val
in
clip_list
:
layer
:
nn
.
Linear
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
def
apply_scale
(
module
,
scales_list
,
input_feat_dict
=
None
):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
allowed_norms
)
\
or
'rmsnorm'
in
str
(
prev_op
.
__class__
).
lower
():
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
allowed_act_fns
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
# apply the scaling to input feat if given; prepare it for clipping
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
@
torch
.
no_grad
()
def
scale_ln_fcs
(
ln
,
fcs
,
scales
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
scales
=
scales
.
to
(
ln
.
weight
.
device
)
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
for
fc
in
fcs
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
ln
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
fc
in
fcs
:
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
fc2
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
fc1
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
p
in
fc2
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
any
(
isinstance
(
gelu
,
t
)
for
t
in
allowed_act_fns
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
\ No newline at end of file
awq/quantize/quantizer.py
View file @
45c22ee5
import
torch
import
torch
import
logging
import
functools
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
awq.utils.utils
import
clear_memory
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.quantize.apply_quantized
import
apply_scale
,
apply_clip
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
AwqQuantizer
:
def
__init__
(
self
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
)
->
None
:
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
self
.
version
=
version
self
.
calib_data
=
calib_data
self
.
split
=
split
self
.
text_column
=
text_column
self
.
modules
,
self
.
module_kwargs
=
self
.
init_quant
()
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
def
pseudo_quantize_tensor
(
w
,
w_bit
=
4
,
org_w_shape
=
w
.
shape
zero_point
=
True
,
if
self
.
group_size
>
0
:
q_group_size
=-
1
,
assert
org_w_shape
[
-
1
]
%
self
.
group_size
==
0
inplace
=
False
,
w
=
w
.
reshape
(
-
1
,
self
.
group_size
)
get_scale_zp
=
False
assert
w
.
dim
()
==
2
):
org_w_shape
=
w
.
shape
# zero point quantization
if
q_group_size
>
0
:
assert
org_w_shape
[
-
1
]
%
q_group_size
==
0
w
=
w
.
reshape
(
-
1
,
q_group_size
)
assert
w
.
dim
()
==
2
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
w_bit
-
1
max_int
=
2
**
self
.
w_bit
-
1
min_int
=
0
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
else
:
# we actually never used this
assert
min_val
is
None
assert
torch
.
isnan
(
scales
).
sum
()
==
0
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
assert
torch
.
isnan
(
w
).
sum
()
==
0
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
w_bit
-
1
)
-
1
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
min_int
=
-
2
**
(
w_bit
-
1
)
assert
torch
.
isnan
(
w
).
sum
()
==
0
scales
=
max_val
/
max_int
zeros
=
0
w
=
w
.
reshape
(
org_w_shape
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
if
get_scale_zp
:
assert
torch
.
isnan
(
w
).
sum
()
==
0
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
else
:
if
inplace
:
return
w
((
w
.
div_
(
scales
).
round_
().
add_
(
zeros
)).
clamp_
(
min_int
,
max_int
).
sub_
(
zeros
)).
mul_
(
scales
)
def
quantize
(
self
,
get_layers_for_scaling
:
function
):
else
:
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
""
):
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
# [STEP 1]: Get layer, extract linear modules, extract input features
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
self
.
modules
[
i
]
=
self
.
modules
[
i
].
cuda
()
assert
torch
.
isnan
(
w
).
sum
()
==
0
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
input_feat
=
self
.
_get_input_feat
(
self
.
modules
[
i
],
named_linears
)
w
=
w
.
reshape
(
org_w_shape
)
clear_memory
()
if
get_scale_zp
:
# [STEP 2]: Compute and apply scale list
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
module_config
:
list
[
dict
]
=
get_layers_for_scaling
(
else
:
self
.
modules
[
i
],
input_feat
,
self
.
module_kwargs
return
w
)
scales_list
=
[
self
.
_search_best_scale
(
**
layer
)
for
layer
in
module_config
]
apply_scale
(
self
.
modules
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
scales_list
=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 3]: Compute and apply clipping list
clip_list
=
self
.
_search_best_clip
(
named_linears
,
input_feat
)
apply_clip
(
self
.
modules
[
i
],
clip_list
)
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 4]: Quantize weights
for
name
,
linear_layer
in
named_linears
.
items
():
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
linear_layer
.
weight
.
data
,
get_scale_zp
=
True
)
if
self
.
version
==
'GEMM'
:
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMM
elif
self
.
version
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
linear
=
linear_layer
,
w_bit
=
self
.
w_bit
,
group_size
=
self
.
group_size
,
init_only
=
False
,
scales
=
scales
,
zeros
=
zeros
)
linear_layer
.
cpu
()
q_linear
.
to
(
next
(
self
.
modules
[
i
].
parameters
()).
device
)
set_op_by_name
(
self
.
modules
[
i
],
name
,
q_linear
)
clear_memory
()
clear_memory
()
return
self
.
model
@
torch
.
no_grad
()
def
_search_best_clip
(
self
,
layer
,
named_linears
,
input_feat
):
clip_list
=
[]
avoid_clipping
=
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
avoid_clipping
]):
continue
named_linears
[
name
].
cuda
()
max_val
=
self
.
_compute_best_clip
(
named_linears
[
name
].
weight
,
input_feat
[
name
])
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
@
torch
.
no_grad
()
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
self
.
group_size
if
self
.
group_size
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
self
.
pseudo_quantize_tensor
(
cur_w
)
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
clear_memory
(
input_feat
)
clear_memory
(
org_out
)
return
best_max_val
.
squeeze
(
1
)
@
torch
.
no_grad
()
def
_search_best_scale
(
self
,
previous_layer
,
linears2scale
:
list
[
nn
.
Linear
],
x
:
torch
.
Tensor
,
kwargs
=
{}):
# Put x on the right device
x
=
x
.
to
(
next
(
previous_layer
.
parameters
()).
device
)
# [STEP 1]: Compute maximum of weight
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
weight
=
weight
.
view
(
-
1
,
self
.
group_size
)
w_max
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_max
=
w_max
.
view
(
weight
.
shape
)
w_max
=
w_max
.
mean
(
0
)
clear_memory
(
weight
)
# [STEP 2]: Compute maximum of x
x_max
=
x
.
abs
().
view
(
-
1
,
x
.
shape
[
-
1
]).
mean
(
0
)
# [STEP 3]: Compute output of previous layer
with
torch
.
no_grad
():
org_out
=
previous_layer
(
x
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
org_out
=
org_out
[
0
]
# [STEP 4]: Compute loss
best_scales
=
self
.
_compute_best_scale
(
x
,
w_max
,
x_max
,
previous_layer
,
linears2scale
,
org_out
,
kwargs
)
return
best_scales
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
previous_layer
,
linears2scale
:
list
[
nn
.
Linear
],
org_out
,
kwargs
=
{}):
"""
Compute loss and select best scales
L(s) = ||Q(W \cdot s) (s^{-1} \cdot X) - W \cdot X||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
"""
n_grid
=
20
history
=
[]
best_ratio
=
-
1
best_scales
=
None
best_error
=
float
(
'inf'
)
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
previous_layer
.
state_dict
().
items
()}
for
ratio
in
range
(
n_grid
):
# create new scales
ratio
=
ratio
*
1
/
n_grid
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
# multiply scale and quantize
for
fc
in
linears2scale
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
data
=
self
.
pseudo_quantize_tensor
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
previous_layer
(
x
,
**
kwargs
)
if
isinstance
(
out
,
tuple
):
out
=
out
[
0
]
# measure loss and check if better than best
loss
=
(
org_out
-
out
).
float
().
pow
(
2
).
mean
().
item
()
# NOTE: float prevents overflow
history
.
append
(
loss
)
is_best
=
loss
<
best_error
if
is_best
:
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
previous_layer
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
raise
Exception
best_scales
=
best_scales
.
view
(
-
1
)
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
return
best_scales
.
detach
()
def
init_quant
(
self
,
n_samples
=
128
,
seqlen
=
512
):
layers
=
self
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
data
=
self
.
calib_data
,
tokenizer
=
self
.
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
,
split
=
self
.
split
,
text_column
=
self
.
text_column
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
self
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
self
.
move_embed
(
self
.
model
,
"cpu"
)
clear_memory
()
return
layers
,
layer_kwargs
def
_get_input_feat
(
self
,
layer
,
named_linears
):
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
self
.
module_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
return
input_feat
awq/utils/utils.py
View file @
45c22ee5
import
gc
import
torch
import
torch
import
accelerate
import
accelerate
...
@@ -53,3 +54,9 @@ def set_module_name(model, name, value):
...
@@ -53,3 +54,9 @@ def set_module_name(model, name, value):
child_name
=
name
child_name
=
name
setattr
(
parent
,
child_name
,
value
)
setattr
(
parent
,
child_name
,
value
)
def
clear_memory
(
weight
=
None
):
if
weight
is
not
None
:
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment