Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
54f02854
Commit
54f02854
authored
Sep 09, 2023
by
Casper Hansen
Browse files
Falcon fused layers
parent
73c5e2bf
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
54 additions
and
2 deletions
+54
-2
awq/models/falcon.py
awq/models/falcon.py
+54
-2
No files found.
awq/models/falcon.py
View file @
54f02854
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
,
FalconForCausalLM
from
transformers.models.falcon.modeling_falcon
import
FalconDecoderLayer
,
FalconForCausalLM
,
FalconAttention
class
FalconAWQForCausalLM
(
BaseAWQForCausalLM
):
class
FalconAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"FalconDecoderLayer"
layer_type
=
"FalconDecoderLayer"
@
staticmethod
def
fuse_layers
(
model
:
FalconForCausalLM
,
quant_config
:
dict
):
fuser
=
FalconFuser
(
model
)
# fuser.fuse_attention()
# fuser.fuse_layernorm()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
FalconForCausalLM
):
def
get_model_layers
(
model
:
FalconForCausalLM
):
return
model
.
transformer
.
h
return
model
.
transformer
.
h
...
@@ -56,4 +62,50 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -56,4 +62,50 @@ class FalconAWQForCausalLM(BaseAWQForCausalLM):
kwargs
=
module_kwargs
,
kwargs
=
module_kwargs
,
))
))
return
layers
return
layers
\ No newline at end of file
import
torch
import
xformers
from
torch.nn
import
LayerNorm
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
xformers.triton.layer_norm
import
FusedLayerNorm
from
awq.modules.fused.attn
import
QuantAttentionFused
class
FalconFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
attention_modules
:
List
[
Tuple
[
str
,
FalconAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
FalconAttention
)
]
self
.
layernorm_modules
:
List
[
Tuple
[
str
,
LayerNorm
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LayerNorm
)
]
def
fuse_attention
(
self
):
for
name
,
qkv_layer
in
self
.
attention_modules
:
attn
=
QuantAttentionFused
(
qkv_layer
.
hidden_size
,
qkv_layer
.
num_heads
,
qkv_layer
,
qkv_layer
.
dense
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
self
.
model
.
config
.
max_new_tokens
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
fuse_layernorm
(
self
):
xformers
.
triton
.
k_layer_norm
.
_triton_layernorm_fp16_enabled
=
True
for
name
,
module
in
self
.
layernorm_modules
:
norm
=
FusedLayerNorm
(
module
.
weight
.
shape
,
eps
=
module
.
eps
).
to
(
module
.
weight
.
device
)
# copy weights and bias
with
torch
.
no_grad
():
norm
.
weight
=
module
.
weight
norm
.
bias
=
module
.
bias
set_module_name
(
self
.
model
,
name
,
norm
)
\ No newline at end of file
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment