Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
5d4ab5dc
Commit
5d4ab5dc
authored
Aug 17, 2023
by
Casper Hansen
Browse files
Refactor entry.py [WIP]
parent
290f45e7
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
99 additions
and
280 deletions
+99
-280
awq/entry.py
awq/entry.py
+33
-195
awq/models/base.py
awq/models/base.py
+62
-7
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+4
-78
No files found.
awq/entry.py
View file @
5d4ab5dc
from
lm_eval
import
evaluator
,
tasks
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
import
torch
import
argparse
import
os
import
json
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_in_model
from
accelerate.utils.modeling
import
get_balanced_memory
from
awq.utils.parallel
import
auto_parallel
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.utils
import
simple_dispatch_model
import
torch
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'path of the hf model'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
,
help
=
'batch size'
)
parser
.
add_argument
(
"--tasks"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
"--output_path"
,
default
=
None
,
type
=
str
)
parser
.
add_argument
(
'--num_fewshot'
,
type
=
int
,
default
=
0
)
# model config
parser
.
add_argument
(
'--parallel'
,
action
=
'store_true'
,
help
=
"enable model parallelism"
)
# max memory to offload larger models to CPU
parser
.
add_argument
(
'--max_memory'
,
type
=
str
,
nargs
=
'*'
,
help
=
"List of device_id:max_memory pairs to be parsed into a dictionary; "
\
+
"Example: 0:10GiB 1:10GiB cpu:30GiB; "
\
+
"mode details here: "
\
+
"https://huggingface.co/docs/accelerate/usage_guides/big_modeling"
)
parser
.
add_argument
(
'--auto_parallel'
,
action
=
'store_true'
,
help
=
"automatically set parallel and batch_size"
)
# quantization config
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
None
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=-
1
)
parser
.
add_argument
(
'--no_zero_point'
,
action
=
'store_true'
,
help
=
"disable zero_point"
)
parser
.
add_argument
(
'--q_backend'
,
type
=
str
,
default
=
"fake"
,
choices
=
[
"fake"
,
"real"
])
# save/load real quantized weights
parser
.
add_argument
(
'--dump_quant'
,
type
=
str
,
default
=
None
,
help
=
'save quantized model'
)
parser
.
add_argument
(
'--load_quant'
,
type
=
str
,
default
=
None
,
help
=
'load quantized model'
)
# apply/save/load awq
parser
.
add_argument
(
'--run_awq'
,
action
=
'store_true'
,
help
=
"perform awq search process"
)
parser
.
add_argument
(
'--dump_awq'
,
type
=
str
,
default
=
None
,
help
=
"save the awq search results"
)
parser
.
add_argument
(
'--load_awq'
,
type
=
str
,
default
=
None
,
help
=
"load the awq search results"
)
args
=
parser
.
parse_args
()
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
args
.
max_memory
or
[])]
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
None
or
[])]
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
if
args
.
auto_parallel
:
gpu_list
=
auto_parallel
(
args
)
# get quantization config (apart from w_bit)
q_config
=
{
"zero_point"
:
not
args
.
no_zero_point
,
# by default True
"q_group_size"
:
args
.
q_group_size
,
# whether to use group quantization
}
print
(
"Quantization config:"
,
q_config
)
def
get_awq_model
(
model
):
from
awq.models
import
MptAWQForCausalLM
...
...
@@ -73,149 +15,45 @@ def get_awq_model(model):
else
:
raise
NotImplementedError
(
type
(
model
))
def
build_model_and_enc
(
model_path
):
if
not
os
.
path
.
exists
(
model_path
):
# look into ssd
raise
FileNotFoundError
(
f
"
{
model_path
}
not found!"
)
print
(
f
"* Building model
{
model_path
}
"
)
# all hf model
def
load_unquantized
(
model_path
):
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
if
"mpt"
in
config
.
__class__
.
__name__
.
lower
():
enc
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
else
:
enc
=
AutoTokenizer
.
from_pretrained
(
model_path
,
use_fast
=
False
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
if
args
.
load_quant
:
# directly load quantized weights
print
(
"Loading pre-computed quantized weights..."
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
.
tie_weights
()
# Infer device map
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model
(
model
,
checkpoint
=
args
.
load_quant
,
device_map
=
device_map
,
offload_state_dict
=
True
,
)
# Dispatch model
model
=
simple_dispatch_model
(
model
,
device_map
=
device_map
)
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
else
:
# fp16 to quantized
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
# Init model on CPU:
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
model
.
eval
()
return
model
,
tokenizer
if
args
.
run_awq
:
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
def
load_search_result_into_memory
(
model
,
search_path
)
:
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
awq_model
=
get_awq_model
(
model
)
awq_results
=
awq_model
.
quantize
(
model
,
enc
,
args
.
w_bit
,
q_config
)
if
args
.
dump_awq
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_awq
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
exit
(
0
)
if
args
.
load_awq
:
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
# weight quantization
if
args
.
w_bit
is
not
None
:
if
args
.
q_backend
==
"fake"
:
assert
args
.
dump_quant
is
None
,
\
"Need to use real quantization to dump quantized weights"
pseudo_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
)
elif
args
.
q_backend
==
"real"
:
# real quantization
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
)
if
args
.
dump_quant
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_quant
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
print
(
f
"Saving the quantized model at
{
args
.
dump_quant
}
..."
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
args
.
dump_quant
)
exit
(
0
)
else
:
raise
NotImplementedError
# Move the model to GPU (as much as possible) for LM evaluation
kwargs
=
{
"max_memory"
:
get_balanced_memory
(
model
,
max_memory
if
len
(
max_memory
)
>
0
else
None
)}
device_map
=
infer_auto_device_map
(
model
,
# TODO: can we remove this?
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
model
=
dispatch_model
(
model
,
device_map
=
device_map
)
return
model
,
enc
def
main
():
if
args
.
output_path
is
not
None
and
os
.
path
.
exists
(
args
.
output_path
):
# print(f"Results {args.output_path} already generated. Exit.")
print
(
f
"Results
{
args
.
output_path
}
already generated. Overwrite."
)
# exit()
if
args
.
dump_awq
and
os
.
path
.
exists
(
args
.
dump_awq
):
print
(
f
"Found existing AWQ results
{
args
.
dump_awq
}
, exit."
)
exit
()
# a hack here to auto set model group
model
,
enc
=
build_model_and_enc
(
args
.
model_path
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
if
args
.
tasks
is
not
None
:
task_names
=
args
.
tasks
.
split
(
","
)
def
run_search
(
model
,
dump_path
):
model
,
tokenizer
=
load_unquantized
(
model_path
)
awq_model
=
get_awq_model
(
model
)
awq_results
=
awq_model
.
quantize
(
model
,
tokenizer
,
w_bit
=
4
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
lm_eval_model
=
LMEvalAdaptor
(
args
.
model_path
,
model
,
enc
,
args
.
batch_size
)
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
task_names
,
batch_size
=
args
.
batch_size
,
no_cache
=
True
,
num_fewshot
=
args
.
num_fewshot
,
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
dump_path
)
print
(
evaluator
.
make_table
(
results
))
def
run_quant
(
model
,
search_path
,
dump_path
):
model
,
tokenizer
=
load_unquantized
(
model_path
)
load_search_result_into_memory
(
model
,
search_path
)
if
args
.
output_path
is
not
None
:
os
.
makedirs
(
os
.
path
.
dirname
(
args
.
output_path
),
exist_ok
=
True
)
# otherwise cannot save
results
[
"config"
][
"model"
]
=
args
.
model_path
with
open
(
args
.
output_path
,
"w"
)
as
f
:
json
.
dump
(
results
,
f
,
indent
=
2
)
awq_model
=
get_awq_model
(
model
)
awq_model
.
quantize
(
model
,
w_bit
=
4
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
dump_path
)
if
__name__
==
'__main__'
:
main
()
model_path
=
"./mpt-7b-8k-chat"
search_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
}
awq/models/base.py
View file @
5d4ab5dc
import
gc
import
tqdm
import
torch
import
functools
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
class
BaseAWQForCausalLM
:
@
torch
.
no_grad
()
def
quantize
(
self
,
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
calib_data
=
"pileval"
,
init_only
=
False
):
if
run_search
:
self
.
_awq_search
(
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
self
.
_awq_quant
(
model
,
w_bit
,
q_config
,
init_only
)
def
_awq_quant
(
self
,
model
,
w_bit
,
q_config
,
init_only
):
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
model
)
# Run AWQ quantization
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
if
not
isinstance
(
layer
.
ffn
.
act
,
ScaledActivation
):
param
=
next
(
layer
.
parameters
())
# get activation scale
scale_dict
=
self
.
get_act_for_scaling
(
layer
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
else
:
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
_awq_search
(
self
,
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
layers
=
self
.
get_model_layers
(
model
)
samples
=
get_calib_dataset
(
...
...
@@ -62,8 +117,8 @@ class BaseAWQForCausalLM:
"clip"
:
[],
}
#
solve
layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"
Running AWQ...
"
):
#
Run AWQ search
layer by layer
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"
AWQ Search:
"
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
...
...
@@ -119,7 +174,7 @@ class BaseAWQForCausalLM:
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
awq_results
def
save_quantized
():
...
...
awq/quantize/quantizer.py
View file @
5d4ab5dc
import
torch
import
torch.nn
as
nn
from
tqdm
import
tqdm
import
gc
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
EMBEDDING_KEYWORDS
=
[
"embed"
]
LM_HEAD_KEYWORDS
=
[
"lm_head"
,
"embed_out"
,
"output"
]
def
scale_activations
(
module
):
param
=
next
(
module
.
parameters
())
dtype
=
param
.
dtype
device
=
param
.
device
if
isinstance
(
module
,
BloomBlock
):
if
isinstance
(
module
.
mlp
.
gelu_impl
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
gelu_impl
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.gelu_impl"
,
act
)
elif
'mptblock'
in
str
(
module
.
__class__
.
__name__
).
lower
():
if
isinstance
(
module
.
ffn
.
act
,
ScaledActivation
):
return
# get activation scale
scale_dict
=
MptAWQForCausalLM
().
get_act_for_scaling
(
module
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
dtype
,
device
=
device
)
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
module
,
scale_dict
[
'scale_name'
],
scaled_act
)
elif
'falcon'
in
str
(
module
.
__class__
).
lower
():
if
isinstance
(
module
.
mlp
.
act
,
ScaledActivation
):
return
c
=
module
.
mlp
.
dense_h_to_4h
.
out_features
act
=
ScaledActivation
(
module
.
mlp
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"mlp.act"
,
act
)
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
w
,
n_bit
=
8
,
zero_point
=
True
,
q_group_size
=-
1
,
...
...
@@ -107,38 +63,8 @@ def pseudo_quantize_model_weight(
@
torch
.
no_grad
()
def
real_quantize_model_weight
(
model
,
w_bit
,
q_config
,
init_only
=
False
):
from
.qmodule
import
WQLinear
from
.pre_quant
import
get_blocks
,
get_named_linears
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
get_blocks
(
model
)
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
+
(
"(init only)"
if
init_only
else
""
)):
def
real_quantize_model_weight
(
model
,
awq_model
):
layers
=
awq_model
.
get_model_layers
(
model
)
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"real weight quantization..."
):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
scale_activations
(
layer
)
for
name
,
module
in
named_linears
.
items
():
if
init_only
:
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
else
:
module
.
cuda
()
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
module
.
weight
.
data
,
n_bit
=
w_bit
,
get_scale_zp
=
True
,
**
q_config
)
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
del
layer
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment