Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e9f62694
Unverified
Commit
e9f62694
authored
Apr 06, 2024
by
Isotr0py
Committed by
GitHub
Apr 06, 2024
Browse files
Add StableLM support (#410)
Co-authored-by:
Casper
<
casperbh.96@gmail.com
>
parent
33dfb048
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
164 additions
and
9 deletions
+164
-9
awq/models/__init__.py
awq/models/__init__.py
+2
-1
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/base.py
awq/models/base.py
+1
-0
awq/models/stablelm.py
awq/models/stablelm.py
+136
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+22
-8
awq/modules/fused/block.py
awq/modules/fused/block.py
+2
-0
No files found.
awq/models/__init__.py
View file @
e9f62694
...
...
@@ -15,4 +15,5 @@ from .llava import LlavaAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.qwen2
import
Qwen2AWQForCausalLM
from
.gemma
import
GemmaAWQForCausalLM
from
.starcoder2
import
Starcoder2AWQForCausalLM
\ No newline at end of file
from
.stablelm
import
StableLmAWQForCausalLM
from
.starcoder2
import
Starcoder2AWQForCausalLM
awq/models/auto.py
View file @
e9f62694
...
...
@@ -24,6 +24,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"llava"
:
LlavaAWQForCausalLM
,
"qwen2"
:
Qwen2AWQForCausalLM
,
"gemma"
:
GemmaAWQForCausalLM
,
"stablelm"
:
StableLmAWQForCausalLM
,
"starcoder2"
:
Starcoder2AWQForCausalLM
,
}
...
...
awq/models/base.py
View file @
e9f62694
...
...
@@ -68,6 +68,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"llava"
:
"AutoModelForVision2Seq"
,
"qwen2"
:
"AutoModelForCausalLM"
,
"gemma"
:
"AutoModelForCausalLM"
,
"stablelm"
:
"AutoModelForCausalLM"
,
"starcoder2"
:
"AutoModelForCausalLM"
,
}
...
...
awq/models/stablelm.py
0 → 100644
View file @
e9f62694
import
tqdm
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
transformers.models.stablelm
import
StableLmForCausalLM
as
OldStableLmForCausalLM
from
transformers.models.stablelm.modeling_stablelm
import
(
StableLmDecoderLayer
as
OldStableLmDecoderLayer
,
)
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
StableLmAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"StableLmDecoderLayer"
max_seq_len_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
OldStableLmForCausalLM
):
fuser
=
StableLmFuser
(
model
)
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
:
OldStableLmForCausalLM
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
:
OldStableLmForCausalLM
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
:
OldStableLmForCausalLM
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
OldStableLmDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
],
inp
=
input_feat
[
"self_attn.q_proj"
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
"self_attn.o_proj"
],
)
)
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
"mlp.gate_proj"
],
module2inspect
=
module
.
mlp
,
)
)
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
"mlp.down_proj"
],
)
)
return
layers
class
StableLmFuser
:
def
__init__
(
self
,
model
:
OldStableLmForCausalLM
):
self
.
model
=
model
self
.
stablelm_blocks
:
List
[
Tuple
[
str
,
OldStableLmDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
"StableLmDecoderLayer"
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldStableLmDecoderLayer
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
qkv
=
fuse_qkv
(
module
,
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
)
norm_1
=
module
.
input_layernorm
norm_2
=
module
.
post_attention_layernorm
blocks
.
append
(
LlamaLikeBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
mlp
=
module
.
mlp
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_seq_len
,
rope_theta
=
self
.
model
.
config
.
rope_theta
,
partial_rotary_factor
=
self
.
model
.
config
.
partial_rotary_factor
,
)
)
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
setattr
(
self
.
model
.
model
,
"blocks"
,
self
.
model
.
model
.
blocks
)
awq/modules/fused/attn.py
View file @
e9f62694
...
...
@@ -29,9 +29,7 @@ class RoPE(nn.Module):
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
head_dim
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
self
.
precompute_freqs_cis
(
head_dim
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
requires_grad
=
False
,
)
...
...
@@ -118,6 +116,7 @@ class QuantAttentionFused(nn.Module):
use_alibi
=
False
,
attention_shapes
=
None
,
rope_theta
=
10000
,
partial_rotary_factor
=
1.0
,
head_dim
=
None
,
**
kwargs
):
...
...
@@ -127,7 +126,7 @@ class QuantAttentionFused(nn.Module):
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_groups
=
n_heads
//
n_kv_heads
if
n_kv_heads
!=
0
else
0
self
.
head_dim
=
head_dim
if
head_dim
is
None
:
self
.
head_dim
=
hidden_size
//
n_heads
...
...
@@ -167,8 +166,9 @@ class QuantAttentionFused(nn.Module):
self
.
is_neox
=
False
else
:
self
.
alibi
=
None
self
.
rope
=
RoPE
(
self
.
head_dim
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rotary_dim
=
self
.
head_dim
self
.
partial_rotary_factor
=
partial_rotary_factor
self
.
rotary_dim
=
int
(
self
.
head_dim
*
self
.
partial_rotary_factor
)
self
.
rope
=
RoPE
(
self
.
rotary_dim
,
max_seq_len
,
dev
,
rope_theta
)
self
.
is_neox
=
True
def
forward
(
...
...
@@ -209,13 +209,27 @@ class QuantAttentionFused(nn.Module):
xk
=
self
.
attention_shapes
[
"xk_slice"
](
xqkv
)
xv
=
self
.
attention_shapes
[
"xv_slice"
](
xqkv
)
if
seqlen
>
1
or
not
FT_INSTALLED
:
if
seqlen
>
1
or
self
.
partial_rotary_factor
<
1
or
not
FT_INSTALLED
:
xq
=
xq
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xq_view"
])
xk
=
xk
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xk_view"
])
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
if
not
self
.
use_alibi
:
xq
,
xk
=
self
.
rope
.
forward
(
xq
,
xk
,
self
.
start_pos
,
seqlen
)
# Partial rotary embedding
if
self
.
partial_rotary_factor
<
1
:
xq_rot
,
xq_pass
=
(
xq
[...,
:
self
.
rotary_dim
],
xq
[...,
self
.
rotary_dim
:],
)
xk_rot
,
xk_pass
=
(
xk
[...,
:
self
.
rotary_dim
],
xk
[...,
self
.
rotary_dim
:],
)
xq_rot
,
xk_rot
=
self
.
rope
.
forward
(
xq_rot
,
xk_rot
,
self
.
start_pos
,
seqlen
)
xq
=
torch
.
cat
((
xq_rot
,
xq_pass
),
dim
=-
1
)
xk
=
torch
.
cat
((
xk_rot
,
xk_pass
),
dim
=-
1
)
else
:
xq
,
xk
=
self
.
rope
.
forward
(
xq
,
xk
,
self
.
start_pos
,
seqlen
)
values_store
=
xv
.
transpose
(
2
,
1
)
keys_store
=
(
...
...
awq/modules/fused/block.py
View file @
e9f62694
...
...
@@ -79,6 +79,7 @@ class LlamaLikeBlock(nn.Module):
dev
,
max_seq_len
,
rope_theta
=
10000
,
partial_rotary_factor
=
1.0
,
use_alibi
=
False
,
head_dim
=
None
,
):
...
...
@@ -103,6 +104,7 @@ class LlamaLikeBlock(nn.Module):
max_seq_len
=
max_seq_len
,
use_alibi
=
use_alibi
,
rope_theta
=
rope_theta
,
partial_rotary_factor
=
partial_rotary_factor
,
head_dim
=
head_dim
,
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment