Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
5b9f3c47
Unverified
Commit
5b9f3c47
authored
Dec 22, 2023
by
Casper
Committed by
GitHub
Dec 22, 2023
Browse files
Mixtral: Mixture of Experts quantization (#251)
parent
2350a4d0
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
323 additions
and
18 deletions
+323
-18
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/base.py
awq/models/base.py
+9
-3
awq/models/mixtral.py
awq/models/mixtral.py
+137
-0
awq/modules/fused/block.py
awq/modules/fused/block.py
+34
-0
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+4
-1
awq/modules/fused/model.py
awq/modules/fused/model.py
+57
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+13
-9
awq/quantize/scale.py
awq/quantize/scale.py
+23
-1
awq/utils/module.py
awq/utils/module.py
+11
-1
examples/basic_quant.py
examples/basic_quant.py
+3
-1
examples/mixtral_quant.py
examples/mixtral_quant.py
+30
-0
No files found.
awq/models/__init__.py
View file @
5b9f3c47
...
...
@@ -10,3 +10,4 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from
.aquila
import
AquilaAWQForCausalLM
from
.yi
import
YiAWQForCausalLM
from
.qwen
import
QwenAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
\ No newline at end of file
awq/models/auto.py
View file @
5b9f3c47
...
...
@@ -14,6 +14,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"gptj"
:
GPTJAWQForCausalLM
,
"gpt_bigcode"
:
GptBigCodeAWQForCausalLM
,
"mistral"
:
MistralAWQForCausalLM
,
"mixtral"
:
MixtralAWQForCausalLM
,
"gpt_neox"
:
GPTNeoXAWQForCausalLM
,
"aquila"
:
AquilaAWQForCausalLM
,
"Yi"
:
YiAWQForCausalLM
,
...
...
awq/models/base.py
View file @
5b9f3c47
...
...
@@ -12,7 +12,11 @@ from huggingface_hub import snapshot_download
from
awq.quantize.quantizer
import
AwqQuantizer
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
(
get_named_linears
,
set_op_by_name
,
exclude_layers_to_not_quantize
,
)
from
transformers
import
(
AutoModelForCausalLM
,
AutoConfig
,
...
...
@@ -24,7 +28,6 @@ from accelerate.big_modeling import (
infer_auto_device_map
,
load_checkpoint_and_dispatch
,
)
from
accelerate.utils
import
get_balanced_memory
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
config
,
quant_config
):
...
...
@@ -176,7 +179,7 @@ class BaseAWQForCausalLM(nn.Module):
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
,
"optimizer.pt"
]
if
safetensors
:
ignore_patterns
.
extend
([
"*.pt*"
,
"*.bin*"
])
ignore_patterns
.
extend
([
"*.pt*"
,
"*.bin*"
,
"consolidated*"
])
else
:
ignore_patterns
.
append
(
"*.safetensors*"
)
...
...
@@ -215,6 +218,9 @@ class BaseAWQForCausalLM(nn.Module):
# Get every linear layer in a block
named_linears
=
get_named_linears
(
layer
)
# Filter out the linear layers we don't want to exclude
named_linears
=
exclude_layers_to_not_quantize
(
named_linears
,
quant_config
.
modules_to_not_convert
)
# Replace activation functions
self
.
_scale_activations
(
self
,
layer
)
...
...
awq/models/mixtral.py
0 → 100644
View file @
5b9f3c47
import
tqdm
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
MixtralBlock
from
awq.modules.fused.model
import
MixtralModel
from
transformers.models.mixtral.modeling_mixtral
import
(
MixtralDecoderLayer
as
OldMixtralDecoderLayer
,
MixtralForCausalLM
as
OldMixtralForCausalLM
)
from
awq.modules.fused.mlp
import
QuantFusedMLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
MixtralAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MixtralDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
OldMixtralForCausalLM
):
fuser
=
MixtralFuser
(
model
)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()
@
staticmethod
def
get_model_layers
(
model
:
OldMixtralForCausalLM
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
:
OldMixtralForCausalLM
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
OldMixtralDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
inp
=
input_feat
[
'self_attn.q_proj'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# attention out
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
'self_attn.o_proj'
],
))
# linear in
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
w
for
expert
in
module
.
block_sparse_moe
.
experts
for
w
in
[
expert
.
w1
,
expert
.
w3
]
],
inp
=
input_feat
[
'block_sparse_moe'
],
module2inspect
=
module
.
block_sparse_moe
,
))
# linear out
for
i
,
expert
in
enumerate
(
module
.
block_sparse_moe
.
experts
):
layers
.
append
(
dict
(
prev_op
=
expert
.
w3
,
layers
=
[
expert
.
w2
],
inp
=
input_feat
[
f
'block_sparse_moe.experts.
{
i
}
.w2'
],
))
return
layers
class
MixtralFuser
:
def
__init__
(
self
,
model
:
OldMixtralForCausalLM
):
self
.
model
=
model
self
.
mixtral_blocks
:
List
[
Tuple
[
str
,
OldMixtralDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
'MixtralDecoderLayer'
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldMixtralDecoderLayer
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
qkv
=
fuse_qkv
(
module
,
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
)
# Adapt to mixture of experts
for
i
in
range
(
len
(
module
.
block_sparse_moe
.
experts
)):
mlp
=
QuantFusedMLP
(
gate_proj
=
module
.
block_sparse_moe
.
experts
[
i
].
w1
,
down_proj
=
module
.
block_sparse_moe
.
experts
[
i
].
w2
,
up_proj
=
module
.
block_sparse_moe
.
experts
[
i
].
w3
)
module
.
block_sparse_moe
.
experts
[
i
]
=
mlp
norm_1
=
FasterTransformerRMSNorm
(
module
.
input_layernorm
.
weight
,
module
.
input_layernorm
.
variance_epsilon
)
norm_2
=
FasterTransformerRMSNorm
(
module
.
post_attention_layernorm
.
weight
,
module
.
post_attention_layernorm
.
variance_epsilon
)
blocks
.
append
(
MixtralBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
moe
=
module
.
block_sparse_moe
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
))
self
.
model
.
model
=
MixtralModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
awq/modules/fused/block.py
View file @
5b9f3c47
...
...
@@ -2,6 +2,40 @@ import os
import
torch.nn
as
nn
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MixtralBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
moe
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
hidden_size
=
hidden_size
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
moe
=
moe
self
.
device
=
dev
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
):
norm_out
=
self
.
norm_1
(
hidden_states
)
attn_output
,
_
,
past_key_value
=
self
.
attn
.
forward
(
hidden_states
=
norm_out
,
past_key_value
=
past_key_value
,
attention_mask
=
attention_mask
)
h
=
hidden_states
.
to
(
attn_output
.
device
)
+
attn_output
out
,
_
=
self
.
moe
.
forward
(
self
.
norm_2
(
h
))
out
=
h
+
out
return
out
,
None
,
past_key_value
class
LlamaLikeBlock
(
nn
.
Module
):
"""
LlamaLikeBlock is intended to be reused across blocks that have
...
...
awq/modules/fused/mlp.py
View file @
5b9f3c47
...
...
@@ -36,7 +36,7 @@ class QuantFusedMLP(nn.Module):
self
.
activation
=
activation
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
routing_weights
=
None
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
gate_output
=
self
.
linear
(
...
...
@@ -57,6 +57,9 @@ class QuantFusedMLP(nn.Module):
x
=
x
.
reshape
(
out_shape
)
x
=
self
.
down_proj
(
x
)
if
routing_weights
is
not
None
:
x
=
routing_weights
*
x
return
x
...
...
awq/modules/fused/model.py
View file @
5b9f3c47
...
...
@@ -2,8 +2,63 @@ import torch
import
torch.nn
as
nn
from
typing
import
List
from
awq.utils
import
fused_utils
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
,
LlamaLikeBlock
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
,
MoeModelOutputWithPast
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
,
LlamaLikeBlock
,
MixtralBlock
class
MixtralModel
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
blocks
,
embedding
,
norm
):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
embedding
=
embedding
self
.
blocks
:
List
[
MixtralBlock
]
=
nn
.
ModuleList
(
blocks
)
self
.
norm
=
norm
self
.
last_forward_num_tokens
=
0
@
torch
.
inference_mode
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
,
*
args
,
**
kwargs
,
):
input_ids
,
self
.
last_forward_num_tokens
=
fused_utils
.
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
fused_utils
.
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
embedding
(
input_ids
)
mask
=
fused_utils
.
prepare_attention_mask
(
seqlen
=
seqlen
,
start_pos
=
self
.
blocks
[
0
].
attn
.
start_pos
,
device
=
input_ids
.
device
,
type_as
=
h
,
)
for
layer
in
self
.
blocks
:
h
,
mask
=
fused_utils
.
prepare_correct_devices
(
layer
,
h
,
mask
,
)
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
norm
(
h
)
return
MoeModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
(),
router_logits
=
(),
)
class
LlamaLikeModel
(
nn
.
Module
):
"""
...
...
awq/quantize/quantizer.py
View file @
5b9f3c47
...
...
@@ -10,7 +10,13 @@ from awq.utils.utils import clear_memory
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
(
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
,
exclude_layers_to_not_quantize
)
class
AwqQuantizer
:
...
...
@@ -70,13 +76,6 @@ class AwqQuantizer:
return
w
def
_exclude_layers_to_not_quantize
(
self
,
linear_layers
):
filtered_layers
=
{}
for
name
,
linear_layer
in
linear_layers
.
items
():
if
not
any
(
key
in
name
for
key
in
self
.
modules_to_not_convert
):
filtered_layers
[
name
]
=
linear_layer
return
filtered_layers
def
quantize
(
self
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"AWQ"
):
# Move module and inputs to correct device
...
...
@@ -91,7 +90,7 @@ class AwqQuantizer:
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
# Filter out the linear layers we don't want to exclude
named_linears
=
self
.
_
exclude_layers_to_not_quantize
(
named_linears
)
named_linears
=
exclude_layers_to_not_quantize
(
named_linears
,
self
.
modules_to_not_convert
)
input_feat
=
self
.
_get_input_feat
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
...
...
@@ -387,6 +386,11 @@ class AwqQuantizer:
input_feat
=
defaultdict
(
list
)
handles
=
[]
# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if
self
.
awq_model
.
model_type
==
"mixtral"
:
named_linears
=
{
**
named_linears
,
"block_sparse_moe"
:
layer
.
block_sparse_moe
}
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
...
...
awq/quantize/scale.py
View file @
5b9f3c47
...
...
@@ -33,7 +33,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
layer
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
if
isinstance
(
prev_op
,
nn
.
Linear
)
and
type
(
layers
)
==
list
and
isinstance
(
layers
[
0
],
nn
.
Linear
):
scale_fc_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
...
...
@@ -101,6 +104,25 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
for
p
in
fc2
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_fc_fcs
(
fc1
:
nn
.
Linear
,
fcs
:
List
[
nn
.
Linear
],
scales
:
torch
.
Tensor
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
for
fc
in
fcs
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
fc1
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
fc
in
fcs
:
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
:
allowed_act_fns
,
fc
:
nn
.
Linear
,
scales
:
torch
.
Tensor
):
...
...
awq/utils/module.py
View file @
5b9f3c47
...
...
@@ -41,4 +41,14 @@ def append_str_prefix(x, prefix):
elif
isinstance
(
x
,
list
):
return
[
append_str_prefix
(
y
,
prefix
)
for
y
in
x
]
else
:
return
x
\ No newline at end of file
return
x
def
exclude_layers_to_not_quantize
(
linear_layers
,
modules_to_not_convert
):
if
modules_to_not_convert
is
None
:
return
linear_layers
filtered_layers
=
{}
for
name
,
linear_layer
in
linear_layers
.
items
():
if
not
any
(
key
in
name
for
key
in
modules_to_not_convert
):
filtered_layers
[
name
]
=
linear_layer
return
filtered_layers
\ No newline at end of file
examples/basic_quant.py
View file @
5b9f3c47
...
...
@@ -7,7 +7,9 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":
# Load model
# NOTE: pass safetensors=True to load safetensors
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
**
{
"low_cpu_mem_usage"
:
True
})
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
**
{
"low_cpu_mem_usage"
:
True
,
"use_cache"
:
False
}
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
...
...
examples/mixtral_quant.py
0 → 100644
View file @
5b9f3c47
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
model_path
=
'mistralai/Mixtral-8x7B-Instruct-v0.1'
quant_path
=
'mixtral-instruct-awq'
modules_to_not_convert
=
[
"gate"
]
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
,
"modules_to_not_convert"
:
modules_to_not_convert
}
# Load model
# NOTE: pass safetensors=True to load safetensors
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
safetensors
=
True
,
**
{
"low_cpu_mem_usage"
:
True
}
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
model
.
quantize
(
tokenizer
,
quant_config
=
quant_config
,
modules_to_not_convert
=
modules_to_not_convert
)
# Save quantized model
model
.
save_quantized
(
quant_path
)
tokenizer
.
save_pretrained
(
quant_path
)
print
(
f
'Model is quantized and saved at "
{
quant_path
}
"'
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment