Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
64e6b3e1
Commit
64e6b3e1
authored
Sep 08, 2023
by
Casper Hansen
Browse files
Use apply_rotary_emb
parent
a11c313a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
35 additions
and
10 deletions
+35
-10
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+35
-10
No files found.
awq/modules/fused/attn.py
View file @
64e6b3e1
...
...
@@ -4,6 +4,36 @@ import torch.nn as nn
import
awq_inference_engine
from
torch.nn
import
functional
as
F
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
freqs
=
torch
.
outer
(
t
,
freqs
).
float
()
# type: ignore
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64
return
freqs_cis
def
reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
):
ndim
=
x
.
ndim
assert
0
<=
1
<
ndim
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
])
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
,
):
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
freqs_cis
=
reshape_for_broadcast
(
freqs_cis
,
xq_
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
class
QuantLlamaRotaryEmbedding
(
nn
.
Module
):
def
__init__
(
self
,
dim
,
max_position_embeddings
=
2048
,
base
=
10000
,
device
=
None
):
...
...
@@ -37,8 +67,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
sin
=
freqs
.
sin
()
cache
=
torch
.
cat
((
cos
,
sin
),
dim
=-
1
)
# self.register_buffer("cos_cached", emb.cos()[None, None, :, :].to(dtype), persistent=False)
# self.register_buffer("sin_cached", emb.sin()[None, None, :, :].to(dtype), persistent=False)
self
.
register_buffer
(
"cos_sin_cache"
,
cache
.
half
(),
persistent
=
False
)
def
forward
(
...
...
@@ -61,7 +89,6 @@ class QuantLlamaRotaryEmbedding(nn.Module):
)
return
query
,
key
class
QuantLlamaAttentionFused
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
num_heads
,
qkv_layer
,
o_proj
,
dev
,
max_position_embeddings
):
super
().
__init__
()
...
...
@@ -85,12 +112,10 @@ class QuantLlamaAttentionFused(nn.Module):
(
1
,
self
.
n_local_heads
,
self
.
head_dim
//
8
,
max_position_embeddings
,
8
,
)
).
to
(
dev
).
half
()
)
self
.
rotary_emb
=
QuantLlamaRotaryEmbedding
(
dim
=
hidden_size
//
num_heads
,
max_position_embeddings
=
max_position_embeddings
,
device
=
dev
)
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
num_heads
,
max_position_embeddings
*
2
,
).
to
(
dev
)
def
forward
(
self
,
...
...
@@ -108,7 +133,7 @@ class QuantLlamaAttentionFused(nn.Module):
xk
=
xk
.
view
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
)
xv
=
xv
.
view
(
bsz
,
seqlen
,
self
.
n_local_heads
,
self
.
head_dim
)
xq
,
xk
=
self
.
rotary_emb
(
xq
,
xk
,
position_ids
)
xq
,
xk
=
apply_
rotary_emb
(
xq
,
xk
,
freqs_cis
=
self
.
freqs_cis
[
self
.
start_pos
:
self
.
start_pos
+
seqlen
]
)
self
.
cache_k
=
self
.
cache_k
.
to
(
xq
)
self
.
cache_v
=
self
.
cache_v
.
to
(
xq
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment