Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
71f674ae
"vscode:/vscode.git/clone" did not exist on "10d7ecdbffefee6073790154052967a9dc2e4119"
Commit
71f674ae
authored
Nov 17, 2022
by
Tri Dao
Browse files
[Rotary] Customize base, support seqlen_offset
parent
d6ef701a
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
7 deletions
+12
-7
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+12
-7
No files found.
flash_attn/layers/rotary.py
View file @
71f674ae
...
@@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -136,20 +136,20 @@ class RotaryEmbedding(torch.nn.Module):
"""
"""
def
__init__
(
self
,
dim
_model
:
int
,
*
_
,
**
__
):
def
__init__
(
self
,
dim
:
int
,
base
=
10000
,
*
_
,
**
__
):
super
().
__init__
()
super
().
__init__
()
# Generate and save the inverse frequency buffer (non trainable)
# Generate and save the inverse frequency buffer (non trainable)
inv_freq
=
1.0
/
(
10000
**
(
torch
.
arange
(
0
,
dim
_model
,
2
).
float
()
/
dim
_model
))
inv_freq
=
1.0
/
(
base
**
(
torch
.
arange
(
0
,
dim
,
2
).
float
()
/
dim
))
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
register_buffer
(
"inv_freq"
,
inv_freq
)
self
.
_seq_len_cached
=
0
self
.
_seq_len_cached
=
0
self
.
_cos_cached
=
None
self
.
_cos_cached
=
None
self
.
_sin_cached
=
None
self
.
_sin_cached
=
None
def
_update_cos_sin_cache
(
self
,
x
):
def
_update_cos_sin_cache
(
self
,
x
,
seqlen_offset
=
0
):
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
"""x: (batch, seqlen, nheads, headdim) or (batch, seqlen, 3, nheads, headdim)
"""
"""
seqlen
=
x
.
shape
[
1
]
seqlen
=
x
.
shape
[
1
]
+
seqlen_offset
# Reset the tables if the sequence length has changed,
# Reset the tables if the sequence length has changed,
# or if we're on a new device (possibly due to tracing for instance)
# or if we're on a new device (possibly due to tracing for instance)
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
x
.
device
if
(
seqlen
>
self
.
_seq_len_cached
or
self
.
_cos_cached
.
device
!=
x
.
device
...
@@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -162,6 +162,11 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
x
.
dtype
)
self
.
_cos_cached
=
torch
.
cos
(
freqs
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
x
.
dtype
)
self
.
_sin_cached
=
torch
.
sin
(
freqs
).
to
(
x
.
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
seqlen_offset
:
int
=
0
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
self
.
_update_cos_sin_cache
(
qkv
)
"""
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
,
self
.
_sin_cached
)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
"""
self
.
_update_cos_sin_cache
(
qkv
,
seqlen_offset
)
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:])
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment