Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ac3e86df
"...text-generation-inference.git" did not exist on "eefea5ee3184179b2f440238e403d26e34a17491"
Commit
ac3e86df
authored
Sep 11, 2023
by
Casper Hansen
Browse files
Remove fusing attention, only blocks
parent
950851b3
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
1 addition
and
21 deletions
+1
-21
awq/models/mpt.py
awq/models/mpt.py
+1
-21
No files found.
awq/models/mpt.py
View file @
ac3e86df
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
transformers.models.mpt.modeling_mpt
import
MptBlock
as
OldMptBlock
,
MptForCausalLM
,
MptAttention
from
transformers.models.mpt.modeling_mpt
import
MptBlock
as
OldMptBlock
,
MptForCausalLM
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
layer_type
=
"MPTBlock"
...
@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -8,7 +8,6 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
@
staticmethod
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
dict
):
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
dict
):
fuser
=
MptFuser
(
model
)
fuser
=
MptFuser
(
model
)
fuser
.
fuse_attention
()
fuser
.
fuse_block
()
fuser
.
fuse_block
()
@
staticmethod
@
staticmethod
...
@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -69,34 +68,15 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from
typing
import
List
,
Tuple
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.block
import
MptBlock
from
awq.modules.fused.block
import
MptBlock
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MptFuser
:
class
MptFuser
:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
model
=
model
self
.
attention_modules
:
List
[
Tuple
[
str
,
MptAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
MptAttention
)
]
self
.
mpt_blocks
:
List
[
Tuple
[
str
,
OldMptBlock
]]
=
[
self
.
mpt_blocks
:
List
[
Tuple
[
str
,
OldMptBlock
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
'mptblock'
in
module
.
__class__
.
__name__
.
lower
()
if
'mptblock'
in
module
.
__class__
.
__name__
.
lower
()
]
]
def
fuse_attention
(
self
):
for
name
,
qkv_layer
in
self
.
attention_modules
:
attn
=
QuantAttentionFused
(
qkv_layer
.
hidden_size
,
qkv_layer
.
n_heads
,
qkv_layer
,
qkv_layer
.
out_proj
,
next
(
iter
(
qkv_layer
.
state_dict
().
values
())).
device
,
self
.
model
.
config
.
max_new_tokens
,
use_alibi
=
True
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
fuse_block
(
self
):
def
fuse_block
(
self
):
for
name
,
module
in
self
.
mpt_blocks
:
for
name
,
module
in
self
.
mpt_blocks
:
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment