Skip to content
GitLab
Menu
Projects
Groups
Snippets
Loading...
Help
Help
Support
Community forum
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
Menu
Open sidebar
gaoqiong
flash-attention
Commits
166f33fd
Commit
166f33fd
authored
Feb 26, 2024
by
skrider
Committed by
Woosuk Kwon
Mar 28, 2024
Browse files
reshape rotary sin/cos copy to align with paged KV copy
parent
53c6eb1f
Changes
4
Hide whitespace changes
Inline
Side-by-side
Showing
4 changed files
with
46 additions
and
12 deletions
+46
-12
csrc/flash_attn/src/flash_fwd_kernel.h
csrc/flash_attn/src/flash_fwd_kernel.h
+22
-9
csrc/flash_attn/src/kernel_traits.h
csrc/flash_attn/src/kernel_traits.h
+11
-1
csrc/flash_attn/src/utils.h
csrc/flash_attn/src/utils.h
+12
-1
tests/test_flash_attn.py
tests/test_flash_attn.py
+1
-1
No files found.
csrc/flash_attn/src/flash_fwd_kernel.h
View file @
166f33fd
...
...
@@ -652,7 +652,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Repeat the partitioning with identity layouts
Tensor
tQcQ
=
gmem_thr_copy_Q
.
partition_S
(
cQ
);
// (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor
tKVcKV
=
gmem_thr_copy_KV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor
tKVcKV_
=
gmem_thr_copy_KV
.
partition_S
(
cKV
);
// (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor
tKVcKV
=
make_tensor
(
tKVcKV_
.
data
(),
reshape_thread_tile
(
tKVcKV_
.
layout
()));
// Allocate predicate tensors for k
Tensor
tQpQ
=
make_tensor
<
bool
>
(
make_shape
(
size
<
2
>
(
tQsQ
)));
...
...
@@ -669,11 +670,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
// Prologue
// Copy from Knew to K, optionally apply rotary embedding.
typename
Kernel_traits
::
GmemTiledCopyRotcossin
gmem_tiled_copy_rotary
;
auto
gmem_thr_copy_rotary
=
gmem_tiled_copy_rotary
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyRotcossinCont
gmem_tiled_copy_rotary_cont
;
auto
gmem_thr_copy_rotary_cont
=
gmem_tiled_copy_rotary_cont
.
get_thread_slice
(
tidx
);
if
constexpr
(
Append_KV
)
{
typename
Kernel_traits
::
GmemTiledCopyRotcossinPaged
gmem_tiled_copy_rotary
;
auto
gmem_thr_copy_rotary
=
gmem_tiled_copy_rotary
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyRotcossinContPaged
gmem_tiled_copy_rotary_cont
;
auto
gmem_thr_copy_rotary_cont
=
gmem_tiled_copy_rotary_cont
.
get_thread_slice
(
tidx
);
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
...
...
@@ -690,10 +692,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
Tensor
gSinCont
=
make_tensor
(
make_gmem_ptr
(
reinterpret_cast
<
Element
*>
(
params
.
rotary_sin_ptr
)
+
row_offset_cossin
),
Shape
<
Int
<
kBlockN
>
,
Int
<
kHeadDim
>>
{},
make_stride
(
params
.
rotary_dim
/
2
,
_1
{}));
Tensor
tRgCos
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgSin
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgCosCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
Tensor
tRgSinCont
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gSinCont
);
Tensor
tRgCos_
=
gmem_thr_copy_rotary
.
partition_S
(
gCos
);
Tensor
tRgSin_
=
gmem_thr_copy_rotary
.
partition_S
(
gSin
);
Tensor
tRgCosCont_
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gCosCont
);
Tensor
tRgSinCont_
=
gmem_thr_copy_rotary_cont
.
partition_S
(
gSinCont
);
Tensor
tRgCos
=
make_tensor
(
tRgCos_
.
data
(),
reshape_thread_tile
(
tRgCos_
.
layout
()));
Tensor
tRgSin
=
make_tensor
(
tRgSin_
.
data
(),
reshape_thread_tile
(
tRgSin_
.
layout
()));
Tensor
tRgCosCont
=
make_tensor
(
tRgCosCont_
.
data
(),
reshape_flatten_thread_tile
(
tRgCosCont_
.
layout
()));
Tensor
tRgSinCont
=
make_tensor
(
tRgSinCont_
.
data
(),
reshape_flatten_thread_tile
(
tRgSinCont_
.
layout
()));
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); }
...
...
@@ -779,6 +788,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params ¶ms, cons
flash
::
copy
<
Is_even_MN
,
Is_even_K
>
(
gmem_tiled_copy_Q
,
tQgQ
,
tQsQ
,
tQcQ
,
tQpQ
,
binfo
.
actual_seqlen_q
-
m_block
*
kBlockM
);
}
else
{
typename
Kernel_traits
::
GmemTiledCopyRotcossin
gmem_tiled_copy_rotary
;
auto
gmem_thr_copy_rotary
=
gmem_tiled_copy_rotary
.
get_thread_slice
(
tidx
);
typename
Kernel_traits
::
GmemTiledCopyRotcossinCont
gmem_tiled_copy_rotary_cont
;
auto
gmem_thr_copy_rotary_cont
=
gmem_tiled_copy_rotary_cont
.
get_thread_slice
(
tidx
);
const
index_t
row_offset_cossin
=
(
binfo
.
seqlen_k_cache
+
(
Is_causal
||
Is_local
?
m_block
*
kBlockM
:
0
))
*
(
params
.
rotary_dim
/
2
);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0.
...
...
csrc/flash_attn/src/kernel_traits.h
View file @
166f33fd
...
...
@@ -158,7 +158,9 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
ElementAccum
>
{},
GmemLayoutAtomOaccum
{},
Layout
<
Shape
<
_1
,
_4
>>
{}));
// Val layout, 4 vals per store
using
GmemLayoutAtomRotcossin
=
GmemLayoutAtom
;
// using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using
GmemLayoutAtomRotcossin
=
Layout
<
Shape
<
Int
<
kNThreads
/
kGmemThreadsPerRow
>
,
Int
<
kGmemThreadsPerRow
>>
,
Stride
<
Int
<
kGmemThreadsPerRow
>
,
_1
>>
;
using
GmemTiledCopyRotcossin
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint64_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
...
...
@@ -167,6 +169,14 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
_1
,
_8
>>
{}));
// Val layout, 8 vals per load
using
GmemTiledCopyRotcossinPaged
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
UniversalCopy
<
uint64_t
>
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_4
>
,
Stride
<
_4
,
_1
>>
{}));
// Val layout, 4 vals per load
using
GmemTiledCopyRotcossinContPaged
=
decltype
(
make_tiled_copy
(
Copy_Atom
<
DefaultCopy
,
Element
>
{},
GmemLayoutAtomRotcossin
{},
Layout
<
Shape
<
Int
<
kGmemRowsPerThread
>
,
_8
>
,
Stride
<
_8
,
_1
>>
{}));
// Val layout, 8 vals per load
};
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
...
...
csrc/flash_attn/src/utils.h
View file @
166f33fd
...
...
@@ -344,7 +344,7 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
////////////////////////////////////////////////////////////////////////////////////////////////////
//
somewhat unorthodox
reshape function. Given a
tuple
((v1, v2), m, k), returns (v1, v2, k),
//
Layout
reshape function. Given a
layout with modes
((v1, v2), m, k), returns (v1, v2, k),
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
template
<
class
Shape
,
class
Stride
>
...
...
@@ -354,6 +354,17 @@ auto reshape_thread_tile(Layout<Shape, Stride> l) {
append
(
get
<
0
>
(
l
.
stride
()),
get
<
2
>
(
l
.
stride
())));
}
// reshapes and flattens the thread tile layout. A separate function is needed for the case where
// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
// for the case of swizzled layouts
template
<
class
Shape
,
class
Stride
>
__forceinline__
__device__
auto
reshape_flatten_thread_tile
(
Layout
<
Shape
,
Stride
>
l
)
{
auto
mode_0
=
filter
(
flatten
(
get
<
0
>
(
l
)));
return
make_layout
(
append
(
mode_0
.
shape
(),
get
<
2
>
(
l
.
shape
())),
append
(
mode_0
.
stride
(),
get
<
2
>
(
l
.
stride
())));
}
////////////////////////////////////////////////////////////////////////////////////////////////////
template
<
bool
Is_even_MN
=
true
,
bool
Is_even_K
=
true
,
bool
Clear_OOB_MN
=
false
,
bool
Clear_OOB_K
=
true
,
...
...
tests/test_flash_attn.py
View file @
166f33fd
...
...
@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv(
@
pytest
.
mark
.
parametrize
(
"rotary_fraction"
,
[
0.0
,
0.5
,
1.0
])
# @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
16
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"paged_kv_block_size"
,
[
None
,
16
,
256
,
512
])
@
pytest
.
mark
.
parametrize
(
"has_batch_idx"
,
[
False
,
True
])
# @pytest.mark.parametrize("has_batch_idx", [False])
@
pytest
.
mark
.
parametrize
(
"d"
,
[
32
,
59
,
64
,
80
,
128
,
256
])
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
.
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment