Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e9f62694
Unverified
Commit
e9f62694
authored
Apr 06, 2024
by
Isotr0py
Committed by
GitHub
Apr 06, 2024
Browse files
Add StableLM support (#410)
Co-authored-by:
Casper
<
casperbh.96@gmail.com
>
parent
33dfb048
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
164 additions
and
9 deletions
+164
-9
awq/models/__init__.py
awq/models/__init__.py
+2
-1
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/base.py
awq/models/base.py
+1
-0
awq/models/stablelm.py
awq/models/stablelm.py
+136
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+22
-8
awq/modules/fused/block.py
awq/modules/fused/block.py
+2
-0
No files found.
awq/models/__init__.py
View file @
e9f62694
...
@@ -15,4 +15,5 @@ from .llava import LlavaAWQForCausalLM
...
@@ -15,4 +15,5 @@ from .llava import LlavaAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.qwen2
import
Qwen2AWQForCausalLM
from
.qwen2
import
Qwen2AWQForCausalLM
from
.gemma
import
GemmaAWQForCausalLM
from
.gemma
import
GemmaAWQForCausalLM
from
.stablelm
import
StableLmAWQForCausalLM
from
.starcoder2
import
Starcoder2AWQForCausalLM
from
.starcoder2
import
Starcoder2AWQForCausalLM
awq/models/auto.py
View file @
e9f62694
...
@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
...
@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llava"
:
LlavaAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
"qwen2"
:
Qwen2AWQForCausalLM
,
"qwen2"
:
Qwen2AWQForCausalLM
,
"gemma"
:
GemmaAWQForCausalLM
,
"gemma"
:
GemmaAWQForCausalLM
,
"stablelm"
:
StableLmAWQForCausalLM
,
"starcoder2"
:
Starcoder2AWQForCausalLM
,
"starcoder2"
:
Starcoder2AWQForCausalLM
,
}
}
...
...
awq/models/base.py
View file @
e9f62694
...
@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
...
@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"llava"
:
"AutoModelForVision2Seq"
,
"llava"
:
"AutoModelForVision2Seq"
,
"qwen2"
:
"AutoModelForCausalLM"
,
"qwen2"
:
"AutoModelForCausalLM"
,
"gemma"
:
"AutoModelForCausalLM"
,
"gemma"
:
"AutoModelForCausalLM"
,
"stablelm"
:
"AutoModelForCausalLM"
,
"starcoder2"
:
"AutoModelForCausalLM"
,
"starcoder2"
:
"AutoModelForCausalLM"
,
}
}
...
...
awq/models/stablelm.py
0 → 100644
View file @
e9f62694
import
tqdm
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
transformers.models.stablelm
import
StableLmForCausalLM
as
OldStableLmForCausalLM
from
transformers.models.stablelm.modeling_stablelm
import
(
StableLmDecoderLayer
as
OldStableLmDecoderLayer
,
)
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
StableLmAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"StableLmDecoderLayer"
max_seq_len_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
OldStableLmForCausalLM
):
fuser
=
StableLmFuser
(
model
)
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
:
OldStableLmForCausalLM
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
:
OldStableLmForCausalLM
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
:
OldStableLmForCausalLM
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
OldStableLmDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
],
inp
=
input_feat
[
"self_attn.q_proj"
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
"self_attn.o_proj"
],
)
)
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
"mlp.gate_proj"
],
module2inspect
=
module
.
mlp
,
)
)
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
"mlp.down_proj"
],
)
)
return
layers
class
StableLmFuser
:
def
__init__
(
self
,
model
:
OldStableLmForCausalLM
):
self
.
model
=
model
self
.
stablelm_blocks
:
List
[
Tuple
[
str
,
OldStableLmDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
"StableLmDecoderLayer"
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldStableLmDecoderLayer
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
qkv
=
fuse_qkv
(
module
,
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
)
norm_1
=
module
.
input_layernorm
norm_2
=
module
.
post_attention_layernorm
blocks
.
append
(
LlamaLikeBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
mlp
=
module
.
mlp
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_seq_len
,
rope_theta
=
self
.
model
.
config
.
rope_theta
,
partial_rotary_factor
=
self
.
model
.
config
.
partial_rotary_factor
,
)
)
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
setattr
(
self
.
model
.
model
,
"blocks"
,
self
.
model
.
model
.
blocks
)
awq/modules/fused/attn.py
View file @
e9f62694
...
@@ -29,9 +29,7 @@ class RoPE(nn.Module):
...
@@ -29,9 +29,7 @@ class RoPE(nn.Module):
super
(
RoPE
,
self
).
__init__
()
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
self
.
precompute_freqs_cis
(
head_dim
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
head_dim
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -118,6 +116,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -118,6 +116,7 @@ class QuantAttentionFused(nn.Module):
use_alibi
=
False
,
use_alibi
=
False
,
attention_shapes
=
None
,
attention_shapes
=
None
,
rope_theta
=
10000
,
rope_theta
=
10000
,
partial_rotary_factor
=
1.0
,
head_dim
=
None
,
head_dim
=
None
,
**
kwargs
**
kwargs
):
):
...
@@ -167,8 +166,9 @@ class QuantAttentionFused(nn.Module):
...
@@ -167,8 +166,9 @@ class QuantAttentionFused(nn.Module):
self
.
is_neox
=
False
self
.
is_neox
=
False
else
:
else
:
self
.
alibi
=
None
self
.
alibi
=
None
self
.
rope
=
RoPE
(
self
.
head_dim
,
max_seq_len
,
dev
,
rope_theta
)
self
.
partial_rotary_factor
=
partial_rotary_factor
self
.
rotary_dim
=
self
.
head_dim
self
.
rotary_dim
=
int
(
self
.
head_dim
*
self
.
partial_rotary_factor
)
self
.
rope
=
RoPE
(
self
.
rotary_dim
,
max_seq_len
,
dev
,
rope_theta
)
self
.
is_neox
=
True
self
.
is_neox
=
True
def
forward
(
def
forward
(
...
@@ -209,12 +209,26 @@ class QuantAttentionFused(nn.Module):
...
@@ -209,12 +209,26 @@ class QuantAttentionFused(nn.Module):
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
if
seqlen
>
1
or
not
FT_INSTALLED
:
if
seqlen
>
1
or
self
.
partial_rotary_factor
<
1
or
not
FT_INSTALLED
:
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
if
not
self
.
use_alibi
:
if
not
self
.
use_alibi
:
# Partial rotary embedding
if
self
.
partial_rotary_factor
<
1
:
xq_rot
,
xq_pass
=
(
xq
[...,
:
self
.
rotary_dim
],
xq
[...,
self
.
rotary_dim
:],
)
xk_rot
,
xk_pass
=
(
xk
[...,
:
self
.
rotary_dim
],
xk
[...,
self
.
rotary_dim
:],
)
xq_rot
,
xk_rot
=
self
.
rope
.
forward
(
xq_rot
,
xk_rot
,
self
.
start_pos
,
seqlen
)
xq
=
torch
.
cat
((
xq_rot
,
xq_pass
),
dim
=-
1
)
xk
=
torch
.
cat
((
xk_rot
,
xk_pass
),
dim
=-
1
)
else
:
xq
,
xk
=
self
.
rope
.
forward
(
xq
,
xk
,
self
.
start_pos
,
seqlen
)
xq
,
xk
=
self
.
rope
.
forward
(
xq
,
xk
,
self
.
start_pos
,
seqlen
)
values_store
=
xv
.
transpose
(
2
,
1
)
values_store
=
xv
.
transpose
(
2
,
1
)
...
...
awq/modules/fused/block.py
View file @
e9f62694
...
@@ -79,6 +79,7 @@ class LlamaLikeBlock(nn.Module):
...
@@ -79,6 +79,7 @@ class LlamaLikeBlock(nn.Module):
dev
,
dev
,
max_seq_len
,
max_seq_len
,
rope_theta
=
10000
,
rope_theta
=
10000
,
partial_rotary_factor
=
1.0
,
use_alibi
=
False
,
use_alibi
=
False
,
head_dim
=
None
,
head_dim
=
None
,
):
):
...
@@ -103,6 +104,7 @@ class LlamaLikeBlock(nn.Module):
...
@@ -103,6 +104,7 @@ class LlamaLikeBlock(nn.Module):
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
use_alibi
=
use_alibi
,
use_alibi
=
use_alibi
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
partial_rotary_factor
=
partial_rotary_factor
,
head_dim
=
head_dim
,
head_dim
=
head_dim
,
).
to
(
dev
)
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment