Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
94e73f0b
Unverified
Commit
94e73f0b
authored
Mar 11, 2024
by
TechxGenus
Committed by
GitHub
Mar 11, 2024
Browse files
Add Gemma Support (#393)
parent
d8ca1e2f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
182 additions
and
8 deletions
+182
-8
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/base.py
awq/models/base.py
+1
-0
awq/models/gemma.py
awq/models/gemma.py
+149
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+9
-4
awq/modules/fused/block.py
awq/modules/fused/block.py
+8
-0
awq/modules/fused/model.py
awq/modules/fused/model.py
+2
-2
awq/quantize/scale.py
awq/quantize/scale.py
+11
-2
No files found.
awq/models/__init__.py
View file @
94e73f0b
...
...
@@ -14,3 +14,4 @@ from .baichuan import BaichuanAWQForCausalLM
from
.llava
import
LlavaAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.qwen2
import
Qwen2AWQForCausalLM
from
.gemma
import
GemmaAWQForCausalLM
awq/models/auto.py
View file @
94e73f0b
...
...
@@ -23,6 +23,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"baichuan"
:
BaichuanAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
"qwen2"
:
Qwen2AWQForCausalLM
,
"gemma"
:
GemmaAWQForCausalLM
,
}
...
...
awq/models/base.py
View file @
94e73f0b
...
...
@@ -67,6 +67,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"baichuan"
:
"AutoModelForCausalLM"
,
"llava"
:
"AutoModelForVision2Seq"
,
"qwen2"
:
"AutoModelForCausalLM"
,
"gemma"
:
"AutoModelForCausalLM"
,
}
...
...
awq/models/gemma.py
0 → 100644
View file @
94e73f0b
import
tqdm
import
torch
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
transformers.models.gemma.modeling_gemma
import
(
GemmaDecoderLayer
as
OldGemmaDecoderLayer
,
GemmaForCausalLM
as
OldGemmaForCausalLM
,
)
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
GemmaAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"GemmaDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
OldGemmaDecoderLayer
):
fuser
=
GemmaFuser
(
model
)
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
:
OldGemmaForCausalLM
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
:
OldGemmaDecoderLayer
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
:
OldGemmaForCausalLM
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
OldGemmaDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
],
inp
=
input_feat
[
"self_attn.q_proj"
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
"self_attn.o_proj"
],
)
)
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
"mlp.gate_proj"
],
module2inspect
=
module
.
mlp
,
)
)
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
"mlp.down_proj"
],
)
)
return
layers
class
GemmaFuser
:
def
__init__
(
self
,
model
:
OldGemmaForCausalLM
):
self
.
model
=
model
self
.
Gemma_blocks
:
List
[
Tuple
[
str
,
OldGemmaDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
"GemmaDecoderLayer"
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldGemmaDecoderLayer
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
qkv
=
fuse_qkv
(
module
,
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
)
with
torch
.
no_grad
():
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
module
.
input_layernorm
.
weight
+=
1
module
.
post_attention_layernorm
.
weight
+=
1
norm_1
=
FasterTransformerRMSNorm
(
module
.
input_layernorm
.
weight
,
module
.
input_layernorm
.
eps
)
norm_2
=
FasterTransformerRMSNorm
(
module
.
post_attention_layernorm
.
weight
,
module
.
post_attention_layernorm
.
eps
,
)
blocks
.
append
(
LlamaLikeBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
mlp
=
module
.
mlp
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_seq_len
,
rope_theta
=
self
.
model
.
config
.
rope_theta
,
head_dim
=
self
.
model
.
config
.
head_dim
,
)
)
with
torch
.
no_grad
():
# Normalize Gemma's embedding layer
self
.
model
.
model
.
embed_tokens
.
weight
*=
self
.
model
.
config
.
hidden_size
**
0.5
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
setattr
(
self
.
model
.
model
,
"blocks"
,
self
.
model
.
model
.
blocks
)
awq/modules/fused/attn.py
View file @
94e73f0b
...
...
@@ -25,12 +25,12 @@ if HF_NEW_CACHE_FORMAT:
class
RoPE
(
nn
.
Module
):
def
__init__
(
self
,
h
idden_size
,
n_heads
,
max_seq_len
,
device
,
rope_theta
):
def
__init__
(
self
,
h
ead_dim
,
max_seq_len
,
device
,
rope_theta
):
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
h
idden_size
//
n_heads
,
max_seq_len
*
2
,
rope_theta
h
ead_dim
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
requires_grad
=
False
,
)
...
...
@@ -118,6 +118,7 @@ class QuantAttentionFused(nn.Module):
use_alibi
=
False
,
attention_shapes
=
None
,
rope_theta
=
10000
,
head_dim
=
None
,
**
kwargs
):
super
().
__init__
()
...
...
@@ -125,7 +126,11 @@ class QuantAttentionFused(nn.Module):
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_groups
=
n_heads
//
n_kv_heads
if
n_kv_heads
!=
0
else
0
self
.
head_dim
=
self
.
hidden_size
//
n_heads
self
.
head_dim
=
head_dim
if
head_dim
is
None
:
self
.
head_dim
=
hidden_size
//
n_heads
self
.
qkv_proj
=
qkv_layer
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
...
...
@@ -162,7 +167,7 @@ class QuantAttentionFused(nn.Module):
self
.
is_neox
=
False
else
:
self
.
alibi
=
None
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rope
=
RoPE
(
self
.
head_dim
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rotary_dim
=
self
.
head_dim
self
.
is_neox
=
True
...
...
awq/modules/fused/block.py
View file @
94e73f0b
...
...
@@ -80,10 +80,17 @@ class LlamaLikeBlock(nn.Module):
max_seq_len
,
rope_theta
=
10000
,
use_alibi
=
False
,
head_dim
=
None
,
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
head_dim
=
hidden_size
//
n_heads
# To support gemma-7b, its head_dim is separate
if
head_dim
:
self
.
head_dim
=
head_dim
self
.
hidden_size
=
hidden_size
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
...
...
@@ -96,6 +103,7 @@ class LlamaLikeBlock(nn.Module):
max_seq_len
=
max_seq_len
,
use_alibi
=
use_alibi
,
rope_theta
=
rope_theta
,
head_dim
=
head_dim
,
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
...
...
awq/modules/fused/model.py
View file @
94e73f0b
...
...
@@ -116,14 +116,14 @@ class LlamaLikeModel(nn.Module):
h
,
mask
,
)
h
,
_
,
past_key_value
=
layer
(
h
,
_
,
_
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
norm
(
h
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_valu
e
,
past_key_values
=
Non
e
,
hidden_states
=
(),
attentions
=
(),
)
...
...
awq/quantize/scale.py
View file @
94e73f0b
...
...
@@ -6,9 +6,10 @@ from awq.modules.act import ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
from
transformers.models.gemma.modeling_gemma
import
GemmaRMSNorm
from
transformers.activations
import
NewGELUActivation
,
PytorchGELUTanh
,
GELUActivation
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
]
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
,
GemmaRMSNorm
]
allowed_act_fns
=
[
nn
.
GELU
,
BloomGelu
,
...
...
@@ -88,7 +89,15 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
scales
=
scales
.
to
(
ln
.
weight
.
device
)
ln
.
weight
.
div_
(
scales
)
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
isinstance
(
ln
,
GemmaRMSNorm
):
ln
.
weight
+=
1
ln
.
weight
.
div_
(
scales
)
ln
.
weight
-=
1
else
:
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
"bias"
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment