Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
OpenDAS
AutoAWQ
Commits
69733d2c
Commit
69733d2c
authored
Oct 06, 2023
by
Casper Hansen
Browse files
Create RoPE module
parent
306de683
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
44 additions
and
32 deletions
+44
-32
awq/modules/fused/attn.py
awq/modules/fused/attn.py
+44
-32
No files found.
awq/modules/fused/attn.py
View file @
69733d2c
...
...
@@ -12,31 +12,45 @@ try:
except
:
FT_INSTALLED
=
False
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
:
float
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
,
device
=
freqs
.
device
)
# type: ignore
freqs
=
torch
.
outer
(
t
,
freqs
).
float
()
# type: ignore
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
# complex64
return
freqs_cis
def
reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
):
ndim
=
x
.
ndim
assert
0
<=
1
<
ndim
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
])
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
apply_rotary_emb
(
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
freqs_cis
:
torch
.
Tensor
):
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
freqs_cis
=
reshape_for_broadcast
(
freqs_cis
,
xq_
).
to
(
xq_
.
device
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
class
RoPE
(
nn
.
Module
):
def
__init__
(
self
,
hidden_size
,
n_heads
,
max_seq_len
,
device
):
super
(
RoPE
,
self
).
__init__
()
self
.
freqs_cis
=
nn
.
Parameter
(
self
.
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
).
to
(
device
),
requires_grad
=
False
)
@
staticmethod
def
precompute_freqs_cis
(
dim
:
int
,
end
:
int
,
theta
=
10000.0
):
freqs
=
1.0
/
(
theta
**
(
torch
.
arange
(
0
,
dim
,
2
)[:
(
dim
//
2
)].
float
()
/
dim
))
t
=
torch
.
arange
(
end
)
freqs
=
torch
.
outer
(
t
,
freqs
).
float
()
freqs_cis
=
torch
.
polar
(
torch
.
ones_like
(
freqs
),
freqs
)
return
freqs_cis
@
staticmethod
def
reshape_for_broadcast
(
freqs_cis
:
torch
.
Tensor
,
x
:
torch
.
Tensor
):
ndim
=
x
.
ndim
assert
0
<=
1
<
ndim
assert
freqs_cis
.
shape
==
(
x
.
shape
[
1
],
x
.
shape
[
-
1
])
shape
=
[
d
if
i
==
1
or
i
==
ndim
-
1
else
1
for
i
,
d
in
enumerate
(
x
.
shape
)]
return
freqs_cis
.
view
(
*
shape
)
def
forward
(
self
,
xq
:
torch
.
Tensor
,
xk
:
torch
.
Tensor
,
start_pos
:
int
,
seqlen
:
int
):
xq_
=
torch
.
view_as_complex
(
xq
.
float
().
reshape
(
*
xq
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
xk_
=
torch
.
view_as_complex
(
xk
.
float
().
reshape
(
*
xk
.
shape
[:
-
1
],
2
,
-
1
).
transpose
(
-
2
,
-
1
).
contiguous
()
)
freqs_cis
=
self
.
freqs_cis
[
start_pos
:
start_pos
+
seqlen
]
freqs_cis
=
self
.
reshape_for_broadcast
(
freqs_cis
,
xq_
).
to
(
xq_
.
device
)
xq_out
=
torch
.
view_as_real
(
xq_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
xk_out
=
torch
.
view_as_real
(
xk_
*
freqs_cis
).
transpose
(
-
2
,
-
1
).
flatten
(
3
)
return
xq_out
.
type_as
(
xq
),
xk_out
.
type_as
(
xk
)
class
ALiBi
(
nn
.
Module
):
def
__init__
(
self
,
n_heads
,
max_seq_len
,
device
,
alibi_bias_max
=
8
):
...
...
@@ -101,12 +115,9 @@ class QuantAttentionFused(nn.Module):
self
.
rotary_dim
=
0
self
.
is_neox
=
False
else
:
self
.
freqs_cis
=
precompute_freqs_cis
(
hidden_size
//
n_heads
,
max_seq_len
*
2
,
).
to
(
dev
)
self
.
alibi
=
None
self
.
rope
=
RoPE
(
hidden_size
,
n_heads
,
max_seq_len
,
dev
)
self
.
rotary_dim
=
self
.
head_dim
self
.
alibi_slopes
=
None
self
.
is_neox
=
True
def
forward
(
self
,
hidden_states
:
torch
.
Tensor
,
attention_mask
=
None
,
*
args
,
**
kwargs
):
...
...
@@ -134,7 +145,7 @@ class QuantAttentionFused(nn.Module):
xv
=
xv
.
view
((
bsz
,
seqlen
)
+
self
.
attention_shapes
[
"xv_view"
])
if
not
self
.
use_alibi
:
xq
,
xk
=
apply_rotary_emb
(
xq
,
xk
,
freqs_cis
=
self
.
freqs_cis
[
self
.
start_pos
:
self
.
start_pos
+
seqlen
]
)
xq
,
xk
=
self
.
rope
.
forward
(
xq
,
xk
,
self
.
start_pos
,
seqlen
)
self
.
cache
.
to
(
xq
)
...
...
@@ -176,6 +187,7 @@ class QuantAttentionFused(nn.Module):
xk
=
xk
.
view
((
bsz
,)
+
self
.
attention_shapes
[
"single_xk_view"
])
xv
=
xv
.
view
((
bsz
,)
+
self
.
attention_shapes
[
"single_xv_view"
])
alibi_slopes
=
self
.
alibi
.
slopes
if
self
.
alibi
is
not
None
else
None
attention_weight
=
ft_inference_engine
.
single_query_attention
(
xq
,
# query
xk
,
# key
...
...
@@ -183,7 +195,7 @@ class QuantAttentionFused(nn.Module):
self
.
cache
.
k
,
# key cache
self
.
cache
.
v
,
# value cache
None
,
# length per sample
self
.
alibi
.
slopes
,
# alibi slopes
alibi
_
slopes
,
# alibi slopes
self
.
start_pos
,
# timestep
self
.
rotary_dim
,
# rotary embedding dimension
10000
,
# rotary embedding base
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment