Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
cef9f113
Unverified
Commit
cef9f113
authored
Dec 23, 2023
by
Aoyu
Committed by
GitHub
Dec 23, 2023
Browse files
Add Baichuan2 Support (#247)
Co-authored-by:
Casper
<
casperbh.96@gmail.com
>
parent
9e8e28b2
Changes
7
Show whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
163 additions
and
9 deletions
+163
-9
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/baichuan.py
awq/models/baichuan.py
+137
-0
awq/models/base.py
awq/models/base.py
+2
-0
awq/modules/fused/block.py
awq/modules/fused/block.py
+3
-3
awq/utils/calib_data.py
awq/utils/calib_data.py
+18
-6
examples/benchmark.py
examples/benchmark.py
+1
-0
No files found.
awq/models/__init__.py
View file @
cef9f113
...
...
@@ -10,5 +10,6 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from
.aquila
import
AquilaAWQForCausalLM
from
.yi
import
YiAWQForCausalLM
from
.qwen
import
QwenAWQForCausalLM
from
.baichuan
import
BaichuanAWQForCausalLM
from
.llava
import
LlavaAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
awq/models/auto.py
View file @
cef9f113
...
...
@@ -19,6 +19,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"aquila"
:
AquilaAWQForCausalLM
,
"Yi"
:
YiAWQForCausalLM
,
"qwen"
:
QwenAWQForCausalLM
,
"baichuan"
:
BaichuanAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
}
...
...
awq/models/baichuan.py
0 → 100644
View file @
cef9f113
import
tqdm
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
transformers.models.llama.modeling_llama
import
(
LlamaDecoderLayer
as
OldLlamaDecoderLayer
,
LlamaForCausalLM
as
OldLlamaForCausalLM
)
from
awq.modules.fused.mlp
import
QuantFusedMLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
BaichuanAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"BaichuanLayer"
max_new_tokens_key
=
"model_max_length"
@
staticmethod
def
fuse_layers
(
model
):
fuser
=
BaichuanFuser
(
model
)
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
# def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
W_pack
],
inp
=
input_feat
[
'self_attn.W_pack'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# # attention out
# # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
# if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
# layers.append(dict(
# prev_op=module.self_attn.v_proj,
# layers=[module.self_attn.o_proj],
# inp=input_feat['self_attn.o_proj'],
# ))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
W_pack
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
'self_attn.o_proj'
],
))
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
'mlp.gate_proj'
],
module2inspect
=
module
.
mlp
,
))
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
return
layers
class
BaichuanFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
llama_blocks
:
List
[
Tuple
[
str
,
OldLlamaDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
'LlamaDecoderLayer'
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
# qkv = fuse_qkv(
# module,
# module.self_attn.q_proj,
# module.self_attn.k_proj,
# module.self_attn.v_proj
# )
qkv
=
module
.
self_attn
.
W_pack
mlp
=
QuantFusedMLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
)
norm_1
=
FasterTransformerRMSNorm
(
module
.
input_layernorm
.
weight
,
module
.
input_layernorm
.
epsilon
)
norm_2
=
FasterTransformerRMSNorm
(
module
.
post_attention_layernorm
.
weight
,
module
.
post_attention_layernorm
.
epsilon
)
blocks
.
append
(
LlamaLikeBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_attention_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
mlp
=
mlp
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
use_alibi
=
True
))
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
awq/models/base.py
View file @
cef9f113
...
...
@@ -55,6 +55,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"aquila"
:
"AutoModelForCausalLM"
,
"Yi"
:
"AutoModelForCausalLM"
,
"qwen"
:
"AutoModelForCausalLM"
,
"baichuan"
:
"AutoModelForCausalLM"
,
"llava"
:
"AutoModelForVision2Seq"
,
}
...
...
@@ -90,6 +91,7 @@ class BaseAWQForCausalLM(nn.Module):
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
modules_to_not_convert
)
quantizer
.
quantize
()
self
.
is_quantized
=
True
@
staticmethod
...
...
awq/modules/fused/block.py
View file @
cef9f113
...
...
@@ -43,7 +43,7 @@ class LlamaLikeBlock(nn.Module):
"""
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
,
use_alibi
=
False
):
super
().
__init__
()
self
.
n_heads
=
n_heads
...
...
@@ -52,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
rope_theta
=
rope_theta
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
use_alibi
,
rope_theta
=
rope_theta
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
...
...
awq/utils/calib_data.py
View file @
cef9f113
...
...
@@ -3,7 +3,7 @@ import logging
from
typing
import
List
,
Union
from
datasets
import
load_dataset
def
get_calib_dataset
(
data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
def
get_calib_dataset
(
data
:
Union
[
str
,
List
[
str
]
,
List
[
List
[
int
]]
]
=
"pileval"
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
,
split
=
"train"
,
text_column
=
"text"
):
if
isinstance
(
data
,
str
):
...
...
@@ -15,15 +15,27 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
dataset
=
dataset
.
shuffle
(
seed
=
42
)
elif
isinstance
(
data
,
list
):
if
isinstance
(
data
[
0
],
str
):
dataset
=
[{
text_column
:
text
}
for
text
in
data
]
elif
isinstance
(
data
[
0
][
0
],
int
):
dataset
=
data
else
:
raise
NotImplementedError
(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element."
)
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
else
:
raise
NotImplementedError
(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
samples
=
[]
n_run
=
0
for
data
in
dataset
:
if
isinstance
(
data
,
list
):
line_encoded
=
data
else
:
line
=
data
[
text_column
]
line
=
line
.
strip
()
line_encoded
=
tokenizer
.
encode
(
line
)
...
...
examples/benchmark.py
View file @
cef9f113
...
...
@@ -156,6 +156,7 @@ def main(args):
{
"context"
:
512
,
"n_generate"
:
512
},
{
"context"
:
1024
,
"n_generate"
:
1024
},
{
"context"
:
2048
,
"n_generate"
:
2048
},
{
"context"
:
4096
,
"n_generate"
:
4096
},
]
if
args
.
generator
==
"torch"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment