Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
2082197d
Commit
2082197d
authored
Sep 02, 2023
by
Casper Hansen
Browse files
Refactor Llama quant attention
parent
560fbe59
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
88 additions
and
55 deletions
+88
-55
awq/models/llama.py
awq/models/llama.py
+71
-6
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+5
-49
awq/utils/utils.py
awq/utils/utils.py
+12
-0
No files found.
awq/models/llama.py
View file @
2082197d
from
.base
import
BaseAWQForCausalLM
from
.base
import
BaseAWQForCausalLM
from
awq.modules
import
make_quant_norm
,
make_quant_attn
,
make_fused_mlp
from
awq.modules
import
make_quant_norm
,
make_fused_mlp
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
from
transformers.models.llama.modeling_llama
import
LlamaDecoderLayer
,
LlamaForCausalLM
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
class
LlamaAWQForCausalLM
(
BaseAWQForCausalLM
):
...
@@ -7,10 +7,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -7,10 +7,11 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
max_new_tokens_key
=
"max_position_embeddings"
max_new_tokens_key
=
"max_position_embeddings"
@
staticmethod
@
staticmethod
def
fuse_layers
(
awq_model
):
def
fuse_layers
(
awq_model
:
BaseAWQForCausalLM
):
make_quant_attn
(
awq_model
,
awq_model
.
device
)
fuser
=
LlamaFuser
(
awq_model
)
make_quant_norm
(
awq_model
)
fuser
.
fuse_attention
()
make_fused_mlp
(
awq_model
)
make_quant_norm
(
awq_model
)
#fuser.fuse_rmsnorm()
make_fused_mlp
(
awq_model
)
#fuser.fuse_mlp()
@
staticmethod
@
staticmethod
def
get_model_layers
(
model
:
LlamaForCausalLM
):
def
get_model_layers
(
model
:
LlamaForCausalLM
):
...
@@ -63,4 +64,68 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
...
@@ -63,4 +64,68 @@ class LlamaAWQForCausalLM(BaseAWQForCausalLM):
inp
=
input_feat
[
'mlp.down_proj'
],
inp
=
input_feat
[
'mlp.down_proj'
],
))
))
return
layers
return
layers
\ No newline at end of file
import
torch
from
typing
import
List
,
Tuple
from
awq.quantize.qmodule
import
WQLinear
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
class
LlamaFuser
:
def
__init__
(
self
,
awq_model
:
BaseAWQForCausalLM
):
self
.
awq_model
=
awq_model
self
.
model
=
awq_model
.
model
self
.
attention_modules
:
List
[
Tuple
[
str
,
LlamaAttention
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaAttention
)
]
self
.
rmsnorm_modules
:
List
[
Tuple
[
str
,
LlamaRMSNorm
]]
=
[
(
name
,
module
)
for
name
,
module
in
self
.
model
.
named_modules
()
if
isinstance
(
module
,
LlamaRMSNorm
)
]
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
num_heads
,
qkv_layer
,
module
.
o_proj
,
qkv_layer
.
qweight
.
device
,
self
.
awq_model
.
model
.
config
.
max_new_tokens
)
set_module_name
(
self
.
model
,
name
,
attn
)
def
_fuse_qkv
(
self
,
module
:
LlamaAttention
):
# get qkv and bias
q_proj
,
k_proj
,
v_proj
=
module
.
q_proj
,
module
.
k_proj
,
module
.
v_proj
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
# create module
qkv_layer
=
WQLinear
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
)
# replace buffers with real weights
qkv_layer
.
qweight
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qkv_layer
.
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
qkv_layer
.
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
qkv_layer
.
bias
=
bias
return
qkv_layer
def
fuse_rmsnorm
(
self
):
pass
def
fuse_mlp
(
self
):
pass
awq/modules/fused_attn.py
View file @
2082197d
import
math
import
torch
import
torch
import
torch.nn
as
nn
import
torch.nn
as
nn
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRotaryEmbedding
,
apply_rotary_pos_emb
from
awq.quantize.qmodule
import
WQLinear
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module):
...
@@ -64,7 +59,8 @@ class QuantLlamaAttention(nn.Module):
num_heads
,
num_heads
,
qkv_proj
,
qkv_proj
,
o_proj
,
o_proj
,
dev
dev
,
max_new_tokens
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
...
@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module):
...
@@ -76,7 +72,7 @@ class QuantLlamaAttention(nn.Module):
f
" and `num_heads`:
{
num_heads
}
)."
)
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
2048
,
device
=
dev
)
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
...
@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module):
...
@@ -101,7 +97,7 @@ class QuantLlamaAttention(nn.Module):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value_states
=
value_states
.
to
(
"cuda:0"
)
value_states
=
value_states
.
to
(
key_states
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
# reuse k, v, self_attention
...
@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module):
...
@@ -125,43 +121,3 @@ class QuantLlamaAttention(nn.Module):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
def
make_quant_attn
(
model
,
dev
):
"""
Replace all LlamaAttention modules with QuantLlamaAttention modules, fusing the q, k, v projections.
"""
for
name
,
m
in
model
.
named_modules
():
if
not
isinstance
(
m
,
LlamaAttention
):
continue
q_proj
=
m
.
q_proj
k_proj
=
m
.
k_proj
v_proj
=
m
.
v_proj
qweights
=
torch
.
cat
([
q_proj
.
qweight
,
k_proj
.
qweight
,
v_proj
.
qweight
],
dim
=
1
)
qzeros
=
torch
.
cat
([
q_proj
.
qzeros
,
k_proj
.
qzeros
,
v_proj
.
qzeros
],
dim
=
1
)
scales
=
torch
.
cat
([
q_proj
.
scales
,
k_proj
.
scales
,
v_proj
.
scales
],
dim
=
1
)
g_idx
=
None
bias
=
torch
.
cat
([
q_proj
.
bias
,
k_proj
.
bias
,
v_proj
.
bias
],
dim
=
0
)
if
q_proj
.
bias
is
not
None
else
None
qkv_layer
=
WQLinear
(
q_proj
.
w_bit
,
q_proj
.
group_size
,
q_proj
.
in_features
,
q_proj
.
out_features
+
k_proj
.
out_features
+
v_proj
.
out_features
,
q_proj
.
bias
is
not
None
,
q_proj
.
qweight
.
device
)
qkv_layer
.
qweight
=
qweights
qkv_layer
.
qzeros
=
qzeros
qkv_layer
.
scales
=
scales
qkv_layer
.
bias
=
bias
attn
=
QuantLlamaAttention
(
m
.
hidden_size
,
m
.
num_heads
,
qkv_layer
,
m
.
o_proj
,
dev
)
if
'.'
in
name
:
parent_name
=
name
.
rsplit
(
'.'
,
1
)[
0
]
child_name
=
name
[
len
(
parent_name
)
+
1
:]
parent
=
model
.
get_submodule
(
parent_name
)
else
:
parent_name
=
''
parent
=
model
child_name
=
name
setattr
(
parent
,
child_name
,
attn
)
awq/utils/utils.py
View file @
2082197d
...
@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map):
...
@@ -41,3 +41,15 @@ def simple_dispatch_model(model, device_map):
model
.
hf_device_map
=
device_map
model
.
hf_device_map
=
device_map
return
model
return
model
def
set_module_name
(
model
,
name
,
value
):
if
'.'
in
name
:
parent_name
=
name
.
rsplit
(
'.'
,
1
)[
0
]
child_name
=
name
[
len
(
parent_name
)
+
1
:]
parent
=
model
.
get_submodule
(
parent_name
)
else
:
parent_name
=
''
parent
=
model
child_name
=
name
setattr
(
parent
,
child_name
,
value
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment