Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
950851b3
Commit
950851b3
authored
Sep 11, 2023
by
Casper Hansen
Browse files
Fuse MPT block
parent
d7badefc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
46 additions
and
27 deletions
+46
-27
awq/models/mpt.py
awq/models/mpt.py
+17
-27
awq/modules/fused/block.py
awq/modules/fused/block.py
+29
-0
No files found.
awq/models/mpt.py
View file @
950851b3
from
.base
import
BaseAWQForCausalLM
from
transformers.models.mpt.modeling_mpt
import
MptBlock
,
MptForCausalLM
,
MptMLP
,
MptAttention
,
LayerNorm
from
transformers.models.mpt.modeling_mpt
import
MptBlock
as
OldMptBlock
,
MptForCausalLM
,
MptAttention
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
...
...
@@ -9,14 +9,14 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
dict
):
fuser
=
MptFuser
(
model
)
fuser
.
fuse_attention
()
fuser
.
fuse_
layernorm
()
fuser
.
fuse_
block
()
@
staticmethod
def
get_model_layers
(
model
:
MptForCausalLM
):
return
model
.
transformer
.
blocks
@
staticmethod
def
get_act_for_scaling
(
module
:
MptBlock
):
def
get_act_for_scaling
(
module
:
Old
MptBlock
):
return
dict
(
is_scalable
=
True
,
scale_name
=
"ffn.act"
,
...
...
@@ -30,7 +30,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
:
MptBlock
,
input_feat
,
module_kwargs
):
def
get_layers_for_scaling
(
module
:
Old
MptBlock
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
...
...
@@ -66,11 +66,9 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return
layers
import
torch
import
xformers
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
xformers.triton.layer_norm
import
FusedLayerNorm
from
awq.modules.fused.block
import
MptBlock
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MptFuser
:
...
...
@@ -82,14 +80,9 @@ class MptFuser:
if
isinstance
(
module
,
MptAttention
)
]
self
.
layernorm_module
s
:
List
[
Tuple
[
str
,
LayerNorm
]]
=
[
self
.
mpt_block
s
:
List
[
Tuple
[
str
,
OldMptBlock
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LayerNorm
)
]
self
.
mlp_modules
:
List
[
Tuple
[
str
,
MptMLP
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
MptMLP
)
if
'mptblock'
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_attention
(
self
):
...
...
@@ -105,17 +98,14 @@ class MptFuser:
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
fuse_layernorm
(
self
):
xformers
.
triton
.
k_layer_norm
.
_triton_layernorm_fp16_enabled
=
True
for
name
,
module
in
self
.
layernorm_modules
:
norm
=
FusedLayerNorm
(
module
.
weight
.
shape
,
eps
=
module
.
eps
).
to
(
module
.
weight
.
device
)
# copy weights and bias
with
torch
.
no_grad
():
norm
.
weight
=
module
.
weight
norm
.
bias
=
module
.
bias
set_module_name
(
self
.
model
,
name
,
norm
)
def
fuse_block
(
self
):
for
name
,
module
in
self
.
mpt_blocks
:
block
=
MptBlock
(
self
.
model
.
config
.
d_model
,
self
.
model
.
config
.
n_heads
,
module
.
attn
.
Wqkv
,
module
.
attn
.
out_proj
,
module
.
ffn
)
def
fuse_mlp
(
self
):
pass
\ No newline at end of file
set_module_name
(
self
.
model
,
name
,
block
)
\ No newline at end of file
awq/modules/fused/block.py
0 → 100644
View file @
950851b3
import
torch.nn
as
nn
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MptBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
hidden_size
=
hidden_size
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
dev
=
"cuda:0"
,
max_seq_len
=
8096
,
use_alibi
=
True
).
to
(
"cuda:0"
)
self
.
ffn
=
mpt_mlp
.
to
(
"cuda:0"
)
self
.
norm_1
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
).
half
().
to
(
"cuda:0"
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
).
half
().
to
(
"cuda:0"
)
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
,
attention_mask
,
is_causal
):
norm_out
=
self
.
norm_1
(
hidden_states
)
attn_output
,
_
,
past_key_value
=
self
.
attn
.
forward
(
hidden_states
=
norm_out
,
past_key_value
=
past_key_value
,
attention_mask
=
attention_mask
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
True
)
h
=
hidden_states
+
attn_output
out
=
h
+
self
.
ffn
.
forward
(
self
.
norm_2
(
h
))
return
out
,
None
,
past_key_value
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment