Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
c7c66976
"...data/git@developer.sourcefind.cn:wangsen/paddle_dbnet.git" did not exist on "a28ef7f0261ecf8572db193f6265e5ab3d5acfaa"
Commit
c7c66976
authored
Dec 16, 2022
by
Alexander Ploshkin
Browse files
fix slicing dimensions
parent
96656b93
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
12 additions
and
14 deletions
+12
-14
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+12
-14
No files found.
flash_attn/layers/rotary.py
View file @
c7c66976
...
@@ -46,8 +46,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -46,8 +46,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
x1
,
x2
=
x
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
x1
,
x2
=
x
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
o1
,
o2
=
out
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
x1
,
x2
)
o1
,
o2
=
out
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
x1
,
x2
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
o1
,
o2
,
False
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
not
inplace
and
rotary_dim
<
headdim
:
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
save_for_backward
(
cos
,
sin
)
...
@@ -64,8 +64,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -64,8 +64,8 @@ class ApplyRotaryEmb(torch.autograd.Function):
do1
,
do2
=
do
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
do1
,
do2
=
do
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
dx1
,
dx2
=
dx
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
do1
,
do2
)
dx1
,
dx2
=
dx
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
do1
,
do2
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dx1
,
dx2
,
True
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
not
inplace
and
rotary_dim
<
headdim
:
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
return
dx
,
None
,
None
,
None
...
@@ -90,14 +90,12 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -90,14 +90,12 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
*=
2
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
seqlen
<=
rotary_seqlen
assert
cos
.
shape
==
(
seqlen
,
rotary_dim
//
2
)
assert
sin
.
shape
==
(
seqlen
,
rotary_dim
//
2
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
k1
,
k2
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
k1
,
k2
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
k1
,
k2
,
False
)
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
save_for_backward
(
cos
,
sin
)
return
qkv
return
qkv
...
@@ -108,11 +106,11 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -108,11 +106,11 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
rotary_dim
*=
2
dq1
,
dq2
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
dq1
,
dq2
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dq1
,
dq2
,
True
)
dk1
,
dk2
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
dk1
,
dk2
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[
:,
:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[
:,
:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
dk1
,
dk2
,
True
)
return
dqkv
,
None
,
None
return
dqkv
,
None
,
None
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment