Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
14d198c6
Commit
14d198c6
authored
Aug 17, 2023
by
Casper Hansen
Browse files
Implement from_pretrained. Fix static methods and classmethods
parent
35ac58c7
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
54 additions
and
52 deletions
+54
-52
awq/entry.py
awq/entry.py
+13
-38
awq/models/auto.py
awq/models/auto.py
+8
-4
awq/models/base.py
awq/models/base.py
+23
-5
awq/models/mpt.py
awq/models/mpt.py
+10
-5
No files found.
awq/entry.py
View file @
14d198c6
...
@@ -2,37 +2,12 @@ import os
...
@@ -2,37 +2,12 @@ import os
import
torch
import
torch
import
argparse
import
argparse
from
lm_eval
import
evaluator
from
lm_eval
import
evaluator
from
transformers
import
AutoTokenizer
from
awq.models.auto
import
AutoAWQForCausalLM
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_clip
import
apply_clip
from
awq.quantize.auto_scale
import
apply_scale
from
awq.quantize.auto_scale
import
apply_scale
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
awq.utils.lm_eval_adaptor
import
LMEvalAdaptor
from
transformers
import
AutoModelForCausalLM
,
AutoTokenizer
,
AutoConfig
def
get_awq_model
(
model
):
from
awq.models
import
MptAWQForCausalLM
if
"mpt"
in
str
(
model
.
__class__
).
lower
():
return
MptAWQForCausalLM
()
else
:
raise
NotImplementedError
(
type
(
model
))
def
load_unquantized
(
model_path
):
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
config
.
tokenizer_name
,
trust_remote_code
=
True
)
kwargs
=
{
"torch_dtype"
:
torch
.
float16
,
"low_cpu_mem_usage"
:
True
}
model
=
AutoModelForCausalLM
.
from_pretrained
(
model_path
,
config
=
config
,
trust_remote_code
=
True
,
**
kwargs
)
model
.
eval
()
return
model
,
tokenizer
def
load_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
):
from
awq.models.auto
import
AutoAWQForCausalLM
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
return
model
,
tokenizer
def
load_search_result_into_memory
(
model
,
search_path
):
def
load_search_result_into_memory
(
model
,
search_path
):
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
awq_results
=
torch
.
load
(
search_path
,
map_location
=
"cpu"
)
...
@@ -41,27 +16,27 @@ def load_search_result_into_memory(model, search_path):
...
@@ -41,27 +16,27 @@ def load_search_result_into_memory(model, search_path):
apply_clip
(
model
,
awq_results
[
"clip"
])
apply_clip
(
model
,
awq_results
[
"clip"
])
def
run_search
(
model_path
,
dump_path
,
w_bit
,
q_config
):
def
run_search
(
model_path
,
dump_path
,
w_bit
,
q_config
):
model
,
tokenizer
=
load_unquantiz
ed
(
model_path
)
model
=
AutoAWQForCausalLM
.
from_pretrain
ed
(
model_path
)
awq_model
=
get_awq_model
(
model
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
awq_results
=
awq_
model
.
quantize
(
model
,
tokenizer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
awq_results
=
model
.
quantize
(
model
.
model
,
tokenizer
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
True
,
run_quant
=
False
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
awq_results
,
dump_path
)
torch
.
save
(
awq_results
,
dump_path
)
def
run_quant
(
model_path
,
search_path
,
dump_path
,
w_bit
,
q_config
,
device
):
def
run_quant
(
model_path
,
search_path
,
dump_path
,
w_bit
,
q_config
):
model
,
tokenizer
=
load_unquantized
(
model_path
,
device
)
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
)
load_search_result_into_memory
(
model
,
search_path
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
load_search_result_into_memory
(
model
.
model
,
search_path
)
awq_model
=
get_awq_model
(
model
)
model
.
quantize
(
model
.
model
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
awq_model
.
quantize
(
model
,
w_bit
=
w_bit
,
q_config
=
q_config
,
run_search
=
False
,
run_quant
=
True
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
dirpath
=
os
.
path
.
dirname
(
dump_path
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
os
.
makedirs
(
dirpath
,
exist_ok
=
True
)
torch
.
save
(
model
.
cpu
().
state_dict
(),
dump_path
)
torch
.
save
(
model
.
model
.
cpu
().
state_dict
(),
dump_path
)
def
run_perplexity
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
):
def
run_perplexity
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
):
model
,
tokenizer
=
load_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
model
=
AutoAWQForCausalLM
.
from_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
lm_eval_model
=
LMEvalAdaptor
(
model_path
,
model
,
tokenizer
,
device
,
batch_size
=
1
)
results
=
evaluator
.
simple_evaluate
(
results
=
evaluator
.
simple_evaluate
(
...
...
awq/models/auto.py
View file @
14d198c6
...
@@ -18,13 +18,17 @@ class AutoAWQForCausalLM:
...
@@ -18,13 +18,17 @@ class AutoAWQForCausalLM:
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
'AutoAWQForCausalLM.from_quantized or AutoAWQForCausalLM.from_pretrained'
)
@
classmethod
@
classmethod
def
from_pretrained
():
def
from_pretrained
(
self
,
model_path
,
trust_remote_code
=
True
):
pass
model_type
=
check_and_get_model_type
(
model_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_pretrained
(
model_path
,
model_type
,
trust_remote_code
=
trust_remote_code
)
@
classmethod
@
classmethod
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
model_type
=
check_and_get_model_type
(
model_path
,
trust_remote_code
)
model_type
=
check_and_get_model_type
(
model_path
,
trust_remote_code
)
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
]
()
.
from_quantized
(
return
AWQ_CAUSAL_LM_MODEL_MAP
[
model_type
].
from_quantized
(
model_path
,
quant_path
,
w_bit
,
q_config
,
device
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
)
)
\ No newline at end of file
awq/models/base.py
View file @
14d198c6
...
@@ -15,6 +15,11 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
...
@@ -15,6 +15,11 @@ from accelerate import init_empty_weights, load_checkpoint_and_dispatch
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
class
BaseAWQForCausalLM
:
class
BaseAWQForCausalLM
:
def
__init__
(
self
,
model
,
model_type
,
is_quantized
):
self
.
model
=
model
self
.
model_type
=
model_type
self
.
is_quantized
=
is_quantized
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
def
quantize
(
self
,
model
,
tokenizer
=
None
,
w_bit
=
4
,
q_config
=
{},
n_samples
=
128
,
seqlen
=
512
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
auto_scale
=
True
,
mse_range
=
True
,
run_search
=
False
,
run_quant
=
True
,
...
@@ -39,7 +44,7 @@ class BaseAWQForCausalLM:
...
@@ -39,7 +44,7 @@ class BaseAWQForCausalLM:
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"AWQ Quantization"
):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
self
.
_scale_activations
(
layer
)
self
.
_scale_activations
(
self
,
layer
)
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
module
.
cuda
()
module
.
cuda
()
...
@@ -167,9 +172,21 @@ class BaseAWQForCausalLM:
...
@@ -167,9 +172,21 @@ class BaseAWQForCausalLM:
def
save_quantized
():
def
save_quantized
():
pass
pass
def
from_pretrained
():
@
classmethod
pass
def
from_pretrained
(
self
,
model_path
,
model_type
,
torch_dtype
:
torch
.
dtype
=
torch
.
float16
,
trust_remote_code
=
True
):
# Load config
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
# Load empty weights
with
init_empty_weights
():
model
=
AutoModelForCausalLM
.
from_config
(
config
=
config
,
torch_dtype
=
torch
.
float16
,
trust_remote_code
=
True
)
# Load model weights
model
=
load_checkpoint_and_dispatch
(
model
,
model_path
,
device_map
=
"balanced"
,
no_split_module_classes
=
[
self
.
layer_type
])
return
self
(
model
,
model_type
,
is_quantized
=
False
)
@
classmethod
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
def
from_quantized
(
self
,
model_path
,
quant_path
,
w_bit
,
q_config
,
device
,
trust_remote_code
=
True
):
# Load config
# Load config
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
config
=
AutoConfig
.
from_pretrained
(
model_path
,
trust_remote_code
=
trust_remote_code
)
...
@@ -183,7 +200,7 @@ class BaseAWQForCausalLM:
...
@@ -183,7 +200,7 @@ class BaseAWQForCausalLM:
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"Replacing layers..."
):
for
i
in
tqdm
(
range
(
len
(
layers
)),
desc
=
"Replacing layers..."
):
layer
=
layers
[
i
]
layer
=
layers
[
i
]
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
self
.
_scale_activations
(
layer
)
self
.
_scale_activations
(
self
,
layer
)
for
name
,
module
in
named_linears
.
items
():
for
name
,
module
in
named_linears
.
items
():
q_linear
=
WQLinear
.
from_linear
(
q_linear
=
WQLinear
.
from_linear
(
...
@@ -196,10 +213,11 @@ class BaseAWQForCausalLM:
...
@@ -196,10 +213,11 @@ class BaseAWQForCausalLM:
model
.
tie_weights
()
model
.
tie_weights
()
model
=
load_checkpoint_and_dispatch
(
model
,
quant_path
,
device_map
=
"balanced"
)
model
=
load_checkpoint_and_dispatch
(
model
,
quant_path
,
device_map
=
"balanced"
,
no_split_module_classes
=
[
self
.
layer_type
]
)
return
model
return
model
@
staticmethod
def
_scale_activations
(
self
,
layer
):
def
_scale_activations
(
self
,
layer
):
act_function
=
self
.
get_act_from_layer
(
layer
)
act_function
=
self
.
get_act_from_layer
(
layer
)
...
...
awq/models/mpt.py
View file @
14d198c6
...
@@ -3,10 +3,12 @@ from .base import BaseAWQForCausalLM
...
@@ -3,10 +3,12 @@ from .base import BaseAWQForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
layer_type
=
"MPTBlock"
def
get_model_layers
(
self
,
model
):
@
staticmethod
def
get_model_layers
(
model
):
return
model
.
transformer
.
blocks
return
model
.
transformer
.
blocks
def
get_layers_for_scaling
(
self
,
module
,
input_feat
,
module_kwargs
):
@
staticmethod
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
layers
=
[]
layers
=
[]
# attention input
# attention input
...
@@ -42,16 +44,19 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -42,16 +44,19 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return
layers
return
layers
def
get_act_from_layer
(
self
,
layer
):
@
staticmethod
def
get_act_from_layer
(
layer
):
return
layer
.
ffn
.
act
return
layer
.
ffn
.
act
def
get_act_for_scaling
(
self
,
module
):
@
staticmethod
def
get_act_for_scaling
(
module
):
return
dict
(
return
dict
(
scale_name
=
"ffn.act"
,
scale_name
=
"ffn.act"
,
scale_layer
=
module
.
ffn
.
act
,
scale_layer
=
module
.
ffn
.
act
,
scale_shape
=
module
.
ffn
.
up_proj
.
out_features
scale_shape
=
module
.
ffn
.
up_proj
.
out_features
)
)
def
move_embed
(
self
,
model
,
device
):
@
staticmethod
def
move_embed
(
model
,
device
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment