"tests/pipelines/vscode:/vscode.git/clone" did not exist on "24895a1f494062d73028e31880c8848c6a674750"
Commit 166f33fd authored by skrider's avatar skrider Committed by Woosuk Kwon
Browse files

reshape rotary sin/cos copy to align with paged KV copy

parent 53c6eb1f
...@@ -652,7 +652,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -652,7 +652,8 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Repeat the partitioning with identity layouts // Repeat the partitioning with identity layouts
Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k) Tensor tQcQ = gmem_thr_copy_Q.partition_S(cQ); // (ACPY,ACPY_M,ACPY_K) -> (blk_m,blk_k)
Tensor tKVcKV = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k) Tensor tKVcKV_ = gmem_thr_copy_KV.partition_S(cKV); // (BCPY,BCPY_N,BCPY_K) -> (blk_n,blk_k)
Tensor tKVcKV = make_tensor(tKVcKV_.data(), reshape_thread_tile(tKVcKV_.layout()));
// Allocate predicate tensors for k // Allocate predicate tensors for k
Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ))); Tensor tQpQ = make_tensor<bool>(make_shape(size<2>(tQsQ)));
...@@ -669,11 +670,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -669,11 +670,12 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
// Prologue // Prologue
// Copy from Knew to K, optionally apply rotary embedding. // Copy from Knew to K, optionally apply rotary embedding.
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
if constexpr (Append_KV) { if constexpr (Append_KV) {
typename Kernel_traits::GmemTiledCopyRotcossinPaged gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinContPaged gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
// Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to // Even if we have MQA / GQA, all threadblocks responsible for the same KV head are writing to
// gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe. // gmem. Technically it's a race condition, but they all write the same content anyway, and it's safe.
// We want to do this so that all threadblocks can proceed right after they finish writing the KV cache. // We want to do this so that all threadblocks can proceed right after they finish writing the KV cache.
...@@ -690,10 +692,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -690,10 +692,17 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin), Tensor gSinCont = make_tensor(make_gmem_ptr(reinterpret_cast<Element *>(params.rotary_sin_ptr) + row_offset_cossin),
Shape<Int<kBlockN>, Int<kHeadDim>>{}, Shape<Int<kBlockN>, Int<kHeadDim>>{},
make_stride(params.rotary_dim / 2, _1{})); make_stride(params.rotary_dim / 2, _1{}));
Tensor tRgCos = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgSin = gmem_thr_copy_rotary.partition_S(gSin); Tensor tRgCos_ = gmem_thr_copy_rotary.partition_S(gCos);
Tensor tRgCosCont = gmem_thr_copy_rotary_cont.partition_S(gCosCont); Tensor tRgSin_ = gmem_thr_copy_rotary.partition_S(gSin);
Tensor tRgSinCont = gmem_thr_copy_rotary_cont.partition_S(gSinCont); Tensor tRgCosCont_ = gmem_thr_copy_rotary_cont.partition_S(gCosCont);
Tensor tRgSinCont_ = gmem_thr_copy_rotary_cont.partition_S(gSinCont);
Tensor tRgCos = make_tensor(tRgCos_.data(), reshape_thread_tile(tRgCos_.layout()));
Tensor tRgSin = make_tensor(tRgSin_.data(), reshape_thread_tile(tRgSin_.layout()));
Tensor tRgCosCont = make_tensor(tRgCosCont_.data(), reshape_flatten_thread_tile(tRgCosCont_.layout()));
Tensor tRgSinCont = make_tensor(tRgSinCont_.data(), reshape_flatten_thread_tile(tRgSinCont_.layout()));
// if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); } // if (cute::thread(0, 0)) { printf("rotary_cos_ptr = %p, gCos.data() = %p, tRgCos.data() = %p, rotary_dim = %d\n", params.rotary_cos_ptr, gCos.data(), tRgCos.data(), params.rotary_dim); }
// if (cute::thread(8, 0)) { print_tensor(gCos); } // if (cute::thread(8, 0)) { print_tensor(gCos); }
// if (cute::thread(0, 0)) { print_tensor(tRgCos); } // if (cute::thread(0, 0)) { print_tensor(tRgCos); }
...@@ -779,6 +788,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons ...@@ -779,6 +788,10 @@ inline __device__ void compute_attn_1rowblock_splitkv(const Params &params, cons
flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ, flash::copy<Is_even_MN, Is_even_K>(gmem_tiled_copy_Q, tQgQ, tQsQ, tQcQ, tQpQ,
binfo.actual_seqlen_q - m_block * kBlockM); binfo.actual_seqlen_q - m_block * kBlockM);
} else { } else {
typename Kernel_traits::GmemTiledCopyRotcossin gmem_tiled_copy_rotary;
auto gmem_thr_copy_rotary = gmem_tiled_copy_rotary.get_thread_slice(tidx);
typename Kernel_traits::GmemTiledCopyRotcossinCont gmem_tiled_copy_rotary_cont;
auto gmem_thr_copy_rotary_cont = gmem_tiled_copy_rotary_cont.get_thread_slice(tidx);
const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2); const index_t row_offset_cossin = (binfo.seqlen_k_cache + (Is_causal || Is_local ? m_block * kBlockM : 0)) * (params.rotary_dim / 2);
// If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache. // If not causal, all the queries get the same the cos/sin, taken at location seqlen_k_cache.
// We do this by setting the row stride of gCos / gSin to 0. // We do this by setting the row stride of gCos / gSin to 0.
......
...@@ -158,7 +158,9 @@ struct Flash_fwd_kernel_traits : public Base { ...@@ -158,7 +158,9 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{}, make_tiled_copy(Copy_Atom<DefaultCopy, ElementAccum>{},
GmemLayoutAtomOaccum{}, GmemLayoutAtomOaccum{},
Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store Layout<Shape < _1, _4>>{})); // Val layout, 4 vals per store
using GmemLayoutAtomRotcossin = GmemLayoutAtom; // using GmemLayoutAtomRotcossin = GmemLayoutAtom;
using GmemLayoutAtomRotcossin = Layout<Shape <Int<kNThreads / kGmemThreadsPerRow>, Int<kGmemThreadsPerRow>>,
Stride<Int<kGmemThreadsPerRow>, _1>>;
using GmemTiledCopyRotcossin = decltype( using GmemTiledCopyRotcossin = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{}, make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{}, GmemLayoutAtomRotcossin{},
...@@ -167,6 +169,14 @@ struct Flash_fwd_kernel_traits : public Base { ...@@ -167,6 +169,14 @@ struct Flash_fwd_kernel_traits : public Base {
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{}, make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{}, GmemLayoutAtomRotcossin{},
Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load Layout<Shape < _1, _8>>{})); // Val layout, 8 vals per load
using GmemTiledCopyRotcossinPaged = decltype(
make_tiled_copy(Copy_Atom<UniversalCopy<uint64_t>, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _4>, Stride<_4, _1>>{})); // Val layout, 4 vals per load
using GmemTiledCopyRotcossinContPaged = decltype(
make_tiled_copy(Copy_Atom<DefaultCopy, Element>{},
GmemLayoutAtomRotcossin{},
Layout<Shape<Int<kGmemRowsPerThread>, _8>, Stride<_8, _1>>{})); // Val layout, 8 vals per load
}; };
// Is_V_in_regs is an option to reduce smem usage, but will increase register pressue. // Is_V_in_regs is an option to reduce smem usage, but will increase register pressue.
......
...@@ -344,7 +344,7 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const ...@@ -344,7 +344,7 @@ int advance_thread_kv_page_slice_offset(const int tidx, const int n_block, const
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
// somewhat unorthodox reshape function. Given a tuple ((v1, v2), m, k), returns (v1, v2, k), // Layout reshape function. Given a layout with modes ((v1, v2), m, k), returns (v1, v2, k),
// where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures // where v2 may be a tuple itself, in the case of swizzled smem-backed thread tiles. This ensures
// that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors. // that paged and non-paged copies result in equivalently shaped, if not necessarily strided, tensors.
template <class Shape, class Stride> template <class Shape, class Stride>
...@@ -354,6 +354,17 @@ auto reshape_thread_tile(Layout<Shape, Stride> l) { ...@@ -354,6 +354,17 @@ auto reshape_thread_tile(Layout<Shape, Stride> l) {
append(get<0>(l.stride()), get<2>(l.stride()))); append(get<0>(l.stride()), get<2>(l.stride())));
} }
// reshapes and flattens the thread tile layout. A separate function is needed for the case where
// one of the modes of l is a layout itself and must be flattened, as opposed to keeping it intact
// for the case of swizzled layouts
template <class Shape, class Stride>
__forceinline__ __device__
auto reshape_flatten_thread_tile(Layout<Shape, Stride> l) {
auto mode_0 = filter(flatten(get<0>(l)));
return make_layout(append(mode_0.shape(), get<2>(l.shape())),
append(mode_0.stride(), get<2>(l.stride())));
}
//////////////////////////////////////////////////////////////////////////////////////////////////// ////////////////////////////////////////////////////////////////////////////////////////////////////
template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true, template <bool Is_even_MN=true, bool Is_even_K=true, bool Clear_OOB_MN=false, bool Clear_OOB_K=true,
......
...@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv( ...@@ -1833,7 +1833,7 @@ def test_flash_attn_splitkv(
@pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0]) @pytest.mark.parametrize("rotary_fraction", [0.0, 0.5, 1.0])
# @pytest.mark.parametrize("rotary_fraction", [0.0]) # @pytest.mark.parametrize("rotary_fraction", [0.0])
# @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512]) # @pytest.mark.parametrize("paged_kv_block_size", [None, 256, 512])
@pytest.mark.parametrize("paged_kv_block_size", [16, 256, 512]) @pytest.mark.parametrize("paged_kv_block_size", [None, 16, 256, 512])
@pytest.mark.parametrize("has_batch_idx", [False, True]) @pytest.mark.parametrize("has_batch_idx", [False, True])
# @pytest.mark.parametrize("has_batch_idx", [False]) # @pytest.mark.parametrize("has_batch_idx", [False])
@pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256]) @pytest.mark.parametrize("d", [32, 59, 64, 80, 128, 256])
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment