Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
934ad336
Commit
934ad336
authored
Aug 17, 2023
by
Casper Hansen
Browse files
Implement argparse and perplexity
parent
5d4ab5dc
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
57 additions
and
17 deletions
+57
-17
awq/entry.py
awq/entry.py
+49
-12
awq/models/base.py
awq/models/base.py
+6
-3
awq/utils/lm_eval_adaptor.py
awq/utils/lm_eval_adaptor.py
+2
-2
No files found.
awq/entry.py
View file @
934ad336
import
os
import
torch
import
argparse
from
lm_eval
import
evaluator
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
max_memory
=
[
v
.
split
(
':'
)
for
v
in
(
None
or
[])]
max_memory
=
{(
int
(
k
)
if
k
.
isdigit
()
else
k
):
v
for
k
,
v
in
max_memory
}
def
get_awq_model
(
model
):
from
awq.models
import
MptAWQForCausalLM
...
...
@@ -27,33 +27,70 @@ def load_unquantized(model_path):
return
model
,
tokenizer
def
load_quantized
(
model_path
):
awq_model
=
get_awq_model
(
model
)
def
load_search_result_into_memory
(
model
,
search_path
):
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
apply_scale
(
model
,
awq_results
[
"scale"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
def
run_search
(
model
,
dump_path
):
def
run_search
(
model
_path
,
dump_path
,
w_bit
,
q_config
):
model
,
tokenizer
=
load_unquantized
(
model_path
)
awq_model
=
get_awq_model
(
model
)
awq_results
=
awq_model
.
quantize
(
model
,
tokenizer
,
w_bit
=
4
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
awq_results
=
awq_model
.
quantize
(
model
,
tokenizer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
dump_path
)
def
run_quant
(
model
,
search_path
,
dump_path
):
model
,
tokenizer
=
load_unquantized
(
model_path
)
def
run_quant
(
model
_path
,
search_path
,
dump_path
,
w_bit
,
q_config
,
device
):
model
,
tokenizer
=
load_unquantized
(
model_path
,
device
)
load_search_result_into_memory
(
model
,
search_path
)
awq_model
=
get_awq_model
(
model
)
awq_model
.
quantize
(
model
,
w_bit
=
4
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
awq_model
.
quantize
(
model
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
dump_path
)
model_path
=
"./mpt-7b-8k-chat"
search_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
quant_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
}
def
run_perplexity
(
model_path
,
device
):
model
,
tokenizer
=
load_unquantized
(
model_path
)
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
results
=
evaluator
.
simple_evaluate
(
model
=
lm_eval_model
,
tasks
=
[
'wikitext'
],
batch_size
=
1
,
no_cache
=
True
,
num_fewshot
=
0
,
)
print
(
evaluator
.
make_table
(
results
))
if
__name__
==
'__main__'
:
parser
=
argparse
.
ArgumentParser
()
parser
.
add_argument
(
'--entry_type'
,
type
=
str
,
help
=
'The type of task to run (search|quant|perplexity)'
)
parser
.
add_argument
(
'--model_path'
,
type
=
str
,
help
=
'Path to hf model'
)
parser
.
add_argument
(
'--search_path'
,
type
=
str
,
help
=
'Path to save/load AWQ search results'
)
parser
.
add_argument
(
'--quant_path'
,
type
=
str
,
help
=
'Path to save/load AWQ quant model'
)
parser
.
add_argument
(
'--device'
,
type
=
str
,
default
=
'cuda:0'
,
help
=
'Device to load model to'
)
parser
.
add_argument
(
'--w_bit'
,
type
=
int
,
default
=
4
)
parser
.
add_argument
(
'--q_group_size'
,
type
=
int
,
default
=
128
)
args
=
parser
.
parse_args
()
args
.
model_path
=
"./mpt-7b-8k-chat"
args
.
search_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-awq-search.pt"
args
.
quant_path
=
"./mpt-7b-8k-chat/mpt-7b-8k-chat-w4-g128.pt"
q_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
args
.
q_group_size
}
if
args
.
entry_type
==
'search'
:
run_search
(
args
.
model_path
,
args
.
search_path
,
args
.
w_bit
,
q_config
)
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
args
.
w_bit
,
q_config
)
elif
args
.
entry_type
==
'perplexity'
:
run_perplexity
(
args
.
model_path
,
args
.
device
)
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
\ No newline at end of file
awq/models/base.py
View file @
934ad336
...
...
@@ -19,13 +19,16 @@ class BaseAWQForCausalLM:
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
calib_data
=
"pileval"
,
init_only
=
False
):
search_result
=
None
if
run_search
:
self
.
_awq_search
(
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
search_result
=
self
.
_awq_search
(
model
,
tokenizer
,
w_bit
,
q_config
,
n_samples
=
n_samples
,
seqlen
=
seqlen
,
auto_scale
=
auto_scale
,
mse_range
=
mse_range
,
calib_data
=
calib_data
)
if
run_quant
:
self
.
_awq_quant
(
model
,
w_bit
,
q_config
,
init_only
)
return
search_result
def
_awq_quant
(
self
,
model
,
w_bit
,
q_config
,
init_only
):
...
...
@@ -118,7 +121,7 @@ class BaseAWQForCausalLM:
}
# Run AWQ search layer by layer
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Search
:
"
):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Search"
):
layer
=
layers
[
i
]
layer
=
layer
.
cuda
()
named_linears
=
get_named_linears
(
layer
)
...
...
awq/utils/lm_eval_adaptor.py
View file @
934ad336
...
...
@@ -6,13 +6,13 @@ import fnmatch
class
LMEvalAdaptor
(
BaseLM
):
def
__init__
(
self
,
model_name
,
model
,
tokenizer
,
batch_size
=
1
,
max_length
=-
1
):
def
__init__
(
self
,
model_name
,
model
,
tokenizer
,
device
,
batch_size
=
1
,
max_length
=-
1
):
super
().
__init__
()
assert
isinstance
(
batch_size
,
int
)
self
.
model_name
=
model_name
self
.
model
=
model
self
.
model
=
model
.
to
(
device
)
self
.
model
.
eval
()
self
.
tokenizer
=
tokenizer
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment