Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e09dc751
Commit
e09dc751
authored
Aug 17, 2023
by
Casper Hansen
Browse files
Create AutoAWQForCausalLM and load quantized models with from_quantized
parent
934ad336
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
89 additions
and
11 deletions
+89
-11
awq/entry.py
awq/entry.py
+9
-5
awq/models/auto.py
awq/models/auto.py
+30
-0
awq/models/base.py
awq/models/base.py
+47
-6
awq/models/mpt.py
awq/models/mpt.py
+3
-0
No files found.
awq/entry.py
View file @
e09dc751
...
@@ -27,8 +27,12 @@ def load_unquantized(model_path):
...
@@ -27,8 +27,12 @@ def load_unquantized(model_path):
return
model
,
tokenizer
return
model
,
tokenizer
def
load_quantized
(
model_path
):
def
load_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
):
awq_model
=
get_awq_model
(
model
)
from
awq.models.auto
import
AutoAWQForCausalLM
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
return
model
,
tokenizer
def
load_search_result_into_memory
(
model
,
search_path
):
def
load_search_result_into_memory
(
model
,
search_path
):
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
...
@@ -56,8 +60,8 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
...
@@ -56,8 +60,8 @@ def run_quant(model_path, search_path, dump_path, w_bit, q_config, device):
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
dump_path
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
dump_path
)
def
run_perplexity
(
model_path
,
device
):
def
run_perplexity
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
):
model
,
tokenizer
=
load_
un
quantized
(
model_path
)
model
,
tokenizer
=
load_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
...
@@ -91,6 +95,6 @@ if __name__ == '__main__':
...
@@ -91,6 +95,6 @@ if __name__ == '__main__':
elif
args
.
entry_type
==
'quant'
:
elif
args
.
entry_type
==
'quant'
:
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
args
.
w_bit
,
q_config
)
run_quant
(
args
.
model_path
,
args
.
search_path
,
args
.
quant_path
,
args
.
w_bit
,
q_config
)
elif
args
.
entry_type
==
'perplexity'
:
elif
args
.
entry_type
==
'perplexity'
:
run_perplexity
(
args
.
model_path
,
args
.
device
)
run_perplexity
(
args
.
model_path
,
args
.
quant_path
,
args
.
w_bit
,
q_config
,
args
.
device
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|perplexity)'
)
\ No newline at end of file
awq/models/auto.py
0 → 100644
View file @
e09dc751
from
transformers
import
AutoConfig
from
awq.models
import
MptAWQForCausalLM
AWQ_CAUSAL_LM_MODEL_MAP
=
{
"mpt"
:
MptAWQForCausalLM
,
}
def
check_and_get_model_type
(
model_dir
,
trust_remote_code
=
True
):
config
=
AutoConfig
.
from_pretrained
(
model_dir
,
trust_remote_code
=
trust_remote_code
)
if
config
.
model_type
not
in
AWQ_CAUSAL_LM_MODEL_MAP
.
keys
():
raise
TypeError
(
f
"
{
config
.
model_type
}
isn't supported yet."
)
model_type
=
config
.
model_type
return
model_type
class
AutoAWQForCausalLM
:
def
__init__
(
self
):
raise
EnvironmentError
(
'You must instantiate AutoAWQForCausalLM with
\n
'
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
@
classmethod
def
from_pretrained
():
pass
@
classmethod
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
model_type
=
check_and_get_model_type
(
model_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
]().
from_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
\ No newline at end of file
awq/models/base.py
View file @
e09dc751
...
@@ -6,15 +6,15 @@ from tqdm import tqdm
...
@@ -6,15 +6,15 @@ from tqdm import tqdm
from
collections
import
defaultdict
from
collections
import
defaultdict
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
transformers
import
AutoModelForCausalLM
,
AutoConfig
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_clip
import
auto_clip_block
,
apply_clip
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
awq.quantize.auto_scale
import
auto_scale_block
,
apply_scale
from
accelerate
import
init_empty_weights
,
load_checkpoint_and_dispatch
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.quantize.quantizer
import
pseudo_quantize_tensor
from
awq.quantize.qmodule
import
WQLinear
,
ScaledActivation
class
BaseAWQForCausalLM
:
class
BaseAWQForCausalLM
:
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
...
@@ -186,5 +186,46 @@ class BaseAWQForCausalLM:
...
@@ -186,5 +186,46 @@ class BaseAWQForCausalLM:
def
from_pretrained
():
def
from_pretrained
():
pass
pass
def
from_quantized
():
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
pass
# Load config
\ No newline at end of file
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
# Initialize layers
assert
q_config
[
"zero_point"
],
"We only support zero_point quantization now."
layers
=
self
.
get_model_layers
(
model
)
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"Replacing layers..."
):
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
self
.
_scale_activations
(
layer
)
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
module
,
w_bit
,
q_config
[
'q_group_size'
],
True
)
q_linear
.
to
(
next
(
layer
.
parameters
()).
device
)
set_op_by_name
(
layer
,
name
,
q_linear
)
torch
.
cuda
.
empty_cache
()
gc
.
collect
()
model
.
tie_weights
()
model
=
load_checkpoint_and_dispatch
(
model
,
quant_path
,
device_map
=
"balanced"
)
return
model
def
_scale_activations
(
self
,
layer
):
act_function
=
self
.
get_act_from_layer
(
layer
)
if
act_function
is
not
None
and
not
isinstance
(
act_function
,
ScaledActivation
):
param
=
next
(
layer
.
parameters
())
# get activation scale
scale_dict
=
self
.
get_act_for_scaling
(
layer
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
param
.
dtype
,
device
=
param
.
device
)
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
layer
,
scale_dict
[
'scale_name'
],
scaled_act
)
\ No newline at end of file
awq/models/mpt.py
View file @
e09dc751
...
@@ -42,6 +42,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -42,6 +42,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return
layers
return
layers
def
get_act_from_layer
(
self
,
layer
):
return
layer
.
ffn
.
act
def
get_act_for_scaling
(
self
,
module
):
def
get_act_for_scaling
(
self
,
module
):
return
dict
(
return
dict
(
scale_name
=
"ffn.act"
,
scale_name
=
"ffn.act"
,
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment