Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d35ade75
Commit
d35ade75
authored
Aug 19, 2023
by
Casper Hansen
Browse files
Rename q_config -> quant_config. Include w_bit in quant_config. Save quant_config.json.
parent
6f30f051
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
66 additions
and
52 deletions
+66
-52
awq/entry.py
awq/entry.py
+10
-10
awq/models/auto.py
awq/models/auto.py
+4
-4
awq/models/base.py
awq/models/base.py
+33
-21
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+7
-5
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+6
-7
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+6
-5
No files found.
awq/entry.py
View file @
d35ade75
...
@@ -15,7 +15,7 @@ def load_search_result_into_memory(model, search_path):
...
@@ -15,7 +15,7 @@ def load_search_result_into_memory(model, search_path):
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
def
run_search
(
model_path
,
dump_path
,
w_bit
,
q
_config
):
def
run_search
(
model_path
,
dump_path
,
quant
_config
):
"""
"""
Step 1/2: Search the pile for an optimal scaling factor.
Step 1/2: Search the pile for an optimal scaling factor.
"""
"""
...
@@ -24,7 +24,7 @@ def run_search(model_path, dump_path, w_bit, q_config):
...
@@ -24,7 +24,7 @@ def run_search(model_path, dump_path, w_bit, q_config):
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
# Quantize
model
.
quantize
(
tokenizer
,
w_bit
=
w_bit
,
q
_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
model
.
quantize
(
tokenizer
,
quant
_config
=
q
uant
_config
,
run_search
=
True
,
run_quant
=
False
)
# Save search results
# Save search results
model
.
save_quantized
(
dump_path
)
model
.
save_quantized
(
dump_path
)
...
@@ -32,7 +32,7 @@ def run_search(model_path, dump_path, w_bit, q_config):
...
@@ -32,7 +32,7 @@ def run_search(model_path, dump_path, w_bit, q_config):
# Save tokenizer
# Save tokenizer
tokenizer
.
save_pretrained
(
dump_path
)
tokenizer
.
save_pretrained
(
dump_path
)
def
run_quant
(
model_path
,
search_path
,
dump_path
,
w_bit
,
q
_config
):
def
run_quant
(
model_path
,
search_path
,
dump_path
,
quant
_config
):
"""
"""
Step 2/2: Use the search results to quantize model weights
Step 2/2: Use the search results to quantize model weights
"""
"""
...
@@ -41,17 +41,17 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
...
@@ -41,17 +41,17 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
load_search_result_into_memory
(
model
.
model
,
search_path
)
load_search_result_into_memory
(
model
.
model
,
search_path
)
# Run actual weight quantization
# Run actual weight quantization
model
.
quantize
(
w_bit
=
w_bit
,
q
_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
model
.
quantize
(
quant
_config
=
q
uant
_config
,
run_search
=
False
,
run_quant
=
True
)
# Save quantized model
# Save quantized model
model
.
save_quantized
(
dump_path
)
model
.
save_quantized
(
dump_path
)
def
run_perplexity
(
quant_path
,
quant_file
,
w_bit
,
q
_config
,
device
):
def
run_perplexity
(
quant_path
,
quant_file
,
quant
_config
,
device
):
"""
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
"""
# Load model
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
,
w_bit
,
q
_config
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant_path
,
quant_file
,
quant
_config
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant_path
,
trust_remote_code
=
True
)
# Load adapter
# Load adapter
...
@@ -85,13 +85,13 @@ if __name__ == '__main__':
...
@@ -85,13 +85,13 @@ if __name__ == '__main__':
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
}
q
uant
_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
if
args
.
entry_type
==
'search'
:
if
args
.
entry_type
==
'search'
:
run_search
(
args
.
model_path
,
args
.
search_path
,
args
.
w_bit
,
q
_config
)
run_search
(
args
.
model_path
,
args
.
search_path
,
quant
_config
)
elif
args
.
entry_type
==
'quant'
:
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
args
.
w_bit
,
q
_config
)
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
quant
_config
)
elif
args
.
entry_type
==
'perplexity'
:
elif
args
.
entry_type
==
'perplexity'
:
run_perplexity
(
args
.
quant_path
,
args
.
quant_file
,
args
.
w_bit
,
q_config
,
args
.
device
)
run_perplexity
(
args
.
quant_path
,
args
.
quant_file
,
args
.
w_bit
,
q
uant
_config
,
args
.
device
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
\ No newline at end of file
awq/models/auto.py
View file @
d35ade75
...
@@ -16,7 +16,7 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
...
@@ -16,7 +16,7 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return
model_type
return
model_type
class
AutoAWQForCausalLM
:
class
AutoAWQForCausalLM
:
default_q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
}
default_q
uant
_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
}
def
__init__
(
self
):
def
__init__
(
self
):
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
...
@@ -31,11 +31,11 @@ class AutoAWQForCausalLM:
...
@@ -31,11 +31,11 @@ class AutoAWQForCausalLM:
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
w_bit
=
4
,
q
_config
=
{},
def
from_quantized
(
self
,
quant_path
,
quant_filename
,
quant
_config
=
{},
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
quant_path
,
trust_remote_code
)
q_config
=
q_config
if
q_config
else
self
.
default_q_config
q
uant
_config
=
q
uant
_config
if
q
uant
_config
else
self
.
default_q
uant
_config
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
quant_path
,
model_type
,
quant_filename
,
w_bit
,
q
_config
,
device
,
trust_remote_code
=
trust_remote_code
quant_path
,
model_type
,
quant_filename
,
quant
_config
,
device
,
trust_remote_code
=
trust_remote_code
)
)
\ No newline at end of file
awq/models/base.py
View file @
d35ade75
import
os
import
os
import
gc
import
gc
import
json
import
torch
import
torch
import
functools
import
functools
import
accelerate
import
accelerate
...
@@ -18,11 +19,12 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_a
...
@@ -18,11 +19,12 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch, infer_a
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
:
class
BaseAWQForCausalLM
:
def
__init__
(
self
,
model
,
model_type
,
is_quantized
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
self
.
model
:
PreTrainedModel
=
model
self
.
model
:
PreTrainedModel
=
model
self
.
model_type
:
str
=
model_type
self
.
model_type
:
str
=
model_type
self
.
is_quantized
:
bool
=
is_quantized
self
.
is_quantized
:
bool
=
is_quantized
self
.
search_result
=
None
self
.
search_result
=
None
self
.
quant_config
:
dict
=
quant_config
def
to
(
self
,
device
:
str
):
def
to
(
self
,
device
:
str
):
return
self
.
model
.
to
(
device
)
return
self
.
model
.
to
(
device
)
...
@@ -31,20 +33,21 @@ class BaseAWQForCausalLM:
...
@@ -31,20 +33,21 @@ class BaseAWQForCausalLM:
return
self
.
model
(
*
args
,
**
kwargs
)
return
self
.
model
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
w_bit
=
4
,
q
_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
tokenizer
=
None
,
quant
_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
calib_data
=
"pileval"
):
calib_data
=
"pileval"
):
self
.
quant_config
=
quant_config
if
run_search
:
if
run_search
:
self
.
search_result
=
self
.
_awq_search
(
tokenizer
,
w_bit
,
q
_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
self
.
search_result
=
self
.
_awq_search
(
tokenizer
,
quant
_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
if
run_quant
:
self
.
_awq_quant
(
w_bit
,
q
_config
)
self
.
_awq_quant
(
quant
_config
)
def
_awq_quant
(
self
,
w_bit
,
q
_config
):
def
_awq_quant
(
self
,
quant
_config
):
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
assert
q
uant
_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
self
.
model
)
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
# Run AWQ quantization
...
@@ -55,11 +58,11 @@ class BaseAWQForCausalLM:
...
@@ -55,11 +58,11 @@ class BaseAWQForCausalLM:
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
module
.
cuda
()
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n
_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
w
_bit
=
quant_config
[
'
w_bit
'
]
,
get_scale_zp
=
True
,
**
q
uant
_config
)
scales
=
scales
.
t
().
contiguous
()
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
,
quant_config
[
'
w_bit
'
]
,
q
uant
_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
@@ -69,7 +72,7 @@ class BaseAWQForCausalLM:
...
@@ -69,7 +72,7 @@ class BaseAWQForCausalLM:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
def
_awq_search
(
self
,
tokenizer
,
w_bit
,
q
_config
,
n_samples
=
128
,
seqlen
=
512
,
def
_awq_search
(
self
,
tokenizer
,
quant
_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
layers
=
self
.
get_model_layers
(
self
.
model
)
layers
=
self
.
get_model_layers
(
self
.
model
)
...
@@ -148,12 +151,14 @@ class BaseAWQForCausalLM:
...
@@ -148,12 +151,14 @@ class BaseAWQForCausalLM:
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
scales_list
=
auto_scale_block
(
self
,
self
,
layer
,
layer_kwargs
,
layer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
layer_kwargs
,
quant_config
=
quant_config
,
input_feat
=
input_feat
,
input_feat
=
input_feat
,
)
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
...
@@ -161,9 +166,12 @@ class BaseAWQForCausalLM:
...
@@ -161,9 +166,12 @@ class BaseAWQForCausalLM:
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
if
mse_range
:
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
clip_list
=
auto_clip_block
(
w_bit
=
w_bit
,
q_config
=
q_config
,
layer
,
input_feat
=
input_feat
,)
quant_config
=
quant_config
,
input_feat
=
input_feat
)
apply_clip
(
layer
,
clip_list
)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
...
@@ -191,6 +199,10 @@ class BaseAWQForCausalLM:
...
@@ -191,6 +199,10 @@ class BaseAWQForCausalLM:
# Save search results
# Save search results
torch
.
save
(
model
,
f
'
{
save_dir
}
/
{
model_name
}
'
)
torch
.
save
(
model
,
f
'
{
save_dir
}
/
{
model_name
}
'
)
# Save config
with
open
(
f
'
{
save_dir
}
/quant_config.json'
,
'w+'
)
as
file
:
file
.
write
(
json
.
dumps
(
self
.
quant_config
,
indent
=
4
))
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
# Save model
# Save model
...
@@ -215,7 +227,7 @@ class BaseAWQForCausalLM:
...
@@ -215,7 +227,7 @@ class BaseAWQForCausalLM:
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
q_config
=
{},
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filename
,
w_bit
=
4
,
q
uant
_config
=
{},
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
safetensors
=
False
,
is_quantized
=
True
):
safetensors
=
False
,
is_quantized
=
True
):
# Download model if path is not a directory
# Download model if path is not a directory
...
@@ -241,7 +253,7 @@ class BaseAWQForCausalLM:
...
@@ -241,7 +253,7 @@ class BaseAWQForCausalLM:
# Only need to replace layers if a model is AWQ quantized
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
)
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
q
uant
_config
)
model
.
tie_weights
()
model
.
tie_weights
()
...
@@ -266,11 +278,11 @@ class BaseAWQForCausalLM:
...
@@ -266,11 +278,11 @@ class BaseAWQForCausalLM:
)
)
model
.
eval
()
model
.
eval
()
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
,
quant_config
=
quant_config
)
def
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
):
def
_load_quantized_modules
(
self
,
model
,
w_bit
,
q
uant
_config
):
# Real quantization of weights
# Real quantization of weights
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
assert
q
uant
_config
[
"zero_point"
],
"We only support zero_point quantization now."
# Get blocks of model
# Get blocks of model
layers
=
self
.
get_model_layers
(
model
)
layers
=
self
.
get_model_layers
(
model
)
...
@@ -287,7 +299,7 @@ class BaseAWQForCausalLM:
...
@@ -287,7 +299,7 @@ class BaseAWQForCausalLM:
# Replace nn.Linear with WQLinear
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
module
,
w_bit
,
q
uant
_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
set_op_by_name
(
layer
,
name
,
q_linear
)
...
...
awq/quantize/auto_clip.py
View file @
d35ade75
...
@@ -8,7 +8,9 @@ __all__ = ["auto_clip_block"]
...
@@ -8,7 +8,9 @@ __all__ = ["auto_clip_block"]
# weight quantization
# weight quantization
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
auto_clip_layer
(
w
,
input_feat
,
n_bit
,
q_config
,
def
auto_clip_layer
(
w
,
input_feat
,
quant_config
,
n_grid
=
20
,
n_grid
=
20
,
max_shrink
=
0.5
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
n_sample_token
=
512
):
...
@@ -16,7 +18,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
...
@@ -16,7 +18,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
org_w_shape
=
w
.
shape
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
q_config
[
"q_group_size"
]
if
q_config
[
"q_group_size"
]
>
0
else
w
.
shape
[
1
]
group_size
=
q
uant
_config
[
"q_group_size"
]
if
q
uant
_config
[
"q_group_size"
]
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
...
@@ -41,7 +43,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
...
@@ -41,7 +43,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
pseudo_quantize_tensor
(
cur_w
,
n_bit
=
n_bit
,
**
q
_config
)
q_w
=
pseudo_quantize_tensor
(
cur_w
,
**
quant
_config
)
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
# co, 1, n_group, 1
...
@@ -64,7 +66,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
...
@@ -64,7 +66,7 @@ def auto_clip_layer(w, input_feat, n_bit, q_config,
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
auto_clip_block
(
module
,
def
auto_clip_block
(
module
,
w_bit
,
q
_config
,
quant
_config
,
input_feat
):
input_feat
):
named_linears
=
{
name
:
m
for
name
,
named_linears
=
{
name
:
m
for
name
,
...
@@ -77,7 +79,7 @@ def auto_clip_block(module,
...
@@ -77,7 +79,7 @@ def auto_clip_block(module,
continue
continue
named_linears
[
name
].
cuda
()
named_linears
[
name
].
cuda
()
max_val
=
auto_clip_layer
(
max_val
=
auto_clip_layer
(
named_linears
[
name
].
weight
,
input_feat
[
name
],
n_bit
=
w_bit
,
q
_config
=
q_config
)
named_linears
[
name
].
weight
,
input_feat
[
name
],
quant
_config
=
q
uant
_config
)
clip_list
.
append
((
name
,
max_val
))
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
named_linears
[
name
].
cpu
()
return
clip_list
return
clip_list
...
...
awq/quantize/auto_scale.py
View file @
d35ade75
...
@@ -90,15 +90,14 @@ def scale_gelu_fc(gelu, fc, scales):
...
@@ -90,15 +90,14 @@ def scale_gelu_fc(gelu, fc, scales):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
auto_scale_block
(
awq_model
,
def
auto_scale_block
(
awq_model
,
module
,
module_kwargs
,
module
,
w_bit
,
q_config
,
module_kwargs
,
quant_config
,
input_feat
):
input_feat
):
from
.quantizer
import
pseudo_quantize_tensor
from
.quantizer
import
pseudo_quantize_tensor
# firstly, get the weight quantize function
# firstly, get the weight quantize function
if
w_bit
is
not
None
:
if
quant_config
[
'w_bit'
]
is
not
None
:
def
w_quantize_func
(
p
):
return
pseudo_quantize_tensor
(
def
w_quantize_func
(
p
):
return
pseudo_quantize_tensor
(
p
,
**
quant_config
).
detach
()
p
,
n_bit
=
w_bit
,
**
q_config
,
).
detach
()
else
:
else
:
def
w_quantize_func
(
p
):
return
p
def
w_quantize_func
(
p
):
return
p
...
@@ -111,7 +110,7 @@ def auto_scale_block(awq_model,
...
@@ -111,7 +110,7 @@ def auto_scale_block(awq_model,
# x: n, ci
# x: n, ci
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
w_max
=
get_weight_scale
(
w_max
=
get_weight_scale
(
weight
,
q_group_size
=
q_config
.
get
(
"q_group_size"
,
-
1
))
weight
,
q_group_size
=
q
uant
_config
.
get
(
"q_group_size"
,
-
1
))
# Clear GPU memory
# Clear GPU memory
del
weight
del
weight
gc
.
collect
()
gc
.
collect
()
...
...
awq/quantize/quantizer.py
View file @
d35ade75
import
torch
import
torch
# core quantization method (simulated quantization)
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
def
pseudo_quantize_tensor
(
w
,
w_bit
=
4
,
zero_point
=
True
,
q_group_size
=-
1
,
zero_point
=
True
,
q_group_size
=-
1
,
inplace
=
False
,
inplace
=
False
,
get_scale_zp
=
False
get_scale_zp
=
False
):
):
...
@@ -14,7 +15,7 @@ def pseudo_quantize_tensor(w, n_bit=8,
...
@@ -14,7 +15,7 @@ def pseudo_quantize_tensor(w, n_bit=8,
if
zero_point
:
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
n
_bit
-
1
max_int
=
2
**
w
_bit
-
1
min_int
=
0
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
...
@@ -22,8 +23,8 @@ def pseudo_quantize_tensor(w, n_bit=8,
...
@@ -22,8 +23,8 @@ def pseudo_quantize_tensor(w, n_bit=8,
assert
min_val
is
None
assert
min_val
is
None
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
n
_bit
-
1
)
-
1
max_int
=
2
**
(
w
_bit
-
1
)
-
1
min_int
=
-
2
**
(
n
_bit
-
1
)
min_int
=
-
2
**
(
w
_bit
-
1
)
scales
=
max_val
/
max_int
scales
=
max_val
/
max_int
zeros
=
0
zeros
=
0
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment