Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
cef9f113
Unverified
Commit
cef9f113
authored
Dec 23, 2023
by
Aoyu
Committed by
GitHub
Dec 23, 2023
Browse files
Add Baichuan2 Support (#247)
Co-authored-by:
Casper
<
casperbh.96@gmail.com
>
parent
9e8e28b2
Changes
7
Hide whitespace changes
Inline
Side-by-side
Showing
7 changed files
with
163 additions
and
9 deletions
+163
-9
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/baichuan.py
awq/models/baichuan.py
+137
-0
awq/models/base.py
awq/models/base.py
+2
-0
awq/modules/fused/block.py
awq/modules/fused/block.py
+3
-3
awq/utils/calib_data.py
awq/utils/calib_data.py
+18
-6
examples/benchmark.py
examples/benchmark.py
+1
-0
No files found.
awq/models/__init__.py
View file @
cef9f113
...
@@ -10,5 +10,6 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
...
@@ -10,5 +10,6 @@ from .gpt_neox import GPTNeoXAWQForCausalLM
from
.aquila
import
AquilaAWQForCausalLM
from
.aquila
import
AquilaAWQForCausalLM
from
.yi
import
YiAWQForCausalLM
from
.yi
import
YiAWQForCausalLM
from
.qwen
import
QwenAWQForCausalLM
from
.qwen
import
QwenAWQForCausalLM
from
.baichuan
import
BaichuanAWQForCausalLM
from
.llava
import
LlavaAWQForCausalLM
from
.llava
import
LlavaAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
awq/models/auto.py
View file @
cef9f113
...
@@ -19,6 +19,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
...
@@ -19,6 +19,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"aquila"
:
AquilaAWQForCausalLM
,
"aquila"
:
AquilaAWQForCausalLM
,
"Yi"
:
YiAWQForCausalLM
,
"Yi"
:
YiAWQForCausalLM
,
"qwen"
:
QwenAWQForCausalLM
,
"qwen"
:
QwenAWQForCausalLM
,
"baichuan"
:
BaichuanAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
}
}
...
...
awq/models/baichuan.py
0 → 100644
View file @
cef9f113
import
tqdm
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
transformers.models.llama.modeling_llama
import
(
LlamaDecoderLayer
as
OldLlamaDecoderLayer
,
LlamaForCausalLM
as
OldLlamaForCausalLM
)
from
awq.modules.fused.mlp
import
QuantFusedMLP
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
BaichuanAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"BaichuanLayer"
max_new_tokens_key
=
"model_max_length"
@
staticmethod
def
fuse_layers
(
model
):
fuser
=
BaichuanFuser
(
model
)
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
# def get_layers_for_scaling(module: OldLlamaDecoderLayer, input_feat, module_kwargs):
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
W_pack
],
inp
=
input_feat
[
'self_attn.W_pack'
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
))
# # attention out
# # Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
# if module.self_attn.v_proj.weight.shape == module.self_attn.o_proj.weight.shape:
# layers.append(dict(
# prev_op=module.self_attn.v_proj,
# layers=[module.self_attn.o_proj],
# inp=input_feat['self_attn.o_proj'],
# ))
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
W_pack
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
'self_attn.o_proj'
],
))
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
'mlp.gate_proj'
],
module2inspect
=
module
.
mlp
,
))
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
return
layers
class
BaichuanFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
llama_blocks
:
List
[
Tuple
[
str
,
OldLlamaDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
'LlamaDecoderLayer'
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
# qkv = fuse_qkv(
# module,
# module.self_attn.q_proj,
# module.self_attn.k_proj,
# module.self_attn.v_proj
# )
qkv
=
module
.
self_attn
.
W_pack
mlp
=
QuantFusedMLP
(
module
.
mlp
.
gate_proj
,
module
.
mlp
.
down_proj
,
module
.
mlp
.
up_proj
)
norm_1
=
FasterTransformerRMSNorm
(
module
.
input_layernorm
.
weight
,
module
.
input_layernorm
.
epsilon
)
norm_2
=
FasterTransformerRMSNorm
(
module
.
post_attention_layernorm
.
weight
,
module
.
post_attention_layernorm
.
epsilon
)
blocks
.
append
(
LlamaLikeBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_attention_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
mlp
=
mlp
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
use_alibi
=
True
))
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
awq/models/base.py
View file @
cef9f113
...
@@ -55,6 +55,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
...
@@ -55,6 +55,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"aquila"
:
"AutoModelForCausalLM"
,
"aquila"
:
"AutoModelForCausalLM"
,
"Yi"
:
"AutoModelForCausalLM"
,
"Yi"
:
"AutoModelForCausalLM"
,
"qwen"
:
"AutoModelForCausalLM"
,
"qwen"
:
"AutoModelForCausalLM"
,
"baichuan"
:
"AutoModelForCausalLM"
,
"llava"
:
"AutoModelForVision2Seq"
,
"llava"
:
"AutoModelForVision2Seq"
,
}
}
...
@@ -90,6 +91,7 @@ class BaseAWQForCausalLM(nn.Module):
...
@@ -90,6 +91,7 @@ class BaseAWQForCausalLM(nn.Module):
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
modules_to_not_convert
self
.
quant_config
.
version
,
calib_data
,
split
,
text_column
,
duo_scaling
,
modules_to_not_convert
=
modules_to_not_convert
)
)
quantizer
.
quantize
()
quantizer
.
quantize
()
self
.
is_quantized
=
True
self
.
is_quantized
=
True
@
staticmethod
@
staticmethod
...
...
awq/modules/fused/block.py
View file @
cef9f113
...
@@ -43,7 +43,7 @@ class LlamaLikeBlock(nn.Module):
...
@@ -43,7 +43,7 @@ class LlamaLikeBlock(nn.Module):
"""
"""
def
__init__
(
def
__init__
(
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
self
,
hidden_size
,
n_heads
,
n_kv_heads
,
qkv_layer
,
o_proj
,
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
,
rope_theta
,
use_alibi
=
False
):
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
...
@@ -52,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
...
@@ -52,7 +52,7 @@ class LlamaLikeBlock(nn.Module):
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
attn
=
QuantAttentionFused
(
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
self
.
hidden_size
,
self
.
n_heads
,
self
.
n_kv_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
,
rope_theta
=
rope_theta
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
use_alibi
,
rope_theta
=
rope_theta
).
to
(
dev
)
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
...
@@ -185,4 +185,4 @@ class FalconDecoderLayer(nn.Module):
...
@@ -185,4 +185,4 @@ class FalconDecoderLayer(nn.Module):
out
=
h_attn
+
h_mlp
out
=
h_attn
+
h_mlp
return
out
,
None
,
past_key_value
return
out
,
None
,
past_key_value
\ No newline at end of file
awq/utils/calib_data.py
View file @
cef9f113
...
@@ -3,7 +3,7 @@ import logging
...
@@ -3,7 +3,7 @@ import logging
from
typing
import
List
,
Union
from
typing
import
List
,
Union
from
datasets
import
load_dataset
from
datasets
import
load_dataset
def
get_calib_dataset
(
data
:
Union
[
str
,
List
[
str
]]
=
"pileval"
,
def
get_calib_dataset
(
data
:
Union
[
str
,
List
[
str
]
,
List
[
List
[
int
]]
]
=
"pileval"
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
,
tokenizer
=
None
,
n_samples
=
512
,
block_size
=
512
,
split
=
"train"
,
text_column
=
"text"
):
split
=
"train"
,
text_column
=
"text"
):
if
isinstance
(
data
,
str
):
if
isinstance
(
data
,
str
):
...
@@ -15,18 +15,30 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
...
@@ -15,18 +15,30 @@ def get_calib_dataset(data: Union[str, List[str]] = "pileval",
dataset
=
dataset
.
shuffle
(
seed
=
42
)
dataset
=
dataset
.
shuffle
(
seed
=
42
)
elif
isinstance
(
data
,
list
):
elif
isinstance
(
data
,
list
):
dataset
=
[{
text_column
:
text
}
for
text
in
data
]
if
isinstance
(
data
[
0
],
str
):
dataset
=
[{
text_column
:
text
}
for
text
in
data
]
elif
isinstance
(
data
[
0
][
0
],
int
):
dataset
=
data
else
:
raise
NotImplementedError
(
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
else
:
else
:
raise
NotImplementedError
(
raise
NotImplementedError
(
"Either pass a string to a huggingface dataset or a list"
"Either pass a string to a huggingface dataset or a list"
"that is preprocessed with one sample of text per element."
)
"that is preprocessed with one sample of text per element"
" or a list of list of int for tokenized words."
)
samples
=
[]
samples
=
[]
n_run
=
0
n_run
=
0
for
data
in
dataset
:
for
data
in
dataset
:
line
=
data
[
text_column
]
if
isinstance
(
data
,
list
):
line
=
line
.
strip
()
line_encoded
=
data
line_encoded
=
tokenizer
.
encode
(
line
)
else
:
line
=
data
[
text_column
]
line
=
line
.
strip
()
line_encoded
=
tokenizer
.
encode
(
line
)
if
len
(
line_encoded
)
>
512
:
if
len
(
line_encoded
)
>
512
:
continue
continue
sample
=
torch
.
tensor
([
line_encoded
])
sample
=
torch
.
tensor
([
line_encoded
])
...
...
examples/benchmark.py
View file @
cef9f113
...
@@ -156,6 +156,7 @@ def main(args):
...
@@ -156,6 +156,7 @@ def main(args):
{
"context"
:
512
,
"n_generate"
:
512
},
{
"context"
:
512
,
"n_generate"
:
512
},
{
"context"
:
1024
,
"n_generate"
:
1024
},
{
"context"
:
1024
,
"n_generate"
:
1024
},
{
"context"
:
2048
,
"n_generate"
:
2048
},
{
"context"
:
2048
,
"n_generate"
:
2048
},
{
"context"
:
4096
,
"n_generate"
:
4096
},
]
]
if
args
.
generator
==
"torch"
:
if
args
.
generator
==
"torch"
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment