Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ded3ea71
Commit
ded3ea71
authored
Sep 02, 2023
by
Casper Hansen
Browse files
Refactor Llama Quant MLP
parent
620966e8
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
12 additions
and
7 deletions
+12
-7
awq/models/llama.py
awq/models/llama.py
+11
-4
awq/modules/fused_mlp.py
awq/modules/fused_mlp.py
+1
-3
No files found.
awq/models/llama.py
View file @
ded3ea71
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_fused_mlp
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -11,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -11,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
fuser
=
LlamaFuser
(
awq_model
)
fuser
=
LlamaFuser
(
awq_model
)
fuser
.
fuse_attention
()
fuser
.
fuse_attention
()
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_rmsnorm
()
make_fused_mlp
(
awq_model
)
#
fuser.fuse_mlp()
fuser
.
fuse_mlp
()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
LlamaForCausalLM
):
def
get_model_layers
(
model
:
LlamaForCausalLM
):
...
@@ -70,9 +69,10 @@ import torch
...
@@ -70,9 +69,10 @@ import torch
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
awq.quantize.qmodule
import
WQLinear
from
awq.quantize.qmodule
import
WQLinear
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_mlp
import
QuantLlamaMLP
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
class
LlamaFuser
:
def
__init__
(
self
,
awq_model
:
BaseAWQForCausalLM
):
def
__init__
(
self
,
awq_model
:
BaseAWQForCausalLM
):
...
@@ -88,6 +88,11 @@ class LlamaFuser:
...
@@ -88,6 +88,11 @@ class LlamaFuser:
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaRMSNorm
)
if
isinstance
(
module
,
LlamaRMSNorm
)
]
]
self
.
mlp_modules
:
List
[
Tuple
[
str
,
LlamaMLP
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaMLP
)
]
def
fuse_attention
(
self
):
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
for
name
,
module
in
self
.
attention_modules
:
...
@@ -131,4 +136,6 @@ class LlamaFuser:
...
@@ -131,4 +136,6 @@ class LlamaFuser:
set_module_name
(
self
.
model
,
name
,
norm
)
set_module_name
(
self
.
model
,
name
,
norm
)
def
fuse_mlp
(
self
):
def
fuse_mlp
(
self
):
pass
for
name
,
module
in
self
.
mlp_modules
:
mlp
=
QuantLlamaMLP
(
module
.
gate_proj
,
module
.
down_proj
,
module
.
up_proj
)
set_module_name
(
self
.
model
,
name
,
mlp
)
\ No newline at end of file
awq/modules/fused_mlp.py
View file @
ded3ea71
...
@@ -75,9 +75,7 @@ def make_fused_mlp(m, parent_name=''):
...
@@ -75,9 +75,7 @@ def make_fused_mlp(m, parent_name=''):
"""
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
"""
if
isinstance
(
m
,
LlamaMLP
):
if
"mptmlp"
in
str
(
m
.
__class__
).
lower
():
return
QuantLlamaMLP
(
m
.
gate_proj
,
m
.
down_proj
,
m
.
up_proj
)
elif
"mptmlp"
in
str
(
m
.
__class__
).
lower
():
return
QuantMPTMLP
(
m
.
up_proj
,
m
.
act
,
m
.
down_proj
)
return
QuantMPTMLP
(
m
.
up_proj
,
m
.
act
,
m
.
down_proj
)
for
name
,
child
in
m
.
named_children
():
for
name
,
child
in
m
.
named_children
():
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment