Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
4517b3f2
Commit
4517b3f2
authored
Sep 12, 2023
by
Casper Hansen
Browse files
Create Falcon block and model for fusing
parent
e120c9b6
Changes
3
Show whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
118 additions
and
32 deletions
+118
-32
awq/models/falcon.py
awq/models/falcon.py
+42
-29
awq/modules/fused/block.py
awq/modules/fused/block.py
+47
-2
awq/modules/fused/model.py
awq/modules/fused/model.py
+29
-1
No files found.
awq/models/falcon.py
View file @
4517b3f2
from
.base
import
BaseAWQForCausalLM
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
,
FalconForCausalLM
,
FalconAttention
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
as
OldFalconDecoderLayer
,
FalconForCausalLM
,
FalconAttention
class
FalconAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"FalconDecoderLayer"
...
...
@@ -7,13 +7,14 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
dict
):
fuser
=
FalconFuser
(
model
)
# fuser.fuse_transformer()
@
staticmethod
def
get_model_layers
(
model
:
FalconForCausalLM
):
return
model
.
transformer
.
h
@
staticmethod
def
get_act_for_scaling
(
module
:
FalconDecoderLayer
):
def
get_act_for_scaling
(
module
:
Old
FalconDecoderLayer
):
return
dict
(
is_scalable
=
True
,
scale_name
=
"mlp.act"
,
...
...
@@ -26,7 +27,7 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
model
.
transformer
.
word_embeddings
=
model
.
transformer
.
word_embeddings
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
FalconDecoderLayer
,
input_feat
,
module_kwargs
):
def
get_layers_for_scaling
(
module
:
Old
FalconDecoderLayer
,
input_feat
,
module_kwargs
):
layers
=
[]
# Falcon 7B (older architecture)
...
...
@@ -62,34 +63,46 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
return
layers
import
torch
from
torch.nn
import
LayerNorm
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.attn
import
QuantAttentionFused
from
awq.modules.fused.model
import
FalconModel
from
awq.modules.fused.block
import
FalconDecoderLayer
class
FalconFuser
:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
:
FalconForCausalLM
):
self
.
model
=
model
self
.
attention_modules
:
List
[
Tuple
[
str
,
FalconAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
FalconAttention
)
]
self
.
layernorm_modules
:
List
[
Tuple
[
str
,
LayerNorm
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LayerNorm
)
]
def
fuse_attention
(
self
):
for
name
,
qkv_layer
in
self
.
attention_modules
:
attn
=
QuantAttentionFused
(
qkv_layer
.
hidden_size
,
qkv_layer
.
num_heads
,
qkv_layer
,
qkv_layer
.
dense
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
self
.
model
.
config
.
max_new_tokens
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldFalconDecoderLayer
for
module
in
self
.
model
.
transformer
.
h
:
if
module
.
config
.
num_attention_heads
==
71
:
input_layernorm
=
module
.
input_layernorm
ln_attn
=
None
ln_mlp
=
None
new_decoder_arch
=
False
else
:
input_layernorm
=
None
ln_attn
=
module
.
ln_attn
ln_mlp
=
module
.
ln_mlp
new_decoder_arch
=
True
blocks
.
append
(
FalconDecoderLayer
(
hidden_size
=
module
.
config
.
hidden_size
,
n_heads
=
module
.
config
.
num_attention_heads
,
qkv_layer
=
module
.
self_attention
.
query_key_value
,
o_proj
=
module
.
self_attention
.
dense
,
mlp
=
module
.
mlp
,
dev
=
next
(
iter
(
module
.
state_dict
().
values
())).
device
,
max_seq_len
=
self
.
model
.
config
.
max_new_tokens
,
input_layernorm
=
input_layernorm
,
ln_attn
=
ln_attn
,
ln_mlp
=
ln_mlp
,
new_decoder_arch
=
new_decoder_arch
))
self
.
model
.
transformer
=
FalconModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
transformer
.
word_embeddings
,
self
.
model
.
transformer
.
ln_f
,
)
\ No newline at end of file
set_module_name
(
self
.
model
,
name
,
attn
)
\ No newline at end of file
awq/modules/fused/block.py
View file @
4517b3f2
import
torch.nn
as
nn
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MPTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
super
().
__init__
()
...
...
@@ -28,3 +27,49 @@ class MPTBlock(nn.Module):
h
=
hidden_states
+
attn_output
out
=
h
+
self
.
ffn
.
forward
(
self
.
norm_2
(
h
))
return
out
,
None
,
past_key_value
class
FalconDecoderLayer
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mlp
,
dev
,
max_seq_len
,
input_layernorm
=
None
,
ln_attn
=
None
,
ln_mlp
=
None
,
new_decoder_arch
=
True
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
hidden_size
=
hidden_size
# TODO: Falcon has ALiBi implemented but which model uses it?
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
False
).
to
(
dev
)
self
.
new_decoder_arch
=
new_decoder_arch
if
new_decoder_arch
:
self
.
ln_attn
=
ln_attn
# before attention
self
.
ln_mlp
=
ln_mlp
# before mlp
else
:
self
.
input_layernorm
=
input_layernorm
# before attention
self
.
mlp
=
mlp
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
):
if
self
.
new_decoder_arch
:
layernorm_out
=
self
.
ln_attn
(
hidden_states
)
mlp_layernorm_out
=
self
.
ln_mlp
(
hidden_states
)
else
:
layernorm_out
=
self
.
input_layernorm
(
hidden_states
)
attn_output
,
_
,
past_key_value
=
self
.
attn
.
forward
(
hidden_states
=
layernorm_out
,
past_key_value
=
past_key_value
,
attention_mask
=
attention_mask
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
True
)
h_attn
=
hidden_states
+
attn_output
if
self
.
new_decoder_arch
:
h_mlp
=
self
.
mlp
.
forward
(
mlp_layernorm_out
)
else
:
h_mlp
=
self
.
mlp
.
forward
(
layernorm_out
)
out
=
h_attn
+
h_mlp
return
out
,
None
,
past_key_value
awq/modules/fused/model.py
View file @
4517b3f2
import
torch
import
torch.nn
as
nn
from
awq.modules.fused.block
import
MPTBlock
from
awq.modules.fused.block
import
MPTBlock
,
FalconDecoderLayer
from
transformers.modeling_outputs
import
BaseModelOutputWithPast
class
MPTModel
(
nn
.
Module
):
...
...
@@ -30,3 +30,31 @@ class MPTModel(nn.Module):
h
=
self
.
norm_f
(
h
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
())
class
FalconModel
(
nn
.
Module
):
def
__init__
(
self
,
vocab_size
,
blocks
,
word_embeddings
,
ln_f
):
super
().
__init__
()
self
.
vocab_size
=
vocab_size
self
.
word_embeddings
=
word_embeddings
self
.
blocks
:
list
[
FalconDecoderLayer
]
=
nn
.
ModuleList
(
blocks
)
self
.
ln_f
=
ln_f
self
.
attn_uses_sequence_id
=
False
self
.
prefix_lm
=
False
@
torch
.
inference_mode
()
def
forward
(
self
,
input_ids
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
,
*
args
,
**
kwargs
):
_bsz
,
seqlen
=
input_ids
.
shape
h
=
self
.
word_embeddings
(
input_ids
)
mask
=
None
if
seqlen
>
1
:
mask
=
torch
.
full
(
(
1
,
1
,
seqlen
,
seqlen
),
float
(
"-inf"
),
device
=
input_ids
.
device
)
mask
=
torch
.
triu
(
mask
,
diagonal
=
self
.
blocks
[
0
].
attn
.
start_pos
+
1
).
type_as
(
h
)
for
layer
in
self
.
blocks
:
h
,
_
,
past_key_value
=
layer
(
h
,
None
,
attention_mask
=
mask
,
is_causal
=
is_causal
)
h
=
self
.
ln_f
(
h
)
return
BaseModelOutputWithPast
(
last_hidden_state
=
h
,
past_key_values
=
past_key_value
,
hidden_states
=
(),
attentions
=
())
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment