Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
620966e8
Commit
620966e8
authored
Sep 02, 2023
by
Casper Hansen
Browse files
Refactor Llama Quant RMSNorm
parent
2082197d
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
6 additions
and
27 deletions
+6
-27
awq/models/llama.py
awq/models/llama.py
+6
-3
awq/modules/fused_norm.py
awq/modules/fused_norm.py
+0
-24
No files found.
awq/models/llama.py
View file @
620966e8
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_quant_norm
,
make_fused_mlp
from
awq.modules
import
make_fused_mlp
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -10,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -10,7 +10,7 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
def
fuse_layers
(
awq_model
:
BaseAWQForCausalLM
):
def
fuse_layers
(
awq_model
:
BaseAWQForCausalLM
):
fuser
=
LlamaFuser
(
awq_model
)
fuser
=
LlamaFuser
(
awq_model
)
fuser
.
fuse_attention
()
fuser
.
fuse_attention
()
make_quant_norm
(
awq_model
)
#
fuser.fuse_rmsnorm()
fuser
.
fuse_rmsnorm
()
make_fused_mlp
(
awq_model
)
#fuser.fuse_mlp()
make_fused_mlp
(
awq_model
)
#fuser.fuse_mlp()
@
staticmethod
@
staticmethod
...
@@ -70,6 +70,7 @@ import torch
...
@@ -70,6 +70,7 @@ import torch
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
awq.quantize.qmodule
import
WQLinear
from
awq.quantize.qmodule
import
WQLinear
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
...
@@ -125,7 +126,9 @@ class LlamaFuser:
...
@@ -125,7 +126,9 @@ class LlamaFuser:
return
qkv_layer
return
qkv_layer
def
fuse_rmsnorm
(
self
):
def
fuse_rmsnorm
(
self
):
pass
for
name
,
module
in
self
.
rmsnorm_modules
:
norm
=
FTLlamaRMSNorm
(
module
.
weight
,
module
.
variance_epsilon
)
set_module_name
(
self
.
model
,
name
,
norm
)
def
fuse_mlp
(
self
):
def
fuse_mlp
(
self
):
pass
pass
awq/modules/fused_norm.py
View file @
620966e8
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
import
awq_inference_engine
import
awq_inference_engine
class
FTLlamaRMSNorm
(
nn
.
Module
):
class
FTLlamaRMSNorm
(
nn
.
Module
):
...
@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
...
@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
output
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
awq_inference_engine
.
layernorm_forward_cuda
(
x
,
self
.
weight
,
output
,
self
.
variance_epsilon
)
awq_inference_engine
.
layernorm_forward_cuda
(
x
,
self
.
weight
,
output
,
self
.
variance_epsilon
)
return
output
return
output
def
make_quant_norm
(
model
):
"""
Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules
"""
for
name
,
m
in
model
.
named_modules
():
if
not
isinstance
(
m
,
LlamaRMSNorm
):
continue
norm
=
FTLlamaRMSNorm
(
m
.
weight
,
m
.
variance_epsilon
)
if
'.'
in
name
:
parent_name
=
name
.
rsplit
(
'.'
,
1
)[
0
]
child_name
=
name
[
len
(
parent_name
)
+
1
:]
parent
=
model
.
get_submodule
(
parent_name
)
else
:
parent_name
=
''
parent
=
model
child_name
=
name
setattr
(
parent
,
child_name
,
norm
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment