Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
dd41a223
Commit
dd41a223
authored
Sep 11, 2023
by
Casper Hansen
Browse files
Update MPTBlock, fuse with MPTModel
parent
7631add1
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
29 additions
and
16 deletions
+29
-16
awq/models/mpt.py
awq/models/mpt.py
+21
-9
awq/modules/fused/block.py
awq/modules/fused/block.py
+8
-7
No files found.
awq/models/mpt.py
View file @
dd41a223
...
...
@@ -8,7 +8,7 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
@
staticmethod
def
fuse_layers
(
model
:
MptForCausalLM
,
quant_config
:
dict
):
fuser
=
MptFuser
(
model
)
fuser
.
fuse_
block
()
fuser
.
fuse_
transformer
()
@
staticmethod
def
get_model_layers
(
model
:
MptForCausalLM
):
...
...
@@ -67,10 +67,11 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused.block
import
MptBlock
from
awq.modules.fused.block
import
MPTBlock
from
awq.modules.fused.model
import
MPTModel
class
MptFuser
:
def
__init__
(
self
,
model
):
def
__init__
(
self
,
model
:
MptForCausalLM
):
self
.
model
=
model
self
.
mpt_blocks
:
List
[
Tuple
[
str
,
OldMptBlock
]]
=
[
...
...
@@ -78,15 +79,26 @@ class MptFuser:
if
'mptblock'
in
module
.
__class__
.
__name__
.
lower
()
]
def
fuse_block
(
self
):
for
name
,
module
in
self
.
mpt_blocks
:
block
=
MptBlock
(
def
fuse_transformer
(
self
):
blocks
=
[]
module
:
OldMptBlock
for
module
in
self
.
model
.
transformer
.
blocks
:
blocks
.
append
(
MPTBlock
(
self
.
model
.
config
.
d_model
,
self
.
model
.
config
.
n_heads
,
module
.
attn
.
Wqkv
,
module
.
attn
.
out_proj
,
module
.
ffn
,
next
(
iter
(
module
.
state_dict
().
values
())).
device
)
module
.
norm_1
,
module
.
norm_2
,
next
(
iter
(
module
.
state_dict
().
values
())).
device
,
self
.
model
.
config
.
max_new_tokens
))
set_module_name
(
self
.
model
,
name
,
block
)
\ No newline at end of file
self
.
model
.
transformer
=
MPTModel
(
self
.
model
.
config
.
vocab_size
,
blocks
,
self
.
model
.
transformer
.
wte
,
self
.
model
.
transformer
.
norm_f
,
)
\ No newline at end of file
awq/modules/fused/block.py
View file @
dd41a223
import
torch.nn
as
nn
from
awq.modules.fused.attn
import
QuantAttentionFused
class
MptBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
,
dev
):
class
MPTBlock
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
qkv_layer
,
o_proj
,
mpt_mlp
,
norm_1
,
norm_2
,
dev
,
max_seq_len
):
super
().
__init__
()
self
.
n_heads
=
n_heads
self
.
hidden_size
=
hidden_size
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
dev
=
"cuda:0"
,
max_seq_len
=
8096
,
use_alibi
=
True
)
self
.
ffn
=
mpt_mlp
self
.
norm_
1
=
n
n
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
).
half
().
to
(
dev
)
self
.
norm_2
=
nn
.
LayerNorm
(
hidden_size
,
eps
=
1e-6
).
half
()
.
to
(
dev
)
self
.
norm_1
=
norm_1
self
.
attn
=
QuantAttentionFused
(
hidden_size
,
self
.
n_heads
,
qkv_layer
,
o_proj
,
dev
=
dev
,
max_seq_len
=
max_seq_len
,
use_alibi
=
True
).
to
(
dev
)
self
.
norm_
2
=
n
orm_2
self
.
ffn
=
mpt_mlp
.
to
(
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
,
attn_bias
,
attention_mask
,
is_causal
self
,
hidden_states
,
past_key_value
,
attn_bias
=
None
,
attention_mask
=
None
,
is_causal
=
None
):
norm_out
=
self
.
norm_1
(
hidden_states
)
attn_output
,
_
,
past_key_value
=
self
.
attn
.
forward
(
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment