Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
f3695d60
Commit
f3695d60
authored
Sep 09, 2023
by
Casper Hansen
Browse files
Fuse MPT
parent
5bd6fbc7
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
39 additions
and
44 deletions
+39
-44
awq/models/mpt.py
awq/models/mpt.py
+39
-8
awq/modules/fused/mlp.py
awq/modules/fused/mlp.py
+0
-36
No files found.
awq/models/mpt.py
View file @
f3695d60
from
.base
import
BaseAWQForCausalLM
from
transformers.models.mpt.modeling_mpt
import
MptBlock
,
MptForCausalLM
,
MptMLP
from
transformers.models.mpt.modeling_mpt
import
MptBlock
,
MptForCausalLM
,
MptMLP
,
MptAttention
,
LayerNorm
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
...
...
@@ -8,7 +8,8 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
dict
):
fuser
=
MptFuser
(
model
)
fuser
.
fuse_mlp
()
fuser
.
fuse_attention
()
fuser
.
fuse_layernorm
()
@
staticmethod
def
get_model_layers
(
model
:
MptForCausalLM
):
...
...
@@ -65,26 +66,56 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
return
layers
import
torch
import
xformers
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.mlp
import
QuantMPTMLP
from
xformers.triton.layer_norm
import
FusedLayerNorm
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MptFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
attention_modules
:
List
[
Tuple
[
str
,
MptAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
MptAttention
)
]
self
.
layernorm_modules
:
List
[
Tuple
[
str
,
LayerNorm
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LayerNorm
)
]
self
.
mlp_modules
:
List
[
Tuple
[
str
,
MptMLP
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
MptMLP
)
]
def
fuse_attention
(
self
):
pass
for
name
,
qkv_layer
in
self
.
attention_modules
:
attn
=
QuantAttentionFused
(
qkv_layer
.
hidden_size
,
qkv_layer
.
n_heads
,
qkv_layer
,
qkv_layer
.
out_proj
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
self
.
model
.
config
.
max_new_tokens
,
use_alibi
=
True
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
fuse_layernorm
(
self
):
pass
xformers
.
triton
.
k_layer_norm
.
_triton_layernorm_fp16_enabled
=
True
for
name
,
module
in
self
.
layernorm_modules
:
norm
=
FusedLayerNorm
(
module
.
weight
.
shape
,
eps
=
module
.
eps
).
to
(
module
.
weight
.
device
)
# copy weights and bias
with
torch
.
no_grad
():
norm
.
weight
=
module
.
weight
norm
.
bias
=
module
.
bias
set_module_name
(
self
.
model
,
name
,
norm
)
def
fuse_mlp
(
self
):
for
name
,
module
in
self
.
mlp_modules
:
mlp
=
QuantMPTMLP
(
module
.
up_proj
,
module
.
act
,
module
.
down_proj
)
set_module_name
(
self
.
model
,
name
,
mlp
)
\ No newline at end of file
pass
\ No newline at end of file
awq/modules/fused/mlp.py
View file @
f3695d60
import
torch
import
torch.nn
as
nn
import
awq_inference_engine
import
torch.nn.functional
as
F
from
awq.modules.linear
import
WQLinear_GEMM
,
WQLinear_GEMV
class
QuantMPTMLP
(
nn
.
Module
):
def
__init__
(
self
,
up_proj
,
act
,
down_proj
):
super
().
__init__
()
self
.
register_buffer
(
'up_proj_qweight'
,
up_proj
.
qweight
)
self
.
register_buffer
(
'up_proj_scales'
,
up_proj
.
scales
)
self
.
register_buffer
(
'up_proj_qzeros'
,
up_proj
.
qzeros
)
self
.
up_proj
=
up_proj
self
.
act
=
act
self
.
down_proj
=
down_proj
if
isinstance
(
down_proj
,
WQLinear_GEMV
):
self
.
linear
=
awq_inference_engine
.
gemv_forward_cuda
self
.
group_size
=
down_proj
.
group_size
else
:
self
.
linear
=
awq_inference_engine
.
gemm_forward_cuda
self
.
group_size
=
8
def
forward
(
self
,
x
:
torch
.
Tensor
):
x
=
x
.
reshape
(
-
1
,
x
.
shape
[
-
1
])
x
=
self
.
linear
(
x
,
self
.
up_proj_qweight
,
self
.
up_proj_scales
,
self
.
up_proj_qzeros
,
self
.
group_size
)
return
self
.
down_proj
(
self
.
act
(
x
))
class
QuantLlamaMLP
(
nn
.
Module
):
def
__init__
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment