Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
942fcbf0
Commit
942fcbf0
authored
Sep 03, 2023
by
Tri Dao
Browse files
[Rotary] Implement rotary in Triton
parent
08e98471
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
601 additions
and
231 deletions
+601
-231
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+237
-211
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+2
-5
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+1
-3
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+182
-0
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-2
tests/test_rotary.py
tests/test_rotary.py
+177
-10
No files found.
flash_attn/layers/rotary.py
View file @
942fcbf0
# Copyright (c) 2023, Tri Dao.
# Copyright (c) 2023, Tri Dao.
import
math
import
math
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
import
rotary_emb
import
torch
import
torch
from
einops
import
rearrange
,
repeat
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.triton.rotary
import
apply_rotary
def
rotate_half
(
x
,
interleaved
=
False
):
def
rotate_half
(
x
,
interleaved
=
False
):
...
@@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False):
...
@@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False):
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
"""
"""
x: (batch_size, seqlen, nheads, headdim)
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos, sin: (seqlen, rotary_dim / 2)
or (batch_size, seqlen, rotary_dim / 2)
"""
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
"
s
d ->
s
1 (2 d)"
)
cos
=
repeat
(
cos
,
"
...
d ->
...
1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
sin
=
repeat
(
sin
,
"
s
d ->
s
1 (2 d)"
)
sin
=
repeat
(
sin
,
"
...
d ->
...
1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
return
torch
.
cat
(
return
torch
.
cat
(
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]],
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]],
dim
=-
1
,
dim
=-
1
,
...
@@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
...
@@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
):
def
forward
(
"""
ctx
,
x: (batch_size, seqlen, nheads, headdim)
x
,
cos, sin: (seqlen, rotary_dim / 2)
cos
,
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
sin
,
of 1st half and 2nd half (GPT-NeoX style).
interleaved
=
False
,
rotary_dim must be <= headdim
inplace
=
False
,
Apply rotary embedding to the first rotary_dim of x.
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
"""
):
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
out
=
apply_rotary
(
rotary_seqlen
,
rotary_dim
=
cos
.
shape
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
x_ro
=
x
[...,
:
rotary_dim
]
x1
,
x2
=
x_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
x_ro
[...,
::
2
],
x_ro
[...,
1
::
2
])
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
out_ro
=
out
[...,
:
rotary_dim
]
if
inplace
:
o1
,
o2
=
x1
,
x2
else
:
o1
,
o2
=
(
out_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
o1
,
o2
,
False
,
)
)
if
not
inplace
and
rotary_dim
<
headdim
:
if
isinstance
(
seqlen_offsets
,
int
):
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
# Can't save int with save_for_backward
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
seqlen_offsets
=
seqlen_offsets
else
:
ctx
.
save_for_backward
(
cos
,
sin
,
seqlen_offsets
)
ctx
.
seqlen_offsets
=
None
ctx
.
interleaved
=
interleaved
ctx
.
interleaved
=
interleaved
ctx
.
inplace
=
inplace
ctx
.
inplace
=
inplace
return
out
if
not
inplace
else
x
return
out
if
not
inplace
else
x
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
do
):
def
backward
(
ctx
,
do
):
cos
,
sin
=
ctx
.
saved_tensors
seqlen_offsets
=
ctx
.
seqlen_offsets
_
,
seqlen
,
_
,
headdim
=
do
.
shape
if
seqlen_offsets
is
None
:
rotary_dim
=
cos
.
shape
[
-
1
]
cos
,
sin
,
seqlen_offsets
=
ctx
.
saved_tensors
rotary_dim
*=
2
inplace
=
ctx
.
inplace
do_ro
=
do
[...,
:
rotary_dim
]
do1
,
do2
=
(
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
])
)
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
if
inplace
:
dx1
,
dx2
=
do1
,
do2
else
:
else
:
dx_ro
=
dx
[...,
:
rotary_dim
]
cos
,
sin
=
ctx
.
saved_tensors
dx1
,
dx2
=
(
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
dx_ro
.
chunk
(
2
,
dim
=-
1
)
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if
not
ctx
.
interleaved
if
not
ctx
.
interleaved
and
not
ctx
.
inplace
:
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
])
do
=
do
.
clone
()
)
dx
=
apply_rotary
(
rotary_emb
.
apply_rotary
(
do
,
do1
,
cos
,
do2
,
sin
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
seqlen_offsets
=
seqlen_offsets
,
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
interleaved
=
ctx
.
interleaved
,
dx1
,
inplace
=
ctx
.
inplace
,
dx2
,
conjugate
=
True
,
True
,
)
)
if
not
inplace
and
rotary_dim
<
headdim
:
return
dx
,
None
,
None
,
None
,
None
,
None
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
,
None
def
apply_rotary_emb
(
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
out: (batch_size, seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return
ApplyRotaryEmb
.
apply
(
x
,
cos
,
sin
,
interleaved
,
inplace
,
seqlen_offsets
)
apply_rotary_emb_func
=
ApplyRotaryEmb
.
apply
# For backward compatibility
apply_rotary_emb_func
=
apply_rotary_emb
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
):
def
forward
(
"""
ctx
,
qkv: (batch_size, seqlen, 3, nheads, headdim)
qkv
,
cos, sin: (seqlen, rotary_dim / 2)
cos
,
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
sin
,
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
cos_k
=
None
,
1st half and 2nd half (GPT-NeoX style).
sin_k
=
None
,
rotary_dim must be <= headdim
interleaved
=
False
,
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
"""
):
batch
,
seqlen
,
three
,
nheads
,
headdim
=
qkv
.
shape
batch
,
seqlen
,
three
,
nheads
,
headdim
=
qkv
.
shape
assert
three
==
3
assert
three
==
3
rotary_seqlen
,
rotary_dim
=
cos
.
shape
if
cos_k
is
None
and
sin_k
is
None
and
qkv
.
is_contiguous
():
rotary_dim
*=
2
# Call 1 kernel instead of 2 kernels
assert
rotary_dim
<=
headdim
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
assert
seqlen
<=
rotary_seqlen
# dimensions, we get the same tensor
cos_k
=
cos
if
cos_k
is
None
else
cos_k
qk
=
rearrange
(
qkv
[:,
:,
:
2
],
"b s t h d -> b s (t h) d"
)
sin_k
=
sin
if
sin_k
is
None
else
sin_k
apply_rotary
(
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
qk
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
q_ro
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
]
)
q1
,
q2
=
q_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
q_ro
[...,
::
2
],
q_ro
[...,
1
::
2
])
else
:
rotary_emb
.
apply_rotary
(
cos_k
=
cos
if
cos_k
is
None
else
cos_k
q1
,
sin_k
=
sin
if
sin_k
is
None
else
sin_k
q2
,
q
,
k
=
qkv
[:,
:,
0
],
qkv
[:,
:,
1
]
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
apply_rotary
(
q
,
cos
,
sin
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
apply_rotary
(
k
,
cos_k
,
sin_k
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
q1
,
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
q2
,
if
isinstance
(
seqlen_offsets
,
int
):
False
,
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
)
ctx
.
seqlen_offsets
=
seqlen_offsets
k_ro
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
]
else
:
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
,
seqlen_offsets
)
rotary_emb
.
apply_rotary
(
ctx
.
seqlen_offsets
=
None
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
interleaved
=
interleaved
ctx
.
interleaved
=
interleaved
return
qkv
return
qkv
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dqkv
):
def
backward
(
ctx
,
dqkv
):
cos
,
sin
,
cos_k
,
sin_k
=
ctx
.
saved_tensors
seqlen_offsets
=
ctx
.
seqlen_offsets
_
,
seqlen
,
_
,
_
,
headdim
=
dqkv
.
shape
if
seqlen_offsets
is
None
:
rotary_dim
=
cos
.
shape
[
-
1
]
cos
,
sin
,
cos_k
,
sin_k
,
seqlen_offsets
=
ctx
.
saved_tensors
rotary_dim
*=
2
else
:
dq_ro
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
]
cos
,
sin
,
cos_k
,
sin_k
=
ctx
.
saved_tensors
dq1
,
dq2
=
(
if
cos_k
is
None
and
sin_k
is
None
and
dqkv
.
is_contiguous
():
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
])
# Call 1 kernel instead of 2 kernels
)
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
rotary_emb
.
apply_rotary
(
# dimensions, we get the same tensor
dq1
,
dqk
=
rearrange
(
dqkv
[:,
:,
:
2
],
"b s t h d -> b s (t h) d"
)
dq2
,
apply_rotary
(
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
dqk
,
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
cos
,
dq1
,
sin
,
dq2
,
seqlen_offsets
=
seqlen_offsets
,
True
,
interleaved
=
ctx
.
interleaved
,
)
inplace
=
True
,
dk_ro
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
]
conjugate
=
True
,
dk1
,
dk2
=
(
)
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
else
:
)
cos_k
=
cos
if
cos_k
is
None
else
cos_k
rotary_emb
.
apply_rotary
(
sin_k
=
sin
if
sin_k
is
None
else
sin_k
dk1
,
dq
,
dk
=
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
]
dk2
,
apply_rotary
(
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
dq
,
cos
,
sin
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
,
conjugate
=
True
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
)
dk1
,
apply_rotary
(
dk2
,
dk
,
True
,
cos_k
,
)
sin_k
,
return
dqkv
,
None
,
None
,
None
,
None
,
None
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
,
conjudate
=
True
,
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
def
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return
ApplyRotaryEmbQKV_
.
apply
(
qkv
,
cos
,
sin
,
cos_k
,
sin_k
,
interleaved
,
seqlen_offsets
)
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
@
staticmethod
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
batch
,
seqlen
,
two
,
nheads
,
headdim
=
kv
.
shape
batch
,
seqlen
,
two
,
nheads
,
headdim
=
kv
.
shape
assert
two
==
2
assert
two
==
2
rotary_seqlen
,
rotary_dim
=
cos
.
shape
k
=
kv
[:,
:,
0
]
rotary_dim
*=
2
apply_rotary
(
assert
rotary_dim
<=
headdim
k
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
assert
seqlen
<=
rotary_seqlen
)
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
if
isinstance
(
seqlen_offsets
,
int
):
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
ctx
.
save_for_backward
(
cos
,
sin
)
# Can't save int with save_for_backward
rotary_emb
.
apply_rotary
(
ctx
.
seqlen_offsets
=
seqlen_offsets
k1
,
else
:
k2
,
ctx
.
save_for_backward
(
cos
,
sin
,
seqlen_offsets
)
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
ctx
.
seqlen_offsets
=
None
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
# conj=False since this is the forward pass
ctx
.
save_for_backward
(
cos
,
sin
)
ctx
.
interleaved
=
interleaved
ctx
.
interleaved
=
interleaved
return
kv
return
kv
@
staticmethod
@
staticmethod
def
backward
(
ctx
,
dkv
):
def
backward
(
ctx
,
dkv
):
cos
,
sin
=
ctx
.
saved_tensors
seqlen_offsets
=
ctx
.
seqlen_offsets
_
,
seqlen
,
_
,
_
,
headdim
=
dkv
.
shape
if
seqlen_offsets
is
None
:
rotary_dim
=
cos
.
shape
[
-
1
]
cos
,
sin
,
seqlen_offsets
=
ctx
.
saved_tensors
rotary_dim
*=
2
else
:
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
cos
,
sin
=
ctx
.
saved_tensors
dk1
,
dk2
=
(
apply_rotary
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
dkv
[:,
:,
0
],
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
ctx
.
interleaved
,
inplace
=
True
,
conjugate
=
True
,
)
)
rotary_emb
.
apply_rotary
(
return
dkv
,
None
,
None
,
None
,
None
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
# conj=True since this is the backward pass
return
dkv
,
None
,
None
,
None
apply_rotary_emb_kv_
=
ApplyRotaryEmbKV_
.
apply
apply_rotary_emb_kv_
=
ApplyRotaryEmbKV_
.
apply
def
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return
ApplyRotaryEmbKV_
.
apply
(
kv
,
cos
,
sin
,
interleaved
,
seqlen_offsets
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
"""
"""
The rotary position embeddings from RoFormer_ (Su et. al).
The rotary position embeddings from RoFormer_ (Su et. al).
...
@@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module):
...
@@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
max_seqlen
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
token in the batch.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
"""
seqlen
=
qkv
.
shape
[
1
]
seqlen
=
qkv
.
shape
[
1
]
self
.
_update_cos_sin_cache
(
seqlen
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
if
isinstance
(
seqlen_offset
,
int
):
self
.
_update_cos_sin_cache
(
seqlen
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
elif
max_seqlen
is
not
None
:
self
.
_update_cos_sin_cache
(
max_seqlen
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
if
kv
is
None
:
if
kv
is
None
:
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_cos_cached
,
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_sin_cached
,
None
,
interleaved
=
self
.
interleaved
,
None
,
seqlen_offsets
=
seqlen_offset
,
self
.
interleaved
,
)
)
else
:
else
:
return
apply_rotary_emb_qkv_
(
return
apply_rotary_emb_qkv_
(
qkv
,
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_cos_cached
,
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_sin_cached
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
,
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
,
self
.
interleaved
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
)
else
:
else
:
q
=
qkv
q
=
qkv
q
=
apply_rotary_emb_func
(
q
=
apply_rotary_emb_func
(
q
,
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_cos_cached
,
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_sin_cached
,
self
.
interleaved
,
interleaved
=
self
.
interleaved
,
True
,
inplace
=
True
,
seqlen_offsets
=
seqlen_offset
,
)
)
if
self
.
scale
is
None
:
if
self
.
scale
is
None
:
kv
=
apply_rotary_emb_kv_
(
kv
=
apply_rotary_emb_kv_
(
kv
,
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_cos_cached
,
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_sin_cached
,
self
.
interleaved
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
)
else
:
else
:
kv
=
apply_rotary_emb_kv_
(
kv
=
apply_rotary_emb_kv_
(
kv
,
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
,
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
,
self
.
interleaved
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
)
return
q
,
kv
return
q
,
kv
flash_attn/models/gpt_neox.py
View file @
942fcbf0
...
@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
...
@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases
# We don't store these biases
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.rotary_emb.inv_freq"
,
None
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
...
@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
...
@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r
"transformer.layers.\1.mixer.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
key
,
)
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.rotary_emb."
,
r
"transformer.layers.\1.mixer.rotary_emb."
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
flash_attn/ops/triton/linear.py
View file @
942fcbf0
# Adapted
on
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# Adapted
from
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
torch.autograd.function
import
FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
from
flash_attn.ops.triton.k_activations
import
(
...
...
flash_attn/ops/triton/rotary.py
0 → 100644
View file @
942fcbf0
from
typing
import
Union
import
torch
import
triton
import
triton.language
as
tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
nheads
,
rotary_dim
,
seqlen_ro
,
CACHE_KEY_SEQLEN
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
X
=
X
+
(
pid_batch
*
stride_x_batch
+
rm
[:,
None
]
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
+
rk
[
None
,
:]
*
stride_x_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
stride_x_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
not
CONJUGATE
:
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
else
:
o0
=
x0
*
cos
+
x1
*
sin
o1
=
-
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
pid_batch
*
stride_out_batch
+
rm
[:,
None
]
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
+
rk
[
None
,
:]
*
stride_out_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
stride_out_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
)
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
),
# strides
output
.
stride
(
1
),
output
.
stride
(
2
),
output
.
stride
(
3
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
tests/models/test_gpt_generation_parallel.py
View file @
942fcbf0
...
@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
...
@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
)
)
print
(
out_cg
.
sequences
)
print
(
out_cg
.
sequences
)
parallel_state
.
destroy_model_parallel
()
if
not
rotary
:
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
...
@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).
abs
().
max
().
item
()
<
3
*
(
).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
parallel_state
.
destroy_model_parallel
()
tests/test_rotary.py
View file @
942fcbf0
import
math
import
math
import
random
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_qkv_
,
apply_rotary_emb_kv_
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
...
@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
...
@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_
single_tensor
(
inplace
,
rotary_fraction
,
dtype
):
def
test_rotary_
emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
rtol
=
1e-3
batch_size
=
32
batch_size
=
32
nheads
=
4
nheads
=
4
seqlen
=
217
seqlen
=
217
headdim
=
128
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
n
(
seqlen
,
rotary_dim
//
2
,
device
=
"cuda"
)
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
out
=
apply_rotary_emb_func
(
x
,
cos
,
sin
,
inplace
)
if
seqlen_offsets_type
==
0
:
out_pt
=
apply_rotary_emb_torch
(
x_pt
,
cos
,
sin
)
seqlen_offsets
=
0
# Numerical error if we just do any arithmetic
elif
seqlen_offsets_type
is
int
:
atol
=
((
out
+
0.3
-
0.3
)
-
out
).
abs
().
max
().
item
()
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
x
.
grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x
,
x_pt
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_qkv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
512
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
qkv_pt
=
qkv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
q_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
k_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
1
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
q_pt
,
k_pt
,
qkv_pt
[:,
:,
2
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
qkv
.
grad
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
qkv_pt
.
grad
+
0.3
-
0.3
)
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
qkv
.
grad
,
qkv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_kv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
781
headdim
=
64
device
=
"cuda"
torch
.
manual_seed
(
42
)
kv
=
torch
.
randn
(
batch_size
,
seqlen
,
2
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
kv_pt
=
kv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
k_pt
=
apply_rotary_emb_torch
(
kv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
k_pt
,
kv_pt
[:,
:,
1
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
kv
.
grad
-
kv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
kv_pt
.
grad
+
0.3
-
0.3
)
-
kv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
kv
.
grad
,
kv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment