Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ba01560f
Commit
ba01560f
authored
Jul 11, 2023
by
Abhinav Kulkarni
Browse files
Memory optimization
parent
d2a10bd9
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
105 additions
and
16 deletions
+105
-16
awq/entry.py
awq/entry.py
+26
-7
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+30
-9
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+1
-0
awq/quantize/qmodule.py
awq/quantize/qmodule.py
+4
-0
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+1
-0
awq/utils/utils.py
awq/utils/utils.py
+43
-0
No files found.
awq/entry.py
View file @
ba01560f
...
@@ -4,11 +4,12 @@ import torch
...
@@ -4,11 +4,12 @@ import torch
import
argparse
import
argparse
import
os
import
os
import
json
import
json
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_
and_dispatch
from
accelerate
import
init_empty_weights
,
infer_auto_device_map
,
dispatch_model
,
load_checkpoint_
in_model
from
awq.utils.parallel
import
auto_parallel
from
awq.utils.parallel
import
auto_parallel
from
awq.quantize.pre_quant
import
run_awq
,
apply_awq
from
awq.quantize.pre_quant
import
run_awq
,
apply_awq
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.quantize.quantizer
import
pseudo_quantize_model_weight
,
real_quantize_model_weight
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.utils
import
simple_dispatch_model
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
...
@@ -80,16 +81,32 @@ def build_model_and_enc(model_path):
...
@@ -80,16 +81,32 @@ def build_model_and_enc(model_path):
if
args
.
load_quant
:
# directly load quantized weights
if
args
.
load_quant
:
# directly load quantized weights
print
(
"Loading pre-computed quantized weights..."
)
print
(
"Loading pre-computed quantized weights..."
)
with
init_empty_weights
():
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_
pretrained
(
model_path
,
config
=
config
,
model
=
AutoModelForCausalLM
.
from_
config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
real_quantize_model_weight
(
real_quantize_model_weight
(
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
init_only
=
True
)
model
=
load_checkpoint_and_dispatch
(
model
,
args
.
load_quant
,
device_map
=
"balanced"
,
model
.
tie_weights
()
# TODO: can we remove this?
# Infer device map
kwargs
=
{
"max_memory"
:
max_memory
}
if
len
(
max_memory
)
else
{}
device_map
=
infer_auto_device_map
(
model
,
no_split_module_classes
=
[
no_split_module_classes
=
[
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
]
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
],
**
kwargs
)
# Load checkpoint in the model
load_checkpoint_in_model
(
model
,
checkpoint
=
args
.
load_quant
,
device_map
=
device_map
,
offload_state_dict
=
True
,
)
)
# Dispatch model
model
=
simple_dispatch_model
(
model
,
device_map
=
device_map
)
model
.
eval
()
else
:
# fp16 to quantized
else
:
# fp16 to quantized
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
# Init model on CPU:
# Init model on CPU:
...
@@ -97,6 +114,8 @@ def build_model_and_enc(model_path):
...
@@ -97,6 +114,8 @@ def build_model_and_enc(model_path):
model
=
AutoModelForCausalLM
.
from_pretrained
(
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
if
args
.
run_awq
:
if
args
.
run_awq
:
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
...
...
awq/quantize/auto_scale.py
View file @
ba01560f
import
gc
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
...
@@ -112,6 +113,7 @@ def auto_scale_block(module, module_kwargs,
...
@@ -112,6 +113,7 @@ def auto_scale_block(module, module_kwargs,
weight
,
q_group_size
=
q_config
.
get
(
"q_group_size"
,
-
1
))
weight
,
q_group_size
=
q_config
.
get
(
"q_group_size"
,
-
1
))
# Clear GPU memory
# Clear GPU memory
del
weight
del
weight
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
x
=
x
.
to
(
next
(
block
.
parameters
()).
device
)
...
@@ -129,8 +131,6 @@ def auto_scale_block(module, module_kwargs,
...
@@ -129,8 +131,6 @@ def auto_scale_block(module, module_kwargs,
n_grid
=
20
n_grid
=
20
history
=
[]
history
=
[]
# Clear GPU memory
torch
.
cuda
.
empty_cache
()
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
block
.
state_dict
().
items
()}
org_sd
=
{
k
:
v
.
cpu
()
for
k
,
v
in
block
.
state_dict
().
items
()}
for
ratio
in
range
(
n_grid
):
for
ratio
in
range
(
n_grid
):
ratio
=
ratio
*
1
/
n_grid
ratio
=
ratio
*
1
/
n_grid
...
@@ -169,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
...
@@ -169,6 +169,7 @@ def auto_scale_block(module, module_kwargs,
module2inspect
=
layers
[
0
]
module2inspect
=
layers
[
0
]
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
_search_module_scale
(
module2inspect
,
layers
,
inp
,
kwargs
)
scales
=
scales
.
detach
().
cpu
()
# prev_op_name, [layer_name], scale
# prev_op_name, [layer_name], scale
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
return
(
get_op_name
(
module
,
prev_op
),
tuple
([
get_op_name
(
module
,
m
)
for
m
in
layers
]),
scales
)
...
@@ -302,13 +303,31 @@ def auto_scale_block(module, module_kwargs,
...
@@ -302,13 +303,31 @@ def auto_scale_block(module, module_kwargs,
))
))
"""
"""
# fc1, as long as it is scaled, everything is screwed up
# fc1, as long as it is scaled, everything is screwed up
scales_list
.
append
(
_auto_get_scale
(
if
"falcon-7b"
in
str
(
module
.
__class__
).
lower
():
prev_op
=
module
.
input_layernorm
,
scales_list
.
append
(
_auto_get_scale
(
layers
=
[
module
.
mlp
.
dense_h_to_4h
,
module
.
self_attention
.
query_key_value
],
prev_op
=
module
.
input_layernorm
,
inp
=
input_feat
[
'self_attention.query_key_value'
],
layers
=
[
module
.
mlp
.
dense_h_to_4h
,
module
.
self_attention
.
query_key_value
],
module2inspect
=
module
,
inp
=
input_feat
[
'self_attention.query_key_value'
],
kwargs
=
module_kwargs
,
module2inspect
=
module
,
))
kwargs
=
module_kwargs
,
))
elif
"falcon-40b"
in
str
(
module
.
__class__
).
lower
():
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_attn
,
layers
=
[
module
.
self_attention
.
query_key_value
],
inp
=
input_feat
[
'self_attention.query_key_value'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ln_mlp
,
layers
=
[
module
.
mlp
.
dense_h_to_4h
],
inp
=
input_feat
[
'mlp.dense_h_to_4h'
],
module2inspect
=
module
,
kwargs
=
module_kwargs
,
))
else
:
raise
NotImplementedError
(
"Unknown Falcon architecture, currently only falcon-7b and falcon-40b are supported"
)
# fc2
# fc2
scales_list
.
append
(
_auto_get_scale
(
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
mlp
.
act
,
prev_op
=
module
.
mlp
.
act
,
...
@@ -329,6 +348,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -329,6 +348,7 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op
.
cuda
()
prev_op
.
cuda
()
for
layer
in
layers
:
for
layer
in
layers
:
layer
.
cuda
()
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
if
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
assert
len
(
layers
)
==
1
...
@@ -352,3 +372,4 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -352,3 +372,4 @@ def apply_scale(module, scales_list, input_feat_dict=None):
prev_op
.
cpu
()
prev_op
.
cpu
()
for
layer
in
layers
:
for
layer
in
layers
:
layer
.
cpu
()
layer
.
cpu
()
scales
.
cpu
()
awq/quantize/pre_quant.py
View file @
ba01560f
...
@@ -95,6 +95,7 @@ def run_awq(
...
@@ -95,6 +95,7 @@ def run_awq(
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
model
(
samples
.
to
(
next
(
model
.
parameters
()).
device
))
except
ValueError
:
# work with early exit
except
ValueError
:
# work with early exit
pass
pass
del
samples
layers
[
0
]
=
layers
[
0
].
module
# restore
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
inps
=
inps
[
0
]
...
...
awq/quantize/qmodule.py
View file @
ba01560f
...
@@ -93,3 +93,7 @@ class WQLinear(nn.Module):
...
@@ -93,3 +93,7 @@ class WQLinear(nn.Module):
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
out
=
out
+
self
.
bias
if
self
.
bias
is
not
None
else
out
return
out
.
reshape
(
out_shape
)
return
out
.
reshape
(
out_shape
)
def
extra_repr
(
self
)
->
str
:
return
'in_features={}, out_features={}, bias={}, w_bit={}, group_size={}'
.
format
(
self
.
in_features
,
self
.
out_features
,
self
.
bias
is
not
None
,
self
.
w_bit
,
self
.
group_size
)
awq/quantize/quantizer.py
View file @
ba01560f
...
@@ -130,6 +130,7 @@ def real_quantize_model_weight(
...
@@ -130,6 +130,7 @@ def real_quantize_model_weight(
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
,
w_bit
,
q_config
[
'q_group_size'
],
False
,
scales
,
zeros
)
module
.
cpu
()
module
.
cpu
()
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
gc
.
collect
()
awq/utils/utils.py
0 → 100644
View file @
ba01560f
import
torch
import
accelerate
def
get_module_by_name_suffix
(
model
,
module_name
:
str
):
for
name
,
module
in
model
.
named_modules
():
if
name
.
endswith
(
module_name
):
return
module
def
simple_dispatch_model
(
model
,
device_map
):
from
accelerate.hooks
import
add_hook_to_module
,
AlignDevicesHook
if
""
in
device_map
:
d
=
device_map
[
""
]
model
=
model
.
to
(
torch
.
device
(
d
))
model
.
hf_device_map
=
device_map
return
model
tied_params
=
accelerate
.
utils
.
modeling
.
find_tied_parameters
(
model
)
if
set
(
device_map
.
values
())
==
{
"cpu"
}
or
set
(
device_map
.
values
())
==
{
"cpu"
,
"disk"
}:
main_device
=
"cpu"
else
:
main_device
=
[
d
for
d
in
device_map
.
values
()
if
d
not
in
[
"cpu"
,
"disk"
]][
0
]
cpu_offload_group
=
[(
n
,
d
)
for
n
,
d
in
device_map
.
items
()
if
d
==
"cpu"
]
prev_hook
=
None
for
idx
,
(
n
,
d
)
in
enumerate
(
cpu_offload_group
):
m
=
get_module_by_name_suffix
(
model
,
n
)
_
,
prev_hook
=
accelerate
.
cpu_offload_with_hook
(
m
,
execution_device
=
main_device
,
prev_module_hook
=
prev_hook
)
# set first cpu offload module's prev_module_hook to the last cpu offload module's hook
if
len
(
cpu_offload_group
)
>
1
:
get_module_by_name_suffix
(
model
,
cpu_offload_group
[
0
][
0
]).
_hf_hook
.
prev_module_hook
=
prev_hook
for
n
,
d
in
device_map
.
items
():
m
=
get_module_by_name_suffix
(
model
,
n
)
if
d
!=
"cpu"
:
d
=
torch
.
device
(
d
)
hook
=
AlignDevicesHook
(
d
,
io_same_device
=
True
,
place_submodules
=
True
)
add_hook_to_module
(
m
,
hook
)
accelerate
.
utils
.
modeling
.
retie_parameters
(
model
,
tied_params
)
model
.
hf_device_map
=
device_map
return
model
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment