Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
290f45e7
Commit
290f45e7
authored
Aug 17, 2023
by
Casper Hansen
Browse files
Refactored pre_quant into base
parent
d5be2115
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
152 additions
and
339 deletions
+152
-339
awq/entry.py
awq/entry.py
+16
-10
awq/models/base.py
awq/models/base.py
+121
-2
awq/models/mpt.py
awq/models/mpt.py
+4
-4
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+7
-146
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+0
-175
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+1
-2
awq/utils/module.py
awq/utils/module.py
+3
-0
No files found.
awq/entry.py
View file @
290f45e7
...
@@ -7,11 +7,11 @@ import json
...
@@ -7,11 +7,11 @@ import json
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_in_model
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_in_model
from
accelerate.utils.modeling
import
get_balanced_memory
from
accelerate.utils.modeling
import
get_balanced_memory
from
awq.utils.parallel
import
auto_parallel
from
awq.utils.parallel
import
auto_parallel
from
awq.quantize.pre_quant
import
run_awq
,
apply_awq
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.utils
import
simple_dispatch_model
from
awq.utils.utils
import
simple_dispatch_model
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'path of the hf model'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'path of the hf model'
)
...
@@ -65,7 +65,13 @@ q_config = {
...
@@ -65,7 +65,13 @@ q_config = {
}
}
print
(
"Quantization config:"
,
q_config
)
print
(
"Quantization config:"
,
q_config
)
# build model and tokenizer
def
get_awq_model
(
model
):
from
awq.models
import
MptAWQForCausalLM
if
"mpt"
in
str
(
model
.
__class__
).
lower
():
return
MptAWQForCausalLM
()
else
:
raise
NotImplementedError
(
type
(
model
))
def
build_model_and_enc
(
model_path
):
def
build_model_and_enc
(
model_path
):
if
not
os
.
path
.
exists
(
model_path
):
# look into ssd
if
not
os
.
path
.
exists
(
model_path
):
# look into ssd
...
@@ -119,12 +125,10 @@ def build_model_and_enc(model_path):
...
@@ -119,12 +125,10 @@ def build_model_and_enc(model_path):
if
args
.
run_awq
:
if
args
.
run_awq
:
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
awq_results
=
run_awq
(
awq_model
=
get_awq_model
(
model
)
model
,
enc
,
awq_results
=
awq_model
.
quantize
(
model
,
enc
,
args
.
w_bit
,
q_config
)
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
n_samples
=
128
,
seqlen
=
512
,
)
if
args
.
dump_awq
:
if
args
.
dump_awq
:
dirpath
=
os
.
path
.
dirname
(
args
.
dump_awq
)
dirpath
=
os
.
path
.
dirname
(
args
.
dump_awq
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
...
@@ -137,7 +141,9 @@ def build_model_and_enc(model_path):
...
@@ -137,7 +141,9 @@ def build_model_and_enc(model_path):
if
args
.
load_awq
:
if
args
.
load_awq
:
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
apply_awq
(
model
,
awq_results
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
# weight quantization
# weight quantization
if
args
.
w_bit
is
not
None
:
if
args
.
w_bit
is
not
None
:
...
...
awq/models/base.py
View file @
290f45e7
import
gc
import
tqdm
import
torch
import
functools
import
torch.nn
as
nn
from
collections
import
defaultdict
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
class
BaseAWQForCausalLM
:
class
BaseAWQForCausalLM
:
@
torch
.
no_grad
()
def
quantize
(
self
,
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
calib_data
=
"pileval"
):
def
quantize
():
layers
=
self
.
get_model_layers
(
model
)
pass
samples
=
get_calib_dataset
(
data
=
calib_data
,
tokenizer
=
tokenizer
,
n_samples
=
n_samples
,
block_size
=
seqlen
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
self
.
move_embed
(
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
inp
,
**
kwargs
):
inps
.
append
(
inp
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
self
.
move_embed
(
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
awq_results
=
{
"scale"
:
[],
"clip"
:
[],
}
# solve layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
self
,
layer
,
layer_kwargs
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
awq_results
def
save_quantized
():
def
save_quantized
():
pass
pass
...
...
awq/models/mpt.py
View file @
290f45e7
...
@@ -3,10 +3,10 @@ from .base import BaseAWQForCausalLM
...
@@ -3,10 +3,10 @@ from .base import BaseAWQForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
layer_type
=
"MPTBlock"
def
get_model_layers
(
model
):
def
get_model_layers
(
self
,
model
):
return
model
.
transformer
.
blocks
return
model
.
transformer
.
blocks
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
def
get_layers_for_scaling
(
self
,
module
,
input_feat
,
module_kwargs
):
layers
=
[]
layers
=
[]
# attention input
# attention input
...
@@ -42,13 +42,13 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -42,13 +42,13 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return
layers
return
layers
def
get_act_for_scaling
(
module
):
def
get_act_for_scaling
(
self
,
module
):
return
dict
(
return
dict
(
scale_name
=
"ffn.act"
,
scale_name
=
"ffn.act"
,
scale_layer
=
module
.
ffn
.
act
,
scale_layer
=
module
.
ffn
.
act
,
scale_shape
=
module
.
ffn
.
up_proj
.
out_features
scale_shape
=
module
.
ffn
.
up_proj
.
out_features
)
)
def
move_embed
(
model
,
device
):
def
move_embed
(
self
,
model
,
device
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
\ No newline at end of file
awq/quantize/auto_scale.py
View file @
290f45e7
...
@@ -7,8 +7,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer
...
@@ -7,8 +7,7 @@ from transformers.models.opt.modeling_opt import OPTDecoderLayer
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaRMSNorm
from
.qmodule
import
ScaledActivation
from
.qmodule
import
ScaledActivation
from
..utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
from
awq.utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
from
..models
import
MptAWQForCausalLM
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
...
@@ -90,7 +89,8 @@ def scale_gelu_fc(gelu, fc, scales):
...
@@ -90,7 +89,8 @@ def scale_gelu_fc(gelu, fc, scales):
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
auto_scale_block
(
module
,
module_kwargs
,
def
auto_scale_block
(
awq_model
,
module
,
module_kwargs
,
w_bit
,
q_config
,
w_bit
,
q_config
,
input_feat
):
input_feat
):
from
.quantizer
import
pseudo_quantize_tensor
from
.quantizer
import
pseudo_quantize_tensor
...
@@ -174,149 +174,10 @@ def auto_scale_block(module, module_kwargs,
...
@@ -174,149 +174,10 @@ def auto_scale_block(module, module_kwargs,
# prev_op_name, [layer_name], scale
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
scales_list
=
[]
# return the searched scales
layers
:
list
[
dict
]
=
awq_model
.
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
if
isinstance
(
module
,
OPTDecoderLayer
):
)
# attention input
scales_list
=
[
_auto_get_scale
(
**
layer
)
for
layer
in
layers
]
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
self_attn_layer_norm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
inp
=
input_feat
[
'self_attn.q_proj'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# attn out
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
out_proj
],
inp
=
input_feat
[
'self_attn.out_proj'
],
))
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
final_layer_norm
,
layers
=
[
module
.
fc1
],
inp
=
input_feat
[
'fc1'
],
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
fc1
,
layers
=
[
module
.
fc2
],
inp
=
input_feat
[
'fc2'
],
))
elif
isinstance
(
module
,
LlamaDecoderLayer
):
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
inp
=
input_feat
[
'self_attn.q_proj'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
'self_attn.o_proj'
],
))
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
'mlp.gate_proj'
],
module2inspect
=
module
.
mlp
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
elif
isinstance
(
module
,
BloomBlock
):
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# attn out
# Please refer to https://github.com/mit-han-lab/llm-awq/issues/2#issuecomment-1606297469
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
gelu_impl
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
elif
"mpt"
in
str
(
module
.
__class__
).
lower
():
layers
:
list
[
dict
]
=
MptAWQForCausalLM
.
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
)
layers_scaled
=
[
_auto_get_scale
(
**
layer
)
for
layer
in
layers
]
scales_list
.
extend
(
layers_scaled
)
elif
"falcon"
in
str
(
module
.
__class__
).
lower
():
# attn out
# Haotian: TBD: need to handle repeated scales for MQ
"""
scales_list.append(_auto_get_scale(
prev_op=module.self_attention.query_key_value,
layers=[module.self_attention.dense],
inp=input_feat['self_attention.dense'],
))
"""
# fc1, as long as it is scaled, everything is screwed up
if
"falcon-7b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
,
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
elif
"falcon-40b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_attn
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_mlp
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
else
:
raise
NotImplementedError
(
"Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported"
)
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
act
,
layers
=
[
module
.
mlp
.
dense_4h_to_h
],
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
else
:
raise
NotImplementedError
(
f
"
{
type
(
module
)
}
not supported yet!"
)
return
scales_list
return
scales_list
...
...
awq/quantize/pre_quant.py
deleted
100644 → 0
View file @
d5be2115
import
torch
import
torch.nn
as
nn
import
tqdm
import
gc
import
functools
from
collections
import
defaultdict
from
transformers.models.bloom.modeling_bloom
import
BloomForCausalLM
from
transformers.models.opt.modeling_opt
import
OPTForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaForCausalLM
from
.auto_scale
import
auto_scale_block
,
apply_scale
from
.auto_clip
import
auto_clip_block
,
apply_clip
from
..models
import
MptAWQForCausalLM
__all__
=
[
"run_awq"
]
def
get_named_linears
(
module
):
return
{
name
:
m
for
name
,
m
in
module
.
named_modules
()
if
isinstance
(
m
,
nn
.
Linear
)}
def
get_blocks
(
model
):
if
isinstance
(
model
,
LlamaForCausalLM
):
layers
=
model
.
model
.
layers
elif
isinstance
(
model
,
OPTForCausalLM
):
layers
=
model
.
model
.
decoder
.
layers
elif
isinstance
(
model
,
BloomForCausalLM
):
layers
=
model
.
transformer
.
h
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
layers
=
MptAWQForCausalLM
.
get_model_layers
(
model
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
h
else
:
raise
NotImplementedError
(
type
(
model
))
return
layers
def
move_embed
(
model
,
device
):
if
isinstance
(
model
,
LlamaForCausalLM
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
elif
isinstance
(
model
,
OPTForCausalLM
):
model
.
model
.
decoder
.
embed_tokens
=
model
.
model
.
decoder
.
embed_tokens
.
to
(
device
)
model
.
model
.
decoder
.
embed_positions
=
model
.
model
.
decoder
.
embed_positions
.
to
(
device
)
elif
isinstance
(
model
,
BloomForCausalLM
):
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
model
.
transformer
.
word_embeddings_layernorm
=
model
.
transformer
.
word_embeddings_layernorm
.
to
(
device
)
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
MptAWQForCausalLM
.
move_embed
(
model
,
device
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
else
:
raise
NotImplementedError
(
type
(
model
))
@
torch
.
no_grad
()
def
run_awq
(
model
,
enc
,
w_bit
,
q_config
,
n_samples
=
512
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
# some configs for ablation study
calib_data
=
"pileval"
,
):
from
..utils.calib_data
import
get_calib_dataset
from
..utils.module
import
append_str_prefix
,
get_op_name
layers
=
get_blocks
(
model
)
samples
=
get_calib_dataset
(
data
=
calib_data
,
tokenizer
=
enc
,
n_samples
=
n_samples
,
block_size
=
seqlen
)
samples
=
torch
.
cat
(
samples
,
dim
=
0
)
inps
=
[]
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
move_embed
(
model
,
"cuda"
)
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
class
Catcher
(
nn
.
Module
):
def
__init__
(
self
,
module
):
super
().
__init__
()
self
.
module
=
module
def
forward
(
self
,
inp
,
**
kwargs
):
inps
.
append
(
inp
)
layer_kwargs
.
update
(
kwargs
)
raise
ValueError
# early exit to break later inference
# patch layer 0 to catch input and kwargs
layers
[
0
]
=
Catcher
(
layers
[
0
])
try
:
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
move_embed
(
model
,
"cpu"
)
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
awq_results
=
{
"scale"
:
[],
"clip"
:
[],
}
# solve layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
def
cache_input_hook
(
m
,
x
,
y
,
name
,
feat_dict
):
x
=
x
[
0
]
x
=
x
.
detach
().
cpu
()
feat_dict
[
name
].
append
(
x
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
feat_dict
=
input_feat
)))
inps
=
inps
.
to
(
next
(
layer
.
parameters
()).
device
)
# in case multi-gpu
# get output as next layer's input
inps
=
layer
(
inps
,
**
layer_kwargs
)[
0
]
for
h
in
handles
:
h
.
remove
()
# now solve for scaling and clipping
input_feat
=
{
k
:
torch
.
cat
(
v
,
dim
=
0
)
for
k
,
v
in
input_feat
.
items
()}
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
auto_scale
:
# if it applies, we should also modify the input_feat with scales
scales_list
=
auto_scale_block
(
layer
,
layer_kwargs
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,
)
# apply_scale(layer, scales_list, input_feat_dict=input_feat)
apply_scale
(
layers
[
i
],
scales_list
,
input_feat_dict
=
input_feat
)
# append prefix to make names global
awq_results
[
"scale"
]
+=
append_str_prefix
(
scales_list
,
get_op_name
(
model
,
layer
)
+
"."
)
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
if
mse_range
:
clip_list
=
auto_clip_block
(
layer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
input_feat
=
input_feat
,)
apply_clip
(
layer
,
clip_list
)
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
del
input_feat
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
return
awq_results
def
apply_awq
(
model
,
awq_results
):
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
awq/quantize/quantizer.py
View file @
290f45e7
...
@@ -4,7 +4,6 @@ from tqdm import tqdm
...
@@ -4,7 +4,6 @@ from tqdm import tqdm
import
gc
import
gc
from
.qmodule
import
ScaledActivation
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
..utils.module
import
set_op_by_name
from
..models
import
MptAWQForCausalLM
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
...
@@ -30,7 +29,7 @@ def scale_activations(module):
...
@@ -30,7 +29,7 @@ def scale_activations(module):
return
return
# get activation scale
# get activation scale
scale_dict
=
MptAWQForCausalLM
.
get_act_for_scaling
(
module
)
scale_dict
=
MptAWQForCausalLM
()
.
get_act_for_scaling
(
module
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
dtype
,
device
=
device
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
dtype
,
device
=
device
)
# scale activation
# scale activation
...
...
awq/utils/module.py
View file @
290f45e7
import
torch.nn
as
nn
def
get_named_linears
(
module
):
return
{
name
:
m
for
name
,
m
in
module
.
named_modules
()
if
isinstance
(
m
,
nn
.
Linear
)}
def
get_op_by_name
(
module
,
op_name
):
def
get_op_by_name
(
module
,
op_name
):
# get the op by its name relative to the module
# get the op by its name relative to the module
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment