Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
94e73f0b
Unverified
Commit
94e73f0b
authored
Mar 11, 2024
by
TechxGenus
Committed by
GitHub
Mar 11, 2024
Browse files
Add Gemma Support (#393)
parent
d8ca1e2f
Changes
8
Hide whitespace changes
Inline
Side-by-side
Showing
8 changed files
with
182 additions
and
8 deletions
+182
-8
awq/models/__init__.py
awq/models/__init__.py
+1
-0
awq/models/auto.py
awq/models/auto.py
+1
-0
awq/models/base.py
awq/models/base.py
+1
-0
awq/models/gemma.py
awq/models/gemma.py
+149
-0
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+9
-4
awq/modules/fused/block.py
awq/modules/fused/block.py
+8
-0
awq/modules/fused/model.py
awq/modules/fused/model.py
+2
-2
awq/quantize/scale.py
awq/quantize/scale.py
+11
-2
No files found.
awq/models/__init__.py
View file @
94e73f0b
...
@@ -14,3 +14,4 @@ from .baichuan import BaichuanAWQForCausalLM
...
@@ -14,3 +14,4 @@ from .baichuan import BaichuanAWQForCausalLM
from
.llava
import
LlavaAWQForCausalLM
from
.llava
import
LlavaAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.mixtral
import
MixtralAWQForCausalLM
from
.qwen2
import
Qwen2AWQForCausalLM
from
.qwen2
import
Qwen2AWQForCausalLM
from
.gemma
import
GemmaAWQForCausalLM
awq/models/auto.py
View file @
94e73f0b
...
@@ -23,6 +23,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
...
@@ -23,6 +23,7 @@ AWQ_CAUSAL_LM_MODEL_MAP = {
"baichuan"
:
BaichuanAWQForCausalLM
,
"baichuan"
:
BaichuanAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
"llava"
:
LlavaAWQForCausalLM
,
"qwen2"
:
Qwen2AWQForCausalLM
,
"qwen2"
:
Qwen2AWQForCausalLM
,
"gemma"
:
GemmaAWQForCausalLM
,
}
}
...
...
awq/models/base.py
View file @
94e73f0b
...
@@ -67,6 +67,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
...
@@ -67,6 +67,7 @@ TRANSFORMERS_AUTO_MAPPING_DICT = {
"baichuan"
:
"AutoModelForCausalLM"
,
"baichuan"
:
"AutoModelForCausalLM"
,
"llava"
:
"AutoModelForVision2Seq"
,
"llava"
:
"AutoModelForVision2Seq"
,
"qwen2"
:
"AutoModelForCausalLM"
,
"qwen2"
:
"AutoModelForCausalLM"
,
"gemma"
:
"AutoModelForCausalLM"
,
}
}
...
...
awq/models/gemma.py
0 → 100644
View file @
94e73f0b
import
tqdm
import
torch
from
typing
import
List
,
Tuple
from
.base
import
BaseAWQForCausalLM
from
awq.utils.fused_utils
import
fuse_qkv
from
awq.modules.fused.block
import
LlamaLikeBlock
from
awq.modules.fused.model
import
LlamaLikeModel
from
transformers.models.gemma.modeling_gemma
import
(
GemmaDecoderLayer
as
OldGemmaDecoderLayer
,
GemmaForCausalLM
as
OldGemmaForCausalLM
,
)
from
awq.modules.fused.norm
import
FasterTransformerRMSNorm
class
GemmaAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"GemmaDecoderLayer"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
def
fuse_layers
(
model
:
OldGemmaDecoderLayer
):
fuser
=
GemmaFuser
(
model
)
fuser
.
fuse_transformer
()
@
staticmethod
def
get_model_layers
(
model
:
OldGemmaForCausalLM
):
return
model
.
model
.
layers
@
staticmethod
def
get_act_for_scaling
(
module
:
OldGemmaDecoderLayer
):
return
dict
(
is_scalable
=
False
)
@
staticmethod
def
move_embed
(
model
:
OldGemmaForCausalLM
,
device
:
str
):
model
.
model
.
embed_tokens
=
model
.
model
.
embed_tokens
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
OldGemmaDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
layers
.
append
(
dict
(
prev_op
=
module
.
input_layernorm
,
layers
=
[
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
],
inp
=
input_feat
[
"self_attn.q_proj"
],
module2inspect
=
module
.
self_attn
,
kwargs
=
module_kwargs
,
)
)
# attention out
# Please refer to https://github.com/mit-han-lab/llm-awq/pull/67#issue-1850622696
if
module
.
self_attn
.
v_proj
.
weight
.
shape
==
module
.
self_attn
.
o_proj
.
weight
.
shape
:
layers
.
append
(
dict
(
prev_op
=
module
.
self_attn
.
v_proj
,
layers
=
[
module
.
self_attn
.
o_proj
],
inp
=
input_feat
[
"self_attn.o_proj"
],
)
)
# linear 1
layers
.
append
(
dict
(
prev_op
=
module
.
post_attention_layernorm
,
layers
=
[
module
.
mlp
.
gate_proj
,
module
.
mlp
.
up_proj
],
inp
=
input_feat
[
"mlp.gate_proj"
],
module2inspect
=
module
.
mlp
,
)
)
# linear 2
layers
.
append
(
dict
(
prev_op
=
module
.
mlp
.
up_proj
,
layers
=
[
module
.
mlp
.
down_proj
],
inp
=
input_feat
[
"mlp.down_proj"
],
)
)
return
layers
class
GemmaFuser
:
def
__init__
(
self
,
model
:
OldGemmaForCausalLM
):
self
.
model
=
model
self
.
Gemma_blocks
:
List
[
Tuple
[
str
,
OldGemmaDecoderLayer
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
"GemmaDecoderLayer"
.
lower
()
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldGemmaDecoderLayer
for
module
in
tqdm
.
tqdm
(
self
.
model
.
model
.
layers
,
desc
=
"Fusing layers..."
):
device
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
qkv
=
fuse_qkv
(
module
,
module
.
self_attn
.
q_proj
,
module
.
self_attn
.
k_proj
,
module
.
self_attn
.
v_proj
,
)
with
torch
.
no_grad
():
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
module
.
input_layernorm
.
weight
+=
1
module
.
post_attention_layernorm
.
weight
+=
1
norm_1
=
FasterTransformerRMSNorm
(
module
.
input_layernorm
.
weight
,
module
.
input_layernorm
.
eps
)
norm_2
=
FasterTransformerRMSNorm
(
module
.
post_attention_layernorm
.
weight
,
module
.
post_attention_layernorm
.
eps
,
)
blocks
.
append
(
LlamaLikeBlock
(
hidden_size
=
self
.
model
.
config
.
hidden_size
,
n_heads
=
self
.
model
.
config
.
num_attention_heads
,
n_kv_heads
=
self
.
model
.
config
.
num_key_value_heads
,
qkv_layer
=
qkv
,
o_proj
=
module
.
self_attn
.
o_proj
,
mlp
=
module
.
mlp
,
norm_1
=
norm_1
,
norm_2
=
norm_2
,
dev
=
device
,
max_seq_len
=
self
.
model
.
config
.
max_seq_len
,
rope_theta
=
self
.
model
.
config
.
rope_theta
,
head_dim
=
self
.
model
.
config
.
head_dim
,
)
)
with
torch
.
no_grad
():
# Normalize Gemma's embedding layer
self
.
model
.
model
.
embed_tokens
.
weight
*=
self
.
model
.
config
.
hidden_size
**
0.5
self
.
model
.
model
=
LlamaLikeModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
model
.
embed_tokens
,
self
.
model
.
model
.
norm
,
)
setattr
(
self
.
model
.
model
,
"blocks"
,
self
.
model
.
model
.
blocks
)
awq/modules/fused/attn.py
View file @
94e73f0b
...
@@ -25,12 +25,12 @@ if HF_NEW_CACHE_FORMAT:
...
@@ -25,12 +25,12 @@ if HF_NEW_CACHE_FORMAT:
class
RoPE
(
nn
.
Module
):
class
RoPE
(
nn
.
Module
):
def
__init__
(
self
,
h
idden_size
,
n_heads
,
max_seq_len
,
device
,
rope_theta
):
def
__init__
(
self
,
h
ead_dim
,
max_seq_len
,
device
,
rope_theta
):
super
(
RoPE
,
self
).
__init__
()
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
self
.
precompute_freqs_cis
(
h
idden_size
//
n_heads
,
max_seq_len
*
2
,
rope_theta
h
ead_dim
,
max_seq_len
*
2
,
rope_theta
).
to
(
device
),
).
to
(
device
),
requires_grad
=
False
,
requires_grad
=
False
,
)
)
...
@@ -118,6 +118,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -118,6 +118,7 @@ class QuantAttentionFused(nn.Module):
use_alibi
=
False
,
use_alibi
=
False
,
attention_shapes
=
None
,
attention_shapes
=
None
,
rope_theta
=
10000
,
rope_theta
=
10000
,
head_dim
=
None
,
**
kwargs
**
kwargs
):
):
super
().
__init__
()
super
().
__init__
()
...
@@ -125,7 +126,11 @@ class QuantAttentionFused(nn.Module):
...
@@ -125,7 +126,11 @@ class QuantAttentionFused(nn.Module):
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_groups
=
n_heads
//
n_kv_heads
if
n_kv_heads
!=
0
else
0
self
.
n_kv_groups
=
n_heads
//
n_kv_heads
if
n_kv_heads
!=
0
else
0
self
.
head_dim
=
self
.
hidden_size
//
n_heads
self
.
head_dim
=
head_dim
if
head_dim
is
None
:
self
.
head_dim
=
hidden_size
//
n_heads
self
.
qkv_proj
=
qkv_layer
self
.
qkv_proj
=
qkv_layer
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
start_pos
=
0
self
.
start_pos
=
0
...
@@ -162,7 +167,7 @@ class QuantAttentionFused(nn.Module):
...
@@ -162,7 +167,7 @@ class QuantAttentionFused(nn.Module):
self
.
is_neox
=
False
self
.
is_neox
=
False
else
:
else
:
self
.
alibi
=
None
self
.
alibi
=
None
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rope
=
RoPE
(
self
.
head_dim
,
max_seq_len
,
dev
,
rope_theta
)
self
.
rotary_dim
=
self
.
head_dim
self
.
rotary_dim
=
self
.
head_dim
self
.
is_neox
=
True
self
.
is_neox
=
True
...
...
awq/modules/fused/block.py
View file @
94e73f0b
...
@@ -80,10 +80,17 @@ class LlamaLikeBlock(nn.Module):
...
@@ -80,10 +80,17 @@ class LlamaLikeBlock(nn.Module):
max_seq_len
,
max_seq_len
,
rope_theta
=
10000
,
rope_theta
=
10000
,
use_alibi
=
False
,
use_alibi
=
False
,
head_dim
=
None
,
):
):
super
().
__init__
()
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
n_heads
=
n_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
n_kv_heads
=
n_kv_heads
self
.
head_dim
=
hidden_size
//
n_heads
# To support gemma-7b, its head_dim is separate
if
head_dim
:
self
.
head_dim
=
head_dim
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
norm_1
=
norm_1
.
to
(
dev
)
self
.
attn
=
QuantAttentionFused
(
self
.
attn
=
QuantAttentionFused
(
...
@@ -96,6 +103,7 @@ class LlamaLikeBlock(nn.Module):
...
@@ -96,6 +103,7 @@ class LlamaLikeBlock(nn.Module):
max_seq_len
=
max_seq_len
,
max_seq_len
=
max_seq_len
,
use_alibi
=
use_alibi
,
use_alibi
=
use_alibi
,
rope_theta
=
rope_theta
,
rope_theta
=
rope_theta
,
head_dim
=
head_dim
,
).
to
(
dev
)
).
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
norm_2
=
norm_2
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
self
.
mlp
=
mlp
.
to
(
dev
)
...
...
awq/modules/fused/model.py
View file @
94e73f0b
...
@@ -116,14 +116,14 @@ class LlamaLikeModel(nn.Module):
...
@@ -116,14 +116,14 @@ class LlamaLikeModel(nn.Module):
h
,
h
,
mask
,
mask
,
)
)
h
,
_
,
past_key_value
=
layer
(
h
,
_
,
_
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
)
h
=
self
.
norm
(
h
)
h
=
self
.
norm
(
h
)
return
BaseModelOutputWithPast
(
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
last_hidden_state
=
h
,
past_key_values
=
past_key_valu
e
,
past_key_values
=
Non
e
,
hidden_states
=
(),
hidden_states
=
(),
attentions
=
(),
attentions
=
(),
)
)
...
...
awq/quantize/scale.py
View file @
94e73f0b
...
@@ -6,9 +6,10 @@ from awq.modules.act import ScaledActivation
...
@@ -6,9 +6,10 @@ from awq.modules.act import ScaledActivation
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
awq.utils.module
import
get_op_by_name
,
set_op_by_name
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.bloom.modeling_bloom
import
BloomGelu
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
from
transformers.models.gemma.modeling_gemma
import
GemmaRMSNorm
from
transformers.activations
import
NewGELUActivation
,
PytorchGELUTanh
,
GELUActivation
from
transformers.activations
import
NewGELUActivation
,
PytorchGELUTanh
,
GELUActivation
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
]
allowed_norms
=
[
nn
.
LayerNorm
,
LlamaRMSNorm
,
GemmaRMSNorm
]
allowed_act_fns
=
[
allowed_act_fns
=
[
nn
.
GELU
,
nn
.
GELU
,
BloomGelu
,
BloomGelu
,
...
@@ -88,7 +89,15 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
...
@@ -88,7 +89,15 @@ def scale_ln_fcs(ln: nn.Linear, fcs: List[nn.Linear], scales: torch.Tensor):
scales
=
scales
.
to
(
ln
.
weight
.
device
)
scales
=
scales
.
to
(
ln
.
weight
.
device
)
ln
.
weight
.
div_
(
scales
)
# GemmaRMSNorm is different from Llama's in that it multiplies
# (1 + weight) to the output, instead of just weight.
if
isinstance
(
ln
,
GemmaRMSNorm
):
ln
.
weight
+=
1
ln
.
weight
.
div_
(
scales
)
ln
.
weight
-=
1
else
:
ln
.
weight
.
div_
(
scales
)
if
hasattr
(
ln
,
"bias"
)
and
ln
.
bias
is
not
None
:
if
hasattr
(
ln
,
"bias"
)
and
ln
.
bias
is
not
None
:
ln
.
bias
.
div_
(
scales
)
ln
.
bias
.
div_
(
scales
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment