Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ed347704
Commit
ed347704
authored
Aug 18, 2023
by
Casper Hansen
Browse files
Implemented save_quantized. Generalize from_quantized.Add comments.
parent
14d198c6
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
111 additions
and
56 deletions
+111
-56
awq/entry.py
awq/entry.py
+31
-13
awq/models/auto.py
awq/models/auto.py
+8
-3
awq/models/base.py
awq/models/base.py
+72
-40
No files found.
awq/entry.py
View file @
ed347704
...
...
@@ -16,29 +16,45 @@ def load_search_result_into_memory(model, search_path):
apply_clip
(
model
,
awq_results
[
"clip"
])
def
run_search
(
model_path
,
dump_path
,
w_bit
,
q_config
):
"""
Step 1/2: Search the pile for an optimal scaling factor.
"""
# Load model
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
awq_results
=
model
.
quantize
(
model
.
model
,
tokenizer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
dump_path
)
# Quantize
model
.
quantize
(
tokenizer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
# Save search results
model
.
save_quantized
(
dump_path
)
def
run_quant
(
model_path
,
search_path
,
dump_path
,
w_bit
,
q_config
):
"""
Step 2/2: Use the search results to quantize model weights
"""
# Load model and search results
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
load_search_result_into_memory
(
model
.
model
,
search_path
)
model
.
quantize
(
model
.
model
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
model
.
model
.
cpu
().
state_dict
(),
dump_path
)
# Run actual weight quantization
model
.
quantize
(
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
# Save quantized model
model
.
save_quantized
(
dump_path
)
def
run_perplexity
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Load adapter
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
[
'wikitext'
],
...
...
@@ -50,19 +66,21 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device):
print
(
evaluator
.
make_table
(
results
))
if
__name__
==
'__main__'
:
"""
python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/pytorch_model.bin --quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --model_path mosaicml/mpt-7b-8k-chat --quant_path mpt-7b-8k-chat-awq
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|perplexity)'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to save/load AWQ quant model'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'
cuda:0
'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'
balanced
'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
args
.
model_path
=
"./mpt-7b-8k-chat"
args
.
search_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
args
.
quant_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
}
if
args
.
entry_type
==
'search'
:
...
...
awq/models/auto.py
View file @
ed347704
from
transformers
import
AutoConfig
from
awq.models
import
MptAWQForCausalLM
from
awq.models.base
import
BaseAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP
=
{
"mpt"
:
MptAWQForCausalLM
,
...
...
@@ -13,12 +14,14 @@ def check_and_get_model_type(model_dir, trust_remote_code=True):
return
model_type
class
AutoAWQForCausalLM
:
default_q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
}
def
__init__
(
self
):
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
@
classmethod
def
from_pretrained
(
self
,
model_path
,
trust_remote_code
=
True
):
def
from_pretrained
(
self
,
model_path
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
model_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_pretrained
(
...
...
@@ -26,9 +29,11 @@ class AutoAWQForCausalLM:
)
@
classmethod
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
def
from_quantized
(
self
,
model_path
,
quant_file
,
w_bit
=
4
,
q_config
=
{},
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
model_path
,
trust_remote_code
)
q_config
=
q_config
if
q_config
else
self
.
default_q_config
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
model_path
,
quant_
path
,
w_bit
,
q_config
,
device
,
trust_remote_code
model_path
,
model_type
,
quant_
file
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
trust_remote_code
)
\ No newline at end of file
awq/models/base.py
View file @
ed347704
import
os
import
gc
import
torch
import
functools
...
...
@@ -5,8 +6,9 @@ import torch.nn as nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
huggingface_hub
import
snapshot_download
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
...
...
@@ -16,29 +18,27 @@ from awq.utils.module import append_str_prefix, get_op_name, get_named_linears,
class
BaseAWQForCausalLM
:
def
__init__
(
self
,
model
,
model_type
,
is_quantized
):
self
.
model
=
model
self
.
model_type
=
model_type
self
.
is_quantized
=
is_quantized
self
.
model
:
PreTrainedModel
=
model
self
.
model_type
:
str
=
model_type
self
.
is_quantized
:
bool
=
is_quantized
self
.
search_result
=
None
@
torch
.
no_grad
()
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
calib_data
=
"pileval"
):
search_result
=
None
if
run_search
:
search_result
=
self
.
_awq_search
(
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
self
.
search_result
=
self
.
_awq_search
(
tokenizer
,
w_bit
,
q_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
self
.
_awq_quant
(
model
,
w_bit
,
q_config
)
return
search_result
self
.
_awq_quant
(
w_bit
,
q_config
)
def
_awq_quant
(
self
,
model
,
w_bit
,
q_config
):
def
_awq_quant
(
self
,
w_bit
,
q_config
):
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
model
)
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
...
...
@@ -62,9 +62,9 @@ class BaseAWQForCausalLM:
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
_awq_search
(
self
,
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
def
_awq_search
(
self
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
layers
=
self
.
get_model_layers
(
model
)
layers
=
self
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
data
=
calib_data
,
tokenizer
=
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
)
...
...
@@ -74,7 +74,7 @@ class BaseAWQForCausalLM:
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
self
.
move_embed
(
model
,
"cuda"
)
self
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
...
...
@@ -92,7 +92,7 @@ class BaseAWQForCausalLM:
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
...
...
@@ -100,7 +100,7 @@ class BaseAWQForCausalLM:
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
self
.
move_embed
(
model
,
"cpu"
)
self
.
move_embed
(
self
.
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
...
...
@@ -148,7 +148,7 @@ class BaseAWQForCausalLM:
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
...
...
@@ -159,7 +159,7 @@ class BaseAWQForCausalLM:
input_feat
=
input_feat
,)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
...
...
@@ -169,39 +169,77 @@ class BaseAWQForCausalLM:
return
awq_results
def
save_quantized
():
pass
def
save_quantized
(
self
,
save_dir
):
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
# Save model
if
self
.
search_result
is
None
:
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
self
.
model
.
state_dict
())
else
:
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
self
.
search_result
)
# TODO: Rename model name & save quant_config
if
self
.
search_result
is
not
None
:
model_name
=
'awq_model_search_result.pt'
else
:
model_name
=
'awq_model_w4_g128.pt'
@
classmethod
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
trust_remote_code
=
True
):
return
self
.
from_quantized
(
model_path
,
model_type
,
quant_file
=
''
,
device
=
'balanced'
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
,
is_quantized
=
False
)
@
classmethod
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
trust_remote_code
=
True
):
def
from_quantized
(
self
,
model_path
,
model_type
,
quant_file
,
w_bit
=
4
,
q_config
=
{},
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
is_quantized
=
True
):
# Download model
model_path
=
snapshot_download
(
model_path
)
quant_path
=
model_path
+
f
'/
{
quant_file
}
'
if
is_quantized
else
model_path
# Load config
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
# Load empty weights
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
)
# Only need to replace layers if a model is AWQ quantized
if
is_quantized
:
# Prepare WQLinear layers, replace nn.Linear
self
.
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
)
model
.
tie_weights
()
# Load model weights
model
=
load_checkpoint_and_dispatch
(
model
,
model
_path
,
device_map
=
"balanced"
,
no_split_module_classes
=
[
self
.
layer_type
])
model
=
load_checkpoint_and_dispatch
(
model
,
quant
_path
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
return
self
(
model
,
model_type
,
is_quantized
=
False
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
@
classmethod
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
# Load config
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
# Initialize layers
def
_load_quantized_modules
(
self
,
model
,
w_bit
,
q_config
):
# Real quantization of weights
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
# Get blocks of model
layers
=
self
.
get_model_layers
(
model
)
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"Replacing layers..."
):
layer
=
layers
[
i
]
# Get every linear layer in a block
named_linears
=
get_named_linears
(
layer
)
# Replace activation functions
self
.
_scale_activations
(
self
,
layer
)
# Replace nn.Linear with WQLinear
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
...
...
@@ -210,12 +248,6 @@ class BaseAWQForCausalLM:
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
model
.
tie_weights
()
model
=
load_checkpoint_and_dispatch
(
model
,
quant_path
,
device_map
=
"balanced"
,
no_split_module_classes
=
[
self
.
layer_type
])
return
model
@
staticmethod
def
_scale_activations
(
self
,
layer
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment