Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
1c523c1c
Commit
1c523c1c
authored
Sep 03, 2023
by
Tri Dao
Browse files
[Rotary] Speed up rotary kernel when interleaved=True
parent
26d7d92f
Changes
2
Hide whitespace changes
Inline
Side-by-side
Showing
2 changed files
with
77 additions
and
45 deletions
+77
-45
flash_attn/ops/triton/rotary.py
flash_attn/ops/triton/rotary.py
+76
-44
tests/test_rotary.py
tests/test_rotary.py
+1
-1
No files found.
flash_attn/ops/triton/rotary.py
View file @
1c523c1c
...
@@ -13,7 +13,7 @@ import triton.language as tl
...
@@ -13,7 +13,7 @@ import triton.language as tl
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 8}),
# triton.Config({"BLOCK_M": 16}),
# triton.Config({"BLOCK_M": 16}),
# ],
# ],
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
# key=["CACHE_KEY_SEQLEN", "BLOCK_K", "INTERLEAVED"]
,
# )
# )
@
triton
.
jit
@
triton
.
jit
def
rotary_kernel
(
def
rotary_kernel
(
...
@@ -49,56 +49,88 @@ def rotary_kernel(
...
@@ -49,56 +49,88 @@ def rotary_kernel(
pid_head
=
tl
.
program_id
(
axis
=
2
)
pid_head
=
tl
.
program_id
(
axis
=
2
)
rotary_dim_half
=
rotary_dim
//
2
rotary_dim_half
=
rotary_dim
//
2
X
=
X
+
pid_batch
*
stride_x_batch
+
pid_head
*
stride_x_nheads
OUT
=
OUT
+
pid_batch
*
stride_out_batch
+
pid_head
*
stride_out_nheads
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rm
=
pid_m
*
BLOCK_M
+
tl
.
arange
(
0
,
BLOCK_M
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
if
not
IS_SEQLEN_OFFSETS_TENSOR
:
rm_cs
=
rm
+
SEQLEN_OFFSETS
rm_cs
=
rm
+
SEQLEN_OFFSETS
else
:
else
:
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
rm_cs
=
rm
+
tl
.
load
(
SEQLEN_OFFSETS
+
pid_batch
)
rk
=
tl
.
arange
(
0
,
BLOCK_K
)
rk_half
=
tl
.
arange
(
0
,
BLOCK_K
//
2
)
X
=
X
+
(
if
not
INTERLEAVED
:
pid_batch
*
stride_x_batch
# Load the 1st and 2nd halves of X, do calculation, then store to 1st and 2nd halves of OUT
+
rm
[:,
None
]
*
stride_x_seqlen
X
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_half
[
None
,
:]
*
stride_x_headdim
)
+
pid_head
*
stride_x_nheads
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
+
rk
[
None
,
:]
*
stride_x_headdim
*
(
2
if
INTERLEAVED
else
1
)
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_half
[
None
,
:])
)
cos
=
tl
.
load
(
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk
[
None
,
:])
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
cos
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
COS
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
1.0
).
to
(
tl
.
float32
)
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
sin
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
tl
.
float32
)
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
x0
=
tl
.
load
(
X
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
).
to
(
X
+
rotary_dim_half
*
stride_x_headdim
,
tl
.
float32
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
)
other
=
0.0
,
x1
=
tl
.
load
(
).
to
(
tl
.
float32
)
X
+
stride_x_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
if
not
CONJUGATE
:
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
o0
=
x0
*
cos
-
x1
*
sin
other
=
0.0
,
o1
=
x0
*
sin
+
x1
*
cos
).
to
(
tl
.
float32
)
else
:
if
not
CONJUGATE
:
o0
=
x0
*
cos
+
x1
*
sin
o0
=
x0
*
cos
-
x1
*
sin
o1
=
-
x0
*
sin
+
x1
*
cos
o1
=
x0
*
sin
+
x1
*
cos
# write back result
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk_half
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
))
tl
.
store
(
OUT
+
rotary_dim_half
*
stride_out_headdim
,
o1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_half
[
None
,
:]
<
rotary_dim_half
),
)
else
:
else
:
o0
=
x0
*
cos
+
x1
*
sin
# We don't want to load X[0, 2, 4, ...] and X[1, 3, 5, ...] separately since both are slow.
o1
=
-
x0
*
sin
+
x1
*
cos
# Instead, we load x0 = X[0, 1, 2, 3, ...] and x1 = X[1, 0, 3, 2, ...].
# Loading x0 will be fast but x1 will be slow.
# write back result
# Then we load cos = COS[0, 0, 1, 1, ...] and sin = SIN[0, 0, 1, 1, ...].
OUT
=
OUT
+
(
# Then we do the calculation and use tl.where to pick put the right outputs for the even
pid_batch
*
stride_out_batch
# and for the odd indices.
+
rm
[:,
None
]
*
stride_out_seqlen
rk_swap
=
rk
+
((
rk
+
1
)
%
2
)
*
2
-
1
# 1, 0, 3, 2, 5, 4, ...
+
pid_head
*
stride_out_nheads
rk_repeat
=
tl
.
arange
(
0
,
BLOCK_K
)
//
2
+
rk
[
None
,
:]
*
stride_out_headdim
*
(
2
if
INTERLEAVED
else
1
)
X0
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk
[
None
,
:]
*
stride_x_headdim
)
)
X1
=
X
+
(
rm
[:,
None
]
*
stride_x_seqlen
+
rk_swap
[
None
,
:]
*
stride_x_headdim
)
tl
.
store
(
OUT
,
o0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
))
COS
=
COS
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
tl
.
store
(
SIN
=
SIN
+
(
rm_cs
[:,
None
]
*
rotary_dim_half
+
rk_repeat
[
None
,
:])
OUT
+
stride_out_headdim
*
(
1
if
INTERLEAVED
else
rotary_dim_half
),
cos
=
tl
.
load
(
o1
,
COS
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim_half
),
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
)
other
=
1.0
,
).
to
(
tl
.
float32
)
sin
=
tl
.
load
(
SIN
,
mask
=
(
rm_cs
[:,
None
]
<
seqlen_ro
)
&
(
rk_repeat
[
None
,
:]
<
rotary_dim_half
),
other
=
0.0
,
).
to
(
tl
.
float32
)
x0
=
tl
.
load
(
X0
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
x1
=
tl
.
load
(
X1
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk_swap
[
None
,
:]
<
rotary_dim
),
other
=
0.0
).
to
(
tl
.
float32
)
if
not
CONJUGATE
:
o0
=
x0
*
cos
-
x1
*
sin
o1
=
x1
*
sin
+
x0
*
cos
else
:
o0
=
x0
*
cos
+
x1
*
sin
o1
=
-
x1
*
sin
+
x0
*
cos
out
=
tl
.
where
(
rk
[
None
,
:]
%
2
==
0
,
o0
,
o1
)
OUT
=
OUT
+
(
rm
[:,
None
]
*
stride_out_seqlen
+
rk
[
None
,
:]
*
stride_out_headdim
)
tl
.
store
(
OUT
,
out
,
mask
=
(
rm
[:,
None
]
<
seqlen
)
&
(
rk
[
None
,
:]
<
rotary_dim
))
def
apply_rotary
(
def
apply_rotary
(
...
...
tests/test_rotary.py
View file @
1c523c1c
...
@@ -20,7 +20,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
...
@@ -20,7 +20,7 @@ is_sm8x = torch.cuda.get_device_capability("cuda") >= (8, 0)
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
1.0
,
0.5
])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
# @pytest.mark.parametrize('rotary_fraction', [1.0])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"interleaved"
,
[
False
,
True
])
# @pytest.mark.parametrize('interleaved', [
Fals
e])
# @pytest.mark.parametrize('interleaved', [
Tru
e])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
@
pytest
.
mark
.
parametrize
(
"inplace"
,
[
False
,
True
])
# @pytest.mark.parametrize('inplace', [False])
# @pytest.mark.parametrize('inplace', [False])
def
test_rotary_emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
def
test_rotary_emb_func
(
inplace
,
interleaved
,
rotary_fraction
,
seqlen_offsets_type
,
dtype
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment