Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
724bda58
Commit
724bda58
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Working for OPT
parent
356cbc92
Changes
5
Show whitespace changes
Inline
Side-by-side
Showing
5 changed files
with
204 additions
and
143 deletions
+204
-143
awq/models/base.py
awq/models/base.py
+22
-14
awq/models/opt.py
awq/models/opt.py
+5
-3
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+0
-1
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+66
-18
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+111
-107
No files found.
awq/models/base.py
View file @
724bda58
...
...
@@ -15,7 +15,6 @@ from huggingface_hub import snapshot_download
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
...
...
@@ -43,24 +42,33 @@ class BaseAWQForCausalLM(nn.Module):
return
self
.
model
.
generate
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
True
,
run_quant
=
True
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
self
.
quant_config
=
quant_config
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
if
run_search
:
self
.
search_result
=
self
.
_awq_search
(
tokenizer
,
quant_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
,
split
=
split
,
text_column
=
text_column
from
awq.quantize.quantizer
import
AwqQuantizer
quantizer
=
AwqQuantizer
(
self
,
self
.
model
,
tokenizer
,
quant_config
[
"w_bit"
],
quant_config
[
"q_group_size"
]
,
quant_config
[
"version"
],
calib_data
,
split
,
text_column
)
quantizer
.
quantize
()
if
run_quant
:
self
.
_awq_quant
()
self
.
is_quantized
=
True
# if run_search:
# self.search_result = self._awq_search(
# tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
# auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
# split=split, text_column=text_column
# )
# if run_quant:
# self._awq_quant()
# self.is_quantized = True
@
staticmethod
def
fuse_layers
(
model
,
quant_config
):
pass
...
...
awq/models/opt.py
View file @
724bda58
...
...
@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM):
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn_layer_norm
,
layers
=
[
module
.
self_attn
.
q_proj
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
inp
=
input_feat
[
'self_attn.q_proj'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# attention out
...
...
awq/quantize/auto_clip.py
View file @
724bda58
import
torch
import
torch.nn
as
nn
from
.quantizer
import
pseudo_quantize_tensor
import
gc
__all__
=
[
"auto_clip_block"
]
...
...
awq/quantize/auto_scale.py
View file @
724bda58
...
...
@@ -88,6 +88,50 @@ def scale_gelu_fc(gelu, fc, scales):
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
def
pseudo_quantize_tensor
(
w
,
w_bit
=
4
,
zero_point
=
True
,
q_group_size
=-
1
,
inplace
=
False
,
get_scale_zp
=
False
):
org_w_shape
=
w
.
shape
if
q_group_size
>
0
:
assert
org_w_shape
[
-
1
]
%
q_group_size
==
0
w
=
w
.
reshape
(
-
1
,
q_group_size
)
assert
w
.
dim
()
==
2
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
w_bit
-
1
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
else
:
# we actually never used this
assert
min_val
is
None
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
w_bit
-
1
)
-
1
min_int
=
-
2
**
(
w_bit
-
1
)
scales
=
max_val
/
max_int
zeros
=
0
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w
).
sum
()
==
0
if
inplace
:
((
w
.
div_
(
scales
).
round_
().
add_
(
zeros
)).
clamp_
(
min_int
,
max_int
).
sub_
(
zeros
)).
mul_
(
scales
)
else
:
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
assert
torch
.
isnan
(
w
).
sum
()
==
0
w
=
w
.
reshape
(
org_w_shape
)
if
get_scale_zp
:
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
else
:
return
w
@
torch
.
no_grad
()
def
auto_scale_block
(
awq_model
,
...
...
@@ -95,7 +139,7 @@ def auto_scale_block(awq_model,
module_kwargs
,
quant_config
,
input_feat
):
from
.quantizer
import
pseudo_quantize_tensor
#
from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if
quant_config
[
'w_bit'
]
is
not
None
:
def
w_quantize_func
(
p
):
return
pseudo_quantize_tensor
(
p
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
]).
detach
()
...
...
@@ -106,24 +150,32 @@ def auto_scale_block(awq_model,
module_kwargs
.
pop
(
"use_cache"
)
# find the best scale ratio
def
_search_module_scale
(
block
,
linears2scale
:
list
,
x
,
kwargs
=
{}):
def
_search_module_scale
(
module2inspect
,
layers
:
list
,
inp
,
kwargs
=
{}):
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
# w: co, ci
# x: n, ci
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
w_max
=
get_weight_scale
(
weight
,
q_group_size
=
quant_config
.
get
(
"q_group_size"
,
-
1
))
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
layers
],
dim
=
0
)
org_shape
=
weight
.
shape
weight
=
weight
.
view
(
-
1
,
quant_config
.
get
(
"q_group_size"
))
w_scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_scale
=
w_scale
.
view
(
org_shape
)
w_max
=
w_scale
.
mean
(
0
)
# Clear GPU memory
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
inp
=
inp
.
to
(
next
(
module2inspect
.
parameters
()).
device
)
with
torch
.
no_grad
():
org_out
=
block
(
x
,
**
kwargs
)
org_out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
org_out
=
org_out
[
0
]
x_max
=
get_act_scale
(
x
)
x_max
=
get_act_scale
(
inp
)
best_error
=
float
(
'inf'
)
best_ratio
=
-
1
...
...
@@ -132,17 +184,17 @@ def auto_scale_block(awq_model,
n_grid
=
20
history
=
[]
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
block
.
state_dict
().
items
()}
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
module2inspect
.
state_dict
().
items
()}
for
ratio
in
range
(
n_grid
):
ratio
=
ratio
*
1
/
n_grid
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
l
inears2scale
:
for
fc
in
l
ayers
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
block
(
x
,
**
kwargs
)
out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
out
,
tuple
):
out
=
out
[
0
]
...
...
@@ -153,7 +205,7 @@ def auto_scale_block(awq_model,
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
block
.
load_state_dict
(
org_sd
)
module2inspect
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
raise
Exception
...
...
@@ -163,13 +215,9 @@ def auto_scale_block(awq_model,
return
best_scales
.
detach
()
def
_auto_get_scale
(
prev_op
,
layers
,
inp
,
module2inspect
=
None
,
kwargs
=
{}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
...
...
awq/quantize/quantizer.py
View file @
724bda58
...
...
@@ -12,7 +12,8 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class
AwqQuantizer
:
def
__init__
(
self
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
)
->
None
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
)
->
None
:
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
w_bit
=
w_bit
...
...
@@ -21,7 +22,7 @@ class AwqQuantizer:
self
.
calib_data
=
calib_data
self
.
split
=
split
self
.
text_column
=
text_column
self
.
modules
,
self
.
module_kwargs
=
self
.
init_quant
()
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
org_w_shape
=
w
.
shape
...
...
@@ -51,8 +52,8 @@ class AwqQuantizer:
else
:
return
w
def
quantize
(
self
,
get_layers_for_scaling
:
function
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"
QUANTIZING
"
):
def
quantize
(
self
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"
AWQ
"
):
# [STEP 1]: Get layer, extract linear modules, extract input features
self
.
modules
[
i
]
=
self
.
modules
[
i
].
cuda
()
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
...
...
@@ -60,22 +61,22 @@ class AwqQuantizer:
clear_memory
()
# [STEP 2]: Compute and apply scale list
module_config
:
list
[
dict
]
=
get_layers_for_scaling
(
module_config
:
list
[
dict
]
=
self
.
awq_model
.
get_layers_for_scaling
(
self
.
modules
[
i
],
input_feat
,
self
.
module_kwargs
)
scales_list
=
[
self
.
_search_best_scale
(
**
layer
)
for
layer
in
module_config
]
scales_list
=
[
self
.
_search_best_scale
(
self
.
modules
[
i
],
**
layer
)
for
layer
in
module_config
]
apply_scale
(
self
.
modules
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
scales_list
=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 3]: Compute and apply clipping list
clip_list
=
self
.
_search_best_clip
(
named_linears
,
input_feat
)
apply_clip
(
self
.
modules
[
i
],
clip_list
)
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
#
clip_list = self._search_best_clip(
self.modules[i],
named_linears, input_feat)
#
apply_clip(self.modules[i], clip_list)
#
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
# [STEP 4]: Quantize weights
for
name
,
linear_layer
in
named_linears
.
items
():
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
linear_layer
.
weight
.
data
,
linear_layer
.
weight
.
data
.
float
()
,
get_scale_zp
=
True
)
...
...
@@ -103,106 +104,45 @@ class AwqQuantizer:
clear_memory
()
return
self
.
model
@
torch
.
no_grad
()
def
_search_best_clip
(
self
,
layer
,
named_linears
,
input_feat
):
clip_list
=
[]
avoid_clipping
=
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
avoid_clipping
]):
continue
named_linears
[
name
].
cuda
()
max_val
=
self
.
_compute_best_clip
(
named_linears
[
name
].
weight
,
input_feat
[
name
])
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
@
torch
.
no_grad
()
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
self
.
group_size
if
self
.
group_size
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
def
_search_best_scale
(
self
,
module
,
prev_op
,
layers
:
list
[
nn
.
Linear
],
inp
:
torch
.
Tensor
,
module2inspect
=
None
,
kwargs
=
{}):
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
if
"use_cache"
in
kwargs
:
kwargs
.
pop
(
"use_cache"
)
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
self
.
pseudo_quantize_tensor
(
cur_w
)
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
clear_memory
(
input_feat
)
clear_memory
(
org_out
)
return
best_max_val
.
squeeze
(
1
)
@
torch
.
no_grad
()
def
_search_best_scale
(
self
,
previous_layer
,
linears2scale
:
list
[
nn
.
Linear
],
x
:
torch
.
Tensor
,
kwargs
=
{}):
# Put x on the right device
x
=
x
.
to
(
next
(
previous_layer
.
parameters
()).
device
)
inp
=
inp
.
to
(
next
(
module2inspect
.
parameters
()).
device
)
# [STEP 1]: Compute maximum of weight
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
layers
],
dim
=
0
)
org_shape
=
weight
.
shape
weight
=
weight
.
view
(
-
1
,
self
.
group_size
)
w_
max
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_
max
=
w_max
.
view
(
weight
.
shape
)
w_max
=
w_
max
.
mean
(
0
)
w_
scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_
scale
=
w_scale
.
view
(
org_
shape
)
w_max
=
w_
scale
.
mean
(
0
)
clear_memory
(
weight
)
# [STEP 2]: Compute maximum of x
x_max
=
x
.
abs
().
view
(
-
1
,
x
.
shape
[
-
1
]).
mean
(
0
)
x_max
=
inp
.
abs
().
view
(
-
1
,
inp
.
shape
[
-
1
]).
mean
(
0
)
# [STEP 3]: Compute output of
previous layer
# [STEP 3]: Compute output of
module
with
torch
.
no_grad
():
org_out
=
previous_layer
(
x
,
**
kwargs
)
org_out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
org_out
=
org_out
[
0
]
# [STEP 4]: Compute loss
best_scales
=
self
.
_compute_best_scale
(
x
,
w_max
,
x_max
,
previous_layer
,
l
inears2scale
,
org_out
,
kwargs
inp
,
w_max
,
x_max
,
module2inspect
,
l
ayers
,
org_out
,
kwargs
)
return
best_scales
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
best_scales
)
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
previous_layer
,
linears2scale
:
list
[
nn
.
Linear
],
org_out
,
kwargs
=
{}):
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
module2inspect
,
linears2scale
:
list
[
nn
.
Linear
],
org_out
,
kwargs
=
{}):
"""
Compute loss and select best scales
...
...
@@ -218,7 +158,7 @@ class AwqQuantizer:
best_scales
=
None
best_error
=
float
(
'inf'
)
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
previous_layer
.
state_dict
().
items
()}
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
module2inspect
.
state_dict
().
items
()}
device
=
x
.
device
x_max
=
x_max
.
view
(
-
1
).
to
(
device
)
...
...
@@ -235,7 +175,7 @@ class AwqQuantizer:
fc
.
weight
.
mul_
(
scales_view
)
fc
.
weight
.
data
=
self
.
pseudo_quantize_tensor
(
fc
.
weight
.
data
)
/
scales_view
out
=
previous_layer
(
x
,
**
kwargs
)
out
=
module2inspect
(
x
,
**
kwargs
)
if
isinstance
(
out
,
tuple
):
out
=
out
[
0
]
...
...
@@ -246,7 +186,7 @@ class AwqQuantizer:
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
.
clone
()
previous_layer
.
load_state_dict
(
org_sd
)
module2inspect
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
...
...
@@ -254,12 +194,76 @@ class AwqQuantizer:
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
return
best_scales
.
detach
()
return
best_scales
.
detach
()
.
cpu
()
@
torch
.
no_grad
()
def
_search_best_clip
(
self
,
layer
,
named_linears
,
input_feat
):
clip_list
=
[]
avoid_clipping
=
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]
def
init_quant
(
self
,
n_samples
=
128
,
seqlen
=
512
):
layers
=
self
.
get_model_layers
(
self
.
model
)
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
avoid_clipping
]):
continue
named_linears
[
name
].
cuda
()
max_val
=
self
.
_compute_best_clip
(
named_linears
[
name
].
weight
,
input_feat
[
name
])
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
@
torch
.
no_grad
()
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
self
.
group_size
if
self
.
group_size
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
self
.
pseudo_quantize_tensor
(
cur_w
)
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
clear_memory
(
input_feat
)
clear_memory
(
org_out
)
return
best_max_val
.
squeeze
(
1
)
def
init_quant
(
self
,
n_samples
=
128
,
seqlen
=
512
):
modules
=
self
.
awq_model
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
data
=
self
.
calib_data
,
tokenizer
=
self
.
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
,
split
=
self
.
split
,
text_column
=
self
.
text_column
...
...
@@ -269,8 +273,8 @@ class AwqQuantizer:
inps
=
[]
layer_kwargs
=
{}
layer
s
[
0
]
=
layer
s
[
0
].
cuda
()
self
.
move_embed
(
self
.
model
,
"cuda"
)
module
s
[
0
]
=
module
s
[
0
].
cuda
()
self
.
awq_model
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
...
...
@@ -286,21 +290,21 @@ class AwqQuantizer:
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layer
s
[
0
]
=
Catcher
(
layer
s
[
0
])
module
s
[
0
]
=
Catcher
(
module
s
[
0
])
try
:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layer
s
[
0
]
=
layer
s
[
0
].
module
# restore
module
s
[
0
]
=
module
s
[
0
].
module
# restore
inps
=
inps
[
0
]
layer
s
[
0
]
=
layer
s
[
0
].
cpu
()
self
.
move_embed
(
self
.
model
,
"cpu"
)
module
s
[
0
]
=
module
s
[
0
].
cpu
()
self
.
awq_model
.
move_embed
(
self
.
model
,
"cpu"
)
clear_memory
()
return
layer
s
,
layer_kwargs
return
module
s
,
layer_kwargs
,
inps
def
_get_input_feat
(
self
,
layer
,
named_linears
):
# firstly, get input features of all linear layers
...
...
@@ -315,9 +319,9 @@ class AwqQuantizer:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
self
.
inps
=
self
.
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
self
.
module_kwargs
)[
0
]
self
.
inps
=
layer
(
self
.
inps
,
**
self
.
module_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment