Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
1df0136e
Commit
1df0136e
authored
Sep 02, 2023
by
Casper Hansen
Browse files
Refactor MPT Quant MLP
parent
ded3ea71
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
34 additions
and
33 deletions
+34
-33
awq/models/mpt.py
awq/models/mpt.py
+33
-8
awq/modules/fused_mlp.py
awq/modules/fused_mlp.py
+1
-25
No files found.
awq/models/mpt.py
View file @
1df0136e
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_fused_mlp
from
transformers.models.mpt.modeling_mpt
import
MptBlock
,
MptForCausalLM
,
MptMLP
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
def
fuse_layers
(
awq_model
):
make_fused_mlp
(
awq_model
)
def
fuse_layers
(
model
:
MptForCausalLM
):
fuser
=
MptFuser
(
model
)
fuser
.
fuse_mlp
()
@
staticmethod
def
get_model_layers
(
model
):
def
get_model_layers
(
model
:
MptForCausalLM
):
return
model
.
transformer
.
blocks
@
staticmethod
def
get_act_for_scaling
(
module
):
def
get_act_for_scaling
(
module
:
MptBlock
):
return
dict
(
is_scalable
=
True
,
scale_name
=
"ffn.act"
,
...
...
@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
)
@
staticmethod
def
move_embed
(
model
,
device
):
def
move_embed
(
model
:
MptForCausalLM
,
device
:
str
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
@
staticmethod
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
def
get_layers_for_scaling
(
module
:
MptBlock
,
input_feat
,
module_kwargs
):
layers
=
[]
# attention input
...
...
@@ -62,4 +63,28 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
inp
=
input_feat
[
'ffn.down_proj'
]
))
return
layers
\ No newline at end of file
return
layers
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_mlp
import
QuantMPTMLP
class
MptFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
mlp_modules
:
List
[
Tuple
[
str
,
MptMLP
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
MptMLP
)
]
def
fuse_attention
(
self
):
pass
def
fuse_layernorm
(
self
):
pass
def
fuse_mlp
(
self
):
for
name
,
module
in
self
.
mlp_modules
:
mlp
=
QuantMPTMLP
(
module
.
up_proj
,
module
.
act
,
module
.
down_proj
)
set_module_name
(
self
.
model
,
name
,
mlp
)
\ No newline at end of file
awq/modules/fused_mlp.py
View file @
1df0136e
import
numpy
as
np
import
torch
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
transformers.models.llama.modeling_llama
import
LlamaMLP
import
awq_inference_engine
import
torch.nn.functional
as
F
class
QuantMPTMLP
(
nn
.
Module
):
def
__init__
(
...
...
@@ -67,23 +63,3 @@ class QuantLlamaMLP(nn.Module):
c
=
gate_output
*
up_output
c
=
c
.
reshape
(
out_shape
)
return
c
def
make_fused_mlp
(
m
,
parent_name
=
''
):
if
not
hasattr
(
make_fused_mlp
,
"called"
):
make_fused_mlp
.
called
=
True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if
"mptmlp"
in
str
(
m
.
__class__
).
lower
():
return
QuantMPTMLP
(
m
.
up_proj
,
m
.
act
,
m
.
down_proj
)
for
name
,
child
in
m
.
named_children
():
child
=
make_fused_mlp
(
child
,
parent_name
=
f
"
{
parent_name
}
.
{
name
}
"
)
if
isinstance
(
child
,
QuantLlamaMLP
):
setattr
(
m
,
name
,
child
)
elif
isinstance
(
child
,
QuantMPTMLP
):
setattr
(
m
,
name
,
child
)
return
m
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment