Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
5b9f3c47
Unverified
Commit
5b9f3c47
authored
Dec 22, 2023
by
Casper
Committed by
GitHub
Dec 22, 2023
Browse files
Mixtral: Mixture of Experts quantization (#251)
parent
2350a4d0
Changes
12
Show whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
323 additions
and
18 deletions
+323
-18
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/base.py
awq/models/base.py
+9
-3
awq/models/mixtral.py
awq/models/mixtral.py
+137
-0
awq/modules/fused/block.py
awq/modules/fused/block.py
+34
-0
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+4
-1
awq/modules/fused/model.py
awq/modules/fused/model.py
+57
-2
awq/quantize/quantizer.py
awq/quantize/quantizer.py
+13
-9
awq/quantize/scale.py
awq/quantize/scale.py
+23
-1
awq/utils/module.py
awq/utils/module.py
+11
-1
examples/basic_quant.py
examples/basic_quant.py
+3
-1
examples/mixtral_quant.py
examples/mixtral_quant.py
+30
-0
No files found.
awq/models/__init__.py
View file @
5b9f3c47
...
@@ -10,3 +10,4 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
...
@@ -10,3 +10,4 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from
.aquila
import
AquilaAWQForCausalLM
from
.aquila
import
AquilaAWQForCausalLM
from
.yi
import
YiAWQForCausalLM
from
.yi
import
YiAWQForCausalLM
from
.qwen
import
QwenAWQForCausalLM
from
.qwen
import
QwenAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
\ No newline at end of file
awq/models/auto.py
View file @
5b9f3c47
...
@@ -14,6 +14,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
...
@@ -14,6 +14,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"gptj"
:
GPTJAWQForCausalLM
,
"gptj"
:
GPTJAWQForCausalLM
,
"gpt_bigcode"
:
GptBigCodeAWQForCausalLM
,
"gpt_bigcode"
:
GptBigCodeAWQForCausalLM
,
"mistral"
:
MistralAWQForCausalLM
,
"mistral"
:
MistralAWQForCausalLM
,
"mixtral"
:
MixtralAWQForCausalLM
,
"gpt_neox"
:
GPTNeoXAWQForCausalLM
,
"gpt_neox"
:
GPTNeoXAWQForCausalLM
,
"aquila"
:
AquilaAWQForCausalLM
,
"aquila"
:
AquilaAWQForCausalLM
,
"Yi"
:
YiAWQForCausalLM
,
"Yi"
:
YiAWQForCausalLM
,
...
...
awq/models/base.py
View file @
5b9f3c47
...
@@ -12,7 +12,11 @@ from huggingface_hub import snapshot_download
...
@@ -12,7 +12,11 @@ from huggingface_hub import snapshot_download
from
awq.quantize.quantizer
import
AwqQuantizer
from
awq.quantize.quantizer
import
AwqQuantizer
from
transformers.modeling_utils
import
shard_checkpoint
from
transformers.modeling_utils
import
shard_checkpoint
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
(
get_named_linears
,
set_op_by_name
,
exclude_layers_to_not_quantize
,
)
from
transformers
import
(
from
transformers
import
(
AutoModelForCausalLM
,
AutoModelForCausalLM
,
AutoConfig
,
AutoConfig
,
...
@@ -24,7 +28,6 @@ from accelerate.big_modeling import (
...
@@ -24,7 +28,6 @@ from accelerate.big_modeling import (
infer_auto_device_map
,
infer_auto_device_map
,
load_checkpoint_and_dispatch
,
load_checkpoint_and_dispatch
,
)
)
from
accelerate.utils
import
get_balanced_memory
class
BaseAWQForCausalLM
(
nn
.
Module
):
class
BaseAWQForCausalLM
(
nn
.
Module
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
config
,
quant_config
):
def
__init__
(
self
,
model
,
model_type
,
is_quantized
,
config
,
quant_config
):
...
@@ -176,7 +179,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -176,7 +179,7 @@ class BaseAWQForCausalLM(nn.Module):
if
not
os
.
path
.
isdir
(
model_path
):
if
not
os
.
path
.
isdir
(
model_path
):
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
,
"optimizer.pt"
]
ignore_patterns
=
[
"*msgpack*"
,
"*h5*"
,
"optimizer.pt"
]
if
safetensors
:
if
safetensors
:
ignore_patterns
.
extend
([
"*.pt*"
,
"*.bin*"
])
ignore_patterns
.
extend
([
"*.pt*"
,
"*.bin*"
,
"consolidated*"
])
else
:
else
:
ignore_patterns
.
append
(
"*.safetensors*"
)
ignore_patterns
.
append
(
"*.safetensors*"
)
...
@@ -215,6 +218,9 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -215,6 +218,9 @@ class BaseAWQForCausalLM(nn.Module):
# Get every linear layer in a block
# Get every linear layer in a block
named_linears
=
get_named_linears
(
layer
)
named_linears
=
get_named_linears
(
layer
)
# Filter out the linear layers we don't want to exclude
named_linears
=
exclude_layers_to_not_quantize
(
named_linears
,
quant_config
.
modules_to_not_convert
)
# Replace activation functions
# Replace activation functions
self
.
_scale_activations
(
self
,
layer
)
self
.
_scale_activations
(
self
,
layer
)
...
...
awq/models/mixtral.py
0 → 100644
View file @
5b9f3c47
import
tqdm
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
MixtralBlock
from
awq.modules.fused.model
import
MixtralModel
from
transformers.models.mixtral.modeling_mixtral
import
(
MixtralDecoderLayer
as
OldMixtralDecoderLayer
,
MixtralForCausalLM
as
OldMixtralForCausalLM
)
from
awq.modules.fused.mlp
import
QuantFusedMLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
MixtralAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MixtralDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
OldMixtralForCausalLM
):
fuser
=
MixtralFuser
(
model
)
# TODO: Fix perplexity on fusing Mixtral
#fuser.fuse_transformer()
@
staticmethod
def
get_model_layers
(
model
:
OldMixtralForCausalLM
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
:
OldMixtralForCausalLM
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
OldMixtralDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
],
inp
=
input_feat
[
'self_attn.q_proj'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# attention out
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
'self_attn.o_proj'
],
))
# linear in
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
w
for
expert
in
module
.
block_sparse_moe
.
experts
for
w
in
[
expert
.
w1
,
expert
.
w3
]
],
inp
=
input_feat
[
'block_sparse_moe'
],
module2inspect
=
module
.
block_sparse_moe
,
))
# linear out
for
i
,
expert
in
enumerate
(
module
.
block_sparse_moe
.
experts
):
layers
.
append
(
dict
(
prev_op
=
expert
.
w3
,
layers
=
[
expert
.
w2
],
inp
=
input_feat
[
f
'block_sparse_moe.experts.
{
i
}
.w2'
],
))
return
layers
class
MixtralFuser
:
def
__init__
(
self
,
model
:
OldMixtralForCausalLM
):
self
.
model
=
model
self
.
mixtral_blocks
:
List
[
Tuple
[
str
,
OldMixtralDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
'MixtralDecoderLayer'
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldMixtralDecoderLayer
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
qkv
=
fuse_qkv
(
module
,
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
)
# Adapt to mixture of experts
for
i
in
range
(
len
(
module
.
block_sparse_moe
.
experts
)):
mlp
=
QuantFusedMLP
(
gate_proj
=
module
.
block_sparse_moe
.
experts
[
i
].
w1
,
down_proj
=
module
.
block_sparse_moe
.
experts
[
i
].
w2
,
up_proj
=
module
.
block_sparse_moe
.
experts
[
i
].
w3
)
module
.
block_sparse_moe
.
experts
[
i
]
=
mlp
norm_1
=
FasterTransformerRMSNorm
(
module
.
input_layernorm
.
weight
,
module
.
input_layernorm
.
variance_epsilon
)
norm_2
=
FasterTransformerRMSNorm
(
module
.
post_attention_layernorm
.
weight
,
module
.
post_attention_layernorm
.
variance_epsilon
)
blocks
.
append
(
MixtralBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
moe
=
module
.
block_sparse_moe
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
))
self
.
model
.
model
=
MixtralModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
awq/modules/fused/block.py
View file @
5b9f3c47
...
@@ -2,6 +2,40 @@ import os
...
@@ -2,6 +2,40 @@ import os
import
torch.nn
as
nn
import
torch.nn
as
nn
from
awq.modules.fused.attn
import
QuantAttentionFused
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MixtralBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
moe
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
hidden_size
=
hidden_size
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
moe
=
moe
self
.
device
=
dev
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
):
norm_out
=
self
.
norm_1
(
hidden_states
)
attn_output
,
_
,
past_key_value
=
self
.
attn
.
forward
(
hidden_states
=
norm_out
,
past_key_value
=
past_key_value
,
attention_mask
=
attention_mask
)
h
=
hidden_states
.
to
(
attn_output
.
device
)
+
attn_output
out
,
_
=
self
.
moe
.
forward
(
self
.
norm_2
(
h
))
out
=
h
+
out
return
out
,
None
,
past_key_value
class
LlamaLikeBlock
(
nn
.
Module
):
class
LlamaLikeBlock
(
nn
.
Module
):
"""
"""
LlamaLikeBlock is intended to be reused across blocks that have
LlamaLikeBlock is intended to be reused across blocks that have
...
...
awq/modules/fused/mlp.py
View file @
5b9f3c47
...
@@ -36,7 +36,7 @@ class QuantFusedMLP(nn.Module):
...
@@ -36,7 +36,7 @@ class QuantFusedMLP(nn.Module):
self
.
activation
=
activation
self
.
activation
=
activation
def
forward
(
self
,
x
):
def
forward
(
self
,
x
,
routing_weights
=
None
):
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
out_shape
=
x
.
shape
[:
-
1
]
+
(
self
.
intermediate_size
,)
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
gate_output
=
self
.
linear
(
gate_output
=
self
.
linear
(
...
@@ -57,6 +57,9 @@ class QuantFusedMLP(nn.Module):
...
@@ -57,6 +57,9 @@ class QuantFusedMLP(nn.Module):
x
=
x
.
reshape
(
out_shape
)
x
=
x
.
reshape
(
out_shape
)
x
=
self
.
down_proj
(
x
)
x
=
self
.
down_proj
(
x
)
if
routing_weights
is
not
None
:
x
=
routing_weights
*
x
return
x
return
x
...
...
awq/modules/fused/model.py
View file @
5b9f3c47
...
@@ -2,8 +2,63 @@ import torch
...
@@ -2,8 +2,63 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
typing
import
List
from
typing
import
List
from
awq.utils
import
fused_utils
from
awq.utils
import
fused_utils
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
,
MoeModelOutputWithPast
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
,
LlamaLikeBlock
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
,
LlamaLikeBlock
,
MixtralBlock
class
MixtralModel
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
blocks
,
embedding
,
norm
):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
embedding
=
embedding
self
.
blocks
:
List
[
MixtralBlock
]
=
nn
.
ModuleList
(
blocks
)
self
.
norm
=
norm
self
.
last_forward_num_tokens
=
0
@
torch
.
inference_mode
()
def
forward
(
self
,
input_ids
:
torch
.
Tensor
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
,
*
args
,
**
kwargs
,
):
input_ids
,
self
.
last_forward_num_tokens
=
fused_utils
.
prepare_input_ids
(
input_ids
,
self
.
last_forward_num_tokens
)
_bsz
,
seqlen
=
input_ids
.
shape
fused_utils
.
prepare_cache
(
self
.
blocks
,
seqlen
)
h
=
self
.
embedding
(
input_ids
)
mask
=
fused_utils
.
prepare_attention_mask
(
seqlen
=
seqlen
,
start_pos
=
self
.
blocks
[
0
].
attn
.
start_pos
,
device
=
input_ids
.
device
,
type_as
=
h
,
)
for
layer
in
self
.
blocks
:
h
,
mask
=
fused_utils
.
prepare_correct_devices
(
layer
,
h
,
mask
,
)
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
norm
(
h
)
return
MoeModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
(),
router_logits
=
(),
)
class
LlamaLikeModel
(
nn
.
Module
):
class
LlamaLikeModel
(
nn
.
Module
):
"""
"""
...
...
awq/quantize/quantizer.py
View file @
5b9f3c47
...
@@ -10,7 +10,13 @@ from awq.utils.utils import clear_memory
...
@@ -10,7 +10,13 @@ from awq.utils.utils import clear_memory
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.utils.calib_data
import
get_calib_dataset
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.quantize.scale
import
apply_scale
,
apply_clip
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
from
awq.utils.module
import
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
from
awq.utils.module
import
(
append_str_prefix
,
get_op_name
,
get_named_linears
,
set_op_by_name
,
exclude_layers_to_not_quantize
)
class
AwqQuantizer
:
class
AwqQuantizer
:
...
@@ -70,13 +76,6 @@ class AwqQuantizer:
...
@@ -70,13 +76,6 @@ class AwqQuantizer:
return
w
return
w
def
_exclude_layers_to_not_quantize
(
self
,
linear_layers
):
filtered_layers
=
{}
for
name
,
linear_layer
in
linear_layers
.
items
():
if
not
any
(
key
in
name
for
key
in
self
.
modules_to_not_convert
):
filtered_layers
[
name
]
=
linear_layer
return
filtered_layers
def
quantize
(
self
):
def
quantize
(
self
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"AWQ"
):
for
i
in
tqdm
(
range
(
len
(
self
.
modules
)),
desc
=
"AWQ"
):
# Move module and inputs to correct device
# Move module and inputs to correct device
...
@@ -91,7 +90,7 @@ class AwqQuantizer:
...
@@ -91,7 +90,7 @@ class AwqQuantizer:
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
named_linears
=
get_named_linears
(
self
.
modules
[
i
])
# Filter out the linear layers we don't want to exclude
# Filter out the linear layers we don't want to exclude
named_linears
=
self
.
_
exclude_layers_to_not_quantize
(
named_linears
)
named_linears
=
exclude_layers_to_not_quantize
(
named_linears
,
self
.
modules_to_not_convert
)
input_feat
=
self
.
_get_input_feat
(
self
.
modules
[
i
],
named_linears
)
input_feat
=
self
.
_get_input_feat
(
self
.
modules
[
i
],
named_linears
)
clear_memory
()
clear_memory
()
...
@@ -387,6 +386,11 @@ class AwqQuantizer:
...
@@ -387,6 +386,11 @@ class AwqQuantizer:
input_feat
=
defaultdict
(
list
)
input_feat
=
defaultdict
(
list
)
handles
=
[]
handles
=
[]
# FIXME: Workaround for Mixtral to use block_sparse_moe input features
if
self
.
awq_model
.
model_type
==
"mixtral"
:
named_linears
=
{
**
named_linears
,
"block_sparse_moe"
:
layer
.
block_sparse_moe
}
for
name
in
named_linears
:
for
name
in
named_linears
:
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
handles
.
append
(
named_linears
[
name
].
register_forward_hook
(
functools
.
partial
(
cache_input_hook
,
name
=
name
,
functools
.
partial
(
cache_input_hook
,
name
=
name
,
...
...
awq/quantize/scale.py
View file @
5b9f3c47
...
@@ -33,7 +33,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
...
@@ -33,7 +33,10 @@ def apply_scale(module, scales_list, input_feat_dict=None):
layer
.
cuda
()
layer
.
cuda
()
scales
.
cuda
()
scales
.
cuda
()
if
isinstance
(
prev_op
,
nn
.
Linear
):
if
isinstance
(
prev_op
,
nn
.
Linear
)
and
type
(
layers
)
==
list
and
isinstance
(
layers
[
0
],
nn
.
Linear
):
scale_fc_fcs
(
prev_op
,
layers
,
scales
)
elif
isinstance
(
prev_op
,
nn
.
Linear
):
assert
len
(
layers
)
==
1
assert
len
(
layers
)
==
1
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
scale_fc_fc
(
prev_op
,
layers
[
0
],
scales
)
...
@@ -101,6 +104,25 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
...
@@ -101,6 +104,25 @@ def scale_fc_fc(fc1: nn.Linear, fc2: nn.Linear, scales: torch.Tensor):
for
p
in
fc2
.
parameters
():
for
p
in
fc2
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
def
scale_fc_fcs
(
fc1
:
nn
.
Linear
,
fcs
:
List
[
nn
.
Linear
],
scales
:
torch
.
Tensor
):
if
not
isinstance
(
fcs
,
list
):
fcs
=
[
fcs
]
scales
=
scales
.
to
(
fc1
.
weight
.
device
)
fc1
.
weight
[
-
scales
.
size
(
0
):].
div_
(
scales
.
view
(
-
1
,
1
))
if
fc1
.
bias
is
not
None
:
fc1
.
bias
.
div_
(
scales
.
view
(
-
1
))
for
fc
in
fcs
:
fc
.
weight
.
mul_
(
scales
.
view
(
1
,
-
1
))
for
p
in
fc1
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
for
fc
in
fcs
:
for
p
in
fc
.
parameters
():
assert
torch
.
isnan
(
p
).
sum
()
==
0
@
torch
.
no_grad
()
@
torch
.
no_grad
()
def
scale_gelu_fc
(
gelu
:
allowed_act_fns
,
fc
:
nn
.
Linear
,
scales
:
torch
.
Tensor
):
def
scale_gelu_fc
(
gelu
:
allowed_act_fns
,
fc
:
nn
.
Linear
,
scales
:
torch
.
Tensor
):
...
...
awq/utils/module.py
View file @
5b9f3c47
...
@@ -42,3 +42,13 @@ def append_str_prefix(x, prefix):
...
@@ -42,3 +42,13 @@ def append_str_prefix(x, prefix):
return
[
append_str_prefix
(
y
,
prefix
)
for
y
in
x
]
return
[
append_str_prefix
(
y
,
prefix
)
for
y
in
x
]
else
:
else
:
return
x
return
x
def
exclude_layers_to_not_quantize
(
linear_layers
,
modules_to_not_convert
):
if
modules_to_not_convert
is
None
:
return
linear_layers
filtered_layers
=
{}
for
name
,
linear_layer
in
linear_layers
.
items
():
if
not
any
(
key
in
name
for
key
in
modules_to_not_convert
):
filtered_layers
[
name
]
=
linear_layer
return
filtered_layers
\ No newline at end of file
examples/basic_quant.py
View file @
5b9f3c47
...
@@ -7,7 +7,9 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":
...
@@ -7,7 +7,9 @@ quant_config = { "zero_point": True, "q_group_size": 128, "w_bit": 4, "version":
# Load model
# Load model
# NOTE: pass safetensors=True to load safetensors
# NOTE: pass safetensors=True to load safetensors
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
**
{
"low_cpu_mem_usage"
:
True
})
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
**
{
"low_cpu_mem_usage"
:
True
,
"use_cache"
:
False
}
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
# Quantize
...
...
examples/mixtral_quant.py
0 → 100644
View file @
5b9f3c47
from
awq
import
AutoAWQForCausalLM
from
transformers
import
AutoTokenizer
model_path
=
'mistralai/Mixtral-8x7B-Instruct-v0.1'
quant_path
=
'mixtral-instruct-awq'
modules_to_not_convert
=
[
"gate"
]
quant_config
=
{
"zero_point"
:
True
,
"q_group_size"
:
128
,
"w_bit"
:
4
,
"version"
:
"GEMM"
,
"modules_to_not_convert"
:
modules_to_not_convert
}
# Load model
# NOTE: pass safetensors=True to load safetensors
model
=
AutoAWQForCausalLM
.
from_pretrained
(
model_path
,
safetensors
=
True
,
**
{
"low_cpu_mem_usage"
:
True
}
)
tokenizer
=
AutoTokenizer
.
from_pretrained
(
model_path
,
trust_remote_code
=
True
)
# Quantize
model
.
quantize
(
tokenizer
,
quant_config
=
quant_config
,
modules_to_not_convert
=
modules_to_not_convert
)
# Save quantized model
model
.
save_quantized
(
quant_path
)
tokenizer
.
save_pretrained
(
quant_path
)
print
(
f
'Model is quantized and saved at "
{
quant_path
}
"'
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment