Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
ab7d68e7
"src/git@developer.sourcefind.cn:renzhc/diffusers_dcu.git" did not exist on "6a05b274cc503276a4c1ac22a451df9184a9f761"
Commit
ab7d68e7
authored
Sep 06, 2023
by
Casper Hansen
Browse files
Correct input sizes
parent
a4626828
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
44 additions
and
36 deletions
+44
-36
awq/models/llama.py
awq/models/llama.py
+1
-0
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+43
-36
No files found.
awq/models/llama.py
View file @
ab7d68e7
...
@@ -99,6 +99,7 @@ class LlamaFuser:
...
@@ -99,6 +99,7 @@ class LlamaFuser:
attn
=
QuantLlamaAttention
(
attn
=
QuantLlamaAttention
(
module
.
hidden_size
,
module
.
hidden_size
,
module
.
num_heads
,
module
.
num_heads
,
module
.
num_key_value_heads
,
qkv_layer
,
qkv_layer
,
module
.
o_proj
,
module
.
o_proj
,
qkv_layer
.
qweight
.
device
,
qkv_layer
.
qweight
.
device
,
...
...
awq/modules/fused_attn.py
View file @
ab7d68e7
...
@@ -30,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -30,6 +30,7 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin
=
freqs
.
sin
()
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# [max_position, rot_dim]
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
def
forward
(
...
@@ -38,11 +39,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -38,11 +39,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
):
):
batch_size
,
seq_len
,
_
=
query
.
shape
query
=
query
.
view
(
batch_size
*
seq_len
,
-
1
)
key
=
key
.
view
(
batch_size
*
seq_len
,
-
1
)
positions
=
positions
.
view
(
-
1
).
to
(
query
.
device
)
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# to the attention op.
query
=
query
.
contiguous
()
query
=
query
.
contiguous
()
...
@@ -57,9 +53,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -57,9 +53,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
True
# is_neox
True
# is_neox
)
)
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
)
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
)
return
query
,
key
return
query
,
key
class
QuantLlamaAttention
(
nn
.
Module
):
class
QuantLlamaAttention
(
nn
.
Module
):
...
@@ -69,15 +62,17 @@ class QuantLlamaAttention(nn.Module):
...
@@ -69,15 +62,17 @@ class QuantLlamaAttention(nn.Module):
self
,
self
,
hidden_size
,
hidden_size
,
num_heads
,
num_heads
,
num_kv_heads
,
qkv_proj
,
qkv_proj
,
o_proj
,
o_proj
,
dev
,
dev
,
max_new_tokens
,
max_new_tokens
,
use_hf_rotary
=
Tru
e
use_hf_rotary
=
Fals
e
):
):
super
().
__init__
()
super
().
__init__
()
self
.
hidden_size
=
hidden_size
self
.
hidden_size
=
hidden_size
self
.
num_heads
=
num_heads
self
.
num_heads
=
num_heads
self
.
num_kv_heads
=
num_kv_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
head_dim
=
hidden_size
//
num_heads
self
.
use_hf_rotary
=
use_hf_rotary
self
.
use_hf_rotary
=
use_hf_rotary
...
@@ -100,32 +95,44 @@ class QuantLlamaAttention(nn.Module):
...
@@ -100,32 +95,44 @@ class QuantLlamaAttention(nn.Module):
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
if
self
.
use_hf_rotary
:
if
self
.
use_hf_rotary
:
# get qkv
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
query
,
key
,
value
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
# This updates the query and key states in-place, saving VRAM.
del
qkv_states
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# reshape for hf rotary
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
kv_seq_len
=
key
_states
.
shape
[
-
2
]
kv_seq_len
=
key
.
shape
[
-
2
]
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
cos
,
sin
=
self
.
rotary_emb
(
value
_states
,
seq_len
=
kv_seq_len
)
cos
,
sin
=
self
.
rotary_emb
(
value
,
seq_len
=
kv_seq_len
)
query
_states
,
key_states
=
apply_rotary_pos_emb
(
query
_states
,
key_states
,
cos
,
sin
,
position_ids
)
query
,
key
=
apply_rotary_pos_emb
(
query
,
key
,
cos
,
sin
,
position_ids
)
else
:
else
:
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
# get qkv
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
query
,
key
,
value
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
del
qkv_states
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
# [num_tokens, num_heads * head_size]
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_batch_size
,
query_len
,
_
=
query
.
shape
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query
=
query
.
view
(
query_len
*
query_batch_size
,
self
.
num_heads
*
self
.
head_dim
)
# [num_tokens, num_kv_heads * head_size]
del
qkv_states
key_batch_size
,
key_len
,
_
=
key
.
shape
key
=
key
.
view
(
key_len
*
key_batch_size
,
self
.
num_kv_heads
*
self
.
head_dim
)
# [num_tokens]
positions
=
position_ids
.
view
(
-
1
).
to
(
query
.
device
)
query
,
key
=
self
.
rotary_emb
(
query
,
key
,
positions
)
query
=
query
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key
=
key
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value
=
value
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
is_causal
=
past_key_value
is
None
is_causal
=
past_key_value
is
None
...
@@ -133,25 +140,25 @@ class QuantLlamaAttention(nn.Module):
...
@@ -133,25 +140,25 @@ class QuantLlamaAttention(nn.Module):
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
kv_seq_len
+=
past_key_value
[
0
].
shape
[
-
2
]
value
_states
=
value
_states
.
to
(
key_states
.
device
)
value
=
value
.
to
(
key
.
device
)
if
past_key_value
is
not
None
:
if
past_key_value
is
not
None
:
# reuse k, v, self_attention
# reuse k, v, self_attention
key
_states
=
torch
.
cat
([
past_key_value
[
0
],
key
_states
],
dim
=
2
)
key
=
torch
.
cat
([
past_key_value
[
0
],
key
],
dim
=
2
)
value
_states
=
torch
.
cat
([
past_key_value
[
1
],
value
_states
],
dim
=
2
)
value
=
torch
.
cat
([
past_key_value
[
1
],
value
],
dim
=
2
)
if
use_cache
:
if
use_cache
:
# Since qkv_proj is fused, query
_states
etc will hold a reference to the original qkv_states tensor
# Since qkv_proj is fused, query etc will hold a reference to the original qkv_states tensor
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
# which can cause excessive memory usage by the cache. `contiguous` is a convenient way to workaround this.
key
_states
=
key_states
.
contiguous
()
key
=
key
.
contiguous
()
value
_states
=
value
_states
.
contiguous
()
value
=
value
.
contiguous
()
query
_states
=
query
_states
.
contiguous
()
query
=
query
.
contiguous
()
past_key_value
=
(
key
_states
,
value_states
)
if
use_cache
else
None
past_key_value
=
(
key
,
value
)
if
use_cache
else
None
# with torch.backends.cuda.sdp_kernel(enable_math=False):
# with torch.backends.cuda.sdp_kernel(enable_math=False):
attn_output
=
F
.
scaled_dot_product_attention
(
query
_states
,
key_states
,
value_states
,
is_causal
=
is_causal
)
attn_output
=
F
.
scaled_dot_product_attention
(
query
,
key
,
value
,
is_causal
=
is_causal
)
del
query
_states
,
key_states
,
value_states
del
query
,
key
,
value
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
attn_output
.
transpose
(
1
,
2
).
reshape
(
bsz
,
q_len
,
self
.
hidden_size
)
attn_output
=
self
.
o_proj
(
attn_output
)
attn_output
=
self
.
o_proj
(
attn_output
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment