Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
942fcbf0
Commit
942fcbf0
authored
Sep 03, 2023
by
Tri Dao
Browse files
[Rotary] Implement rotary in Triton
parent
08e98471
Changes
6
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
601 additions
and
231 deletions
+601
-231
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+237
-211
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+2
-5
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+1
-3
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+182
-0
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-2
tests/test_rotary.py
tests/test_rotary.py
+177
-10
No files found.
flash_attn/layers/rotary.py
View file @
942fcbf0
# Copyright (c) 2023, Tri Dao.
import
math
from
typing
import
Optional
,
Tuple
from
typing
import
Optional
,
Tuple
,
Union
import
rotary_emb
import
torch
from
einops
import
rearrange
,
repeat
from
flash_attn.ops.triton.rotary
import
apply_rotary
def
rotate_half
(
x
,
interleaved
=
False
):
...
...
@@ -20,12 +20,12 @@ def rotate_half(x, interleaved=False):
def
apply_rotary_emb_torch
(
x
,
cos
,
sin
,
interleaved
=
False
):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos, sin: (seqlen, rotary_dim / 2)
or (batch_size, seqlen, rotary_dim / 2)
"""
ro_dim
=
cos
.
shape
[
-
1
]
*
2
assert
ro_dim
<=
x
.
shape
[
-
1
]
cos
=
repeat
(
cos
,
"
s
d ->
s
1 (2 d)"
)
sin
=
repeat
(
sin
,
"
s
d ->
s
1 (2 d)"
)
cos
=
repeat
(
cos
,
"
...
d ->
...
1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
sin
=
repeat
(
sin
,
"
...
d ->
...
1 (2 d)"
if
not
interleaved
else
"... d -> ... 1 (d 2)"
)
return
torch
.
cat
(
[
x
[...,
:
ro_dim
]
*
cos
+
rotate_half
(
x
[...,
:
ro_dim
],
interleaved
)
*
sin
,
x
[...,
ro_dim
:]],
dim
=-
1
,
...
...
@@ -34,229 +34,242 @@ def apply_rotary_emb_torch(x, cos, sin, interleaved=False):
class
ApplyRotaryEmb
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
):
"""
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
rotary_seqlen
,
rotary_dim
=
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
assert
sin
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
x_ro
=
x
[...,
:
rotary_dim
]
x1
,
x2
=
x_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
x_ro
[...,
::
2
],
x_ro
[...,
1
::
2
])
out
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
out_ro
=
out
[...,
:
rotary_dim
]
if
inplace
:
o1
,
o2
=
x1
,
x2
else
:
o1
,
o2
=
(
out_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
out_ro
[...,
::
2
],
out_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
x1
,
x2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
o1
,
o2
,
False
,
def
forward
(
ctx
,
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
):
out
=
apply_rotary
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
if
not
inplace
and
rotary_dim
<
headdim
:
out
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
ctx
.
save_for_backward
(
cos
,
sin
)
if
isinstance
(
seqlen_offsets
,
int
):
ctx
.
save_for_backward
(
cos
,
sin
)
# Can't save int with save_for_backward
ctx
.
seqlen_offsets
=
seqlen_offsets
else
:
ctx
.
save_for_backward
(
cos
,
sin
,
seqlen_offsets
)
ctx
.
seqlen_offsets
=
None
ctx
.
interleaved
=
interleaved
ctx
.
inplace
=
inplace
return
out
if
not
inplace
else
x
@
staticmethod
def
backward
(
ctx
,
do
):
cos
,
sin
=
ctx
.
saved_tensors
_
,
seqlen
,
_
,
headdim
=
do
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
inplace
=
ctx
.
inplace
do_ro
=
do
[...,
:
rotary_dim
]
do1
,
do2
=
(
do_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
do_ro
[...,
::
2
],
do_ro
[...,
1
::
2
])
)
dx
=
torch
.
empty_like
(
do
)
if
not
inplace
else
do
if
inplace
:
dx1
,
dx2
=
do1
,
do2
seqlen_offsets
=
ctx
.
seqlen_offsets
if
seqlen_offsets
is
None
:
cos
,
sin
,
seqlen_offsets
=
ctx
.
saved_tensors
else
:
dx_ro
=
dx
[...,
:
rotary_dim
]
dx1
,
dx2
=
(
dx_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dx_ro
[...,
::
2
],
dx_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
do1
,
do2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dx1
,
dx2
,
True
,
cos
,
sin
=
ctx
.
saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if
not
ctx
.
interleaved
and
not
ctx
.
inplace
:
do
=
do
.
clone
()
dx
=
apply_rotary
(
do
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
ctx
.
interleaved
,
inplace
=
ctx
.
inplace
,
conjugate
=
True
,
)
if
not
inplace
and
rotary_dim
<
headdim
:
dx
[...,
rotary_dim
:].
copy_
(
do
[...,
rotary_dim
:])
return
dx
,
None
,
None
,
None
,
None
return
dx
,
None
,
None
,
None
,
None
,
None
def
apply_rotary_emb
(
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
out: (batch_size, seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return
ApplyRotaryEmb
.
apply
(
x
,
cos
,
sin
,
interleaved
,
inplace
,
seqlen_offsets
)
apply_rotary_emb_func
=
ApplyRotaryEmb
.
apply
# For backward compatibility
apply_rotary_emb_func
=
apply_rotary_emb
class
ApplyRotaryEmbQKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
):
"""
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of q and k.
"""
def
forward
(
ctx
,
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
):
batch
,
seqlen
,
three
,
nheads
,
headdim
=
qkv
.
shape
assert
three
==
3
rotary_seqlen
,
rotary_dim
=
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
cos_k
=
cos
if
cos_k
is
None
else
cos_k
sin_k
=
sin
if
sin_k
is
None
else
sin_k
assert
sin
.
shape
==
cos_k
.
shape
==
sin_k
.
shape
==
(
rotary_seqlen
,
rotary_dim
//
2
)
q_ro
=
qkv
[:,
:,
0
,
:,
:
rotary_dim
]
q1
,
q2
=
q_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
q_ro
[...,
::
2
],
q_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
q1
,
q2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
q1
,
q2
,
False
,
)
k_ro
=
qkv
[:,
:,
1
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
if
cos_k
is
None
and
sin_k
is
None
and
qkv
.
is_contiguous
():
# Call 1 kernel instead of 2 kernels
# We need qkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
qk
=
rearrange
(
qkv
[:,
:,
:
2
],
"b s t h d -> b s (t h) d"
)
apply_rotary
(
qk
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
else
:
cos_k
=
cos
if
cos_k
is
None
else
cos_k
sin_k
=
sin
if
sin_k
is
None
else
sin_k
q
,
k
=
qkv
[:,
:,
0
],
qkv
[:,
:,
1
]
apply_rotary
(
q
,
cos
,
sin
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
apply_rotary
(
k
,
cos_k
,
sin_k
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
if
isinstance
(
seqlen_offsets
,
int
):
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
)
ctx
.
seqlen_offsets
=
seqlen_offsets
else
:
ctx
.
save_for_backward
(
cos
,
sin
,
cos_k
,
sin_k
,
seqlen_offsets
)
ctx
.
seqlen_offsets
=
None
ctx
.
interleaved
=
interleaved
return
qkv
@
staticmethod
def
backward
(
ctx
,
dqkv
):
cos
,
sin
,
cos_k
,
sin_k
=
ctx
.
saved_tensors
_
,
seqlen
,
_
,
_
,
headdim
=
dqkv
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dq_ro
=
dqkv
[:,
:,
0
,
:,
:
rotary_dim
]
dq1
,
dq2
=
(
dq_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dq_ro
[...,
::
2
],
dq_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
dq1
,
dq2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dq1
,
dq2
,
True
,
)
dk_ro
=
dqkv
[:,
:,
1
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos_k
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin_k
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
seqlen_offsets
=
ctx
.
seqlen_offsets
if
seqlen_offsets
is
None
:
cos
,
sin
,
cos_k
,
sin_k
,
seqlen_offsets
=
ctx
.
saved_tensors
else
:
cos
,
sin
,
cos_k
,
sin_k
=
ctx
.
saved_tensors
if
cos_k
is
None
and
sin_k
is
None
and
dqkv
.
is_contiguous
():
# Call 1 kernel instead of 2 kernels
# We need dqkv to be contiguous so that when we reshape to combine (3, nheads)
# dimensions, we get the same tensor
dqk
=
rearrange
(
dqkv
[:,
:,
:
2
],
"b s t h d -> b s (t h) d"
)
apply_rotary
(
dqk
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
ctx
.
interleaved
,
inplace
=
True
,
conjugate
=
True
,
)
else
:
cos_k
=
cos
if
cos_k
is
None
else
cos_k
sin_k
=
sin
if
sin_k
is
None
else
sin_k
dq
,
dk
=
dqkv
[:,
:,
0
],
dqkv
[:,
:,
1
]
apply_rotary
(
dq
,
cos
,
sin
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
,
conjugate
=
True
)
apply_rotary
(
dk
,
cos_k
,
sin_k
,
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
,
conjudate
=
True
,
)
return
dqkv
,
None
,
None
,
None
,
None
,
None
,
None
apply_rotary_emb_qkv_
=
ApplyRotaryEmbQKV_
.
apply
def
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
cos_k
=
None
,
sin_k
=
None
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
):
"""
Arguments:
qkv: (batch_size, seqlen, 3, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
cos_k, sin_k: (seqlen, rotary_dim / 2), optional
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
qkv: (batch_size, seqlen, 3, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of Q and K.
"""
return
ApplyRotaryEmbQKV_
.
apply
(
qkv
,
cos
,
sin
,
cos_k
,
sin_k
,
interleaved
,
seqlen_offsets
)
class
ApplyRotaryEmbKV_
(
torch
.
autograd
.
Function
):
@
staticmethod
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
):
"""
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of k.
"""
def
forward
(
ctx
,
kv
,
cos
,
sin
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
):
batch
,
seqlen
,
two
,
nheads
,
headdim
=
kv
.
shape
assert
two
==
2
rotary_seqlen
,
rotary_dim
=
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
assert
seqlen
<=
rotary_seqlen
k_ro
=
kv
[:,
:,
0
,
:,
:
rotary_dim
]
k1
,
k2
=
k_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
interleaved
else
(
k_ro
[...,
::
2
],
k_ro
[...,
1
::
2
])
rotary_emb
.
apply_rotary
(
k1
,
k2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
k1
,
k2
,
False
,
)
# conj=False since this is the forward pass
ctx
.
save_for_backward
(
cos
,
sin
)
k
=
kv
[:,
:,
0
]
apply_rotary
(
k
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
True
)
if
isinstance
(
seqlen_offsets
,
int
):
ctx
.
save_for_backward
(
cos
,
sin
)
# Can't save int with save_for_backward
ctx
.
seqlen_offsets
=
seqlen_offsets
else
:
ctx
.
save_for_backward
(
cos
,
sin
,
seqlen_offsets
)
ctx
.
seqlen_offsets
=
None
ctx
.
interleaved
=
interleaved
return
kv
@
staticmethod
def
backward
(
ctx
,
dkv
):
cos
,
sin
=
ctx
.
saved_tensors
_
,
seqlen
,
_
,
_
,
headdim
=
dkv
.
shape
rotary_dim
=
cos
.
shape
[
-
1
]
rotary_dim
*=
2
dk_ro
=
dkv
[:,
:,
0
,
:,
:
rotary_dim
]
dk1
,
dk2
=
(
dk_ro
.
chunk
(
2
,
dim
=-
1
)
if
not
ctx
.
interleaved
else
(
dk_ro
[...,
::
2
],
dk_ro
[...,
1
::
2
])
seqlen_offsets
=
ctx
.
seqlen_offsets
if
seqlen_offsets
is
None
:
cos
,
sin
,
seqlen_offsets
=
ctx
.
saved_tensors
else
:
cos
,
sin
=
ctx
.
saved_tensors
apply_rotary
(
dkv
[:,
:,
0
],
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
ctx
.
interleaved
,
inplace
=
True
,
conjugate
=
True
,
)
rotary_emb
.
apply_rotary
(
dk1
,
dk2
,
rearrange
(
cos
[:
seqlen
],
"s d -> s 1 d"
),
rearrange
(
sin
[:
seqlen
],
"s d -> s 1 d"
),
dk1
,
dk2
,
True
,
)
# conj=True since this is the backward pass
return
dkv
,
None
,
None
,
None
return
dkv
,
None
,
None
,
None
,
None
apply_rotary_emb_kv_
=
ApplyRotaryEmbKV_
.
apply
def
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
interleaved
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
):
"""
Arguments:
kv: (batch_size, seqlen, 2, nheads, headdim)
cos, sin: (seqlen, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead of
1st half and 2nd half (GPT-NeoX style).
seqlen_offsets: (batch_size,) or int. Each sequence in Q and K is shifted by this amount.
Most commonly used in inference when we have KV cache.
Return:
kv: (batch_size, seqlen, 2, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding *inplace* to the first rotary_dim of K.
"""
return
ApplyRotaryEmbKV_
.
apply
(
kv
,
cos
,
sin
,
interleaved
,
seqlen_offsets
)
class
RotaryEmbedding
(
torch
.
nn
.
Module
):
"""
The rotary position embeddings from RoFormer_ (Su et. al).
...
...
@@ -372,57 +385,70 @@ class RotaryEmbedding(torch.nn.Module):
self
.
_sin_k_cached
=
(
torch
.
sin
(
freqs
)
/
scale
).
to
(
dtype
)
def
forward
(
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
int
=
0
self
,
qkv
:
torch
.
Tensor
,
kv
:
Optional
[
torch
.
Tensor
]
=
None
,
seqlen_offset
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
max_seqlen
:
Optional
[
int
]
=
None
,
)
->
Tuple
[
torch
.
Tensor
,
torch
.
Tensor
]:
"""
qkv: (batch, seqlen, 3, nheads, headdim) if kv is none,
else it's just q of shape (batch, seqlen, nheads, headdim)
kv: (batch, seqlen, 2, nheads, headdim)
seqlen_offset: can be used in generation where the qkv being passed in is only the last
token in the batch.
seqlen_offset: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
If it's a tensor of shape (batch_size,), then to update the cos / sin cache, one
should pass in max_seqlen, which will update the cos / sin cache up to that length.
Apply rotary embedding *inplace* to qkv and / or kv.
"""
seqlen
=
qkv
.
shape
[
1
]
self
.
_update_cos_sin_cache
(
seqlen
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
if
isinstance
(
seqlen_offset
,
int
):
self
.
_update_cos_sin_cache
(
seqlen
+
seqlen_offset
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
elif
max_seqlen
is
not
None
:
self
.
_update_cos_sin_cache
(
max_seqlen
,
device
=
qkv
.
device
,
dtype
=
qkv
.
dtype
)
if
kv
is
None
:
if
self
.
scale
is
None
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
None
,
None
,
self
.
interleaved
,
self
.
_cos_cached
,
self
.
_sin_cached
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
else
:
return
apply_rotary_emb_qkv_
(
qkv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
,
self
.
_cos_cached
,
self
.
_sin_cached
,
self
.
_cos_k_cached
,
self
.
_sin_k_cached
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
else
:
q
=
qkv
q
=
apply_rotary_emb_func
(
q
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
True
,
self
.
_cos_cached
,
self
.
_sin_cached
,
interleaved
=
self
.
interleaved
,
inplace
=
True
,
seqlen_offsets
=
seqlen_offset
,
)
if
self
.
scale
is
None
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_cached
[
seqlen_offset
:],
self
.
_sin_cached
[
seqlen_offset
:],
self
.
interleaved
,
self
.
_cos_cached
,
self
.
_sin_cached
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
else
:
kv
=
apply_rotary_emb_kv_
(
kv
,
self
.
_cos_k_cached
[
seqlen_offset
:],
self
.
_sin_k_cached
[
seqlen_offset
:],
self
.
interleaved
,
self
.
_cos_k_cached
,
self
.
_sin_k_cached
,
interleaved
=
self
.
interleaved
,
seqlen_offsets
=
seqlen_offset
,
)
return
q
,
kv
flash_attn/models/gpt_neox.py
View file @
942fcbf0
...
...
@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.rotary_emb.inv_freq"
,
None
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
...
...
@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.rotary_emb."
,
r
"transformer.layers.\1.mixer.rotary_emb."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
flash_attn/ops/triton/linear.py
View file @
942fcbf0
# Adapted
on
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# Adapted
from
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
torch.autograd.function
import
FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
...
...
flash_attn/ops/triton/rotary.py
0 → 100644
View file @
942fcbf0
from
typing
import
Union
import
torch
import
triton
import
triton.language
as
tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
nheads
,
rotary_dim
,
seqlen_ro
,
CACHE_KEY_SEQLEN
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
X
=
X
+
(
pid_batch
*
stride_x_batch
+
rm
[:,
None
]
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
+
rk
[
None
,
:]
*
stride_x_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
stride_x_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
not
CONJUGATE
:
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
else
:
o0
=
x0
*
cos
+
x1
*
sin
o1
=
-
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
pid_batch
*
stride_out_batch
+
rm
[:,
None
]
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
+
rk
[
None
,
:]
*
stride_out_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
stride_out_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
)
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
),
# strides
output
.
stride
(
1
),
output
.
stride
(
2
),
output
.
stride
(
3
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
tests/models/test_gpt_generation_parallel.py
View file @
942fcbf0
...
...
@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
)
print
(
out_cg
.
sequences
)
parallel_state
.
destroy_model_parallel
()
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
...
...
@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
).
abs
().
max
().
item
()
parallel_state
.
destroy_model_parallel
()
tests/test_rotary.py
View file @
942fcbf0
import
math
import
random
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_qkv_
,
apply_rotary_emb_kv_
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
...
...
@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_
single_tensor
(
inplace
,
rotary_fraction
,
dtype
):
def
test_rotary_
emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
217
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
n
(
seqlen
,
rotary_dim
//
2
,
device
=
"cuda"
)
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
out
=
apply_rotary_emb_func
(
x
,
cos
,
sin
,
inplace
)
out_pt
=
apply_rotary_emb_torch
(
x_pt
,
cos
,
sin
)
# Numerical error if we just do any arithmetic
atol
=
((
out
+
0.3
-
0.3
)
-
out
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
x
.
grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x
,
x_pt
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_qkv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
512
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
qkv_pt
=
qkv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
q_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
k_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
1
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
q_pt
,
k_pt
,
qkv_pt
[:,
:,
2
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
qkv
.
grad
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
qkv_pt
.
grad
+
0.3
-
0.3
)
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
qkv
.
grad
,
qkv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_kv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
781
headdim
=
64
device
=
"cuda"
torch
.
manual_seed
(
42
)
kv
=
torch
.
randn
(
batch_size
,
seqlen
,
2
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
kv_pt
=
kv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
k_pt
=
apply_rotary_emb_torch
(
kv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
k_pt
,
kv_pt
[:,
:,
1
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
kv
.
grad
-
kv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
kv_pt
.
grad
+
0.3
-
0.3
)
-
kv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
kv
.
grad
,
kv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment