Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
ee8984d2
Commit
ee8984d2
authored
Dec 17, 2022
by
Alexander Ploshkin
Browse files
add asserts for sin shape
parent
c7c66976
Changes
1
Hide whitespace changes
Inline
Side-by-side
Showing
1 changed file
with
2 additions
and
0 deletions
+2
-0
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+2
-0
No files found.
flash_attn/layers/rotary.py
View file @
ee8984d2
...
@@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
...
@@ -43,6 +43,7 @@ class ApplyRotaryEmb(torch.autograd.Function):
rotary_dim
*=
2
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
seqlen
<=
rotary_seqlen
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
x1
,
x2
=
x
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
x1
,
x2
=
x
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
o1
,
o2
=
out
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
x1
,
x2
)
o1
,
o2
=
out
[...,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
if
not
inplace
else
(
x1
,
x2
)
...
@@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
...
@@ -90,6 +91,7 @@ class ApplyRotaryEmbQKV_(torch.autograd.Function):
rotary_dim
*=
2
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
seqlen
<=
rotary_seqlen
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
q1
,
q2
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
].
chunk
(
2
,
dim
=-
1
)
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
's d -> s 1 d'
),
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
rearrange
(
sin
[:
seqlen
],
's d -> s 1 d'
),
q1
,
q2
,
False
)
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment