Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
dc24c226
Unverified
Commit
dc24c226
authored
Dec 17, 2022
by
Tri Dao
Committed by
GitHub
Dec 17, 2022
Browse files
Merge pull request #92 from ploshkin/rm-shape-asserts
Fix slicing dimensions in rotary embeddings
parents
b78f5a39
ee8984d2
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
13 additions
and
15 deletions
+13
-15
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+13
-15
No files found.
flash_attn/layers/rotary.py
View file @
dc24c226
...
...
@@ -43,13 +43,12 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
cos
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
x1
,
x2
=
x
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
o1
,
o2
=
out
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
x1
,
x2
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
if
not
inplace
and
rotary_dim
<
headdim
:
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
...
...
@@ -66,8 +65,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
do1
,
do2
=
do
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
dx1
,
dx2
=
dx
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
do1
,
do2
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
if
not
inplace
and
rotary_dim
<
headdim
:
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
...
...
@@ -92,14 +91,13 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
cos
.
shape
==
(
seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
(
seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
k1
,
k2
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
ctx
.
save_for_backward
(
cos
,
sin
)
return
qkv
...
...
@@ -110,11 +108,11 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dq1
,
dq2
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
dk1
,
dk2
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
return
dqkv
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment