Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d3550fec
"git@developer.sourcefind.cn:OpenDAS/ollama.git" did not exist on "6fd04ca922e5da7ef8c52d86118fc58b798a7e4a"
Commit
d3550fec
authored
Sep 20, 2023
by
Casper Hansen
Browse files
Remove old quantization code
parent
724bda58
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
6 additions
and
548 deletions
+6
-548
awq/models/base.py
awq/models/base.py
+1
-180
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+0
-98
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+0
-267
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+5
-3
No files found.
awq/models/base.py
View file @
d3550fec
...
@@ -2,25 +2,18 @@ import os
...
@@ -2,25 +2,18 @@ import os
import
gc
import
gc
import
json
import
json
import
torch
import
torch
import
logging
import
functools
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
collections
import
defaultdict
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
(
nn
.
Module
):
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
...
@@ -55,183 +48,11 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -55,183 +48,11 @@ class BaseAWQForCausalLM(nn.Module):
quant_config
[
"version"
],
calib_data
,
split
,
text_column
quant_config
[
"version"
],
calib_data
,
split
,
text_column
)
)
quantizer
.
quantize
()
quantizer
.
quantize
()
self
.
is_quantized
=
True
self
.
is_quantized
=
True
# if run_search:
# self.search_result = self._awq_search(
# tokenizer, quant_config, n_samples=n_samples, seqlen=seqlen,
# auto_scale=auto_scale, mse_range=mse_range, calib_data=calib_data,
# split=split, text_column=text_column
# )
# if run_quant:
# self._awq_quant()
# self.is_quantized = True
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
,
quant_config
):
def
fuse_layers
(
model
,
quant_config
):
pass
pass
def
_awq_quant
(
self
):
assert
self
.
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
self
.
_scale_activations
(
self
,
layer
)
for
name
,
module
in
named_linears
.
items
():
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
get_scale_zp
=
True
,
w_bit
=
self
.
quant_config
[
"w_bit"
],
q_group_size
=
self
.
quant_config
[
"q_group_size"
]
)
if
self
.
quant_config
[
"version"
]
==
'GEMM'
:
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMM
elif
self
.
quant_config
[
"version"
]
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
_awq_search
(
self
,
tokenizer
,
quant_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
layers
=
self
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
data
=
calib_data
,
tokenizer
=
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
,
split
=
split
,
text_column
=
text_column
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
self
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
self
.
move_embed
(
self
.
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
awq_results
=
{
"scale"
:
[],
"clip"
:
[],
}
# Run AWQ search layer by layer
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Search"
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
self
,
layer
,
layer_kwargs
,
quant_config
=
quant_config
,
input_feat
=
input_feat
,
)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
quant_config
=
quant_config
,
input_feat
=
input_feat
)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
awq_results
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
def
_save_files
(
save_dir
,
model_name
=
''
,
search_result
=
None
):
def
_save_files
(
save_dir
,
model_name
=
''
,
search_result
=
None
):
...
...
awq/quantize/auto_clip.py
deleted
100644 → 0
View file @
724bda58
import
torch
import
torch.nn
as
nn
import
gc
__all__
=
[
"auto_clip_block"
]
# weight quantization
@
torch
.
no_grad
()
def
auto_clip_layer
(
w
,
input_feat
,
quant_config
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
quant_config
[
"q_group_size"
]
if
quant_config
[
"q_group_size"
]
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
pseudo_quantize_tensor
(
cur_w
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
])
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
del
input_feat
del
org_out
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
best_max_val
.
squeeze
(
1
)
@
torch
.
no_grad
()
def
auto_clip_block
(
module
,
quant_config
,
input_feat
):
named_linears
=
{
name
:
m
for
name
,
m
in
module
.
named_modules
()
if
isinstance
(
m
,
nn
.
Linear
)}
clip_list
=
[]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]]):
continue
named_linears
[
name
].
cuda
()
max_val
=
auto_clip_layer
(
named_linears
[
name
].
weight
,
input_feat
[
name
],
quant_config
=
quant_config
)
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
@
torch
.
no_grad
()
def
apply_clip
(
module
,
clip_list
):
from
..utils.module
import
get_op_by_name
for
name
,
max_val
in
clip_list
:
layer
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
awq/quantize/auto_scale.py
deleted
100644 → 0
View file @
724bda58
import
gc
import
torch
import
torch.nn
as
nn
import
logging
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
from
awq.modules.act
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
@
torch
.
no_grad
()
def
get_weight_scale
(
weight
,
q_group_size
=-
1
):
org_shape
=
weight
.
shape
if
q_group_size
>
0
:
weight
=
weight
.
view
(
-
1
,
q_group_size
)
scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
scale
=
scale
.
view
(
org_shape
)
scale
=
scale
.
mean
(
0
)
return
scale
@
torch
.
no_grad
()
def
get_act_scale
(
x
):
return
x
.
abs
().
view
(
-
1
,
x
.
shape
[
-
1
]).
mean
(
0
)
@
torch
.
no_grad
()
def
scale_ln_fcs
(
ln
,
fcs
,
scales
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
scales
=
scales
.
to
(
ln
.
weight
.
device
)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
for
fc
in
fcs
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
ln
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
fc
in
fcs
:
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
# assert fc1.out_features == fc2.in_features
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
# fc1.weight.div_(scales.view(-1, 1))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
fc2
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
fc1
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
p
in
fc2
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
assert
any
(
isinstance
(
gelu
,
t
)
for
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
])
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
def
pseudo_quantize_tensor
(
w
,
w_bit
=
4
,
zero_point
=
True
,
q_group_size
=-
1
,
inplace
=
False
,
get_scale_zp
=
False
):
org_w_shape
=
w
.
shape
if
q_group_size
>
0
:
assert
org_w_shape
[
-
1
]
%
q_group_size
==
0
w
=
w
.
reshape
(
-
1
,
q_group_size
)
assert
w
.
dim
()
==
2
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
w_bit
-
1
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
else
:
# we actually never used this
assert
min_val
is
None
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
w_bit
-
1
)
-
1
min_int
=
-
2
**
(
w_bit
-
1
)
scales
=
max_val
/
max_int
zeros
=
0
assert
torch
.
isnan
(
scales
).
sum
()
==
0
assert
torch
.
isnan
(
w
).
sum
()
==
0
if
inplace
:
((
w
.
div_
(
scales
).
round_
().
add_
(
zeros
)).
clamp_
(
min_int
,
max_int
).
sub_
(
zeros
)).
mul_
(
scales
)
else
:
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
assert
torch
.
isnan
(
w
).
sum
()
==
0
w
=
w
.
reshape
(
org_w_shape
)
if
get_scale_zp
:
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
else
:
return
w
@
torch
.
no_grad
()
def
auto_scale_block
(
awq_model
,
module
,
module_kwargs
,
quant_config
,
input_feat
):
# from .quantizer import pseudo_quantize_tensor
# firstly, get the weight quantize function
if
quant_config
[
'w_bit'
]
is
not
None
:
def
w_quantize_func
(
p
):
return
pseudo_quantize_tensor
(
p
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
]).
detach
()
else
:
def
w_quantize_func
(
p
):
return
p
if
"use_cache"
in
module_kwargs
:
module_kwargs
.
pop
(
"use_cache"
)
# find the best scale ratio
def
_search_module_scale
(
module2inspect
,
layers
:
list
,
inp
,
kwargs
=
{}):
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
# w: co, ci
# x: n, ci
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
layers
],
dim
=
0
)
org_shape
=
weight
.
shape
weight
=
weight
.
view
(
-
1
,
quant_config
.
get
(
"q_group_size"
))
w_scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_scale
=
w_scale
.
view
(
org_shape
)
w_max
=
w_scale
.
mean
(
0
)
# Clear GPU memory
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
inp
=
inp
.
to
(
next
(
module2inspect
.
parameters
()).
device
)
with
torch
.
no_grad
():
org_out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
org_out
=
org_out
[
0
]
x_max
=
get_act_scale
(
inp
)
best_error
=
float
(
'inf'
)
best_ratio
=
-
1
best_scales
=
None
n_grid
=
20
history
=
[]
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
module2inspect
.
state_dict
().
items
()}
for
ratio
in
range
(
n_grid
):
ratio
=
ratio
*
1
/
n_grid
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
layers
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
out
,
tuple
):
out
=
out
[
0
]
loss
=
(
org_out
-
out
).
float
().
pow
(
2
).
mean
().
item
()
# float prevents overflow
history
.
append
(
loss
)
is_best
=
loss
<
best_error
if
is_best
:
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
module2inspect
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
raise
Exception
best_scales
=
best_scales
.
view
(
-
1
)
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
return
best_scales
.
detach
()
def
_auto_get_scale
(
prev_op
,
layers
,
inp
,
module2inspect
=
None
,
kwargs
=
{}):
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
layers
:
list
[
dict
]
=
awq_model
.
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
)
scales_list
=
[
_auto_get_scale
(
**
layer
)
for
layer
in
layers
]
return
scales_list
def
apply_scale
(
module
,
scales_list
,
input_feat_dict
=
None
):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
[
nn
.
LayerNorm
,
LlamaRMSNorm
])
\
or
'rmsnorm'
in
str
(
prev_op
.
__class__
).
lower
():
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
# apply the scaling to input feat if given; prepare it for clipping
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
awq/quantize/quantizer.py
View file @
d3550fec
...
@@ -69,9 +69,9 @@ class AwqQuantizer:
...
@@ -69,9 +69,9 @@ class AwqQuantizer:
scales_list
=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
scales_list
=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 3]: Compute and apply clipping list
# [STEP 3]: Compute and apply clipping list
#
clip_list = self._search_best_clip(self.modules[i], named_linears, input_feat)
clip_list
=
self
.
_search_best_clip
(
self
.
modules
[
i
],
named_linears
,
input_feat
)
#
apply_clip(self.modules[i], clip_list)
apply_clip
(
self
.
modules
[
i
],
clip_list
)
#
clip_list = append_str_prefix(clip_list, get_op_name(self.model, self.modules[i]) + ".")
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 4]: Quantize weights
# [STEP 4]: Quantize weights
for
name
,
linear_layer
in
named_linears
.
items
():
for
name
,
linear_layer
in
named_linears
.
items
():
...
@@ -211,6 +211,8 @@ class AwqQuantizer:
...
@@ -211,6 +211,8 @@ class AwqQuantizer:
clip_list
.
append
((
name
,
max_val
))
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
named_linears
[
name
].
cpu
()
return
clip_list
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment