Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
7e091fb1
Commit
7e091fb1
authored
Aug 17, 2023
by
Casper
Browse files
Initial refactor [WIP].
parent
efea69e1
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
87 additions
and
37 deletions
+87
-37
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/base.py
awq/models/base.py
+13
-0
awq/models/mpt.py
awq/models/mpt.py
+54
-0
awq/quantize/auto_scale.py
awq/quantize/auto_scale.py
+6
-28
awq/quantize/pre_quant.py
awq/quantize/pre_quant.py
+3
-3
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+10
-6
No files found.
awq/models/__init__.py
0 → 100644
View file @
7e091fb1
from
.mpt
import
MptAWQForCausalLM
\ No newline at end of file
awq/models/base.py
0 → 100644
View file @
7e091fb1
class
BaseAWQForCausalLM
:
def
quantize
():
pass
def
save_quantized
():
pass
def
from_pretrained
():
pass
def
from_quantized
():
pass
\ No newline at end of file
awq/models/mpt.py
0 → 100644
View file @
7e091fb1
from
.base
import
BaseAWQForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
def
get_model_layers
(
model
):
return
model
.
transformer
.
blocks
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
norm_1
,
layers
=
[
module
.
attn
.
Wqkv
],
inp
=
input_feat
[
'attn.Wqkv'
],
module2inspect
=
module
.
attn
,
kwargs
=
module_kwargs
))
# attention output
layers
.
append
(
dict
(
prev_op
=
module
.
attn
.
Wqkv
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
]
))
# linear 1
layers
.
append
(
dict
(
rev_op
=
module
.
norm_2
,
layers
=
[
module
.
ffn
.
up_proj
],
inp
=
input_feat
[
'ffn.up_proj'
],
module2inspect
=
module
.
ffn
))
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
ffn
.
act
,
layers
=
[
module
.
ffn
.
down_proj
],
inp
=
input_feat
[
'ffn.down_proj'
]
))
return
layers
def
get_act_for_scaling
(
module
):
return
dict
(
scale_name
=
"ffn.act"
,
scale_layer
=
module
.
ffn
.
act
,
scale_shape
=
module
.
ffn
.
up_proj
.
out_features
)
def
move_embed
(
model
,
device
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
\ No newline at end of file
awq/quantize/auto_scale.py
View file @
7e091fb1
...
...
@@ -8,6 +8,7 @@ from transformers.models.llama.modeling_llama import LlamaDecoderLayer, LlamaRMS
from
.qmodule
import
ScaledActivation
from
..utils.module
import
get_op_by_name
,
get_op_name
,
set_op_by_name
from
..models
import
MptAWQForCausalLM
__all__
=
[
"auto_scale_block"
,
"apply_scale"
]
...
...
@@ -265,34 +266,11 @@ def auto_scale_block(module, module_kwargs,
inp
=
input_feat
[
'mlp.dense_4h_to_h'
],
))
elif
"mpt"
in
str
(
module
.
__class__
).
lower
():
# attention input
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_1
,
layers
=
[
module
.
attn
.
Wqkv
],
inp
=
input_feat
[
'attn.Wqkv'
],
module2inspect
=
module
.
attn
,
kwargs
=
module_kwargs
,
))
# attn out
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
attn
.
Wqkv
,
layers
=
[
module
.
attn
.
out_proj
],
inp
=
input_feat
[
'attn.out_proj'
],
))
# fc1
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
norm_2
,
layers
=
[
module
.
ffn
.
up_proj
],
inp
=
input_feat
[
'ffn.up_proj'
],
module2inspect
=
module
.
ffn
,
))
# fc2
scales_list
.
append
(
_auto_get_scale
(
prev_op
=
module
.
ffn
.
act
,
layers
=
[
module
.
ffn
.
down_proj
],
inp
=
input_feat
[
'ffn.down_proj'
],
))
layers
:
list
[
dict
]
=
MptAWQForCausalLM
.
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
)
layers_scaled
=
[
_auto_get_scale
(
layer
)
for
layer
in
layers
]
scales_list
.
extend
(
layers_scaled
)
elif
"falcon"
in
str
(
module
.
__class__
).
lower
():
# attn out
...
...
awq/quantize/pre_quant.py
View file @
7e091fb1
...
...
@@ -11,6 +11,7 @@ from transformers.models.llama.modeling_llama import LlamaForCausalLM
from
.auto_scale
import
auto_scale_block
,
apply_scale
from
.auto_clip
import
auto_clip_block
,
apply_clip
from
..models
import
MptAWQForCausalLM
__all__
=
[
"run_awq"
]
...
...
@@ -27,7 +28,7 @@ def get_blocks(model):
elif
isinstance
(
model
,
BloomForCausalLM
):
layers
=
model
.
transformer
.
h
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
blocks
layers
=
MptAWQForCausalLM
.
get_model_layers
(
model
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
layers
=
model
.
transformer
.
h
else
:
...
...
@@ -44,8 +45,7 @@ def move_embed(model, device):
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
model
.
transformer
.
word_embeddings_layernorm
=
model
.
transformer
.
word_embeddings_layernorm
.
to
(
device
)
elif
"mpt"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
MptAWQForCausalLM
.
move_embed
(
model
,
device
)
elif
"falcon"
in
str
(
model
.
__class__
).
lower
():
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
else
:
...
...
awq/quantize/quantizer.py
View file @
7e091fb1
...
...
@@ -4,6 +4,7 @@ from tqdm import tqdm
import
gc
from
.qmodule
import
ScaledActivation
from
..utils.module
import
set_op_by_name
from
..models
import
MptAWQForCausalLM
from
transformers.models.bloom.modeling_bloom
import
BloomBlock
...
...
@@ -27,12 +28,15 @@ def scale_activations(module):
elif
'mptblock'
in
str
(
module
.
__class__
.
__name__
).
lower
():
if
isinstance
(
module
.
ffn
.
act
,
ScaledActivation
):
return
c
=
module
.
ffn
.
up_proj
.
out_features
act
=
ScaledActivation
(
module
.
ffn
.
act
,
torch
.
ones
(
c
,
dtype
=
dtype
,
device
=
device
)
)
set_op_by_name
(
module
,
"ffn.act"
,
act
)
# get activation scale
scale_dict
=
MptAWQForCausalLM
.
get_act_for_scaling
(
module
)
scale_like
=
torch
.
ones
(
scale_dict
[
'scale_shape'
],
dtype
=
dtype
,
device
=
device
)
# scale activation
scaled_act
=
ScaledActivation
(
scale_dict
[
'scale_layer'
],
scale_like
)
set_op_by_name
(
module
,
scale_dict
[
'scale_name'
],
scaled_act
)
elif
'falcon'
in
str
(
module
.
__class__
).
lower
():
if
isinstance
(
module
.
mlp
.
act
,
ScaledActivation
):
return
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment