Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
34f1faff
Commit
34f1faff
authored
Jun 27, 2023
by
Jiaming Tang
Browse files
[Major] Add CPU offloading support for run_awq
parent
a293e16f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
52 additions
and
10 deletions
+52
-10
awq/entry.py
awq/entry.py
+28
-10
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+24
-0
No files found.
awq/entry.py
View file @
34f1faff
...
@@ -83,17 +83,22 @@ def build_model_and_enc(model_path):
...
@@ -83,17 +83,22 @@ def build_model_and_enc(model_path):
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
]
"OPTDecoderLayer"
,
"LlamaDecoderLayer"
,
"BloomBlock"
,
"MPTBlock"
,
"DecoderLayer"
]
)
)
else
:
# fp16 to quantized
else
:
# fp16 to quantized
kwargs
=
{
"device_map"
:
"balanced"
,
"torch_dtype"
:
torch
.
float16
}
args
.
run_awq
&=
not
args
.
load_awq
# if load_awq, no need to run awq
model
=
AutoModelForCausalLM
.
from_pretrained
(
if
args
.
run_awq
:
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
assert
args
.
dump_awq
,
"Please save the awq results with --dump_awq"
if
args
.
load_awq
:
# Init model on CPU
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
def
skip
(
*
args
,
**
kwargs
):
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
pass
apply_awq
(
model
,
awq_results
)
torch
.
nn
.
init
.
kaiming_normal_
=
skip
torch
.
nn
.
init
.
kaiming_uniform_
=
skip
torch
.
nn
.
init
.
uniform_
=
skip
torch
.
nn
.
init
.
normal_
=
skip
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
torch_dtype
=
torch
.
float16
)
elif
args
.
run_awq
:
awq_results
=
run_awq
(
awq_results
=
run_awq
(
model
,
enc
,
model
,
enc
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
w_bit
=
args
.
w_bit
,
q_config
=
q_config
,
...
@@ -102,6 +107,19 @@ def build_model_and_enc(model_path):
...
@@ -102,6 +107,19 @@ def build_model_and_enc(model_path):
if
args
.
dump_awq
:
if
args
.
dump_awq
:
torch
.
save
(
awq_results
,
args
.
dump_awq
)
torch
.
save
(
awq_results
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
print
(
"AWQ results saved at"
,
args
.
dump_awq
)
exit
(
0
)
else
:
# Inference with fake quant
# Init model on GPUs:
kwargs
=
{
"device_map"
:
"balanced"
,
"torch_dtype"
:
torch
.
float16
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
if
args
.
load_awq
:
print
(
"Loading pre-computed AWQ results from"
,
args
.
load_awq
)
awq_results
=
torch
.
load
(
args
.
load_awq
,
map_location
=
"cpu"
)
apply_awq
(
model
,
awq_results
)
# weight quantization
# weight quantization
if
args
.
w_bit
is
not
None
:
if
args
.
w_bit
is
not
None
:
...
...
awq/quantize/pre_quant.py
View file @
34f1faff
...
@@ -34,6 +34,22 @@ def get_blocks(model):
...
@@ -34,6 +34,22 @@ def get_blocks(model):
raise
NotImplementedError
(
type
(
model
))
raise
NotImplementedError
(
type
(
model
))
return
layers
return
layers
def
move_embed
(
model
,
device
):
if
isinstance
(
model
,
LlamaForCausalLM
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
elif
isinstance
(
model
,
OPTForCausalLM
):
model
.
model
.
decoder
.
embed_tokens
=
model
.
model
.
decoder
.
embed_tokens
.
to
(
device
)
model
.
model
.
decoder
.
embed_positions
=
model
.
model
.
decoder
.
embed_positions
.
to
(
device
)
elif
isinstance
(
model
,
BloomForCausalLM
):
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
model
.
transformer
.
word_embeddings_layernorm
=
model
.
transformer
.
word_embeddings_layernorm
.
to
(
device
)
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
else
:
raise
NotImplementedError
(
type
(
model
))
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
run_awq
(
def
run_awq
(
...
@@ -57,6 +73,9 @@ def run_awq(
...
@@ -57,6 +73,9 @@ def run_awq(
inps
=
[]
inps
=
[]
layer_kwargs
=
{}
layer_kwargs
=
{}
layers
[
0
]
=
layers
[
0
].
cuda
()
move_embed
(
model
,
"cuda"
)
# get input and kwargs to layer 0
# get input and kwargs to layer 0
# with_kwargs is only supported in PyTorch 2.0
# with_kwargs is only supported in PyTorch 2.0
# use this Catcher hack for now
# use this Catcher hack for now
...
@@ -79,6 +98,9 @@ def run_awq(
...
@@ -79,6 +98,9 @@ def run_awq(
layers
[
0
]
=
layers
[
0
].
module
# restore
layers
[
0
]
=
layers
[
0
].
module
# restore
inps
=
inps
[
0
]
inps
=
inps
[
0
]
layers
[
0
]
=
layers
[
0
].
cpu
()
move_embed
(
model
,
"cpu"
)
gc
.
collect
()
gc
.
collect
()
torch
.
cuda
.
empty_cache
()
torch
.
cuda
.
empty_cache
()
...
@@ -90,6 +112,7 @@ def run_awq(
...
@@ -90,6 +112,7 @@ def run_awq(
# solve layer by layer
# solve layer by layer
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
for
i
in
tqdm
.
tqdm
(
range
(
len
(
layers
)),
desc
=
"Running AWQ..."
):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
# firstly, get input features of all linear layers
# firstly, get input features of all linear layers
...
@@ -131,6 +154,7 @@ def run_awq(
...
@@ -131,6 +154,7 @@ def run_awq(
# append prefix to make names global
# append prefix to make names global
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
awq_results
[
"clip"
]
+=
append_str_prefix
(
clip_list
,
get_op_name
(
model
,
layer
)
+
"."
)
layer
=
layer
.
cpu
()
# Haotian: check activation replacement
# Haotian: check activation replacement
del
input_feat
del
input_feat
gc
.
collect
()
gc
.
collect
()
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment