Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
e71181bd
"...text-generation-inference.git" did not exist on "e76b9824ae965e95923dbcf50aa30efb633a1974"
Commit
e71181bd
authored
Sep 06, 2023
by
Casper Hansen
Browse files
Update shapes for cuda kernel
parent
88964968
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
19 additions
and
5 deletions
+19
-5
awq/modules/fused_attn.py
awq/modules/fused_attn.py
+19
-5
No files found.
awq/modules/fused_attn.py
View file @
e71181bd
...
@@ -38,18 +38,28 @@ class QuantLlamaRotaryEmbedding(nn.Module):
...
@@ -38,18 +38,28 @@ class QuantLlamaRotaryEmbedding(nn.Module):
key
:
torch
.
Tensor
,
key
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
positions
:
torch
.
Tensor
,
):
):
batch_size
,
seq_len
,
_
=
query
.
shape
query
=
query
.
view
(
batch_size
*
seq_len
,
-
1
)
key
=
key
.
view
(
batch_size
*
seq_len
,
-
1
)
positions
=
positions
.
view
(
-
1
).
to
(
query
.
device
)
# Apply rotary embedding to the query and key before passing them
# Apply rotary embedding to the query and key before passing them
# to the attention op.
# to the attention op.
query
=
query
.
contiguous
()
query
=
query
.
contiguous
()
key
=
key
.
contiguous
()
key
=
key
.
contiguous
()
awq_inference_engine
.
rotary_embedding
_neox
(
awq_inference_engine
.
rotary_embedding
(
positions
,
positions
,
query
,
query
,
key
,
key
,
self
.
dim
,
self
.
dim
,
self
.
cos_sin_cache
,
self
.
cos_sin_cache
,
True
# is_neox
)
)
query
=
query
.
view
(
batch_size
,
seq_len
,
-
1
)
key
=
key
.
view
(
batch_size
,
seq_len
,
-
1
)
return
query
,
key
return
query
,
key
class
QuantLlamaAttention
(
nn
.
Module
):
class
QuantLlamaAttention
(
nn
.
Module
):
...
@@ -88,12 +98,13 @@ class QuantLlamaAttention(nn.Module):
...
@@ -88,12 +98,13 @@ class QuantLlamaAttention(nn.Module):
bsz
,
q_len
,
_
=
hidden_states
.
size
()
bsz
,
q_len
,
_
=
hidden_states
.
size
()
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
self
.
qkv_proj
(
hidden_states
)
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
# This updates the query and key states in-place, saving VRAM.
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
if
self
.
use_hf_rotary
:
if
self
.
use_hf_rotary
:
qkv_states
=
qkv_states
.
view
(
bsz
,
q_len
,
3
,
self
.
num_heads
,
self
.
head_dim
)
# This updates the query and key states in-place, saving VRAM.
query_states
,
key_states
,
value_states
=
torch
.
split
(
qkv_states
,
1
,
dim
=
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
...
@@ -106,10 +117,13 @@ class QuantLlamaAttention(nn.Module):
...
@@ -106,10 +117,13 @@ class QuantLlamaAttention(nn.Module):
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
query_states
,
key_states
=
apply_rotary_pos_emb
(
query_states
,
key_states
,
cos
,
sin
,
position_ids
)
else
:
else
:
query_states
,
key_states
,
value_states
=
qkv_states
.
chunk
(
chunks
=
3
,
dim
=-
1
)
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
query_states
,
key_states
=
self
.
rotary_emb
(
query_states
,
key_states
,
position_ids
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
query_states
=
query_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
key_states
=
key_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
value_states
=
value_states
.
view
(
bsz
,
q_len
,
self
.
num_heads
,
self
.
head_dim
).
transpose
(
1
,
2
)
del
qkv_states
del
qkv_states
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment