Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
b28ec236
Commit
b28ec236
authored
Sep 03, 2023
by
Tri Dao
Browse files
[Rotary] Implement varlen rotary
parent
861c8257
Changes
3
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
with
181 additions
and
89 deletions
+181
-89
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+35
-10
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+38
-13
tests/test_rotary.py
tests/test_rotary.py
+108
-66
No files found.
flash_attn/layers/rotary.py
View file @
b28ec236
...
...
@@ -42,27 +42,37 @@ class ApplyRotaryEmb(torch.autograd.Function):
interleaved
=
False
,
inplace
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen
:
Optional
[
int
]
=
None
,
):
out
=
apply_rotary
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
interleaved
=
interleaved
,
inplace
=
inplace
,
)
if
isinstance
(
seqlen_offsets
,
int
):
ctx
.
save_for_backward
(
cos
,
sin
)
# Can't save int with save_for_backward
ctx
.
save_for_backward
(
cos
,
sin
,
cu_seqlens
)
# Can't save int with save_for_backward
ctx
.
seqlen_offsets
=
seqlen_offsets
else
:
ctx
.
save_for_backward
(
cos
,
sin
,
seqlen_offsets
)
ctx
.
save_for_backward
(
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
)
ctx
.
seqlen_offsets
=
None
ctx
.
interleaved
=
interleaved
ctx
.
inplace
=
inplace
ctx
.
max_seqlen
=
max_seqlen
return
out
if
not
inplace
else
x
@
staticmethod
def
backward
(
ctx
,
do
):
seqlen_offsets
=
ctx
.
seqlen_offsets
if
seqlen_offsets
is
None
:
cos
,
sin
,
seqlen_offsets
=
ctx
.
saved_tensors
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
=
ctx
.
saved_tensors
else
:
cos
,
sin
=
ctx
.
saved_tensors
cos
,
sin
,
cu_seqlens
=
ctx
.
saved_tensors
# TD [2023-09-02]: For some reason Triton (2.0.0.post1) errors with
# "[CUDA]: invalid device context", and cloning makes it work. Idk why. Triton 2.1.0 works.
if
not
ctx
.
interleaved
and
not
ctx
.
inplace
:
...
...
@@ -72,31 +82,46 @@ class ApplyRotaryEmb(torch.autograd.Function):
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
ctx
.
max_seqlen
,
interleaved
=
ctx
.
interleaved
,
inplace
=
ctx
.
inplace
,
conjugate
=
True
,
)
return
dx
,
None
,
None
,
None
,
None
,
None
return
dx
,
None
,
None
,
None
,
None
,
None
,
None
,
None
def
apply_rotary_emb
(
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
x
,
cos
,
sin
,
interleaved
=
False
,
inplace
=
False
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen
:
Optional
[
int
]
=
None
,
):
"""
Arguments:
x: (batch_size, seqlen, nheads, headdim)
x: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
cos, sin: (seqlen_rotary, rotary_dim / 2)
interleaved: if True, rotate pairs of even and odd dimensions (GPT-J style) instead
of 1st half and 2nd half (GPT-NeoX style).
inplace: if True, apply rotary embedding in-place.
seqlen_offsets: (batch_size,) or int. Each sequence in x is shifted by this amount.
Most commonly used in inference when we have KV cache.
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Return:
out: (batch_size, seqlen, nheads, headdim)
out: (batch_size, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim)
rotary_dim must be <= headdim
Apply rotary embedding to the first rotary_dim of x.
"""
return
ApplyRotaryEmb
.
apply
(
x
,
cos
,
sin
,
interleaved
,
inplace
,
seqlen_offsets
)
return
ApplyRotaryEmb
.
apply
(
x
,
cos
,
sin
,
interleaved
,
inplace
,
seqlen_offsets
,
cu_seqlens
,
max_seqlen
)
# For backward compatibility
...
...
flash_attn/ops/triton/rotary.py
View file @
b28ec236
from
typing
import
Union
from
typing
import
Optional
,
Union
import
torch
...
...
@@ -21,6 +21,7 @@ def rotary_kernel(
X
,
COS
,
SIN
,
CU_SEQLENS
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
...
...
@@ -40,6 +41,7 @@ def rotary_kernel(
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
IS_VARLEN
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
...
...
@@ -49,9 +51,17 @@ def rotary_kernel(
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
X
=
X
+
pid_batch
*
stride_x_batch
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
pid_batch
*
stride_out_batch
+
pid_head
*
stride_out_nheads
if
not
IS_VARLEN
:
X
=
X
+
pid_batch
*
stride_x_batch
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
pid_batch
*
stride_out_batch
+
pid_head
*
stride_out_nheads
else
:
start_idx
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
)
seqlen
=
tl
.
load
(
CU_SEQLENS
+
pid_batch
+
1
)
-
start_idx
X
=
X
+
start_idx
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
start_idx
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
if
pid_m
*
BLOCK_M
>=
seqlen
:
return
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
...
...
@@ -134,20 +144,33 @@ def apply_rotary(
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
cu_seqlens
:
Optional
[
torch
.
Tensor
]
=
None
,
max_seqlen
:
Optional
[
int
]
=
None
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
x: (batch, seqlen, nheads, headdim) if cu_seqlens is None
else (total_seqlen, nheads, headdim).
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
cu_seqlens: (batch + 1,) or None
max_seqlen: int
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
is_varlen
=
cu_seqlens
is
not
None
if
not
is_varlen
:
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
else
:
assert
max_seqlen
is
not
None
,
"If cu_seqlens is passed in, then max_seqlen must be passed"
total_seqlen
,
nheads
,
headdim
=
x
.
shape
batch_p_1
=
cu_seqlens
.
shape
[
0
]
batch
=
batch_p_1
-
1
seqlen
=
max_seqlen
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
...
...
@@ -187,22 +210,24 @@ def apply_rotary(
x
,
cos
,
sin
,
cu_seqlens
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
)
,
# strides
output
.
stride
(
1
),
output
.
stride
(
2
),
output
.
stride
(
3
),
x
.
stride
(
0
)
,
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
output
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
output
.
stride
(
-
3
),
# seqlen_stride or total_seqlen_stride
output
.
stride
(
-
2
),
# nheads_stride
output
.
stride
(
-
1
),
# headdim_stride
x
.
stride
(
0
)
if
not
is_varlen
else
0
,
# batch_strides if not varlen else 0
x
.
stride
(
-
3
),
# seqlen stride or total_seqlen_stride
x
.
stride
(
-
2
),
# nheads stride
x
.
stride
(
-
1
),
# headdim stride
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
is_varlen
,
interleaved
,
conjugate
,
BLOCK_M
,
...
...
tests/test_rotary.py
View file @
b28ec236
...
...
@@ -7,10 +7,41 @@ import torch.nn.functional as F
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_qkv_
,
apply_rotary_emb_kv_
from
flash_attn.bert_padding
import
pad_input
,
unpad_input
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
def
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
):
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
return
cos
,
sin
def
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
):
if
seqlen_offsets_type
==
0
:
return
0
elif
seqlen_offsets_type
is
int
:
return
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
return
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
def
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
):
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
batch_size
=
seqlen_offsets
.
shape
[
0
]
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
cos
.
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
return
cos_pt
,
sin_pt
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
...
...
@@ -30,35 +61,18 @@ def test_rotary_emb_func(inplace, interleaved, rotary_fraction, seqlen_offsets_t
seqlen
=
217
headdim
=
128
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
...
...
@@ -96,35 +110,18 @@ def test_rotary_emb_qkv(interleaved, rotary_fraction, seqlen_offsets_type, dtype
seqlen
=
512
headdim
=
128
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
qkv_pt
=
qkv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out
=
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
q_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
...
...
@@ -164,35 +161,16 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype)
seqlen
=
781
headdim
=
64
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
kv
=
torch
.
randn
(
batch_size
,
seqlen
,
2
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
kv_pt
=
kv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
k_pt
=
apply_rotary_emb_torch
(
kv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
...
...
@@ -210,3 +188,67 @@ def test_rotary_emb_kv(interleaved, rotary_fraction, seqlen_offsets_type, dtype)
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
kv_pt
.
grad
+
0.3
-
0.3
)
-
kv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
kv
.
grad
,
kv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize("dtype", ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize("rotary_fraction", [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize("interleaved", [True])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize("inplace", [False])
def
test_rotary_emb_varlen_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
217
headdim
=
128
device
=
"cuda"
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
lengths
=
torch
.
randint
(
max
(
1
,
seqlen
-
20
),
seqlen
+
1
,
(
batch_size
,
1
),
device
=
device
)
padding_mask
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
<
lengths
x_unpad
,
indices
,
cu_seqlens
,
max_seqlen
=
unpad_input
(
x
,
padding_mask
)
x_unpad_clone
=
x_unpad
.
clone
()
x_unpad
=
x_unpad
.
requires_grad_
()
cos
,
sin
=
generate_cos_sin
(
seqlen
,
rotary_dim
,
device
,
dtype
)
seqlen_offsets
=
generate_seqlen_offsets
(
seqlen_offsets_type
,
batch_size
,
seqlen
,
device
)
out_unpad
=
apply_rotary_emb
(
x_unpad
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
,
cu_seqlens
=
cu_seqlens
,
max_seqlen
=
max_seqlen
,
)
out
=
pad_input
(
out_unpad
,
indices
,
batch_size
,
seqlen
)
cos_pt
,
sin_pt
=
index_cos_sin
(
cos
,
sin
,
seqlen_offsets
,
seqlen
)
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
out_pt
.
masked_fill
(
rearrange
(
~
padding_mask
,
"b s -> b s 1 1"
),
0.0
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
x_grad
=
pad_input
(
x_unpad
.
grad
,
indices
,
batch_size
,
seqlen
)
print
(
f
"Grad max diff:
{
(
x_grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x_unpad
,
x_unpad_clone
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x_grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment