Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
942fcbf0
Commit
942fcbf0
authored
Sep 03, 2023
by
Tri Dao
Browse files
[Rotary] Implement rotary in Triton
parent
08e98471
Changes
6
Expand all
Show whitespace changes
Inline
Side-by-side
Showing
6 changed files
with
601 additions
and
231 deletions
+601
-231
flash_attn/layers/rotary.py
flash_attn/layers/rotary.py
+237
-211
flash_attn/models/gpt_neox.py
flash_attn/models/gpt_neox.py
+2
-5
flash_attn/ops/triton/linear.py
flash_attn/ops/triton/linear.py
+1
-3
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+182
-0
tests/models/test_gpt_generation_parallel.py
tests/models/test_gpt_generation_parallel.py
+2
-2
tests/test_rotary.py
tests/test_rotary.py
+177
-10
No files found.
flash_attn/layers/rotary.py
View file @
942fcbf0
This diff is collapsed.
Click to expand it.
flash_attn/models/gpt_neox.py
View file @
942fcbf0
...
@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
...
@@ -68,6 +68,8 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
# We don't store these biases
# We don't store these biases
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.masked_bias"
)
# We don't store these
state_dict
.
pop
(
f
"transformer.layers.
{
l
}
.attention.rotary_emb.inv_freq"
,
None
)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# GPT-NeoX stores Wqkv as ((nheads 3 headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
# while we store Wqkv as ((3 nheads headdim), hidden_dim)
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
headdim
=
config
.
hidden_size
//
config
.
num_attention_heads
...
@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
...
@@ -89,11 +91,6 @@ def remap_state_dict_hf_gpt_neox(state_dict, config):
r
"transformer.layers.\1.mixer.out_proj."
,
r
"transformer.layers.\1.mixer.out_proj."
,
key
,
key
,
)
)
key
=
re
.
sub
(
r
"^transformer.layers.(\d+).attention.rotary_emb."
,
r
"transformer.layers.\1.mixer.rotary_emb."
,
key
,
)
return
key
return
key
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
state_dict
=
OrderedDict
((
key_mapping_attn
(
k
),
v
)
for
k
,
v
in
state_dict
.
items
())
...
...
flash_attn/ops/triton/linear.py
View file @
942fcbf0
# Adapted
on
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# Adapted
from
https://github.com/ELS-RD/kernl/blob/main/src/kernl/implementations/linear_layer.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
# and https://github.com/openai/triton/blob/master/python/triton/ops/matmul.py
from
typing
import
Optional
from
typing
import
Optional
import
torch
import
torch
import
triton
import
triton
import
triton.language
as
tl
import
triton.language
as
tl
from
torch.autograd.function
import
FunctionCtx
from
torch.cuda.amp
import
custom_fwd
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
triton.ops.matmul_perf_model
import
early_config_prune
,
estimate_matmul_time
from
flash_attn.ops.triton.k_activations
import
(
from
flash_attn.ops.triton.k_activations
import
(
...
...
flash_attn/ops/triton/rotary.py
0 → 100644
View file @
942fcbf0
from
typing
import
Union
import
torch
import
triton
import
triton.language
as
tl
# @triton.autotune(
# configs=[
# triton.Config({"BLOCK_M": 2}),
# triton.Config({"BLOCK_M": 4}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# )
@
triton
.
jit
def
rotary_kernel
(
OUT
,
# Pointers to matrices
X
,
COS
,
SIN
,
SEQLEN_OFFSETS
,
# this could be int or a pointer
# Matrix dimensions
seqlen
,
nheads
,
rotary_dim
,
seqlen_ro
,
CACHE_KEY_SEQLEN
,
# strides
stride_out_batch
,
stride_out_seqlen
,
stride_out_nheads
,
stride_out_headdim
,
stride_x_batch
,
stride_x_seqlen
,
stride_x_nheads
,
stride_x_headdim
,
# Meta-parameters
BLOCK_K
:
tl
.
constexpr
,
IS_SEQLEN_OFFSETS_TENSOR
:
tl
.
constexpr
,
INTERLEAVED
:
tl
.
constexpr
,
CONJUGATE
:
tl
.
constexpr
,
BLOCK_M
:
tl
.
constexpr
,
):
pid_m
=
tl
.
program_id
(
axis
=
0
)
pid_batch
=
tl
.
program_id
(
axis
=
1
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
X
=
X
+
(
pid_batch
*
stride_x_batch
+
rm
[:,
None
]
*
stride_x_seqlen
+
pid_head
*
stride_x_nheads
+
rk
[
None
,
:]
*
stride_x_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
cos
=
tl
.
load
(
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X
+
stride_x_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
if
not
CONJUGATE
:
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x0
*
sin
+
x1
*
cos
else
:
o0
=
x0
*
cos
+
x1
*
sin
o1
=
-
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
pid_batch
*
stride_out_batch
+
rm
[:,
None
]
*
stride_out_seqlen
+
pid_head
*
stride_out_nheads
+
rk
[
None
,
:]
*
stride_out_headdim
*
(
2
if
INTERLEAVED
else
1
)
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
stride_out_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
)
def
apply_rotary
(
x
:
torch
.
Tensor
,
cos
:
torch
.
Tensor
,
sin
:
torch
.
Tensor
,
seqlen_offsets
:
Union
[
int
,
torch
.
Tensor
]
=
0
,
interleaved
=
False
,
inplace
=
False
,
conjugate
=
False
,
)
->
torch
.
Tensor
:
"""
Arguments:
x: (batch, seqlen, nheads, headdim)
cos: (seqlen_ro, rotary_dim / 2)
sin: (seqlen_ro, rotary_dim / 2)
seqlen_offsets: integer or integer tensor of size (batch,)
Returns:
y: (batch, seqlen, nheads, headdim)
"""
batch
,
seqlen
,
nheads
,
headdim
=
x
.
shape
seqlen_ro
,
rotary_dim
=
cos
.
shape
assert
sin
.
shape
==
cos
.
shape
rotary_dim
*=
2
assert
rotary_dim
<=
headdim
,
"rotary_dim must be <= headdim"
assert
headdim
<=
256
,
"Only support headdim <= 256"
assert
seqlen_ro
>=
seqlen
,
"seqlen_ro must be >= seqlen"
assert
(
cos
.
dtype
==
sin
.
dtype
),
f
"cos and sin must have the same dtype, got
{
cos
.
dtype
}
and
{
sin
.
dtype
}
"
assert
(
x
.
dtype
==
cos
.
dtype
),
f
"Input and cos/sin must have the same dtype, got
{
x
.
dtype
}
and
{
cos
.
dtype
}
"
cos
,
sin
=
cos
.
contiguous
(),
sin
.
contiguous
()
if
isinstance
(
seqlen_offsets
,
torch
.
Tensor
):
assert
seqlen_offsets
.
shape
==
(
batch
,)
assert
seqlen_offsets
.
dtype
in
[
torch
.
int32
,
torch
.
int64
]
seqlen_offsets
=
seqlen_offsets
.
contiguous
()
else
:
assert
seqlen_offsets
+
seqlen
<=
seqlen_ro
output
=
torch
.
empty_like
(
x
)
if
not
inplace
else
x
if
rotary_dim
<
headdim
and
not
inplace
:
output
[...,
rotary_dim
:].
copy_
(
x
[...,
rotary_dim
:])
BLOCK_K
=
(
32
if
rotary_dim
<=
32
else
(
64
if
rotary_dim
<=
64
else
(
128
if
rotary_dim
<=
128
else
256
))
)
grid
=
lambda
META
:
(
triton
.
cdiv
(
seqlen
,
META
[
"BLOCK_M"
]),
batch
,
nheads
)
# noqa
BLOCK_M
=
4
if
interleaved
else
(
8
if
rotary_dim
<=
64
else
4
)
rotary_kernel
[
grid
](
output
,
# data ptrs
x
,
cos
,
sin
,
seqlen_offsets
,
seqlen
,
# shapes
nheads
,
rotary_dim
,
seqlen_ro
,
seqlen
//
128
,
# key for triton cache (limit number of compilations)
output
.
stride
(
0
),
# strides
output
.
stride
(
1
),
output
.
stride
(
2
),
output
.
stride
(
3
),
x
.
stride
(
0
),
x
.
stride
(
1
),
x
.
stride
(
2
),
x
.
stride
(
3
),
BLOCK_K
,
isinstance
(
seqlen_offsets
,
torch
.
Tensor
),
interleaved
,
conjugate
,
BLOCK_M
,
)
return
output
tests/models/test_gpt_generation_parallel.py
View file @
942fcbf0
...
@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
...
@@ -131,6 +131,8 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
)
)
print
(
out_cg
.
sequences
)
print
(
out_cg
.
sequences
)
parallel_state
.
destroy_model_parallel
()
if
not
rotary
:
if
not
rotary
:
out_hf
=
model_hf
.
generate
(
out_hf
=
model_hf
.
generate
(
input_ids
=
input_ids
,
input_ids
=
input_ids
,
...
@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
...
@@ -171,5 +173,3 @@ def test_tensor_parallel(model_name, rotary, fused_ft_kernel, world_size):
).
abs
().
max
().
item
()
<
3
*
(
).
abs
().
max
().
item
()
<
3
*
(
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
torch
.
stack
(
out_hf
.
scores
,
1
)
-
torch
.
stack
(
out_ref
.
scores
,
1
)
).
abs
().
max
().
item
()
).
abs
().
max
().
item
()
parallel_state
.
destroy_model_parallel
()
tests/test_rotary.py
View file @
942fcbf0
import
math
import
math
import
random
import
pytest
import
pytest
import
torch
import
torch
import
torch.nn.functional
as
F
import
torch.nn.functional
as
F
from
einops
import
rearrange
from
einops
import
rearrange
from
flash_attn.layers.rotary
import
apply_rotary_emb_func
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb
,
apply_rotary_emb_torch
from
flash_attn.layers.rotary
import
apply_rotary_emb_qkv_
,
apply_rotary_emb_kv_
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
is_sm8x
=
torch
.
cuda
.
get_device_capability
(
"cuda"
)
>=
(
8
,
0
)
...
@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
...
@@ -13,33 +15,198 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [0.5])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_
single_tensor
(
inplace
,
rotary_fraction
,
dtype
):
def
test_rotary_
emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
rtol
=
1e-3
batch_size
=
32
batch_size
=
32
nheads
=
4
nheads
=
4
seqlen
=
217
seqlen
=
217
headdim
=
128
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
x
=
torch
.
randn
(
x
=
torch
.
randn
(
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
"cuda"
,
requires_grad
=
True
batch_size
,
seqlen
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
)
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
x_pt
=
x
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
n
(
seqlen
,
rotary_dim
//
2
,
device
=
"cuda"
)
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
out
=
apply_rotary_emb_func
(
x
,
cos
,
sin
,
inplace
)
if
seqlen_offsets_type
==
0
:
out_pt
=
apply_rotary_emb_torch
(
x_pt
,
cos
,
sin
)
seqlen_offsets
=
0
# Numerical error if we just do any arithmetic
elif
seqlen_offsets_type
is
int
:
atol
=
((
out
+
0.3
-
0.3
)
-
out
).
abs
().
max
().
item
()
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb
(
x
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
,
inplace
=
inplace
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
out_pt
=
apply_rotary_emb_torch
(
x_pt
.
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
g_pt
=
g
.
clone
()
# If inplace=True, we might modify the gradient inplace
out
.
backward
(
g
)
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
x
.
grad
-
x_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
if
not
inplace
:
assert
torch
.
equal
(
x
,
x_pt
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
atol
=
((
x_pt
.
grad
+
0.3
-
0.3
)
-
x_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
assert
torch
.
allclose
(
x
.
grad
,
x_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_qkv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
512
headdim
=
128
device
=
"cuda"
torch
.
manual_seed
(
42
)
qkv
=
torch
.
randn
(
batch_size
,
seqlen
,
3
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
qkv_pt
=
qkv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_qkv_
(
qkv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
q_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
k_pt
=
apply_rotary_emb_torch
(
qkv_pt
[:,
:,
1
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
q_pt
,
k_pt
,
qkv_pt
[:,
:,
2
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
qkv
.
grad
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
qkv_pt
.
grad
+
0.3
-
0.3
)
-
qkv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
qkv
.
grad
,
qkv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
@
pytest
.
mark
.
parametrize
(
"dtype"
,
([
torch
.
float16
]
if
not
is_sm8x
else
[
torch
.
float16
,
torch
.
bfloat16
])
)
# @pytest.mark.parametrize('dtype', ([torch.float16]))
@
pytest
.
mark
.
parametrize
(
"seqlen_offsets_type"
,
[
0
,
int
,
torch
.
Tensor
])
# @pytest.mark.parametrize("seqlen_offsets_type", [0])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [False])
def
test_rotary_emb_kv
(
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
rtol
=
1e-3
batch_size
=
32
nheads
=
4
seqlen
=
781
headdim
=
64
device
=
"cuda"
torch
.
manual_seed
(
42
)
kv
=
torch
.
randn
(
batch_size
,
seqlen
,
2
,
nheads
,
headdim
,
dtype
=
dtype
,
device
=
device
,
requires_grad
=
True
)
kv_pt
=
kv
.
detach
().
clone
().
requires_grad_
()
rotary_dim
=
int
(
rotary_fraction
*
headdim
)
assert
rotary_dim
%
2
==
0
angle
=
torch
.
rand
(
seqlen
*
2
,
rotary_dim
//
2
,
device
=
device
)
*
2
*
math
.
pi
cos
=
torch
.
cos
(
angle
).
to
(
dtype
=
dtype
)
sin
=
torch
.
sin
(
angle
).
to
(
dtype
=
dtype
)
if
seqlen_offsets_type
==
0
:
seqlen_offsets
=
0
elif
seqlen_offsets_type
is
int
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
1
,
)).
item
()
elif
seqlen_offsets_type
is
torch
.
Tensor
:
seqlen_offsets
=
torch
.
randint
(
0
,
seqlen
+
1
,
(
batch_size
,),
dtype
=
torch
.
int32
,
device
=
device
)
out
=
apply_rotary_emb_kv_
(
kv
,
cos
,
sin
,
seqlen_offsets
=
seqlen_offsets
,
interleaved
=
interleaved
)
if
seqlen_offsets_type
is
torch
.
Tensor
:
arange
=
rearrange
(
torch
.
arange
(
seqlen
,
device
=
device
),
"s -> 1 s"
)
idx
=
rearrange
(
seqlen_offsets
,
"b -> b 1"
)
+
arange
cos_pt
=
rearrange
(
cos
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
sin_pt
=
rearrange
(
sin
[
idx
.
flatten
()],
"(b s) d -> b s d"
,
b
=
batch_size
)
else
:
cos_pt
=
cos
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
sin_pt
=
sin
[
seqlen_offsets
:
seqlen_offsets
+
seqlen
]
k_pt
=
apply_rotary_emb_torch
(
kv_pt
[:,
:,
0
].
float
(),
cos_pt
.
float
(),
sin_pt
.
float
(),
interleaved
=
interleaved
).
to
(
dtype
=
dtype
)
out_pt
=
torch
.
stack
([
k_pt
,
kv_pt
[:,
:,
1
]],
dim
=
2
)
print
(
f
"Output max diff:
{
(
out
-
out_pt
).
abs
().
max
().
item
()
}
"
)
g
=
torch
.
randn_like
(
out
)
g_pt
=
g
.
clone
()
# Since inplace=True, we modify the gradient inplace
out
.
backward
(
g
)
out_pt
.
backward
(
g_pt
)
print
(
f
"Grad max diff:
{
(
kv
.
grad
-
kv_pt
.
grad
).
abs
().
max
().
item
()
}
"
)
# Numerical error if we just do any arithmetic
atol
=
((
out_pt
+
0.3
-
0.3
)
-
out_pt
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
out
,
out_pt
,
rtol
=
rtol
,
atol
=
2
*
atol
)
atol
=
((
kv_pt
.
grad
+
0.3
-
0.3
)
-
kv_pt
.
grad
).
abs
().
max
().
item
()
assert
torch
.
allclose
(
kv
.
grad
,
kv_pt
.
grad
,
rtol
=
rtol
,
atol
=
2
*
atol
)
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment