Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
b7290868
Commit
b7290868
authored
Sep 06, 2023
by
Casper Hansen
Browse files
Use hf_rotary per default
parent
85430ddc
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
31 additions
and
12 deletions
+31
-12
awq/entry.py
awq/entry.py
+2
-5
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+29
-7
No files found.
awq/entry.py
View file @
b7290868
...
@@ -122,9 +122,9 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
...
@@ -122,9 +122,9 @@ def run_speed(model_path, quant_file, device, n_generate=128, n_context=256, bat
# Prints
# Prints
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
memory_used
=
torch
.
cuda
.
max_memory_allocated
(
device
)
/
(
1024
**
2
)
context_tokens_per_second
=
n_context
/
context_time
*
batch_size
context_tokens_per_second
=
n_context
/
context_time
*
batch_size
context_ms_per_token
=
(
context_time
*
1000
)
/
n_context
*
batch_size
context_ms_per_token
=
(
context_time
*
1000
)
/
n_context
/
batch_size
inference_tokens_per_second
=
n_generate
/
generation_time
*
batch_size
inference_tokens_per_second
=
n_generate
/
generation_time
*
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
*
batch_size
inference_ms_per_token
=
(
generation_time
*
1000
)
/
n_generate
/
batch_size
print
(
f
"[======] Model summary:
{
model_path
}
[======]"
)
print
(
f
"[======] Model summary:
{
model_path
}
[======]"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
print
(
f
"[*] Load time:
{
load_time
:.
2
f
}
seconds"
)
...
@@ -185,9 +185,6 @@ if __name__ == '__main__':
...
@@ -185,9 +185,6 @@ if __name__ == '__main__':
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
run_eval
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
args
.
tasks
,
args
.
task_batch_size
,
args
.
task_n_shot
,
args
.
task_use_pretrained
)
elif
args
.
entry_type
==
'speed'
:
elif
args
.
entry_type
==
'speed'
:
if
args
.
batch_size
>
1
and
not
args
.
disable_fused_layers
:
raise
Exception
(
'Fused layers only support batch_size=1. Pass --disable_fused_layers to run batch_size>1 (much slower).'
)
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
run_speed
(
args
.
model_path
,
args
.
quant_file
,
args
.
device
,
args
.
n_generate
,
args
.
n_context
,
args
.
batch_size
,
args
.
disable_fused_layers
)
else
:
else
:
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
raise
Exception
(
'--entry_type must be one of (search|quant|eval|speed)'
)
awq/modules/fused_attn.py
View file @
b7290868
...
@@ -2,6 +2,7 @@ import torch
...
@@ -2,6 +2,7 @@ import torch
import
torch.nn
as
nn
import
torch.nn
as
nn
import
awq_inference_engine
import
awq_inference_engine
from
torch.nn
import
functional
as
F
from
torch.nn
import
functional
as
F
from
transformers.models.llama.modeling_llama
import
apply_rotary_pos_emb
,
LlamaRotaryEmbedding
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
@@ -41,6 +42,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -41,6 +42,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
# to the attention op.
# to the attention op.
query
=
query
.
contiguous
()
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding_neox
(
awq_inference_engine
.
rotary_embedding_neox
(
positions
,
positions
,
query
,
query
,
...
@@ -60,19 +62,25 @@ class QuantLlamaAttention(nn.Module):
...
@@ -60,19 +62,25 @@ class QuantLlamaAttention(nn.Module):
qkv_proj
,
qkv_proj
,
o_proj
,
o_proj
,
dev
,
dev
,
max_new_tokens
max_new_tokens
,
use_hf_rotary
=
True
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
use_hf_rotary
=
use_hf_rotary
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
if
(
self
.
head_dim
*
num_heads
)
!=
self
.
hidden_size
:
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
raise
ValueError
(
f
"hidden_size must be divisible by num_heads (got `hidden_size`:
{
self
.
hidden_size
}
"
f
" and `num_heads`:
{
num_heads
}
)."
)
f
" and `num_heads`:
{
num_heads
}
)."
)
self
.
qkv_proj
=
qkv_proj
self
.
qkv_proj
=
qkv_proj
self
.
o_proj
=
o_proj
self
.
o_proj
=
o_proj
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
if
use_hf_rotary
:
self
.
rotary_emb
=
LlamaRotaryEmbedding
(
self
.
head_dim
,
max_new_tokens
,
device
=
dev
)
else
:
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
self
.
head_dim
,
max_position_embeddings
=
max_new_tokens
,
device
=
dev
)
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
def
forward
(
self
,
hidden_states
,
past_key_value
=
None
,
attention_mask
=
None
,
position_ids
=
None
,
output_attentions
=
False
,
use_cache
=
False
):
"""Input shape: Batch x Time x Channel"""
"""Input shape: Batch x Time x Channel"""
...
@@ -84,13 +92,27 @@ class QuantLlamaAttention(nn.Module):
...
@@ -84,13 +92,27 @@ class QuantLlamaAttention(nn.Module):
# This updates the query and key states in-place, saving VRAM.
# This updates the query and key states in-place, saving VRAM.
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
if
self
.
use_hf_rotary
:
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key_states
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value_states
,
seq_len
=
kv_seq_len
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
else
:
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
del
qkv_states
del
qkv_states
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
is_causal
=
past_key_value
is
None
is_causal
=
past_key_value
is
None
kv_seq_len
=
q_len
kv_seq_len
=
q_len
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment