Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
8bb8fe20
Commit
8bb8fe20
authored
Sep 05, 2023
by
Casper Hansen
Browse files
Attempt new module
parent
0f699cf9
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
86 additions
and
4 deletions
+86
-4
awq/entry.py
awq/entry.py
+2
-2
awq/models/llama.py
awq/models/llama.py
+2
-2
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+82
-0
No files found.
awq/entry.py
View file @
8bb8fe20
...
@@ -185,8 +185,8 @@ if __name__ == '__main__':
...
@@ -185,8 +185,8 @@ if __name__ == '__main__':
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
elif
args
.
entry_type
==
'speed'
:
if
args
.
batch_size
>
1
and
not
args
.
disable_fused_layers
:
#
if args.batch_size > 1 and not args.disable_fused_layers:
raise
Exception
(
'Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).'
)
#
raise Exception('Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).')
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
else
:
else
:
...
...
awq/models/llama.py
View file @
8bb8fe20
...
@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear
...
@@ -71,7 +71,7 @@ from awq.quantize.qmodule import WQLinear
from
awq.utils.utils
import
set_module_name
from
awq.utils.utils
import
set_module_name
from
awq.modules.fused_mlp
import
QuantLlamaMLP
from
awq.modules.fused_mlp
import
QuantLlamaMLP
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_norm
import
FTLlamaRMSNorm
from
awq.modules.fused_attn
import
QuantLlamaAttention
from
awq.modules.fused_attn
import
QuantLlamaAttention
,
CustomQuantLlamaAttention
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
from
transformers.models.llama.modeling_llama
import
LlamaAttention
,
LlamaRMSNorm
,
LlamaMLP
class
LlamaFuser
:
class
LlamaFuser
:
...
@@ -96,7 +96,7 @@ class LlamaFuser:
...
@@ -96,7 +96,7 @@ class LlamaFuser:
def
fuse_attention
(
self
):
def
fuse_attention
(
self
):
for
name
,
module
in
self
.
attention_modules
:
for
name
,
module
in
self
.
attention_modules
:
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
qkv_layer
:
WQLinear
=
self
.
_fuse_qkv
(
module
)
attn
=
QuantLlamaAttention
(
attn
=
Custom
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_heads
,
qkv_layer
,
qkv_layer
,
...
...
awq/modules/fused_attn.py
View file @
8bb8fe20
...
@@ -121,3 +121,85 @@ class QuantLlamaAttention(nn.Module):
...
@@ -121,3 +121,85 @@ class QuantLlamaAttention(nn.Module):
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
return
attn_output
,
None
,
past_key_value
class
CustomQuantLlamaAttention
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_proj
,
o_proj
,
dev
,
max_new_tokens
):
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
_shape
(
self
,
tensor
:
torch
.
Tensor
,
seq_len
:
int
,
bsz
:
int
):
return
tensor
.
view
(
bsz
,
seq_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
).
contiguous
()
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
,
position_ids
,
past_key_value
,
output_attentions
:
bool
=
False
,
use_cache
:
bool
=
False
,
):
# qkv proj
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
# extract q,k,v
bsz
,
q_len
,
_
=
hidden_states
.
size
()
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
self
.
hidden_size
,
dim
=
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# rotary embedding
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
# cache ops
is_causal
=
past_key_value
is
None
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
key_states
=
torch
.
cat
([
past_key_value
[
0
],
key_states
],
dim
=
2
)
value_states
=
torch
.
cat
([
past_key_value
[
1
],
value_states
],
dim
=
2
)
if
use_cache
:
# Since qkv_proj is fused, query_states etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
query_states
=
query_states
.
contiguous
()
key_states
=
key_states
.
contiguous
()
value_states
=
value_states
.
contiguous
()
past_key_value
=
(
key_states
,
value_states
)
if
use_cache
else
None
# multi-head masked attention
attn_output
=
F
.
scaled_dot_product_attention
(
query_states
,
key_states
,
value_states
,
attn_mask
=
None
if
is_causal
else
attention_mask
,
is_causal
=
is_causal
)
# reshape output
attn_output
=
attn_output
.
transpose
(
1
,
2
).
contiguous
()
attn_output
=
attn_output
.
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
# out projection
attn_output
=
self
.
o_proj
(
attn_output
)
return
attn_output
,
None
,
past_key_value
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment