Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
783afe50
Unverified
Commit
783afe50
authored
Sep 03, 2023
by
Casper
Committed by
GitHub
Sep 03, 2023
Browse files
Merge pull request #18 from casper-hansen/refactor_fused
Refactor fused modules
parents
560fbe59
0aa4a596
Changes
6
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
131 additions
and
114 deletions
+131
-114
awq/models/llama.py
awq/models/llama.py
+80
-6
awq/models/mpt.py
awq/models/mpt.py
+33
-8
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+5
-49
awq/modules/fused_mlp.py
awq/modules/fused_mlp.py
+1
-27
awq/modules/fused_norm.py
awq/modules/fused_norm.py
+0
-24
awq/utils/utils.py
awq/utils/utils.py
+12
-0
No files found.
awq/models/llama.py
View file @
783afe50
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_quant_norm
,
make_quant_attn
,
make_fused_mlp
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -7,10 +6,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -7,10 +6,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
fuse_layers
(
awq_model
):
def
fuse_layers
(
model
:
LlamaForCausalLM
):
make_quant_attn
(
awq_model
,
awq_model
.
device
)
fuser
=
LlamaFuser
(
model
)
make_quant_norm
(
awq_model
)
fuser
.
fuse_attention
()
make_fused_mlp
(
awq_model
)
fuser
.
fuse_rmsnorm
()
fuser
.
fuse_mlp
()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
LlamaForCausalLM
):
def
get_model_layers
(
model
:
LlamaForCausalLM
):
...
@@ -64,3 +64,77 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -64,3 +64,77 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
))
))
return
layers
return
layers
import
torch
from
typing
import
List
,
Tuple
from
awq.quantize.qmodule
import
WQLinear
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_mlp
import
QuantLlamaMLP
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
attention_modules
:
List
[
Tuple
[
str
,
LlamaAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaAttention
)
]
self
.
rmsnorm_modules
:
List
[
Tuple
[
str
,
LlamaRMSNorm
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaRMSNorm
)
]
self
.
mlp_modules
:
List
[
Tuple
[
str
,
LlamaMLP
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaMLP
)
]
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
num_heads
,
qkv_layer
,
module
.
o_proj
,
qkv_layer
.
qweight
.
device
,
self
.
model
.
config
.
max_new_tokens
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
_fuse_qkv
(
self
,
module
:
LlamaAttention
):
# get qkv and bias
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
# create module
qkv_layer
=
WQLinear
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
)
# replace buffers with real weights
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
qkv_layer
.
bias
=
bias
return
qkv_layer
def
fuse_rmsnorm
(
self
):
for
name
,
module
in
self
.
rmsnorm_modules
:
norm
=
FTLlamaRMSNorm
(
module
.
weight
,
module
.
variance_epsilon
)
set_module_name
(
self
.
model
,
name
,
norm
)
def
fuse_mlp
(
self
):
for
name
,
module
in
self
.
mlp_modules
:
mlp
=
QuantLlamaMLP
(
module
.
gate_proj
,
module
.
down_proj
,
module
.
up_proj
)
set_module_name
(
self
.
model
,
name
,
mlp
)
\ No newline at end of file
awq/models/mpt.py
View file @
783afe50
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_fused_mlp
from
transformers.models.mpt.modeling_mpt
import
MptBlock
,
MptForCausalLM
,
MptMLP
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
class
MptAWQForCausalLM
(
BaseAWQForCausalLM
):
layer_type
=
"MPTBlock"
layer_type
=
"MPTBlock"
max_new_tokens_key
=
"max_seq_len"
max_new_tokens_key
=
"max_seq_len"
@
staticmethod
@
staticmethod
def
fuse_layers
(
awq_model
):
def
fuse_layers
(
model
:
MptForCausalLM
):
make_fused_mlp
(
awq_model
)
fuser
=
MptFuser
(
model
)
fuser
.
fuse_mlp
()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
):
def
get_model_layers
(
model
:
MptForCausalLM
):
return
model
.
transformer
.
blocks
return
model
.
transformer
.
blocks
@
staticmethod
@
staticmethod
def
get_act_for_scaling
(
module
):
def
get_act_for_scaling
(
module
:
MptBlock
):
return
dict
(
return
dict
(
is_scalable
=
True
,
is_scalable
=
True
,
scale_name
=
"ffn.act"
,
scale_name
=
"ffn.act"
,
...
@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -23,12 +24,12 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
)
)
@
staticmethod
@
staticmethod
def
move_embed
(
model
,
device
):
def
move_embed
(
model
:
MptForCausalLM
,
device
:
str
):
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
wte
=
model
.
transformer
.
wte
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
model
.
transformer
.
emb_drop
=
model
.
transformer
.
emb_drop
.
to
(
device
)
@
staticmethod
@
staticmethod
def
get_layers_for_scaling
(
module
,
input_feat
,
module_kwargs
):
def
get_layers_for_scaling
(
module
:
MptBlock
,
input_feat
,
module_kwargs
):
layers
=
[]
layers
=
[]
# attention input
# attention input
...
@@ -63,3 +64,27 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -63,3 +64,27 @@ class MptAWQForCausalLM(BaseAWQForCausalLM):
))
))
return
layers
return
layers
from
typing
import
List
,
Tuple
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_mlp
import
QuantMPTMLP
class
MptFuser
:
def
__init__
(
self
,
model
):
self
.
model
=
model
self
.
mlp_modules
:
List
[
Tuple
[
str
,
MptMLP
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
MptMLP
)
]
def
fuse_attention
(
self
):
pass
def
fuse_layernorm
(
self
):
pass
def
fuse_mlp
(
self
):
for
name
,
module
in
self
.
mlp_modules
:
mlp
=
QuantMPTMLP
(
module
.
up_proj
,
module
.
act
,
module
.
down_proj
)
set_module_name
(
self
.
model
,
name
,
mlp
)
\ No newline at end of file
awq/modules/fused_attn.py
View file @
783afe50
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRotaryEmbedding
,
apply_rotary_pos_emb
from
awq.quantize.qmodule
import
WQLinear
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module):
...
@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module):
num_heads
,
num_heads
,
qkv_proj
,
qkv_proj
,
o_proj
,
o_proj
,
dev
dev
,
max_new_tokens
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module):
...
@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module):
f
" and `num_heads`:
{
num_heads
}
)."
)
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
2048
,
device
=
dev
)
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
...
@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module):
...
@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value_states
=
value_states
.
to
(
"cuda:0"
)
value_states
=
value_states
.
to
(
key_states
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
# reuse k, v, self_attention
...
@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module):
...
@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
make_quant_attn
(
model
,
dev
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for
name
,
m
in
model
.
named_modules
():
if
not
isinstance
(
m
,
LlamaAttention
):
continue
q_proj
=
m
.
q_proj
k_proj
=
m
.
k_proj
v_proj
=
m
.
v_proj
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
g_idx
=
None
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
qkv_layer
=
WQLinear
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
)
qkv_layer
.
qweight
=
qweights
qkv_layer
.
qzeros
=
qzeros
qkv_layer
.
scales
=
scales
qkv_layer
.
bias
=
bias
attn
=
QuantLlamaAttention
(
m
.
hidden_size
,
m
.
num_heads
,
qkv_layer
,
m
.
o_proj
,
dev
)
if
'.'
in
name
:
parent_name
=
name
.
rsplit
(
'.'
,
1
)[
0
]
child_name
=
name
[
len
(
parent_name
)
+
1
:]
parent
=
model
.
get_submodule
(
parent_name
)
else
:
parent_name
=
''
parent
=
model
child_name
=
name
setattr
(
parent
,
child_name
,
attn
)
awq/modules/fused_mlp.py
View file @
783afe50
import
numpy
as
np
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
torch.nn.functional
as
F
from
torch.cuda.amp
import
custom_bwd
,
custom_fwd
from
transformers.models.llama.modeling_llama
import
LlamaMLP
import
awq_inference_engine
import
awq_inference_engine
import
torch.nn.functional
as
F
class
QuantMPTMLP
(
nn
.
Module
):
class
QuantMPTMLP
(
nn
.
Module
):
def
__init__
(
def
__init__
(
...
@@ -67,25 +63,3 @@ class QuantLlamaMLP(nn.Module):
...
@@ -67,25 +63,3 @@ class QuantLlamaMLP(nn.Module):
c
=
gate_output
*
up_output
c
=
gate_output
*
up_output
c
=
c
.
reshape
(
out_shape
)
c
=
c
.
reshape
(
out_shape
)
return
c
return
c
def
make_fused_mlp
(
m
,
parent_name
=
''
):
if
not
hasattr
(
make_fused_mlp
,
"called"
):
make_fused_mlp
.
called
=
True
"""
Replace all LlamaMLP modules with QuantLlamaMLP modules, which fuses many of the operations.
"""
if
isinstance
(
m
,
LlamaMLP
):
return
QuantLlamaMLP
(
m
.
gate_proj
,
m
.
down_proj
,
m
.
up_proj
)
elif
"mptmlp"
in
str
(
m
.
__class__
).
lower
():
return
QuantMPTMLP
(
m
.
up_proj
,
m
.
act
,
m
.
down_proj
)
for
name
,
child
in
m
.
named_children
():
child
=
make_fused_mlp
(
child
,
parent_name
=
f
"
{
parent_name
}
.
{
name
}
"
)
if
isinstance
(
child
,
QuantLlamaMLP
):
setattr
(
m
,
name
,
child
)
elif
isinstance
(
child
,
QuantMPTMLP
):
setattr
(
m
,
name
,
child
)
return
m
\ No newline at end of file
awq/modules/fused_norm.py
View file @
783afe50
import
torch
import
torch
from
torch
import
nn
from
torch
import
nn
from
transformers.models.llama.modeling_llama
import
LlamaRMSNorm
import
awq_inference_engine
import
awq_inference_engine
class
FTLlamaRMSNorm
(
nn
.
Module
):
class
FTLlamaRMSNorm
(
nn
.
Module
):
...
@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
...
@@ -16,26 +15,3 @@ class FTLlamaRMSNorm(nn.Module):
output
=
torch
.
empty_like
(
x
)
output
=
torch
.
empty_like
(
x
)
awq_inference_engine
.
layernorm_forward_cuda
(
x
,
self
.
weight
,
output
,
self
.
variance_epsilon
)
awq_inference_engine
.
layernorm_forward_cuda
(
x
,
self
.
weight
,
output
,
self
.
variance_epsilon
)
return
output
return
output
def
make_quant_norm
(
model
):
"""
Replace all LlamaRMSNorm modules with FTLlamaRMSNorm modules
"""
for
name
,
m
in
model
.
named_modules
():
if
not
isinstance
(
m
,
LlamaRMSNorm
):
continue
norm
=
FTLlamaRMSNorm
(
m
.
weight
,
m
.
variance_epsilon
)
if
'.'
in
name
:
parent_name
=
name
.
rsplit
(
'.'
,
1
)[
0
]
child_name
=
name
[
len
(
parent_name
)
+
1
:]
parent
=
model
.
get_submodule
(
parent_name
)
else
:
parent_name
=
''
parent
=
model
child_name
=
name
setattr
(
parent
,
child_name
,
norm
)
awq/utils/utils.py
View file @
783afe50
...
@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map):
...
@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map):
model
.
hf_device_map
=
device_map
model
.
hf_device_map
=
device_map
return
model
return
model
def
set_module_name
(
model
,
name
,
value
):
if
'.'
in
name
:
parent_name
=
name
.
rsplit
(
'.'
,
1
)[
0
]
child_name
=
name
[
len
(
parent_name
)
+
1
:]
parent
=
model
.
get_submodule
(
parent_name
)
else
:
parent_name
=
''
parent
=
model
child_name
=
name
setattr
(
parent
,
child_name
,
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment