Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
942fcbf0
Commit
942fcbf0
authored
Sep 03, 2023
by
Tri Dao
Browse files
[Rotary] Implement rotary in Triton
parent
08e98471
Changes
6
Expand all
Hide whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
601 additions
and
231 deletions
+601
-231
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+237
-211
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+2
-5
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+1
-3
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+182
-0
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-2
tests/test_rotary.py
tests/test_rotary.py
+177
-10
No files found.
flash_attn/layers/rotary.py
View file @
942fcbf0
This diff is collapsed.
Click to expand it.
flash_attn/models/gpt_neox.py
View file @
942fcbf0
...
...
@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.rotary_emb.inv_freq"
,
None
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
...
...
@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.rotary_emb."
,
r
"transformer.layers.\1.mixer.rotary_emb."
,
key
,
)
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
flash_attn/ops/triton/linear.py
View file @
942fcbf0
# Adapted
on
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# Adapted
from
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
import
torch
import
triton
import
triton.language
as
tl
from
torch.autograd.function
import
FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
...
...
flash_attn/ops/triton/rotary.py
0 → 100644
View file @
942fcbf0
from
typing
import
Union
import
torch
import
triton
import
triton.language
as
tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
nheads
,
rotary_dim
,
seqlen_ro
,
CACHE_KEY_SEQLEN
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
X
=
X
+
(
pid_batch
*
stride_x_batch
+
rm
[:,
None
]
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
+
rk
[
None
,
:]
*
stride_x_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
stride_x_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
not
CONJUGATE
:
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
else
:
o0
=
x0
*
cos
+
x1
*
sin
o1
=
-
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
pid_batch
*
stride_out_batch
+
rm
[:,
None
]
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
+
rk
[
None
,
:]
*
stride_out_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
stride_out_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
)
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
),
# strides
output
.
stride
(
1
),
output
.
stride
(
2
),
output
.
stride
(
3
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
tests/models/test_gpt_generation_parallel.py
View file @
942fcbf0
...
...
@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
)
print
(
out_cg
.
sequences
)
parallel_state
.
destroy_model_parallel
()
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
...
...
@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
).
abs
().
max
().
item
()
parallel_state
.
destroy_model_parallel
()
tests/test_rotary.py
View file @
942fcbf0
import
math
import
random
import
pytest
import
torch
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_qkv_
,
apply_rotary_emb_kv_
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
...
...
@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_
single_tensor
(
inplace
,
rotary_fraction
,
dtype
):
def
test_rotary_
emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
217
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
n
(
seqlen
,
rotary_dim
//
2
,
device
=
"cuda"
)
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
out
=
apply_rotary_emb_func
(
x
,
cos
,
sin
,
inplace
)
out_pt
=
apply_rotary_emb_torch
(
x_pt
,
cos
,
sin
)
# Numerical error if we just do any arithmetic
atol
=
((
out
+
0.3
-
0.3
)
-
out
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
x
.
grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x
,
x_pt
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_qkv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
512
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
qkv_pt
=
qkv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
q_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
k_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
1
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
q_pt
,
k_pt
,
qkv_pt
[:,
:,
2
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
qkv
.
grad
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
qkv_pt
.
grad
+
0.3
-
0.3
)
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
qkv
.
grad
,
qkv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_kv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
781
headdim
=
64
device
=
"cuda"
torch
.
manual_seed
(
42
)
kv
=
torch
.
randn
(
batch_size
,
seqlen
,
2
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
kv_pt
=
kv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
k_pt
=
apply_rotary_emb_torch
(
kv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
k_pt
,
kv_pt
[:,
:,
1
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
kv
.
grad
-
kv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
kv_pt
.
grad
+
0.3
-
0.3
)
-
kv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
kv
.
grad
,
kv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment