Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
d430694b
Commit
d430694b
authored
Aug 18, 2023
by
Casper Hansen
Browse files
save_quantized working in all cases, from_quantized adapted.
parent
af4e0622
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
49 additions
and
26 deletions
+49
-26
awq/entry.py
awq/entry.py
+13
-9
awq/models/auto.py
awq/models/auto.py
+3
-3
awq/models/base.py
awq/models/base.py
+33
-14
No files found.
awq/entry.py
View file @
d430694b
...
@@ -29,6 +29,9 @@ def run_search(model_path, dump_path, w_bit, q_config):
...
@@ -29,6 +29,9 @@ def run_search(model_path, dump_path, w_bit, q_config):
# Save search results
# Save search results
model
.
save_quantized
(
dump_path
)
model
.
save_quantized
(
dump_path
)
# Save tokenizer
tokenizer
.
save_pretrained
(
dump_path
)
def
run_quant
(
model_path
,
search_path
,
dump_path
,
w_bit
,
q_config
):
def
run_quant
(
model_path
,
search_path
,
dump_path
,
w_bit
,
q_config
):
"""
"""
Step 2/2: Use the search results to quantize model weights
Step 2/2: Use the search results to quantize model weights
...
@@ -43,16 +46,16 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
...
@@ -43,16 +46,16 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config):
# Save quantized model
# Save quantized model
model
.
save_quantized
(
dump_path
)
model
.
save_quantized
(
dump_path
)
def
run_perplexity
(
model
_path
,
quant_
path
,
w_bit
,
q_config
,
device
):
def
run_perplexity
(
quant
_path
,
quant_
file
,
w_bit
,
q_config
,
device
):
"""
"""
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
Post quantization: Evaluate perplexity on wikitext with EleutherAI Evaluation Harness
"""
"""
# Load model
# Load model
model
=
AutoAWQForCausalLM
.
from_quantized
(
model
_path
,
quant_
path
,
w_bit
,
q_config
,
device
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
quant
_path
,
quant_
file
,
w_bit
,
q_config
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model
_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
quant
_path
,
trust_remote_code
=
True
)
# Load adapter
# Load adapter
lm_eval_model
=
LMEvalAdaptor
(
model
_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
lm_eval_model
=
LMEvalAdaptor
(
quant
_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
# Evaluate perplexity of quantized model
# Evaluate perplexity of quantized model
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
...
@@ -68,15 +71,16 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device):
...
@@ -68,15 +71,16 @@ def run_perplexity(model_path, quant_path, w_bit, q_config, device):
if
__name__
==
'__main__'
:
if
__name__
==
'__main__'
:
"""
"""
python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type search --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/
pytorch_model.bin
--quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type quant --model_path mosaicml/mpt-7b-8k-chat --search_path mpt-7b-8k-chat-awq/
awq_model_search_result.pt
--quant_path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --
model
_path
mosaicml/
mpt-7b-8k-chat --quant_
path mpt-7b-8k-chat-awq
python -m awq.entry --entry_type perplexity --
quant
_path mpt-7b-8k-chat
-awq
--quant_
file awq_model_w4_g128.pt
"""
"""
parser
=
argparse
.
ArgumentParser
()
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|perplexity)'
)
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|perplexity)'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to save/load AWQ quant model'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to AWQ model directory'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'balanced'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
'--quant_file'
,
type
=
str
,
help
=
'Path to quantized AWQ model file'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda:0'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
args
=
parser
.
parse_args
()
...
@@ -88,6 +92,6 @@ if __name__ == '__main__':
...
@@ -88,6 +92,6 @@ if __name__ == '__main__':
elif
args
.
entry_type
==
'quant'
:
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
args
.
w_bit
,
q_config
)
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
args
.
w_bit
,
q_config
)
elif
args
.
entry_type
==
'perplexity'
:
elif
args
.
entry_type
==
'perplexity'
:
run_perplexity
(
args
.
model
_path
,
args
.
quant_
path
,
args
.
w_bit
,
q_config
,
args
.
device
)
run_perplexity
(
args
.
quant
_path
,
args
.
quant_
file
,
args
.
w_bit
,
q_config
,
args
.
device
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
\ No newline at end of file
awq/models/auto.py
View file @
d430694b
...
@@ -29,11 +29,11 @@ class AutoAWQForCausalLM:
...
@@ -29,11 +29,11 @@ class AutoAWQForCausalLM:
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model
_path
,
quant_file
,
w_bit
=
4
,
q_config
=
{},
def
from_quantized
(
self
,
quant
_path
,
quant_file
name
,
w_bit
=
4
,
q_config
=
{},
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
device
=
'balanced'
,
trust_remote_code
=
True
)
->
BaseAWQForCausalLM
:
model_type
=
check_and_get_model_type
(
model
_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
quant
_path
,
trust_remote_code
)
q_config
=
q_config
if
q_config
else
self
.
default_q_config
q_config
=
q_config
if
q_config
else
self
.
default_q_config
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
model
_path
,
model_type
,
quant_file
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
trust_remote_code
quant
_path
,
model_type
,
quant_file
name
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
trust_remote_code
)
)
\ No newline at end of file
awq/models/base.py
View file @
d430694b
...
@@ -22,6 +22,12 @@ class BaseAWQForCausalLM:
...
@@ -22,6 +22,12 @@ class BaseAWQForCausalLM:
self
.
model_type
:
str
=
model_type
self
.
model_type
:
str
=
model_type
self
.
is_quantized
:
bool
=
is_quantized
self
.
is_quantized
:
bool
=
is_quantized
self
.
search_result
=
None
self
.
search_result
=
None
def
to
(
self
,
device
:
str
):
return
self
.
model
.
to
(
device
)
def
forward
(
self
,
*
args
,
**
kwargs
):
return
self
.
model
(
*
args
,
**
kwargs
)
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
...
@@ -170,27 +176,37 @@ class BaseAWQForCausalLM:
...
@@ -170,27 +176,37 @@ class BaseAWQForCausalLM:
return
awq_results
return
awq_results
def
save_quantized
(
self
,
save_dir
):
def
save_quantized
(
self
,
save_dir
):
def
_save_files
(
save_dir
,
model_name
,
model
):
class
EmptyModule
(
nn
.
Module
):
def
__init__
(
self
):
super
(
EmptyModule
,
self
).
__init__
()
def
forward
(
self
,
x
):
return
x
# Save model fiels without search results
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
EmptyModule
().
state_dict
())
# Remove empty module
os
.
remove
(
f
'
{
save_dir
}
/pytorch_model.bin'
)
# Save search results
torch
.
save
(
model
,
f
'
{
save_dir
}
/
{
model_name
}
'
)
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
save_dir
=
save_dir
[:
-
1
]
if
save_dir
[
-
1
]
==
'/'
else
save_dir
# Save model
# Save model
if
self
.
search_result
is
None
:
if
self
.
search_result
is
None
:
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
self
.
model
.
state_dict
())
model_name
=
'awq_model_w4_g128.pt'
_save_files
(
save_dir
,
model_name
,
self
.
model
.
state_dict
())
else
:
else
:
self
.
model
.
save_pretrained
(
save_dir
,
state_dict
=
self
.
search_result
)
# TODO: Rename model name & save quant_config
if
self
.
search_result
is
not
None
:
model_name
=
'awq_model_search_result.pt'
model_name
=
'awq_model_search_result.pt'
else
:
_save_files
(
save_dir
,
model_name
,
self
.
search_result
)
model_name
=
'awq_model_w4_g128.pt'
@
classmethod
@
classmethod
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
trust_remote_code
=
True
):
trust_remote_code
=
True
):
return
self
.
from_quantized
(
return
self
.
from_quantized
(
model_path
,
model_path
,
model_type
,
model_type
,
quant_fil
e
=
''
,
model_filenam
e
=
''
,
device
=
'balanced'
,
device
=
'balanced'
,
torch_dtype
=
torch_dtype
,
torch_dtype
=
torch_dtype
,
trust_remote_code
=
trust_remote_code
,
trust_remote_code
=
trust_remote_code
,
...
@@ -198,11 +214,14 @@ class BaseAWQForCausalLM:
...
@@ -198,11 +214,14 @@ class BaseAWQForCausalLM:
)
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
model_type
,
quant_fil
e
,
w_bit
=
4
,
q_config
=
{},
def
from_quantized
(
self
,
model_path
,
model_type
,
model_filenam
e
,
w_bit
=
4
,
q_config
=
{},
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
is_quantized
=
True
):
device
=
'balanced'
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
,
is_quantized
=
True
):
# Download model
# Download model if path is not a directory
model_path
=
snapshot_download
(
model_path
)
if
not
os
.
path
.
isdir
(
model_path
):
quant_path
=
model_path
+
f
'/
{
quant_file
}
'
if
is_quantized
else
model_path
model_path
=
snapshot_download
(
model_path
)
# TODO: Better naming, model_filename becomes a directory
model_filename
=
model_path
+
f
'/
{
model_filename
}
'
# Load config
# Load config
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
...
@@ -219,7 +238,7 @@ class BaseAWQForCausalLM:
...
@@ -219,7 +238,7 @@ class BaseAWQForCausalLM:
model
.
tie_weights
()
model
.
tie_weights
()
# Load model weights
# Load model weights
model
=
load_checkpoint_and_dispatch
(
model
,
quant_path
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
model
=
load_checkpoint_and_dispatch
(
model
,
model_filename
,
device_map
=
device
,
no_split_module_classes
=
[
self
.
layer_type
])
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
return
self
(
model
,
model_type
,
is_quantized
=
is_quantized
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment