Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d3550fec
Commit
d3550fec
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Remove old quantization code
parent
724bda58
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
548 deletions
+6
-548
awq/models/base.py
awq/models/base.py
+1
-180
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+0
-98
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+0
-267
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+5
-3
No files found.
awq/models/base.py
View file @
d3550fec
...
...
@@ -2,25 +2,18 @@ import os
import
gc
import
json
import
torch
import
logging
import
functools
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
typing
import
List
,
Union
from
collections
import
defaultdict
from
safetensors.torch
import
save_file
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
...
...
@@ -55,183 +48,11 @@ class BaseAWQForCausalLM(nn.Module):
quant_config
[
"version"
],
calib_data
,
split
,
text_column
)
quantizer
.
quantize
()
self
.
is_quantized
=
True
# if run_search:
# self.search_result = self._awq_search(
# tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
# auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
# split=split, text_column=text_column
# )
# if run_quant:
# self._awq_quant()
# self.is_quantized = True
@
staticmethod
def
fuse_layers
(
model
,
quant_config
):
pass
def
_awq_quant
(
self
):
assert
self
.
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
self
.
_scale_activations
(
self
,
layer
)
for
name
,
module
in
named_linears
.
items
():
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
get_scale_zp
=
True
,
w_bit
=
self
.
quant_config
[
"w_bit"
],
q_group_size
=
self
.
quant_config
[
"q_group_size"
]
)
if
self
.
quant_config
[
"version"
]
==
'GEMM'
:
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMM
elif
self
.
quant_config
[
"version"
]
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
_awq_search
(
self
,
tokenizer
,
quant_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
layers
=
self
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
data
=
calib_data
,
tokenizer
=
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
,
split
=
split
,
text_column
=
text_column
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
self
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
self
.
move_embed
(
self
.
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
awq_results
=
{
"scale"
:
[],
"clip"
:
[],
}
# Run AWQ search layer by layer
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Search"
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
self
,
layer
,
layer_kwargs
,
quant_config
=
quant_config
,
input_feat
=
input_feat
,
)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
quant_config
=
quant_config
,
input_feat
=
input_feat
)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
awq_results
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
def
_save_files
(
save_dir
,
model_name
=
''
,
search_result
=
None
):
...
...
awq/quantize/auto_clip.py
deleted
100644 → 0
View file @
724bda58
import
torch
import
torch.nn
as
nn
import
gc
__all__
=
[
"auto_clip_block"
]
# weight quantization
@
torch
.
no_grad
()
def
auto_clip_layer
(
w
,
input_feat
,
quant_config
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
quant_config
[
"q_group_size"
]
if
quant_config
[
"q_group_size"
]
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
pseudo_quantize_tensor
(
cur_w
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
])
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
del
input_feat
del
org_out
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
best_max_val
.
squeeze
(
1
)
@
torch
.
no_grad
()
def
auto_clip_block
(
module
,
quant_config
,
input_feat
):
named_linears
=
{
name
:
m
for
name
,
m
in
module
.
named_modules
()
if
isinstance
(
m
,
nn
.
Linear
)}
clip_list
=
[]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]]):
continue
named_linears
[
name
].
cuda
()
max_val
=
auto_clip_layer
(
named_linears
[
name
].
weight
,
input_feat
[
name
],
quant_config
=
quant_config
)
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
@
torch
.
no_grad
()
def
apply_clip
(
module
,
clip_list
):
from
..utils.module
import
get_op_by_name
for
name
,
max_val
in
clip_list
:
layer
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
awq/quantize/auto_scale.py
deleted
100644 → 0
View file @
724bda58
import
gc
import
torch
import
torch.nn
as
nn
import
logging
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
from
awq.modules.act
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
@
torch
.
no_grad
()
def
get_weight_scale
(
weight
,
q_group_size
=-
1
):
org_shape
=
weight
.
shape
if
q_group_size
>
0
:
weight
=
weight
.
view
(
-
1
,
q_group_size
)
scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
scale
=
scale
.
view
(
org_shape
)
scale
=
scale
.
mean
(
0
)
return
scale
@
torch
.
no_grad
()
def
get_act_scale
(
x
):
return
x
.
abs
().
view
(
-
1
,
x
.
shape
[
-
1
]).
mean
(
0
)
@
torch
.
no_grad
()
def
scale_ln_fcs
(
ln
,
fcs
,
scales
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
scales
=
scales
.
to
(
ln
.
weight
.
device
)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
for
fc
in
fcs
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
ln
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
fc
in
fcs
:
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
# assert fc1.out_features == fc2.in_features
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
# fc1.weight.div_(scales.view(-1, 1))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
fc2
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
fc1
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
p
in
fc2
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
any
(
isinstance
(
gelu
,
t
)
for
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
])
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
def
pseudo_quantize_tensor
(
w
,
w_bit
=
4
,
zero_point
=
True
,
q_group_size
=-
1
,
inplace
=
False
,
get_scale_zp
=
False
):
org_w_shape
=
w
.
shape
if
q_group_size
>
0
:
assert
org_w_shape
[
-
1
]
%
q_group_size
==
0
w
=
w
.
reshape
(
-
1
,
q_group_size
)
assert
w
.
dim
()
==
2
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
w_bit
-
1
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
else
:
# we actually never used this
assert
min_val
is
None
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
w_bit
-
1
)
-
1
min_int
=
-
2
**
(
w_bit
-
1
)
scales
=
max_val
/
max_int
zeros
=
0
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w
).
sum
()
==
0
if
inplace
:
((
w
.
div_
(
scales
).
round_
().
add_
(
zeros
)).
clamp_
(
min_int
,
max_int
).
sub_
(
zeros
)).
mul_
(
scales
)
else
:
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
assert
torch
.
isnan
(
w
).
sum
()
==
0
w
=
w
.
reshape
(
org_w_shape
)
if
get_scale_zp
:
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
else
:
return
w
@
torch
.
no_grad
()
def
auto_scale_block
(
awq_model
,
module
,
module_kwargs
,
quant_config
,
input_feat
):
# from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if
quant_config
[
'w_bit'
]
is
not
None
:
def
w_quantize_func
(
p
):
return
pseudo_quantize_tensor
(
p
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
]).
detach
()
else
:
def
w_quantize_func
(
p
):
return
p
if
"use_cache"
in
module_kwargs
:
module_kwargs
.
pop
(
"use_cache"
)
# find the best scale ratio
def
_search_module_scale
(
module2inspect
,
layers
:
list
,
inp
,
kwargs
=
{}):
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
# w: co, ci
# x: n, ci
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
layers
],
dim
=
0
)
org_shape
=
weight
.
shape
weight
=
weight
.
view
(
-
1
,
quant_config
.
get
(
"q_group_size"
))
w_scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_scale
=
w_scale
.
view
(
org_shape
)
w_max
=
w_scale
.
mean
(
0
)
# Clear GPU memory
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
inp
=
inp
.
to
(
next
(
module2inspect
.
parameters
()).
device
)
with
torch
.
no_grad
():
org_out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
org_out
=
org_out
[
0
]
x_max
=
get_act_scale
(
inp
)
best_error
=
float
(
'inf'
)
best_ratio
=
-
1
best_scales
=
None
n_grid
=
20
history
=
[]
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
module2inspect
.
state_dict
().
items
()}
for
ratio
in
range
(
n_grid
):
ratio
=
ratio
*
1
/
n_grid
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
layers
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
out
,
tuple
):
out
=
out
[
0
]
loss
=
(
org_out
-
out
).
float
().
pow
(
2
).
mean
().
item
()
# float prevents overflow
history
.
append
(
loss
)
is_best
=
loss
<
best_error
if
is_best
:
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
module2inspect
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
raise
Exception
best_scales
=
best_scales
.
view
(
-
1
)
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
return
best_scales
.
detach
()
def
_auto_get_scale
(
prev_op
,
layers
,
inp
,
module2inspect
=
None
,
kwargs
=
{}):
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
layers
:
list
[
dict
]
=
awq_model
.
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
)
scales_list
=
[
_auto_get_scale
(
**
layer
)
for
layer
in
layers
]
return
scales_list
def
apply_scale
(
module
,
scales_list
,
input_feat_dict
=
None
):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
[
nn
.
LayerNorm
,
LlamaRMSNorm
])
\
or
'rmsnorm'
in
str
(
prev_op
.
__class__
).
lower
():
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
# apply the scaling to input feat if given; prepare it for clipping
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
awq/quantize/quantizer.py
View file @
d3550fec
...
...
@@ -69,9 +69,9 @@ class AwqQuantizer:
scales_list
=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 3]: Compute and apply clipping list
#
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
#
apply_clip(self.modules[i], clip_list)
#
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
clip_list
=
self
.
_search_best_clip
(
self
.
modules
[
i
],
named_linears
,
input_feat
)
apply_clip
(
self
.
modules
[
i
],
clip_list
)
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 4]: Quantize weights
for
name
,
linear_layer
in
named_linears
.
items
():
...
...
@@ -211,6 +211,8 @@ class AwqQuantizer:
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
@
torch
.
no_grad
()
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment