Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
de8e7ffc
Commit
de8e7ffc
authored
Sep 21, 2023
by
Casper
Browse files
Merge branch 'main' into bigcode
parents
19568c52
72f954ce
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
559 additions
and
556 deletions
+559
-556
awq/entry.py
awq/entry.py
+0
-191
awq/models/base.py
awq/models/base.py
+42
-223
awq/models/opt.py
awq/models/opt.py
+5
-3
awq/quantize/auto_clip.py
awq/quantize/auto_clip.py
+0
-99
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+337
-40
awq/quantize/scale.py
awq/quantize/scale.py
+112
-0
awq/utils/utils.py
awq/utils/utils.py
+7
-0
examples/eval.py
examples/eval.py
+56
-0
No files found.
awq/entry.py
deleted
100644 → 0
View file @
19568c52
import
os
import
time
import
torch
import
argparse
from
lm_eval
import
evaluator
from
awq
import
AutoAWQForCausalLM
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
transformers
import
AutoTokenizer
,
GenerationConfig
def
load_search_result_into_memory
(
model
,
search_path
):
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
def
run_search
(
model_path
,
dump_path
,
quant_config
):
"""
Step 1/2: Search the pile for an optimal scaling factor.
"""
# Load model
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
model
.
quantize
(
tokenizer
,
quant_config
=
quant_config
,
run_search
=
True
,
run_quant
=
False
)
# Save search results
model
.
save_quantized
(
dump_path
)
# Save tokenizer
tokenizer
.
save_pretrained
(
dump_path
)
def
run_quant
(
model_path
,
search_path
,
dump_path
,
quant_config
):
"""
Step 2/2: Use the search results to quantize model weights
"""
# Load model and search results
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
load_search_result_into_memory
(
model
.
model
,
search_path
)
# Run actual weight quantization
model
.
quantize
(
quant_config
=
quant_config
,
run_search
=
False
,
run_quant
=
True
)
# Save quantized model
model
.
save_quantized
(
dump_path
)
def
run_eval
(
model_path
,
quant_file
,
device
,
tasks
,
task_batch_size
,
task_n_shot
,
task_use_pretrained
):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
if
task_use_pretrained
:
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
else
:
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Load adapter
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
task_batch_size
)
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
tasks
.
split
(
','
),
batch_size
=
task_batch_size
,
no_cache
=
True
,
num_fewshot
=
task_n_shot
,
)
print
(
evaluator
.
make_table
(
results
))
@
torch
.
inference_mode
()
def
run_speed
(
model_path
,
quant_file
,
device
,
n_generate
=
128
,
n_context
=
256
,
batch_size
=
1
,
disable_fused_layers
=
False
):
def
_timer
(
func
):
start
=
time
.
time
()
out
=
func
()
return
out
,
time
.
time
()
-
start
def
_warmup
(
device
:
str
):
warm_up
=
torch
.
randn
((
4096
,
4096
)).
to
(
device
)
torch
.
mm
(
warm_up
,
warm_up
)
if
quant_file
:
fuse_layers
=
False
if
disable_fused_layers
else
True
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
fuse_layers
))
else
:
model
,
load_time
=
_timer
(
lambda
:
AutoAWQForCausalLM
.
from_pretrained
(
model_path
))
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
_warmup
(
device
)
# Generate random inputs
n_context
=
n_context
-
n_generate
ids
=
torch
.
randint
(
0
,
tokenizer
.
vocab_size
,
(
batch_size
,
n_context
)).
cuda
()
# Context stage
_
,
context_time
=
_timer
(
lambda
:
model
.
generate
(
ids
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
0
,
min_new_tokens
=
0
,
use_cache
=
True
)
))
# Generation stage
_
,
generation_time
=
_timer
(
lambda
:
model
.
generate
(
ids
,
generation_config
=
GenerationConfig
(
max_new_tokens
=
n_context
,
min_new_tokens
=
n_context
,
forced_eos_token_id
=-
100
,
pad_token_id
=
tokenizer
.
pad_token_id
,
eos_token_id
=-
100
,
use_cache
=
True
)
))
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
context_tokens_per_second
=
n_context
/
context_time
*
batch_size
context_ms_per_token
=
(
context_time
*
1000
)
/
n_context
/
batch_size
inference_tokens_per_second
=
n_generate
/
generation_time
*
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
/
batch_size
print
(
f
"[=] Model summary:
{
model_path
}
[=]"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
print
(
f
"[*] Context speed:
{
context_tokens_per_second
:.
2
f
}
tokens/second (
{
context_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] Generation speed:
{
inference_tokens_per_second
:.
2
f
}
tokens/second (
{
inference_ms_per_token
:.
2
f
}
ms/token)"
)
print
(
f
"[*] VRAM:
{
memory_used
:.
2
f
}
MB"
)
if
__name__
==
'__main__'
:
"""
- Run AWQ search and save result:
python -m awq.entry --entry_type search --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq
- Run AWQ to save the real quantized weights at the quant_path:
python -m awq.entry --entry_type quant --model_path lmsys/vicuna-7b-v1.5 --search_path vicuna-7b-v1.5-awq/awq_model_search_result.pt --quant_path vicuna-7b-v1.5-awq
- Run perplexity of quantized model:
python -m awq.entry --entry_type eval --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python -m awq.entry --entry_type eval --model_path lmsys/vicuna-7b-v1.5 --task_use_pretrained
- Run a speedtest to benchmark the quantized model:
python -m awq.entry --entry_type speed --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt --n_generate 128 --n_context 256
- Run a speedtest to benchmark the unquantized FP16 model:
python -m awq.entry --entry_type speed --model_path lmsys/vicuna-7b-v1.5 --n_generate 128 --n_context 256
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|eval|speed)'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to save AWQ model to directory'
)
parser
.
add_argument
(
'--quant_file'
,
type
=
str
,
help
=
'Path to quantized AWQ model file'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda:0'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--tasks'
,
type
=
str
,
default
=
'wikitext'
,
help
=
'Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md'
)
parser
.
add_argument
(
"--task_use_pretrained"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Pass '--task_use_pretrained' to use a pretrained model running FP16"
)
parser
.
add_argument
(
'--task_batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--task_n_shot'
,
type
=
int
,
default
=
0
)
parser
.
add_argument
(
'--n_generate'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--n_context'
,
type
=
int
,
default
=
256
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
"--disable_fused_layers"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Pass '--disable_fused_layers' to disable fused layers"
)
args
=
parser
.
parse_args
()
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
,
"w_bit"
:
args
.
w_bit
}
if
args
.
entry_type
==
'search'
:
run_search
(
args
.
model_path
,
args
.
search_path
,
quant_config
)
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
quant_config
)
elif
args
.
entry_type
==
'eval'
:
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
awq/models/base.py
View file @
de8e7ffc
...
@@ -2,26 +2,19 @@ import os
...
@@ -2,26 +2,19 @@ import os
import
gc
import
gc
import
json
import
json
import
torch
import
torch
import
logging
import
functools
import
torch.nn
as
nn
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
tqdm
import
tqdm
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
collections
import
defaultdict
from
safetensors.torch
import
save_file
from
safetensors.torch
import
save_file
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
huggingface_hub
import
snapshot_download
from
huggingface_hub
import
snapshot_download
from
awq.quantize.quantizer
import
AwqQuantizer
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
,
PreTrainedModel
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
accelerate
import
init_empty_weights
,
load_checkpoint_in_model
,
infer_auto_device_map
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
(
nn
.
Module
):
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
quant_config
):
...
@@ -43,238 +36,64 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -43,238 +36,64 @@ class BaseAWQForCausalLM(nn.Module):
return
self
.
model
.
generate
(
*
args
,
**
kwargs
)
return
self
.
model
.
generate
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
tokenizer
=
None
,
quant_config
=
{},
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
True
,
run_quant
=
True
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
split
=
"train"
,
text_column
=
"text"
):
text_column
=
"text"
):
self
.
quant_config
=
quant_config
self
.
quant_config
=
quant_config
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
quant_config
[
"version"
]
=
"GEMM"
if
'version'
not
in
quant_config
.
keys
()
else
quant_config
[
"version"
]
if
run_search
:
quantizer
=
AwqQuantizer
(
self
.
search_result
=
self
.
_awq_search
(
self
,
self
.
model
,
tokenizer
,
quant_config
[
"w_bit"
],
quant_config
[
"q_group_size"
],
tokenizer
,
quant_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
quant_config
[
"version"
],
calib_data
,
split
,
text_column
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
,
)
split
=
split
,
text_column
=
text_column
quantizer
.
quantize
()
)
self
.
is_quantized
=
True
if
run_quant
:
self
.
_awq_quant
()
self
.
is_quantized
=
True
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
,
quant_config
):
def
fuse_layers
(
model
,
quant_config
):
pass
pass
def
_awq_quant
(
self
):
assert
self
.
quant_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
self
.
model
)
# Run AWQ quantization
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
self
.
_scale_activations
(
self
,
layer
)
for
name
,
module
in
named_linears
.
items
():
# Save model
module
.
cuda
()
class
EmptyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
EmptyModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
module
.
weight
.
data
,
scales
,
zeros
=
pseudo_quantize_tensor
(
# Save model files with empty state dict
module
.
weight
.
data
,
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
EmptyModule
().
state_dict
())
get_scale_zp
=
True
,
w_bit
=
self
.
quant_config
[
"w_bit"
],
q_group_size
=
self
.
quant_config
[
"q_group_size"
]
)
if
self
.
quant_config
[
"version"
]
==
'GEMM'
:
# Remove empty state dict
scales
=
scales
.
t
().
contiguous
()
os
.
remove
(
f
'
{
save_dir
}
/pytorch_model.bin'
)
zeros
=
zeros
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMM
elif
self
.
quant_config
[
"version"
]
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
module
,
self
.
quant_config
[
'w_bit'
],
self
.
quant_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
# model_name has no extension, add it when saving state_dict
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
model_name
=
'model.safetensors'
if
safetensors
else
'pytorch_model.bin'
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
def
_awq_search
(
self
,
tokenizer
,
quant_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
split
=
"train"
,
text_column
=
"text"
):
layers
=
self
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
# shard checkpoint into chunks (10GB default)
data
=
calib_data
,
tokenizer
=
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
,
shards
,
index
=
shard_checkpoint
(
split
=
split
,
text_column
=
text_column
self
.
model
.
state_dict
(),
max_shard_size
=
shard_size
,
weights_name
=
model_name
)
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
self
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
self
.
move_embed
(
self
.
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
awq_results
=
{
"scale"
:
[],
"clip"
:
[],
}
# Run AWQ search layer by layer
for
shard_file
,
shard
in
shards
.
items
():
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Search"
):
if
safetensors
:
layer
=
layers
[
i
]
# safetensors must be in the same memory, so we duplicate and use contiguous memory
layer
=
layer
.
cuda
()
shard
=
{
k
:
v
.
clone
().
contiguous
()
for
k
,
v
in
shard
.
items
()}
named_linears
=
get_named_linears
(
layer
)
save_file
(
shard
,
os
.
path
.
join
(
save_dir
,
shard_file
),
metadata
=
{
"format"
:
"pt"
})
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
self
,
layer
,
layer_kwargs
,
quant_config
=
quant_config
,
input_feat
=
input_feat
,
)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
quant_config
=
quant_config
,
input_feat
=
input_feat
)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
awq_results
def
save_quantized
(
self
,
save_dir
,
safetensors
=
False
,
shard_size
=
"10GB"
):
def
_save_files
(
save_dir
,
model_name
=
''
,
search_result
=
None
):
class
EmptyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
EmptyModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
# Save model files with empty state dict
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
EmptyModule
().
state_dict
())
# Remove empty state dict
os
.
remove
(
f
'
{
save_dir
}
/pytorch_model.bin'
)
if
search_result
is
not
None
:
torch
.
save
(
search_result
,
f
'
{
save_dir
}
/
{
model_name
}
'
)
else
:
else
:
# model_name has no extension, add it when saving state_dict
torch
.
save
(
shard
,
os
.
path
.
join
(
save_dir
,
shard_file
))
model_name
=
'model.safetensors'
if
safetensors
else
'pytorch_model.bin'
# shard checkpoint into chunks (10GB default)
# save shard index
shards
,
index
=
shard_checkpoint
(
if
index
is
not
None
:
self
.
model
.
state_dict
(),
with
open
(
f
'
{
save_dir
}
/
{
model_name
}
.index.json'
,
'w+'
)
as
file
:
max_shard_size
=
shard_size
,
file
.
write
(
json
.
dumps
(
index
,
indent
=
4
))
weights_name
=
model_name
)
for
shard_file
,
shard
in
shards
.
items
():
# Save config
if
safetensors
:
with
open
(
f
'
{
save_dir
}
/quant_config.json'
,
'w+'
)
as
file
:
# safetensors must be in the same memory, so we duplicate and use contiguous memory
file
.
write
(
json
.
dumps
(
self
.
quant_config
,
indent
=
4
))
shard
=
{
k
:
v
.
clone
().
contiguous
()
for
k
,
v
in
shard
.
items
()}
save_file
(
shard
,
os
.
path
.
join
(
save_dir
,
shard_file
),
metadata
=
{
"format"
:
"pt"
})
else
:
torch
.
save
(
shard
,
os
.
path
.
join
(
save_dir
,
shard_file
))
# save shard index
if
index
is
not
None
:
with
open
(
f
'
{
save_dir
}
/
{
model_name
}
.index.json'
,
'w+'
)
as
file
:
file
.
write
(
json
.
dumps
(
index
,
indent
=
4
))
# Save config
with
open
(
f
'
{
save_dir
}
/quant_config.json'
,
'w+'
)
as
file
:
file
.
write
(
json
.
dumps
(
self
.
quant_config
,
indent
=
4
))
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
# Save model
if
self
.
search_result
is
None
or
self
.
is_quantized
:
_save_files
(
save_dir
,
''
,
search_result
=
None
)
else
:
model_name
=
'awq_model_search_result.pt'
_save_files
(
save_dir
,
model_name
,
self
.
search_result
)
@
classmethod
@
classmethod
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
...
...
awq/models/opt.py
View file @
de8e7ffc
...
@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -27,10 +27,12 @@ class OptAWQForCausalLM(BaseAWQForCausalLM):
# attention input
# attention input
layers
.
append
(
dict
(
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn_layer_norm
,
prev_op
=
module
.
self_attn_layer_norm
,
layers
=
[
module
.
self_attn
.
q_proj
,
layers
=
[
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
inp
=
input_feat
[
'self_attn.q_proj'
],
inp
=
input_feat
[
'self_attn.q_proj'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
))
# attention out
# attention out
...
...
awq/quantize/auto_clip.py
deleted
100644 → 0
View file @
19568c52
import
torch
import
torch.nn
as
nn
from
.quantizer
import
pseudo_quantize_tensor
import
gc
__all__
=
[
"auto_clip_block"
]
# weight quantization
@
torch
.
no_grad
()
def
auto_clip_layer
(
w
,
input_feat
,
quant_config
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
quant_config
[
"q_group_size"
]
if
quant_config
[
"q_group_size"
]
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
pseudo_quantize_tensor
(
cur_w
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
])
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
del
input_feat
del
org_out
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
best_max_val
.
squeeze
(
1
)
@
torch
.
no_grad
()
def
auto_clip_block
(
module
,
quant_config
,
input_feat
):
named_linears
=
{
name
:
m
for
name
,
m
in
module
.
named_modules
()
if
isinstance
(
m
,
nn
.
Linear
)}
clip_list
=
[]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]]):
continue
named_linears
[
name
].
cuda
()
max_val
=
auto_clip_layer
(
named_linears
[
name
].
weight
,
input_feat
[
name
],
quant_config
=
quant_config
)
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
@
torch
.
no_grad
()
def
apply_clip
(
module
,
clip_list
):
from
..utils.module
import
get_op_by_name
for
name
,
max_val
in
clip_list
:
layer
=
get_op_by_name
(
module
,
name
)
layer
.
cuda
()
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
org_shape
=
layer
.
weight
.
shape
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
awq/quantize/quantizer.py
View file @
de8e7ffc
import
torch
import
torch
import
logging
import
functools
import
torch.nn
as
nn
from
tqdm
import
tqdm
from
collections
import
defaultdict
from
awq.utils.utils
import
clear_memory
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
AwqQuantizer
:
def
__init__
(
self
,
awq_model
,
model
,
tokenizer
,
w_bit
,
group_size
,
version
,
calib_data
,
split
,
text_column
)
->
None
:
self
.
awq_model
=
awq_model
self
.
model
=
model
self
.
tokenizer
=
tokenizer
self
.
w_bit
=
w_bit
self
.
group_size
=
group_size
self
.
version
=
version
self
.
calib_data
=
calib_data
self
.
split
=
split
self
.
text_column
=
text_column
self
.
modules
,
self
.
module_kwargs
,
self
.
inps
=
self
.
init_quant
()
# core quantization method (simulated quantization)
def
pseudo_quantize_tensor
(
self
,
w
:
torch
.
Tensor
,
get_scale_zp
=
False
):
def
pseudo_quantize_tensor
(
w
,
w_bit
=
4
,
org_w_shape
=
w
.
shape
zero_point
=
True
,
if
self
.
group_size
>
0
:
q_group_size
=-
1
,
assert
org_w_shape
[
-
1
]
%
self
.
group_size
==
0
inplace
=
False
,
w
=
w
.
reshape
(
-
1
,
self
.
group_size
)
get_scale_zp
=
False
assert
w
.
dim
()
==
2
):
org_w_shape
=
w
.
shape
# zero point quantization
if
q_group_size
>
0
:
assert
org_w_shape
[
-
1
]
%
q_group_size
==
0
w
=
w
.
reshape
(
-
1
,
q_group_size
)
assert
w
.
dim
()
==
2
if
zero_point
:
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
w
.
amax
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
min_val
=
w
.
amin
(
dim
=
1
,
keepdim
=
True
)
max_int
=
2
**
w_bit
-
1
max_int
=
2
**
self
.
w_bit
-
1
min_int
=
0
min_int
=
0
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
scales
=
(
max_val
-
min_val
).
clamp
(
min
=
1e-5
)
/
max_int
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
zeros
=
(
-
torch
.
round
(
min_val
/
scales
)).
clamp_
(
min_int
,
max_int
)
else
:
# we actually never used this
assert
min_val
is
None
assert
torch
.
isnan
(
scales
).
sum
()
==
0
max_val
=
w
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
assert
torch
.
isnan
(
w
).
sum
()
==
0
max_val
=
max_val
.
clamp
(
min
=
1e-5
)
max_int
=
2
**
(
w_bit
-
1
)
-
1
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
min_int
=
-
2
**
(
w_bit
-
1
)
assert
torch
.
isnan
(
w
).
sum
()
==
0
scales
=
max_val
/
max_int
zeros
=
0
w
=
w
.
reshape
(
org_w_shape
)
assert
torch
.
isnan
(
scales
).
sum
()
==
0
if
get_scale_zp
:
assert
torch
.
isnan
(
w
).
sum
()
==
0
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
else
:
if
inplace
:
return
w
((
w
.
div_
(
scales
).
round_
().
add_
(
zeros
)).
clamp_
(
min_int
,
max_int
).
sub_
(
zeros
)).
mul_
(
scales
)
def
quantize
(
self
):
else
:
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"AWQ"
):
w
=
(
torch
.
clamp
(
torch
.
round
(
w
/
scales
)
+
# [STEP 1]: Get layer, extract linear modules, extract input features
zeros
,
min_int
,
max_int
)
-
zeros
)
*
scales
self
.
modules
[
i
]
=
self
.
modules
[
i
].
cuda
()
assert
torch
.
isnan
(
w
).
sum
()
==
0
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
input_feat
=
self
.
_get_input_feat
(
self
.
modules
[
i
],
named_linears
)
w
=
w
.
reshape
(
org_w_shape
)
clear_memory
()
if
get_scale_zp
:
# [STEP 2]: Compute and apply scale list
return
w
,
scales
.
view
(
w
.
shape
[
0
],
-
1
),
zeros
.
view
(
w
.
shape
[
0
],
-
1
)
module_config
:
list
[
dict
]
=
self
.
awq_model
.
get_layers_for_scaling
(
else
:
self
.
modules
[
i
],
input_feat
,
self
.
module_kwargs
return
w
)
scales_list
=
[
self
.
_search_best_scale
(
self
.
modules
[
i
],
**
layer
)
for
layer
in
module_config
]
apply_scale
(
self
.
modules
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
scales_list
=
append_str_prefix
(
scales_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 3]: Compute and apply clipping list
clip_list
=
self
.
_search_best_clip
(
self
.
modules
[
i
],
named_linears
,
input_feat
)
apply_clip
(
self
.
modules
[
i
],
clip_list
)
clip_list
=
append_str_prefix
(
clip_list
,
get_op_name
(
self
.
model
,
self
.
modules
[
i
])
+
"."
)
# [STEP 4]: Quantize weights
self
.
_apply_quant
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
def
_apply_quant
(
self
,
module
,
named_linears
:
dict
[
str
,
nn
.
Linear
]):
for
name
,
linear_layer
in
named_linears
.
items
():
# NOTE: small regression in perplexity if linear layer uses .cpu().float()
linear_layer
=
linear_layer
.
cuda
().
half
()
linear_layer
.
weight
.
data
,
scales
,
zeros
=
self
.
pseudo_quantize_tensor
(
linear_layer
.
weight
.
data
,
get_scale_zp
=
True
)
if
self
.
version
==
'GEMM'
:
scales
=
scales
.
t
().
contiguous
()
zeros
=
zeros
.
t
().
contiguous
()
q_linear_module
=
WQLinear_GEMM
elif
self
.
version
==
'GEMV'
:
q_linear_module
=
WQLinear_GEMV
q_linear
=
q_linear_module
.
from_linear
(
linear
=
linear_layer
,
w_bit
=
self
.
w_bit
,
group_size
=
self
.
group_size
,
init_only
=
False
,
scales
=
scales
,
zeros
=
zeros
)
linear_layer
.
cpu
()
q_linear
.
to
(
next
(
module
.
parameters
()).
device
)
set_op_by_name
(
module
,
name
,
q_linear
)
clear_memory
()
@
torch
.
no_grad
()
def
_search_best_scale
(
self
,
module
,
prev_op
,
layers
:
list
[
nn
.
Linear
],
inp
:
torch
.
Tensor
,
module2inspect
=
None
,
kwargs
=
{}):
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
if
"use_cache"
in
kwargs
:
kwargs
.
pop
(
"use_cache"
)
# Put x on the right device
inp
=
inp
.
to
(
next
(
module2inspect
.
parameters
()).
device
)
# [STEP 1]: Compute maximum of weight
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
layers
],
dim
=
0
)
org_shape
=
weight
.
shape
weight
=
weight
.
view
(
-
1
,
self
.
group_size
)
w_scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
w_scale
=
w_scale
.
view
(
org_shape
)
w_max
=
w_scale
.
mean
(
0
)
clear_memory
(
weight
)
# [STEP 2]: Compute maximum of x
x_max
=
inp
.
abs
().
view
(
-
1
,
inp
.
shape
[
-
1
]).
mean
(
0
)
# [STEP 3]: Compute output of module
with
torch
.
no_grad
():
fp16_output
=
module2inspect
(
inp
,
**
kwargs
)
if
isinstance
(
fp16_output
,
tuple
):
fp16_output
=
fp16_output
[
0
]
# [STEP 4]: Compute loss
best_scales
=
self
.
_compute_best_scale
(
inp
,
w_max
,
x_max
,
module2inspect
,
layers
,
fp16_output
,
kwargs
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
best_scales
)
def
_compute_best_scale
(
self
,
x
,
w_max
,
x_max
,
module2inspect
,
linears2scale
:
list
[
nn
.
Linear
],
fp16_output
,
kwargs
=
{}):
"""
Compute loss and select best scales
L(s) = || Q(W * s) (s^-1 * X) - W * X ||
Q: weight quantization function | pseudo_quantize_tensor(W * s)
X: inputs from calib dataset | X
W: original weights in FP16 | layer
s: per channel scaling factor | s^-1 * X
"""
n_grid
=
20
history
=
[]
best_ratio
=
-
1
best_scales
=
None
best_error
=
float
(
'inf'
)
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
module2inspect
.
state_dict
().
items
()}
device
=
x
.
device
x_max
=
x_max
.
view
(
-
1
).
to
(
device
)
w_max
=
w_max
.
view
(
-
1
).
to
(
device
)
for
ratio
in
range
(
n_grid
):
# create new scales
ratio
=
ratio
/
n_grid
# NOTE: s^-1 * x is fused here, according to paper
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)).
clamp
(
min
=
1e-4
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
scales_view
=
scales
.
view
(
1
,
-
1
).
to
(
device
)
# Q(W * s)
for
fc
in
linears2scale
:
fc
.
weight
.
mul_
(
scales_view
)
fc
.
weight
.
data
=
self
.
pseudo_quantize_tensor
(
fc
.
weight
.
data
)
/
scales_view
# W * X
int_w_output
=
module2inspect
(
x
,
**
kwargs
)
if
isinstance
(
int_w_output
,
tuple
):
int_w_output
=
int_w_output
[
0
]
# compute mean squared error (L2 norm)
loss
=
(
fp16_output
-
int_w_output
).
float
().
pow
(
2
).
mean
().
item
()
# NOTE: float prevents overflow
history
.
append
(
loss
)
if
loss
<
best_error
:
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
.
clone
()
module2inspect
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
raise
Exception
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
return
best_scales
.
detach
().
cpu
()
@
torch
.
no_grad
()
def
_search_best_clip
(
self
,
layer
,
named_linears
,
input_feat
):
clip_list
=
[]
avoid_clipping
=
[
"q_"
,
"k_"
,
"query"
,
"key"
,
"Wqkv"
]
for
name
in
named_linears
:
# due to qk bmm, it is hard to clip precisely
if
any
([
_
in
name
for
_
in
avoid_clipping
]):
continue
named_linears
[
name
].
cuda
()
max_val
=
self
.
_compute_best_clip
(
named_linears
[
name
].
weight
,
input_feat
[
name
])
clip_list
.
append
((
name
,
max_val
))
named_linears
[
name
].
cpu
()
return
clip_list
@
torch
.
no_grad
()
def
_compute_best_clip
(
self
,
w
:
torch
.
Tensor
,
input_feat
:
torch
.
Tensor
,
n_grid
=
20
,
max_shrink
=
0.5
,
n_sample_token
=
512
):
assert
w
.
dim
()
==
2
org_w_shape
=
w
.
shape
# w [co, ci] -> [co, 1, n_group, group size]
# input_feat [n_token, ci] -> [1, n_token, n_group, group size]
group_size
=
self
.
group_size
if
self
.
group_size
>
0
else
w
.
shape
[
1
]
input_feat
=
input_feat
.
view
(
-
1
,
input_feat
.
shape
[
-
1
])
input_feat
=
input_feat
.
reshape
(
1
,
input_feat
.
shape
[
0
],
-
1
,
group_size
)
input_feat
=
input_feat
[:,
0
::
input_feat
.
shape
[
1
]
//
n_sample_token
]
w
=
w
.
reshape
(
w
.
shape
[
0
],
1
,
-
1
,
group_size
)
oc_batch_size
=
256
if
w
.
shape
[
0
]
%
256
==
0
else
64
# prevent OOM
assert
w
.
shape
[
0
]
%
oc_batch_size
==
0
w_all
=
w
best_max_val_all
=
[]
for
i_b
in
range
(
w
.
shape
[
0
]
//
oc_batch_size
):
w
=
w_all
[
i_b
*
oc_batch_size
:
(
i_b
+
1
)
*
oc_batch_size
]
org_max_val
=
w
.
abs
().
amax
(
dim
=-
1
,
keepdim
=
True
)
# co, 1, n_group, 1
best_max_val
=
org_max_val
.
clone
()
min_errs
=
torch
.
ones_like
(
org_max_val
)
*
1e9
input_feat
=
input_feat
.
to
(
w
.
device
)
org_out
=
(
input_feat
*
w
).
sum
(
dim
=-
1
)
# co, n_token, n_group
for
i_s
in
range
(
int
(
max_shrink
*
n_grid
)):
max_val
=
org_max_val
*
(
1
-
i_s
/
n_grid
)
min_val
=
-
max_val
cur_w
=
torch
.
clamp
(
w
,
min_val
,
max_val
)
q_w
=
self
.
pseudo_quantize_tensor
(
cur_w
)
cur_out
=
(
input_feat
*
q_w
).
sum
(
dim
=-
1
)
# co, 1, n_group, 1
err
=
(
cur_out
-
org_out
).
pow
(
2
).
mean
(
dim
=
1
).
view
(
min_errs
.
shape
)
del
cur_w
del
cur_out
cur_best_idx
=
err
<
min_errs
min_errs
[
cur_best_idx
]
=
err
[
cur_best_idx
]
best_max_val
[
cur_best_idx
]
=
max_val
[
cur_best_idx
]
best_max_val_all
.
append
(
best_max_val
)
best_max_val
=
torch
.
cat
(
best_max_val_all
,
dim
=
0
)
clear_memory
(
input_feat
)
clear_memory
(
org_out
)
return
best_max_val
.
squeeze
(
1
)
def
init_quant
(
self
,
n_samples
=
128
,
seqlen
=
512
):
modules
=
self
.
awq_model
.
get_model_layers
(
self
.
model
)
samples
=
get_calib_dataset
(
data
=
self
.
calib_data
,
tokenizer
=
self
.
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
,
split
=
self
.
split
,
text_column
=
self
.
text_column
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
modules
[
0
]
=
modules
[
0
].
cuda
()
self
.
awq_model
.
move_embed
(
self
.
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
hijacked_inputs
,
**
kwargs
):
inps
.
append
(
hijacked_inputs
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
modules
[
0
]
=
Catcher
(
modules
[
0
])
try
:
self
.
model
(
samples
.
to
(
next
(
self
.
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
modules
[
0
]
=
modules
[
0
].
module
# restore
inps
=
inps
[
0
]
modules
[
0
]
=
modules
[
0
].
cpu
()
self
.
awq_model
.
move_embed
(
self
.
model
,
"cpu"
)
clear_memory
()
return
modules
,
layer_kwargs
,
inps
def
_get_input_feat
(
self
,
layer
,
named_linears
):
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
self
.
inps
=
self
.
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
self
.
inps
=
layer
(
self
.
inps
,
**
self
.
module_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
return
input_feat
awq/quantize/
auto_
scale.py
→
awq/quantize/scale.py
View file @
de8e7ffc
import
gc
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
logging
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
,
BloomGelu
from
typing
import
Tuple
from
transformers.models.opt.modeling_opt
import
OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.activations
import
NewGELUActivation
,
PytorchGELUTanh
from
awq.modules.act
import
ScaledActivation
from
awq.modules.act
import
ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
from
transformers.activations
import
NewGELUActivation
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
]
allowed_act_fns
=
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
]
norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
]
act_functions
=
[
nn
.
GELU
,
BloomGelu
,
NewGELUActivation
,
PytorchGELUTanh
]
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
get_weight_scale
(
weight
,
q_group_size
=-
1
):
def
apply_clip
(
module
,
clip_list
:
Tuple
[
str
,
torch
.
Tensor
]):
org_shape
=
weight
.
shape
for
name
,
max_val
in
clip_list
:
if
q_group_size
>
0
:
layer
:
nn
.
Linear
=
get_op_by_name
(
module
,
name
)
weight
=
weight
.
view
(
-
1
,
q_group_size
)
layer
.
cuda
()
scale
=
weight
.
abs
()
/
weight
.
abs
().
amax
(
dim
=
1
,
keepdim
=
True
)
max_val
=
max_val
.
to
(
layer
.
weight
.
device
)
scale
=
scale
.
view
(
org_shape
)
org_shape
=
layer
.
weight
.
shape
scale
=
scale
.
mean
(
0
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
*
max_val
.
shape
[:
2
],
-
1
)
return
scale
layer
.
weight
.
data
=
torch
.
clamp
(
layer
.
weight
.
data
,
-
max_val
,
max_val
)
layer
.
weight
.
data
=
layer
.
weight
.
data
.
reshape
(
org_shape
)
layer
.
cpu
()
@
torch
.
no_grad
()
def
apply_scale
(
module
,
scales_list
,
input_feat_dict
=
None
):
def
get_act_scale
(
x
):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
return
x
.
abs
().
view
(
-
1
,
x
.
shape
[
-
1
]).
mean
(
0
)
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
allowed_norms
)
\
or
'rmsnorm'
in
str
(
prev_op
.
__class__
).
lower
():
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
allowed_act_fns
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
# apply the scaling to input feat if given; prepare it for clipping
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
scale_ln_fcs
(
ln
,
fcs
,
scales
):
def
scale_ln_fcs
(
ln
:
nn
.
Linear
,
fcs
:
list
[
nn
.
Linear
],
scales
:
torch
.
Tensor
):
if
not
isinstance
(
fcs
,
list
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
fcs
=
[
fcs
]
scales
=
scales
.
to
(
ln
.
weight
.
device
)
scales
=
scales
.
to
(
ln
.
weight
.
device
)
# debugging start even scales = 1 does not work?
"""
scales = scales * 0
scales = scales + 1
"""
# debugging end
ln
.
weight
.
div_
(
scales
)
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
if
hasattr
(
ln
,
'bias'
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
ln
.
bias
.
div_
(
scales
)
...
@@ -58,16 +82,13 @@ def scale_ln_fcs(ln, fcs, scales):
...
@@ -58,16 +82,13 @@ def scale_ln_fcs(ln, fcs, scales):
for
p
in
fc
.
parameters
():
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
scale_fc_fc
(
fc1
,
fc2
,
scales
):
def
scale_fc_fc
(
fc1
:
nn
.
Linear
,
fc2
:
nn
.
Linear
,
scales
:
torch
.
Tensor
):
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc1
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
assert
isinstance
(
fc2
,
nn
.
Linear
)
# assert fc1.out_features == fc2.in_features
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
# fc1.weight.div_(scales.view(-1, 1))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
...
@@ -81,141 +102,11 @@ def scale_fc_fc(fc1, fc2, scales):
...
@@ -81,141 +102,11 @@ def scale_fc_fc(fc1, fc2, scales):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
,
fc
,
scales
):
def
scale_gelu_fc
(
gelu
:
allowed_act_fns
,
fc
:
nn
.
Linear
,
scales
:
torch
.
Tensor
):
assert
any
(
isinstance
(
gelu
,
t
)
for
t
in
a
ct_functio
ns
)
assert
any
(
isinstance
(
gelu
,
t
)
for
t
in
a
llowed_act_f
ns
)
assert
isinstance
(
fc
,
nn
.
Linear
)
assert
isinstance
(
fc
,
nn
.
Linear
)
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
for
p
in
fc
.
parameters
():
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
assert
torch
.
isnan
(
p
).
sum
()
==
0
\ No newline at end of file
@
torch
.
no_grad
()
def
auto_scale_block
(
awq_model
,
module
,
module_kwargs
,
quant_config
,
input_feat
):
from
.quantizer
import
pseudo_quantize_tensor
# firstly, get the weight quantize function
if
quant_config
[
'w_bit'
]
is
not
None
:
def
w_quantize_func
(
p
):
return
pseudo_quantize_tensor
(
p
,
w_bit
=
quant_config
[
"w_bit"
],
q_group_size
=
quant_config
[
"q_group_size"
]).
detach
()
else
:
def
w_quantize_func
(
p
):
return
p
if
"use_cache"
in
module_kwargs
:
module_kwargs
.
pop
(
"use_cache"
)
# find the best scale ratio
def
_search_module_scale
(
block
,
linears2scale
:
list
,
x
,
kwargs
=
{}):
# w: co, ci
# x: n, ci
weight
=
torch
.
cat
([
_m
.
weight
for
_m
in
linears2scale
],
dim
=
0
)
w_max
=
get_weight_scale
(
weight
,
q_group_size
=
quant_config
.
get
(
"q_group_size"
,
-
1
))
# Clear GPU memory
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
with
torch
.
no_grad
():
org_out
=
block
(
x
,
**
kwargs
)
if
isinstance
(
org_out
,
tuple
):
org_out
=
org_out
[
0
]
x_max
=
get_act_scale
(
x
)
best_error
=
float
(
'inf'
)
best_ratio
=
-
1
best_scales
=
None
n_grid
=
20
history
=
[]
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
block
.
state_dict
().
items
()}
for
ratio
in
range
(
n_grid
):
ratio
=
ratio
*
1
/
n_grid
scales
=
(
x_max
.
pow
(
ratio
)
/
w_max
.
pow
(
1
-
ratio
)
).
clamp
(
min
=
1e-4
).
view
(
-
1
)
scales
=
scales
/
(
scales
.
max
()
*
scales
.
min
()).
sqrt
()
for
fc
in
linears2scale
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
).
to
(
fc
.
weight
.
device
))
fc
.
weight
.
data
=
w_quantize_func
(
fc
.
weight
.
data
)
/
(
scales
.
view
(
1
,
-
1
))
out
=
block
(
x
,
**
kwargs
)
if
isinstance
(
out
,
tuple
):
out
=
out
[
0
]
loss
=
(
org_out
-
out
).
float
().
pow
(
2
).
mean
().
item
()
# float prevents overflow
history
.
append
(
loss
)
is_best
=
loss
<
best_error
if
is_best
:
best_error
=
loss
best_ratio
=
ratio
best_scales
=
scales
block
.
load_state_dict
(
org_sd
)
if
best_ratio
==
-
1
:
logging
.
debug
(
history
)
raise
Exception
best_scales
=
best_scales
.
view
(
-
1
)
assert
torch
.
isnan
(
best_scales
).
sum
()
==
0
,
best_scales
return
best_scales
.
detach
()
def
_auto_get_scale
(
prev_op
,
layers
,
inp
,
module2inspect
=
None
,
kwargs
=
{}):
# module2inspect: if given, we will check the output diff of this module instead of layers
if
module2inspect
is
None
:
assert
len
(
layers
)
==
1
module2inspect
=
layers
[
0
]
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
layers
:
list
[
dict
]
=
awq_model
.
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
)
scales_list
=
[
_auto_get_scale
(
**
layer
)
for
layer
in
layers
]
return
scales_list
def
apply_scale
(
module
,
scales_list
,
input_feat_dict
=
None
):
for
prev_op_name
,
layer_names
,
scales
in
scales_list
:
prev_op
=
get_op_by_name
(
module
,
prev_op_name
)
layers
=
[
get_op_by_name
(
module
,
name
)
for
name
in
layer_names
]
prev_op
.
cuda
()
for
layer
in
layers
:
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
norms
)
\
or
'rmsnorm'
in
str
(
prev_op
.
__class__
).
lower
():
scale_ln_fcs
(
prev_op
,
layers
,
scales
)
elif
any
(
isinstance
(
prev_op
,
t
)
for
t
in
act_functions
):
new_module
=
ScaledActivation
(
prev_op
,
scales
)
set_op_by_name
(
module
,
prev_op_name
,
new_module
)
scale_gelu_fc
(
prev_op
,
layers
[
0
],
scales
)
else
:
raise
NotImplementedError
(
f
"prev_op
{
type
(
prev_op
)
}
not supported yet!"
)
# apply the scaling to input feat if given; prepare it for clipping
if
input_feat_dict
is
not
None
:
for
layer_name
in
layer_names
:
inp
=
input_feat_dict
[
layer_name
]
inp
.
div_
(
scales
.
view
(
1
,
-
1
).
to
(
inp
.
device
))
prev_op
.
cpu
()
for
layer
in
layers
:
layer
.
cpu
()
scales
.
cpu
()
awq/utils/utils.py
View file @
de8e7ffc
import
gc
import
torch
import
torch
import
accelerate
import
accelerate
...
@@ -53,3 +54,9 @@ def set_module_name(model, name, value):
...
@@ -53,3 +54,9 @@ def set_module_name(model, name, value):
child_name
=
name
child_name
=
name
setattr
(
parent
,
child_name
,
value
)
setattr
(
parent
,
child_name
,
value
)
def
clear_memory
(
weight
=
None
):
if
weight
is
not
None
:
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
\ No newline at end of file
examples/eval.py
0 → 100644
View file @
de8e7ffc
import
argparse
from
lm_eval
import
evaluator
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
def
run_eval
(
model_path
,
quant_file
,
device
,
tasks
,
task_batch_size
,
task_n_shot
,
task_use_pretrained
):
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
# Load model
if
task_use_pretrained
:
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
else
:
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_file
,
fuse_layers
=
False
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Load adapter
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
task_batch_size
)
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
tasks
.
split
(
','
),
batch_size
=
task_batch_size
,
no_cache
=
True
,
num_fewshot
=
task_n_shot
,
)
print
(
evaluator
.
make_table
(
results
))
if
__name__
==
'__main__'
:
"""
- Run perplexity of quantized model:
python examples/eval.py --model_path vicuna-7b-v1.5-awq --quant_file awq_model_w4_g128.pt
- Run perplexity unquantized FP16 model:
python examples/eval.py --use_pretrained --model_path lmsys/vicuna-7b-v1.5
"""
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--quant_file'
,
default
=
''
,
type
=
str
,
help
=
'Path to quantized AWQ model file'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda:0'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
"--use_pretrained"
,
default
=
False
,
action
=
'store_true'
,
help
=
"Pass '--use_pretrained' to use a pretrained model running FP16"
)
parser
.
add_argument
(
'--tasks'
,
type
=
str
,
default
=
'wikitext'
,
help
=
'Tasks to evaluate. '
'Separate tasks by comma for multiple tasks.'
'https://github.com/EleutherAI/lm-evaluation-harness/blob/master/docs/task_table.md'
)
parser
.
add_argument
(
'--batch_size'
,
type
=
int
,
default
=
1
)
parser
.
add_argument
(
'--n_shot'
,
type
=
int
,
default
=
0
)
args
=
parser
.
parse_args
()
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
batch_size
,
args
.
n_shot
,
args
.
use_pretrained
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment