Unverified Commit c1569892 authored by Po Yen Chen's avatar Po Yen Chen Committed by GitHub
Browse files

[CK_TILE] Add PagedAttention kernels (#1387)



* Use dictionary to config all the functions

* Add init codegen logic for fmha fwd appendkv

* Call HIP_CHECK_ERROR() macro to get real source info

* Setup meaningfull arguments

* Sync kernel name with the codegen

* Add knew/vnew tensors to the kernel argument

* Fix wrong K values after appending

* Fix vnew append errro

* Extract common logics

* Fix Vnew tile dstr for row major case

* Conditionally add fwd_splitkv API in fmha_fwd example

* Conditionally add call to fmha_fwd_splitkv()

* Remove "EXAMPLE_" prefix of cmake variables

* Regsiter API handlers automatically

* Early return if 0 < s_k_new is not supported

* Show message if we are ignoring option

* Unify CMakeLists.txt coding style

* Set num_splits=1 if split-kv is not supported

* Add length/stride getters for HostTensor

* Add RoPE example utilities

* Add reference_rotary_position_embedding() (not implemented)

* Finish reference_rotary_position_embedding() impl

* Fix typo of HostTensor<>::get_length()

* Fix compilation errors

* Fix wrong answer when interleaved=false

* Fix wrong answer when interleaved=true

* Append K/V in the host verification code

* Simplify K appending logics

* Simplify v_host_ref definition

* Reduce input/output dimensions

* Rename function: add "batched" prefix

* Apply RoPE on host side

* Rename RoPE utility function

* Fix wrong tensor size

* Avoid invoking deprecated method 'find_module'

* Pass RoPE kernel args

* Create Rotary Cos/Sin tile windows in kernel

* Add compute data type alias for RoPE

* Randomly generate seqlen_knew if needed

* Fix seqlen_knew enabling check logic

* Add minimum seqlen_k to generate compliance kvcache

* Fix compilation error in debug mode

* Fix wrong boundaries

* Fix wrong seqlen_k for kvcache

* Rename variables used in distributio encoding

* Fix rotary cos/sin tensor/tile size

* Add constraint to the rotary_dim option

* Remove unused inner namespace

* Add dram distribution for rotary_cos/rotary_sin (interleaved)

* Only apply interleaved RoPE on Knew for now

* Fix wrong thread starting offset

* Instantiate multiple kernels for RoPE approaches

* Clean-up pipeline

* Fix error in RoPE host reference

* Handle RoPE half-rotated logics

* Support 8x rotary_dim under half-rotated RoPE

* Add comment

* Apply elementwise function to the loaded tiles

* Unify parameter/variable naming style

* Remove constness from q_ptr

* Add code blocks for q_tile

* Apply RoPE to q_tile

* Remove debug print code in kernel

* Fix wrong knew/vnew appending positions

* Use better naming for tile indices

* Add make_tile_window() for adding distribution only

* Skip code if # of block is more than needed

* Move thread locating logics into policy

* Remove always true static_assert()

* Rename header

* Rename RotaryEmbeddingEnum

* Extract rotary embedding logic out

* Re-order parameters

* Align naming of some tile size constants

* Rename more tile size constants

* Fix wrong grid size

* Fix wrong shape of knew_host/vnew_host

* Fix wrong index into knew_host/vnew_host

* Fix wrong rotary_cos/rotary_sin memory size for Q

* Extract Q/Knew vector size to helper methods

* Use different rotary_cos/rotary_sin distr for Q/Knew

* Update host/device specifiers

* Fix wrong data type for Q rotary_cos/rotary_sin

* Remove RoPEComputeDataType type alias

* Shift rotary_cos/rotary_sin by cache_seqlen_k

* Add comment for why I just 't' for all padding flags

* Align commit message to the real comment

* Fix wrong pipeline

* Rename utility function

* Disable host verification if API not exist

* Fix wrong rope key for fp8 pipeline

* Allow only apply RoPE on Q (without append KV)

* Add append-kv smoke tests

* Remove debug statements

* Remove more debug statements

* Re-arrange the 'set +x' command

* Remove no-longer used method in pipeline

* Add missing init code

* Refine pipeline padding settings

* Enlarge rotary_dim limit (8 -> 16)

* Enlarge KPerThread for rotary_interleaved=false

* Update rotary_dim range in smoke_test_fwd.sh

* Add template argument 'kIsPagedKV' for splitkv kernels

* Launch splitkv kernel if given page_block_size

* Fix wrong kernel name

* Fix seqlen_k_min for pre-fill case (1 -> 0)

* Add copy_const<> type trait

* Add another make_tile_window()

* Introduce 'TileWindowNavigator' types

* Simplify TileWindowNavigator interfaces

* Fix tile window navigation bugs

* Disable calling fmha_fwd()

* Remove ununnecessary data members

* Simplify more make_tile_window() overloads

* Move V tile through TileWindowNavigator

* Fix uneven split checking logic

* Move code after decide seqlen_q/seqlen_k

* Make sure we always start reading complete tile

* Use 128 as minimus page_block_size

* Fix wrong origin for bias

* Add batch_stride_k/batch_stride_v in group mode

* Unify origin

* Add missing kernel arguments for group mode

* Add paged-kv codegen logic for appendkv kernels

* Add block_table kernel args for appendkv kernel

* Add tile navigators to the appendkv kernel

* Fix wrong tensor descriptor lengths

* Pass re-created tile window to pipeline

* Fix wrong strides for appendkv kernel

* Allow transit tile_window to another page-block

* Handle cross-page-block write

* Donot perform write again if already in last page-block

* Always add fmha_fwd() api

* Add missing group mode argument

* Remove debug macro usages

* Rename option s_k_new to s_knew

* Separate splitkv/non-splitkv args/traits

* Remove fmha_fwd_dispatch()

* Fix compilation errors

* Remove dropout code in splitkv kernel

* Allow problem types without define kHasDropout attr

* Use generic lambda to init traits objects

* Separate more non-splitkv & splitkv traits/args

* Display more info for specific kernels

* Show more detailed warning message

* Rename 'max_num_blocks' to 'max_num_page_blocks'

* Remove no-longer used pipeline files

* Wrap code by #if directives

* Move functors to the begining of validation code

* Use generic lambda to init all the api traits/args

* Fix wrong seqlen for kvcache

* Add missing comment

* Rename TileWindowNavigator to PageBlockNavigator

* Only expose necessary methods (not attributes)

* Re-order pipeline paremeters

* Refine smoke_test_fwd.sh

* Fix wrong arugment count

* Make tile window directly via PageBlockNavigator

* Remove unused template paremeter

* Remove group mode from appendkv kernel

* Fix skcheck logic

* Fix wrong syntax in skcheck expr

* Use meaningful options in smoke test

* Remove options

* Fix formatting

* Fix more format

* Re-organize bash functions

* Pass cache_batch_idx to kernels

* Support cache_batch_idx in example

* Fix compilation error

* Add more appendkv test

* Add more case for appendkv

* Fix unexisted attribute

* Remove 0 < seqlen_knew constraint

* Clarify the case in warning message

* Remove macro checking

* Force batch mode when invoking appendkv & splitkv apis

* Fix mode overriding logics

* Fix wrong parameter name

* Randomize seqlen_k if use kvcache

* Use randomized seqlen_k for kvcache

* Avoid using too small rotary_cos & rotary_sin

* Rename parameter

* Add seqlen_q & seqlen_k rules

* Add comment

* Add more comments

* Fix compilation errors

* Fix typo in comment

* Remove type argument

* Avoid seqlen_k=0 for kvcache

* Revert "Avoid seqlen_k=0 for kvcache"

This reverts commit 21c4df89e416182e8e9bc78e67bd4b98dbb6c88d.

* Fix wrong uneven split checking logics

* Only randomize kvcache seqlen_k if 1 < batch

* Return earlier if split is empty

* Revert "Only randomize kvcache seqlen_k if 1 < batch"

This reverts commit b9a4ab0d7e3c2beecc0fccafd2a13259dd06299c.

* Re-order seqlen_k_start adjustment logics

* Fix compilation errors

* Re-format script

* Find executable from folder automatically

* Fix kvcache seqlen_k generating logic

* Make comment more clear

* Fix wrong knew/vew appending logic on host

* Add s_barrier to sync threads

* Revert "Add s_barrier to sync threads"

This reverts commit d3f550f30c0a4d9df15c613015d5dff268d6746d.

* Support only using 1 row of rotary_cos/rotary_sin

* Rotate Q in different way

* Unify tensor view creation logics

* Fix wrong argument

* Add mask to switch how we use the rotary_cos/sin

* Move attr from traits to problem

* Move has_mask to fmha_fwd_appendkv_args

* Support use uint32_t as SAD operand in Alibi<>

* Use sad_u32() in splitkv kernels

* Store tensor views in PageBlockNavigator

* Use stored tensor view to update tile windows

* Enlarge tensor view size

* Remove debug code

* Fix wrong tensor view size

* Wrap tensor view into PageBlockNavigator

* Add DataType member to PageBlockNavigator

* Remove unnecessary member functions

* Refind macro use

* Fix typo

* Add blank line between directives and actual code

* Re-format files

* Remove type in comment

---------
Co-authored-by: default avatarcarlushuang <carlus.huang@amd.com>
Co-authored-by: default avatarrocking <ChunYu.Lai@amd.com>
parent 19d22e60
...@@ -7,7 +7,11 @@ ...@@ -7,7 +7,11 @@
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/block/page_block_navigator.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_appendkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
...@@ -21,11 +25,11 @@ ...@@ -21,11 +25,11 @@
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_combine_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_problem.hpp"
......
...@@ -43,9 +43,12 @@ enum struct AlibiMode ...@@ -43,9 +43,12 @@ enum struct AlibiMode
FROM_BOTTOM_RIGHT = 2, FROM_BOTTOM_RIGHT = 2,
}; };
template <typename DataType, bool RowMajor = true> template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
struct Alibi struct Alibi
{ {
static_assert(1 <= LogMaxSadOprndSize && LogMaxSadOprndSize <= 32,
"for LogMaxSadOprndSize <= 16, we use SAD uint16_t, otherwise, use SAD uint32_t");
// RowMajor here means if pixel within the same thread are along the row, or col // RowMajor here means if pixel within the same thread are along the row, or col
// this may impact the performance of update(), while the result are the same. // this may impact the performance of update(), while the result are the same.
// e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false // e.g. fwd prefer use RowMajor=true, bwd some cases prefer use RowMajor=false
...@@ -79,6 +82,19 @@ struct Alibi ...@@ -79,6 +82,19 @@ struct Alibi
mode = mode_; mode = mode_;
} }
CK_TILE_HOST uint32_t sad(uint32_t x, uint32_t y, uint32_t acc) { return sad_u32(x, y, acc); }
CK_TILE_DEVICE uint32_t sad(uint32_t x, uint32_t y, uint32_t acc)
{
if constexpr(LogMaxSadOprndSize <= 16)
{
return sad_u16(
static_cast<uint16_t>(x), static_cast<uint16_t>(y), static_cast<uint16_t>(acc));
}
return sad_u32(x, y, acc);
}
CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx) CK_TILE_HOST_DEVICE void update(DataType& pixel, index_t row_idx, index_t col_idx)
{ {
if constexpr(RowMajor) if constexpr(RowMajor)
...@@ -128,7 +144,7 @@ struct EmptyPositionEncoding ...@@ -128,7 +144,7 @@ struct EmptyPositionEncoding
// can convert from the FA style left/right to our generic coordinate // can convert from the FA style left/right to our generic coordinate
// if left_size < 0 && right_size = 0, it is normal causal mask // if left_size < 0 && right_size = 0, it is normal causal mask
// local is left_size >=0 or right_size >=0 // local is left_size >=0 or right_size >=0
template <typename DataType, bool RowMajor = true> template <typename DataType, bool RowMajor = true, unsigned LogMaxSadOprndSize = 16>
CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
index_t window_left_size, index_t window_left_size,
index_t window_right_size, index_t window_right_size,
...@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope, ...@@ -142,7 +158,7 @@ CK_TILE_HOST_DEVICE auto make_alibi_from_lr_mask(DataType slope,
AlibiMode alibi_mode = AlibiMode alibi_mode =
is_causal ? AlibiMode::VERTICAL is_causal ? AlibiMode::VERTICAL
: static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/; : static_cast<AlibiMode>(mask_enum) /*either top-left or bottom-right*/;
return Alibi<DataType, RowMajor>{slope, y_total, x_total, alibi_mode}; return Alibi<DataType, RowMajor, LogMaxSadOprndSize>{slope, y_total, x_total, alibi_mode};
} }
// https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742 // https://github.com/ofirpress/attention_with_linear_biases/blob/4b92f28a005ead2567abe2359f633e73e08f3833/fairseq/models/transformer.py#L742
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include <string>
namespace ck_tile {
// This class is used for codegen pattern matching
enum class RotaryEmbeddingEnum
{
NONE = 0,
INTERLEAVED = 1, // combine dimensions 0 & 1, 2 & 3, etc
HALF_ROTATED = 2, // combine dimensions 0 & rotary_dim / 2, 1 & rotary_dim / 2 + 1, etc
};
template <RotaryEmbeddingEnum>
struct RotaryEmbeddingEnumToStr;
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::NONE>
{
static constexpr const char* name = "";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::INTERLEAVED>
{
static constexpr const char* name = "inter";
};
template <>
struct RotaryEmbeddingEnumToStr<RotaryEmbeddingEnum::HALF_ROTATED>
{
static constexpr const char* name = "half";
};
template <RotaryEmbeddingEnum RotaryEnum, typename ComputeDataType = float>
struct BlockRotaryEmbedding
{
template <typename DistributedTensor,
typename OtherDramBlockWindow,
typename RotaryCosDramBlockWindow,
typename RotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE static void apply(DistributedTensor& tile,
OtherDramBlockWindow other_window,
RotaryCosDramBlockWindow rotary_cos_window,
RotarySinDramBlockWindow rotary_sin_window,
index_t rotary_dim,
index_t thread_end)
{
using DataType = typename remove_cvref_t<DistributedTensor>::DataType;
if constexpr(RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
auto rotary_cos_tile = load_tile(rotary_cos_window);
auto rotary_sin_tile = load_tile(rotary_sin_window);
if(thread_end <= rotary_dim)
{
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 2>{}([&](auto idx) {
const auto left = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto right = type_convert<ComputeDataType>(tile.thread_buf_[idx + 1]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx / 2]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx / 2]);
tile.thread_buf_[idx] = type_convert<DataType>(left * cos - right * sin);
tile.thread_buf_[idx + 1] = type_convert<DataType>(right * cos + left * sin);
});
}
}
else if constexpr(RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
if(thread_end <= rotary_dim)
{
const bool is_left = (thread_end <= (rotary_dim / 2));
move_tile_window(other_window, {0, is_left ? rotary_dim / 2 : -(rotary_dim / 2)});
auto other_tile = load_tile(other_window);
move_tile_window(rotary_cos_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_cos_tile = load_tile(rotary_cos_window);
move_tile_window(rotary_sin_window, {0, is_left ? 0 : -(rotary_dim / 2)});
auto rotary_sin_tile = load_tile(rotary_sin_window);
constexpr index_t thread_buffer_size = decltype(tile.thread_buf_)::size();
static_for<0, thread_buffer_size, 1>{}([&](auto idx) {
const auto curr = type_convert<ComputeDataType>(tile.thread_buf_[idx]);
const auto other = type_convert<ComputeDataType>(other_tile.thread_buf_[idx]);
const auto cos =
type_convert<ComputeDataType>(rotary_cos_tile.thread_buf_[idx]);
const auto sin =
type_convert<ComputeDataType>(rotary_sin_tile.thread_buf_[idx]);
tile.thread_buf_[idx] =
type_convert<DataType>(curr * cos + other * (is_left ? -sin : sin));
});
}
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/core/tensor/tile_window.hpp"
namespace ck_tile {
// assume that we have only 1 page-block/tensor view
template <typename TensorView>
struct TrivialPageBlockNavigator
{
using DataType = typename TensorView::DataType;
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr TrivialPageBlockNavigator(const TensorView& tensor_view_)
: tensor_view(tensor_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE constexpr auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
return make_tuple(/*block_index=*/0,
ck_tile::make_tile_window(tensor_view, window_lengths, window_origin));
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE constexpr auto
make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
return make_tuple(
/*block_index=*/0,
ck_tile::make_tile_window(
tensor_view, window_lengths, window_origin, tile_distribution));
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE static index_t
move_tile_window(index_t /*block_index*/,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step)
{
ck_tile::move_tile_window(tile_window, step);
return /*block_index=*/0;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin)
{
return global_window_origin;
}
CK_TILE_HOST_DEVICE static constexpr WindowOrigin
to_global_window_origin(index_t /*block_index*/, const WindowOrigin& local_window_origin)
{
return local_window_origin;
}
private:
TensorView tensor_view;
};
// default page-block navigator, assume that tensor view size is same as page-block size or smaller
// if tile window on last page-block
template <typename DataType_, index_t VirtualDim, typename TensorView>
struct PageBlockNavigator
{
using DataType = DataType_;
static_assert(std::is_same_v<DataType, typename TensorView::DataType>);
static_assert(VirtualDim == 0 || VirtualDim == 1, "only support 2d tile window");
using WindowOrigin = multi_index<2>;
CK_TILE_HOST_DEVICE constexpr PageBlockNavigator(copy_const_t<DataType, void>* physical_blocks_,
long_index_t block_stride_,
long_index_t fixed_offset_,
const int32_t* physical_block_indices_,
index_t num_blocks_,
index_t page_block_size_,
const TensorView& complete_view_,
const TensorView& last_view_)
: physical_blocks(reinterpret_cast<DataType*>(physical_blocks_)),
block_stride(block_stride_),
fixed_offset(fixed_offset_),
physical_block_indices(physical_block_indices_),
num_blocks(num_blocks_),
page_block_size(page_block_size_),
complete_view(complete_view_),
last_view(last_view_)
{
}
template <typename WindowLengths>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename WindowLengths, typename TileDistribution>
CK_TILE_HOST_DEVICE auto make_tile_window(const WindowLengths& window_lengths,
const WindowOrigin& window_origin,
const TileDistribution& tile_distribution) const
{
const index_t block_index = get_block_index(window_origin);
const WindowOrigin local_window_origin = to_local_window_origin(window_origin);
auto new_tile_window =
ck_tile::make_tile_window(is_last_block(block_index) ? last_view : complete_view,
window_lengths,
local_window_origin,
tile_distribution);
new_tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(block_index));
return make_tuple(block_index, new_tile_window);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE index_t
move_tile_window(index_t block_index,
TileWindow& tile_window,
const typename remove_cvref_t<TileWindow>::BottomTensorIndex& step) const
{
ck_tile::move_tile_window(tile_window, step);
const WindowOrigin global_window_origin =
to_global_window_origin(block_index, tile_window.get_window_origin());
const WindowOrigin local_window_origin = to_local_window_origin(global_window_origin);
const index_t new_block_index = get_block_index(global_window_origin);
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(local_window_origin);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
return new_block_index;
}
CK_TILE_HOST_DEVICE bool is_last_block(index_t block_index) const
{
return block_index == num_blocks - 1;
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE bool is_cross_block(index_t block_index,
const TileWindow& tile_window) const
{
const index_t origin = tile_window.get_window_origin().at(number<VirtualDim>{});
const index_t length = tile_window.get_window_lengths().at(number<VirtualDim>{});
return (block_index < num_blocks - 1) && (page_block_size < origin + length);
}
template <typename TileWindow>
CK_TILE_HOST_DEVICE void
move_to_block(index_t block_index, TileWindow& tile_window, index_t new_block_index) const
{
const multi_index<2> step = [&]() {
const index_t origin_diff = (block_index - new_block_index) * page_block_size;
if constexpr(VirtualDim == 0)
{
return make_multi_index(origin_diff, 0);
}
else
{
return make_multi_index(0, origin_diff);
}
}();
/// TODO: only update necessary attributes
tile_window.bottom_tensor_view_.desc_ =
(is_last_block(new_block_index) ? last_view : complete_view).get_tensor_descriptor();
tile_window.set_window_origin(tile_window.get_window_origin() + step);
tile_window.set_bottom_tensor_view_data_ptr(get_block_ptr(new_block_index));
}
CK_TILE_HOST_DEVICE WindowOrigin
to_local_window_origin(const WindowOrigin& global_window_origin) const
{
if constexpr(VirtualDim == 0)
{
const index_t length = global_window_origin.at(number<0>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(length - page_block_size * num_complete_blocks,
global_window_origin.at(number<1>{}));
}
else
{
const index_t length = global_window_origin.at(number<1>{});
const index_t num_complete_blocks = integer_divide_floor(length, page_block_size);
return make_multi_index(global_window_origin.at(number<0>{}),
length - page_block_size * num_complete_blocks);
}
}
CK_TILE_HOST_DEVICE WindowOrigin
to_global_window_origin(index_t block_index, const WindowOrigin& local_window_origin) const
{
if constexpr(VirtualDim == 0)
{
return make_multi_index(block_index * page_block_size +
local_window_origin.at(number<0>{}),
local_window_origin.at(number<1>{}));
}
else
{
return make_multi_index(local_window_origin.at(number<0>{}),
block_index * page_block_size +
local_window_origin.at(number<1>{}));
}
}
private:
CK_TILE_HOST_DEVICE
DataType* get_block_ptr(index_t block_index) const
{
return physical_blocks + physical_block_indices[block_index] * block_stride + fixed_offset;
}
CK_TILE_HOST_DEVICE int32_t get_block_index(const WindowOrigin& global_window_origin) const
{
return integer_divide_floor(global_window_origin.at(number<VirtualDim>{}), page_block_size);
}
DataType* physical_blocks;
long_index_t block_stride;
long_index_t fixed_offset;
const int32_t* physical_block_indices;
index_t num_blocks;
index_t page_block_size;
TensorView complete_view;
TensorView last_view;
};
template <typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(const TensorView& tensor_view)
{
return TrivialPageBlockNavigator<TensorView>(tensor_view);
}
template <typename DataType, index_t VirtualDim, typename TensorView>
CK_TILE_HOST_DEVICE auto make_page_block_navigator(copy_const_t<DataType, void>* physical_blocks,
long_index_t block_stride,
long_index_t fixed_offset,
const int32_t* physical_block_indices,
index_t num_blocks,
index_t page_block_size,
const TensorView& complete_view,
const TensorView& last_view)
{
return PageBlockNavigator<DataType, VirtualDim, TensorView>(physical_blocks,
block_stride,
fixed_offset,
physical_block_indices,
num_blocks,
page_block_size,
complete_view,
last_view);
}
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common.hpp"
#include <string>
#include <type_traits>
namespace ck_tile {
template <typename TilePartitioner_, typename FmhaPipeline_>
struct FmhaFwdAppendKVKernel
{
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
static constexpr ck_tile::index_t kBlockSize = FmhaPipeline::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaPipeline::kBlockPerCu;
static_assert(kBlockPerCu > 0);
static constexpr ck_tile::index_t kBlockPerCuInput = FmhaPipeline::Problem::kBlockPerCu;
using QDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::QDataType>;
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using VLayout = ck_tile::remove_cvref_t<typename FmhaPipeline::VLayout>;
static constexpr bool kApplyRoPE = FmhaPipeline::RotaryEnum != RotaryEmbeddingEnum::NONE;
static constexpr bool kIsPagedKV = FmhaPipeline::kIsPagedKV;
static constexpr bool kPadSeqLenQ = FmhaPipeline::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = FmhaPipeline::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<float> { static constexpr const char * name = "fp32"; };
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
template <> struct t2s<ck_tile::fp8_t> { static constexpr const char * name = "fp8"; };
template <> struct t2s<ck_tile::bf8_t> { static constexpr const char * name = "bf8"; };
// clang-format on
__host__ static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadSeqLenK) n += "sk";
if (kPadHeadDimQ) n += "d";
if (kPadHeadDimV) n += "dv";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_fwd_appendkv_d") + _TS_(FmhaPipeline::kK0) + "_" + _SS_(t2s<QDataType>::name) + "_"
"b" + _TS_(FmhaPipeline::kM0) + "x" + _TS_(FmhaPipeline::kN0) + "x" + _TS_(FmhaPipeline::kK0) + "x" +
_TS_(FmhaPipeline::kN1) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn)
+ (!kApplyRoPE ? _SS_("") : (_SS_("_") + RotaryEmbeddingEnumToStr<FmhaPipeline::RotaryEnum>::name))
+ (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_
#undef _TS_
// clang-format on
}
template <ck_tile::index_t I> // to avoid duplicated base class prblem, introduce an template
// arg
struct EmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct BasicKargs
{
void* q_ptr;
void* k_ptr;
const void* knew_ptr;
void* v_ptr;
const void* vnew_ptr;
const int32_t* seqlen_k_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t seqlen_knew;
ck_tile::index_t hdim_q;
ck_tile::index_t hdim_v;
ck_tile::index_t num_head_q;
// for MQA/GQA, nhead could be different. This parameter is nhead_q / nhead_k
// if this param is larger than 1, indicate MQA/GQA case
ck_tile::index_t nhead_ratio_qk;
ck_tile::index_t stride_q;
ck_tile::index_t stride_k;
ck_tile::index_t stride_knew;
ck_tile::index_t stride_v;
ck_tile::index_t stride_vnew;
ck_tile::index_t nhead_stride_q;
ck_tile::index_t nhead_stride_k;
ck_tile::index_t nhead_stride_knew;
ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_vnew;
ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_knew;
ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_vnew;
};
struct RoPEKargs
{
const void* rotary_cos_ptr;
const void* rotary_sin_ptr;
ck_tile::index_t rotary_dim;
bool has_mask;
};
struct PageBlockTableKargs
{
const int32_t* block_table_ptr;
ck_tile::index_t batch_stride_block_table;
ck_tile::index_t page_block_size;
};
struct CacheBatchIdxKargs
{
const int32_t* cache_batch_idx;
};
struct Kargs : BasicKargs,
std::conditional_t<kApplyRoPE, RoPEKargs, EmptyKargs<0>>,
std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
{
};
__host__ static constexpr Kargs MakeKargs(void* q_ptr,
void* k_ptr,
const void* knew_ptr,
void* v_ptr,
const void* vnew_ptr,
ck_tile::index_t seqlen_q,
const void* seqlen_k_ptr,
ck_tile::index_t seqlen_knew,
ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk,
const void* rotary_cos_ptr,
const void* rotary_sin_ptr,
ck_tile::index_t rotary_dim,
bool has_mask,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void* cache_batch_idx,
ck_tile::index_t stride_q,
ck_tile::index_t stride_k,
ck_tile::index_t stride_knew,
ck_tile::index_t stride_v,
ck_tile::index_t stride_vnew,
ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_knew,
ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_vnew,
ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_knew,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_vnew)
{
Kargs kargs{
{q_ptr,
k_ptr,
knew_ptr,
v_ptr,
vnew_ptr,
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
seqlen_q,
-1, // seqlen_k will be updated by content of seqlen_k_ptr
seqlen_knew,
hdim_q,
hdim_v,
num_head_q,
nhead_ratio_qk,
stride_q,
stride_k,
stride_knew,
stride_v,
stride_vnew,
nhead_stride_q,
nhead_stride_k,
nhead_stride_knew,
nhead_stride_v,
nhead_stride_vnew,
batch_stride_q,
batch_stride_k,
batch_stride_knew,
batch_stride_v,
batch_stride_vnew}, // args for common karg
{}, // placeholder for rope
{} // placeholder for paged-block table or cache_batch_idx
};
if constexpr(kApplyRoPE)
{
kargs.rotary_cos_ptr = rotary_cos_ptr;
kargs.rotary_sin_ptr = rotary_sin_ptr;
kargs.rotary_dim = rotary_dim;
kargs.has_mask = has_mask;
}
if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
else
{
kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
}
return kargs;
}
__host__ static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
return TilePartitioner::GridSize(batch_size, nhead, seqlen_q, seqlen_knew);
}
__host__ static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile, i_nhead, i_batch] = TilePartitioner{}();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kM0);
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile * FmhaPipeline::kN0);
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
if constexpr(kIsPagedKV)
{
return i_batch_;
}
else
{
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
: i_batch_);
}
}();
const long_index_t batch_offset_q =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
const long_index_t batch_offset_k =
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
const long_index_t batch_offset_knew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_knew;
const long_index_t batch_offset_v =
static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
const long_index_t batch_offset_vnew =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_vnew;
kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
// for simplicity, batch stride we just modify the pointer
QDataType* q_ptr = reinterpret_cast<QDataType*>(kargs.q_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
KDataType* k_ptr =
reinterpret_cast<KDataType*>(kargs.k_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_k +
batch_offset_k;
const KDataType* knew_ptr =
reinterpret_cast<const KDataType*>(kargs.knew_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_knew +
batch_offset_knew;
VDataType* v_ptr =
reinterpret_cast<VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v;
const VDataType* vnew_ptr =
reinterpret_cast<const VDataType*>(kargs.vnew_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_vnew +
batch_offset_vnew;
// Q/K/V DRAM and DRAM window
const auto q_dram = [&]() {
const auto q_dram_naive = make_naive_tensor_view<address_space_enum::global>(
q_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{},
number<1>{});
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
const auto make_k_dram = [&](KDataType* data, index_t height) {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_k_dram(nullptr, kargs.page_block_size);
}
else
{
return make_k_dram(k_ptr, kargs.seqlen_k + kargs.seqlen_knew);
}
}();
const auto knew_dram = [&]() {
const auto knew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
knew_ptr,
make_tuple(kargs.seqlen_knew, kargs.hdim_q),
make_tuple(kargs.stride_knew, 1),
number<FmhaPipeline::kAlignmentK>{},
number<1>{});
return pad_tensor_view(
knew_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
const auto make_v_dram = [&](VDataType* data, index_t length) {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto v_dram_transposed =
transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(length)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view(
v_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_v, length),
make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
};
const auto v_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_v_dram(nullptr, kargs.page_block_size);
}
else
{
return make_v_dram(v_ptr, kargs.seqlen_k + kargs.seqlen_knew);
}
}();
const auto vnew_dram = [&]() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
vnew_ptr,
make_tuple(kargs.seqlen_knew, kargs.hdim_v),
make_tuple(kargs.stride_vnew, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
const auto vnew_dram_transposed = transform_tensor_view(
vnew_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_knew)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return pad_tensor_view(
vnew_dram_transposed,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
else
{
const auto vnew_dram_naive = make_naive_tensor_view<address_space_enum::global>(
vnew_ptr,
make_tuple(kargs.hdim_v, kargs.seqlen_knew),
make_tuple(kargs.stride_vnew, 1),
number<FmhaPipeline::kAlignmentV>{},
number<1>{});
return pad_tensor_view(
vnew_dram_naive,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{});
}
}();
constexpr auto q_rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0 / 2>{});
const auto q_rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const QDataType*>(kargs.rotary_cos_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram = [&]() {
return pad_tensor_view(rotary_cos_dram_native,
q_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
}
}();
const auto q_rotary_sin_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const QDataType*>(kargs.rotary_sin_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_q, kargs.rotary_dim / 2),
make_tuple(kargs.has_mask * (kargs.rotary_dim / 2), 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram = [&]() {
return pad_tensor_view(rotary_sin_dram_native,
q_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, q_rotary_cos_sin_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(q_rotary_cos_sin_dram_window_lengths);
}
}();
constexpr auto knew_rotary_cos_sin_dram_window_lengths =
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0 / 2>{});
const auto knew_rotary_cos_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_cos_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_cos_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_cos_dram = [&]() {
return pad_tensor_view(rotary_cos_dram_native,
knew_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_cos_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
}
else
{
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
}
}();
const auto knew_rotary_sin_dram_window = [&]() {
if constexpr(kApplyRoPE)
{
const auto rotary_sin_dram_native =
make_naive_tensor_view<address_space_enum::global>(
reinterpret_cast<const KDataType*>(kargs.rotary_sin_ptr) +
kargs.seqlen_k * (kargs.rotary_dim / 2),
make_tuple(kargs.seqlen_knew, kargs.rotary_dim / 2),
make_tuple(kargs.rotary_dim / 2, 1),
number<8>{},
number<1>{});
const auto rotary_sin_dram = [&]() {
return pad_tensor_view(rotary_sin_dram_native,
knew_rotary_cos_sin_dram_window_lengths,
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}();
return make_tile_window(
rotary_sin_dram, knew_rotary_cos_sin_dram_window_lengths, {i_n0, 0});
}
else
{
return make_null_tile_window(knew_rotary_cos_sin_dram_window_lengths);
}
}();
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k;
return make_page_block_navigator<KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
k_dram,
make_k_dram(nullptr,
(kargs.seqlen_k + kargs.seqlen_knew) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(k_dram);
}
}();
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k + kargs.seqlen_knew, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v;
return make_page_block_navigator<VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
(kargs.seqlen_k + kargs.seqlen_knew) -
(num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(v_dram);
}
}();
auto q_dram_window =
make_tile_window(q_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
{i_m0, 0});
const bool skip_append_kv = kargs.seqlen_knew <= i_n0;
// window origin = (0, 0) if no work to do for current block
auto [i_page_block_k, k_dram_window] = k_page_block_navigator.make_tile_window(
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{!skip_append_kv * (kargs.seqlen_k + i_n0), 0});
auto knew_dram_window =
make_tile_window(knew_dram,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
{i_n0, 0});
// window origin = (0, 0) if no work to do for current block
auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, !skip_append_kv * (kargs.seqlen_k + i_n0)});
auto vnew_dram_window =
make_tile_window(vnew_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kN0>{}),
{0, i_n0});
if constexpr(kApplyRoPE)
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
kargs.rotary_dim,
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
else
{
FmhaPipeline{}(q_dram_window,
k_dram_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_window,
v_dram_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_window,
q_rotary_cos_dram_window,
q_rotary_sin_dram_window,
knew_rotary_cos_dram_window,
knew_rotary_sin_dram_window,
0, // rotary_dim not used
kargs.seqlen_q <= i_m0,
skip_append_kv);
}
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <index_t kM0_, index_t kN0_, index_t kK0_, index_t kN1_>
struct FmhaFwdAppendKVTilePartitioner
{
static constexpr ck_tile::index_t kM0 = kM0_;
static constexpr ck_tile::index_t kN0 = kN0_;
static constexpr ck_tile::index_t kK0 = kK0_;
static constexpr ck_tile::index_t kN1 = kN1_;
static_assert(kK0 == kN1);
CK_TILE_HOST static constexpr auto GridSize(ck_tile::index_t batch_size,
ck_tile::index_t nhead,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_knew)
{
// TODO: this may need tuning
return dim3(std::max(ck_tile::integer_divide_ceil(seqlen_q, kM0),
ck_tile::integer_divide_ceil(seqlen_knew, kN0)),
nhead,
batch_size);
}
CK_TILE_DEVICE auto operator()()
{
const index_t i_tile = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_tile, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel ...@@ -32,8 +32,6 @@ struct FmhaFwdSplitKVKernel
using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>; using KDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::KDataType>;
using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>; using VDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::VDataType>;
using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>; using BiasDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::BiasDataType>;
using RandValOutputDataType =
ck_tile::remove_cvref_t<typename FmhaPipeline::RandValOutputDataType>;
using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>; using LSEDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::LSEDataType>;
using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>; using SaccDataType = ck_tile::remove_cvref_t<typename FmhaPipeline::SaccDataType>;
using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>; using OaccDataType = remove_cvref_t<typename FmhaPipeline::OaccDataType>;
...@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel ...@@ -46,8 +44,10 @@ struct FmhaFwdSplitKVKernel
static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = FmhaPipeline::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = FmhaPipeline::Problem::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = FmhaPipeline::Problem::kIsPagedKV;
static_assert(!kIsGroupMode || (kIsGroupMode && !kIsPagedKV),
"paged-kvcache only supported by batch mode kernels");
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; static constexpr bool kHasMask = FmhaMask::IsMasking;
...@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel ...@@ -85,8 +85,8 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kDoFp8StaticQuant ? "_squant" : "") + (kIsPagedKV ? "_pagedkv" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel ...@@ -110,7 +110,6 @@ struct FmhaFwdSplitKVKernel
void* o_acc_ptr; void* o_acc_ptr;
ck_tile::index_t batch; ck_tile::index_t batch;
ck_tile::index_t max_seqlen_q;
ck_tile::index_t seqlen_q; ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k; ck_tile::index_t seqlen_k;
...@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel ...@@ -136,6 +135,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel ...@@ -173,32 +173,16 @@ struct FmhaFwdSplitKVKernel
float scale_p; float scale_p;
}; };
struct CommonDropoutKargs struct PageBlockTableKargs
{ {
void init_dropout(const float p_drop, const int32_t* block_table_ptr;
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) ck_tile::index_t batch_stride_block_table;
{ ck_tile::index_t page_block_size;
float p_undrop = 1.0 - p_drop;
p_undrop_in_uint8_t =
uint8_t(std::floor(p_undrop * std::numeric_limits<uint8_t>::max()));
rp_undrop = 1.0 / p_undrop;
drop_seed = std::get<0>(drop_seed_offset);
drop_offset = std::get<1>(drop_seed_offset);
}
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1;
uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr;
ck_tile::index_t stride_randval = 0;
ck_tile::index_t nhead_stride_randval = 0;
}; };
struct BatchModeDropoutKargs : CommonDropoutKargs
struct CacheBatchIdxKargs
{ {
ck_tile::index_t batch_stride_randval = 0; const int32_t* cache_batch_idx;
}; };
struct BatchModeKargs struct BatchModeKargs
...@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel ...@@ -210,12 +194,13 @@ struct FmhaFwdSplitKVKernel
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>, std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>,
std::conditional_t<kHasDropout, BatchModeDropoutKargs, EmptyKargs<3>> std::conditional_t<kIsPagedKV, PageBlockTableKargs, CacheBatchIdxKargs>
{ {
const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel ...@@ -226,12 +211,14 @@ struct FmhaFwdSplitKVKernel
AlibiKargs, AlibiKargs,
EmptyKargs<0>>>, EmptyKargs<0>>>,
std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>, std::conditional_t<kHasMask, MaskKargs, EmptyKargs<1>>,
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>, std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<2>>
std::conditional_t<kHasDropout, CommonDropoutKargs, EmptyKargs<3>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
const int32_t* seqlen_k_ptr; const int32_t* seqlen_k_ptr;
ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v;
}; };
using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>; using Kargs = std::conditional_t<kIsGroupMode, GroupModeKargs, BatchModeKargs>;
...@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel ...@@ -242,48 +229,45 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* rand_val_ptr,
void* lse_acc_ptr, void* lse_acc_ptr,
void* o_acc_ptr, void* o_acc_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k, // only used if 'seqlen_k_ptr' is not specified
const void* seqlen_k_ptr, // only used for (paged-) kvcache
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
ck_tile::index_t hdim_v, ck_tile::index_t hdim_v,
ck_tile::index_t num_head_q, ck_tile::index_t num_head_q,
ck_tile::index_t nhead_ratio_qk, ck_tile::index_t nhead_ratio_qk,
ck_tile::index_t num_splits, ck_tile::index_t num_splits,
const void* block_table_ptr,
ck_tile::index_t batch_stride_block_table,
ck_tile::index_t page_block_size,
const void* cache_batch_idx,
float scale_s, float scale_s,
float scale_p, float scale_p,
ck_tile::index_t stride_q, ck_tile::index_t stride_q,
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o_acc, ck_tile::index_t stride_o_acc,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v, ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_bias, ck_tile::index_t batch_stride_bias,
ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_lse_acc, ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type)
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel ...@@ -291,7 +275,6 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr, lse_acc_ptr,
o_acc_ptr, o_acc_ptr,
batch, batch,
max_seqlen_q,
seqlen_q, seqlen_q,
seqlen_k, seqlen_k,
hdim_q, hdim_q,
...@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel ...@@ -313,17 +296,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout {}, // placeholder for paged-block table or cache_batch_idx
reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v};
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel ...@@ -347,14 +331,15 @@ struct FmhaFwdSplitKVKernel
{ {
kargs.scale_p = scale_p; kargs.scale_p = scale_p;
} }
if constexpr(kHasDropout) if constexpr(kIsPagedKV)
{
kargs.block_table_ptr = reinterpret_cast<const int32_t*>(block_table_ptr);
kargs.batch_stride_block_table = batch_stride_block_table;
kargs.page_block_size = page_block_size;
}
else
{ {
kargs.init_dropout(p_drop, drop_seed_offset); kargs.cache_batch_idx = reinterpret_cast<const int32_t*>(cache_batch_idx);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
kargs.is_store_randval = s_randval;
} }
return kargs; return kargs;
...@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel ...@@ -366,11 +351,9 @@ struct FmhaFwdSplitKVKernel
const void* k_ptr, const void* k_ptr,
const void* v_ptr, const void* v_ptr,
const void* bias_ptr, const void* bias_ptr,
void* rand_val_ptr,
void* lse_acc_ptr, void* lse_acc_ptr,
void* o_acc_ptr, void* o_acc_ptr,
ck_tile::index_t batch, ck_tile::index_t batch,
ck_tile::index_t max_seqlen_q,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
const void* seqlen_k_ptr, const void* seqlen_k_ptr,
...@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel ...@@ -385,24 +368,22 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t stride_k, ck_tile::index_t stride_k,
ck_tile::index_t stride_v, ck_tile::index_t stride_v,
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval,
ck_tile::index_t stride_o_acc, ck_tile::index_t stride_o_acc,
ck_tile::index_t nhead_stride_q, ck_tile::index_t nhead_stride_q,
ck_tile::index_t nhead_stride_k, ck_tile::index_t nhead_stride_k,
ck_tile::index_t nhead_stride_v, ck_tile::index_t nhead_stride_v,
ck_tile::index_t nhead_stride_bias, ck_tile::index_t nhead_stride_bias,
ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_k,
ck_tile::index_t batch_stride_v,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type)
float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
k_ptr, k_ptr,
...@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel ...@@ -410,9 +391,8 @@ struct FmhaFwdSplitKVKernel
lse_acc_ptr, lse_acc_ptr,
o_acc_ptr, o_acc_ptr,
batch, batch,
max_seqlen_q, -1, // seqlen_q will be updated by another pointer
-1, // seqlen will be updated by another pointer -1, // seqlen_k will be updated by another pointer
-1, //
hdim_q, hdim_q,
hdim_v, hdim_v,
num_head_q, num_head_q,
...@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel ...@@ -432,16 +412,18 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for bias {}, // placeholder for bias
{}, // placeholder for mask {}, // placeholder for mask
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
{}, // placeholder for dropout
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqlen_k_ptr),
batch_stride_k,
batch_stride_v};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel ...@@ -464,14 +446,6 @@ struct FmhaFwdSplitKVKernel
{ {
kargs.scale_p = scale_p; kargs.scale_p = scale_p;
} }
if constexpr(kHasDropout)
{
kargs.init_dropout(p_drop, drop_seed_offset);
kargs.rand_val_ptr = rand_val_ptr;
kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
kargs.is_store_randval = s_randval;
}
return kargs; return kargs;
} }
...@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel ...@@ -508,7 +482,6 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_k = 0; long_index_t batch_offset_k = 0;
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0;
long_index_t batch_offset_lse_acc = 0; long_index_t batch_offset_lse_acc = 0;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
...@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel ...@@ -534,14 +507,9 @@ struct FmhaFwdSplitKVKernel
{ {
batch_offset_bias = query_start * kargs.stride_bias + key_start; batch_offset_bias = query_start * kargs.stride_bias + key_start;
} }
if constexpr(kHasDropout)
{
batch_offset_randval = query_start * kargs.stride_randval;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; kargs.seqlen_q = kargs.seqstart_q_ptr[i_batch + 1] - kargs.seqstart_q_ptr[i_batch];
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
// # of required blocks is different in each groups, terminate unnecessary blocks // # of required blocks is different in each groups, terminate unnecessary blocks
// earlier // earlier
...@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel ...@@ -556,24 +524,36 @@ struct FmhaFwdSplitKVKernel
} }
else else
{ {
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch; kargs.seqlen_k = kargs.seqstart_k_ptr[i_batch + 1] - kargs.seqstart_k_ptr[i_batch];
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
} }
} }
else else
{ {
const index_t i_cache_batch = [&, i_batch_ = i_batch] {
if constexpr(kIsPagedKV)
{
return i_batch_;
}
else
{
return (kargs.cache_batch_idx != nullptr ? kargs.cache_batch_idx[i_batch_]
: i_batch_);
}
}();
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_cache_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc; batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
} }
if constexpr(kHasDropout)
if(kargs.seqlen_k_ptr != nullptr)
{ {
batch_offset_randval = kargs.seqlen_k = kargs.seqlen_k_ptr[i_batch];
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
} }
} }
...@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel ...@@ -589,6 +569,7 @@ struct FmhaFwdSplitKVKernel
reinterpret_cast<const VDataType*>(kargs.v_ptr) + reinterpret_cast<const VDataType*>(kargs.v_ptr) +
static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v + static_cast<long_index_t>(i_nhead / kargs.nhead_ratio_qk) * kargs.nhead_stride_v +
batch_offset_v; batch_offset_v;
OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) + OaccDataType* o_acc_ptr = reinterpret_cast<OaccDataType*>(kargs.o_acc_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_o_acc +
batch_offset_o_acc + i_split * kargs.split_stride_o_acc; batch_offset_o_acc + i_split * kargs.split_stride_o_acc;
...@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel ...@@ -616,10 +597,11 @@ struct FmhaFwdSplitKVKernel
sequence<kPadSeqLenQ, kPadHeadDimQ>{}); sequence<kPadSeqLenQ, kPadHeadDimQ>{});
} }
}(); }();
const auto k_dram = [&]() {
const auto make_k_dram = [&](const KDataType* data, index_t height) {
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr, data, // will update this pointer if using paged-kvcache
make_tuple(kargs.seqlen_k, kargs.hdim_q), make_tuple(height, kargs.hdim_q),
make_tuple(kargs.stride_k, 1), make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
...@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel ...@@ -628,13 +610,24 @@ struct FmhaFwdSplitKVKernel
k_dram_naive, k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{}); sequence<kPadSeqLenK, kPadHeadDimQ>{});
};
const auto k_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_k_dram(nullptr, kargs.page_block_size);
}
else
{
return make_k_dram(k_ptr, kargs.seqlen_k);
}
}(); }();
const auto v_dram = [&]() {
const auto make_v_dram = [&](const VDataType* data, index_t length) {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr, data, // will update this pointer if using paged-kvcache
make_tuple(kargs.seqlen_k, kargs.hdim_v), make_tuple(length, kargs.hdim_v),
make_tuple(kargs.stride_v, 1), make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
...@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel ...@@ -642,7 +635,7 @@ struct FmhaFwdSplitKVKernel
const auto v_dram_transposed = const auto v_dram_transposed =
transform_tensor_view(v_dram_naive, transform_tensor_view(v_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v), make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_k)), make_pass_through_transform(length)),
make_tuple(sequence<1>{}, sequence<0>{}), make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{})); make_tuple(sequence<0>{}, sequence<1>{}));
...@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel ...@@ -654,8 +647,8 @@ struct FmhaFwdSplitKVKernel
else else
{ {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
v_ptr, data, // will update this pointer if using paged-kvcache
make_tuple(kargs.hdim_v, kargs.seqlen_k), make_tuple(kargs.hdim_v, length),
make_tuple(kargs.stride_v, 1), make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
...@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel ...@@ -665,6 +658,76 @@ struct FmhaFwdSplitKVKernel
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}), make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenK>{}); sequence<kPadHeadDimV, kPadSeqLenK>{});
} }
};
const auto v_dram = [&]() {
if constexpr(kIsPagedKV)
{
return make_v_dram(nullptr, kargs.page_block_size);
}
else
{
return make_v_dram(v_ptr, kargs.seqlen_k);
}
}();
auto k_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_k;
return make_page_block_navigator<const KDataType, 0>(
kargs.k_ptr,
kargs.batch_stride_k,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
k_dram,
make_k_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(k_dram);
}
}();
auto v_page_block_navigator = [&, i_batch_ = i_batch, i_nhead_ = i_nhead]() {
if constexpr(kIsPagedKV)
{
const auto* block_indices =
reinterpret_cast<const int32_t*>(kargs.block_table_ptr) +
i_batch_ * kargs.batch_stride_block_table;
const index_t num_blocks =
integer_divide_ceil(kargs.seqlen_k, kargs.page_block_size);
const long_index_t fixed_offset =
static_cast<long_index_t>(i_nhead_ / kargs.nhead_ratio_qk) *
kargs.nhead_stride_v;
return make_page_block_navigator<const VDataType, 1>(
kargs.v_ptr,
kargs.batch_stride_v,
fixed_offset,
block_indices,
num_blocks,
kargs.page_block_size,
v_dram,
make_v_dram(nullptr,
kargs.seqlen_k - (num_blocks - 1) * kargs.page_block_size));
}
else
{
return make_page_block_navigator(v_dram);
}
}(); }();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
...@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel ...@@ -678,13 +741,11 @@ struct FmhaFwdSplitKVKernel
}(), }(),
{i_m0, 0}); {i_m0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window_lengths =
k_dram, make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}), {0, 0}); make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
auto v_dram_window_lengths =
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{});
auto v_dram_window =
make_tile_window(v_dram,
make_tuple(number<FmhaPipeline::kN1>{}, number<FmhaPipeline::kK1>{}),
{i_n1, 0});
/// FIXME: Before C++20, capturing structured binding variables are not supported. Remove /// FIXME: Before C++20, capturing structured binding variables are not supported. Remove
/// following copy capture of the 'i_nhead' if in C++20 /// following copy capture of the 'i_nhead' if in C++20
const auto bias_dram_window = [&, i_nhead_ = i_nhead]() { const auto bias_dram_window = [&, i_nhead_ = i_nhead]() {
...@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel ...@@ -741,62 +802,6 @@ struct FmhaFwdSplitKVKernel
return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0}); return make_tile_window(lse_acc_dram, lse_acc_dram_window_lengths, {i_m0});
}(); }();
// dropout
float rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
bool is_store_randval = false;
if constexpr(kHasDropout)
{
rp_undrop = kargs.rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
}
BlockDropout dropout(i_batch,
i_nhead,
kargs.num_head_q,
drop_seed,
drop_offset,
rp_undrop,
p_undrop_in_uint8_t,
is_store_randval);
auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout)
{
RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_randval +
batch_offset_randval;
const auto randval_dram = [&]() {
const auto randval_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
rand_val_ptr,
make_tuple(kargs.seqlen_q, kargs.seqlen_k),
make_tuple(kargs.stride_randval, 1),
number<1>{},
number<1>{});
return pad_tensor_view(randval_dram_naive,
randval_dram_window_lengths,
sequence<kPadSeqLenQ, kPadSeqLenK>{});
}();
return make_tile_window(randval_dram, randval_dram_window_lengths, {i_m0, 0});
}
else
{
return make_null_tile_window(randval_dram_window_lengths);
}
}();
FmhaMask mask = [&]() { FmhaMask mask = [&]() {
if constexpr(kHasMask) if constexpr(kHasMask)
return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>( return ck_tile::make_generic_attention_mask_from_lr_window<FmhaMask>(
...@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel ...@@ -823,16 +828,16 @@ struct FmhaFwdSplitKVKernel
#endif #endif
if constexpr(kHasMask) if constexpr(kHasMask)
{ {
return make_alibi_from_lr_mask<SaccDataType, true>(slope, return make_alibi_from_lr_mask<SaccDataType, true, 32>(slope,
kargs.window_size_left, kargs.window_size_left,
kargs.window_size_right, kargs.window_size_right,
kargs.seqlen_q, kargs.seqlen_q,
kargs.seqlen_k, kargs.seqlen_k,
kargs.mask_type); kargs.mask_type);
} }
else else
{ {
return Alibi<SaccDataType, true>{ return Alibi<SaccDataType, true, 32>{
slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT}; slope, kargs.seqlen_q, kargs.seqlen_k, AlibiMode::FROM_BOTTOM_RIGHT};
} }
} }
...@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel ...@@ -847,13 +852,14 @@ struct FmhaFwdSplitKVKernel
{ {
return FmhaPipeline{}(q_dram_window, return FmhaPipeline{}(q_dram_window,
identity{}, // q_element_func identity{}, // q_element_func
k_dram_window, k_dram_window_lengths,
k_page_block_navigator,
identity{}, // k_element_func identity{}, // k_element_func
v_dram_window, v_dram_window_lengths,
v_page_block_navigator,
identity{}, // v_element_func identity{}, // v_element_func
bias_dram_window, bias_dram_window,
identity{}, // bias_element_func identity{}, // bias_element_func
randval_dram_window,
lse_acc_dram_window, lse_acc_dram_window,
identity{}, // lse_element_func identity{}, // lse_element_func
identity{}, // s_acc_element_func identity{}, // s_acc_element_func
...@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel ...@@ -864,24 +870,23 @@ struct FmhaFwdSplitKVKernel
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
smem_ptr, smem_ptr);
dropout);
} }
else else
{ {
return FmhaPipeline{}(q_dram_window, return FmhaPipeline{}(q_dram_window,
k_dram_window, k_dram_window_lengths,
v_dram_window, k_page_block_navigator,
v_dram_window_lengths,
v_page_block_navigator,
bias_dram_window, bias_dram_window,
randval_dram_window,
lse_acc_dram_window, lse_acc_dram_window,
kargs.num_splits, kargs.num_splits,
i_split_, i_split_,
mask, mask,
position_encoding, position_encoding,
kargs.scale_s, kargs.scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}(); }();
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_appendkv_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdAppendKVPipelineDefaultPolicy>
struct BlockFmhaFwdAppendKVPipeline
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = typename Problem::QDataType;
using KDataType = typename Problem::KDataType;
using VDataType = typename Problem::VDataType;
using VLayout = typename Problem::VLayout;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kK0 = Problem::kK0;
static constexpr index_t kN1 = Problem::kN1;
static constexpr auto RotaryEnum = Problem::RotaryEnum;
static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0 <= 32)
{
return 2;
}
else if constexpr(kK0 <= 64)
{
return 3;
}
else if constexpr(kK0 <= 128)
{
return 2;
}
else if constexpr(kK0 <= 256)
{
return 1;
}
}
}();
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QElementFunction,
typename KnewElementFunction,
typename VnewElementFunction,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window, // M0*K0 tile
const QElementFunction& q_element_func,
KDramBlockWindow& k_dram_block_window, // N0*K0 tile
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window, // N0*K0 tile
const KnewElementFunction& knew_element_func,
VDramBlockWindow& v_dram_block_window, // N1*N0 tile
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window, // N1*N0 tile
const VnewElementFunction& vnew_element_func,
const QRotaryCosDramBlockWindow q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow knew_rotary_sin_dram_block_window,
index_t rotary_dim,
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
if(!skip_rotate_append_kv)
{
// append Knew to K
auto knew_window = make_tile_window(
knew_dram_block_window, Policy::template MakeKnewDramTileDistribution<Problem>());
auto knew_tile = [&]() {
auto knew = load_tile(knew_window);
return tile_elementwise_in(knew_element_func, knew);
}();
// optionally apply rotary embedding to Knew
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto rotary_cos_window =
make_tile_window(knew_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
auto rotary_sin_window =
make_tile_window(knew_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/false>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// knew_tile content
auto [thread_start, thread_end] =
Policy::template GetKnewThreadRangeAlongK<Problem>();
ignore = thread_start;
BlockRotaryEmbedding<RotaryEnum>::apply(knew_tile,
knew_window,
rotary_cos_window,
rotary_sin_window,
rotary_dim,
thread_end);
}
store_tile(k_dram_block_window, knew_tile);
// write tile to another block if nesscary
if constexpr(kIsPagedKV)
{
if(k_page_block_navigator.is_cross_block(i_page_block_k, k_dram_block_window))
{
k_page_block_navigator.move_to_block(
i_page_block_k, k_dram_block_window, i_page_block_k + 1);
store_tile(k_dram_block_window, knew_tile);
}
}
// append Vnew to V
auto vnew_window = make_tile_window(
vnew_dram_block_window, Policy::template MakeVnewDramTileDistribution<Problem>());
auto vnew_tile = [&]() {
auto vnew = load_tile(vnew_window);
return tile_elementwise_in(vnew_element_func, vnew);
}();
store_tile(v_dram_block_window, vnew_tile);
// write tile to another block if nesscary
if constexpr(kIsPagedKV)
{
if(v_page_block_navigator.is_cross_block(i_page_block_v, v_dram_block_window))
{
v_page_block_navigator.move_to_block(
i_page_block_v, v_dram_block_window, i_page_block_v + 1);
store_tile(v_dram_block_window, vnew_tile);
}
}
}
if(!skip_rotate_q)
{
// optionally apply rotary embedding to Q
if constexpr(RotaryEnum != RotaryEmbeddingEnum::NONE)
{
auto q_window = make_tile_window(
q_dram_block_window, Policy::template MakeQDramTileDistribution<Problem>());
auto q_tile = [&]() {
auto q = load_tile(q_window);
return tile_elementwise_in(q_element_func, q);
}();
auto rotary_cos_window =
make_tile_window(q_rotary_cos_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
auto rotary_sin_window =
make_tile_window(q_rotary_sin_dram_block_window,
Policy::template MakeRotaryCosSinTileDistribution<
Problem,
/*IsRotaryCosSinForQ=*/true>());
// We assume that each thread owns contiguous elements on head dimention. And we
// will use the distribution to enable/disable threads in order to override partial
// q_tile content
auto [thread_start, thread_end] = Policy::template GetQThreadRangeAlongK<Problem>();
ignore = thread_start;
BlockRotaryEmbedding<RotaryEnum>::apply(
q_tile, q_window, rotary_cos_window, rotary_sin_window, rotary_dim, thread_end);
store_tile(q_dram_block_window, q_tile);
}
}
}
template <typename QDramBlockWindow,
typename KDramBlockWindow,
typename KPageBlockNavigator,
typename KnewDramBlockWindow,
typename VDramBlockWindow,
typename VPageBlockNavigator,
typename VnewDramBlockWindow,
typename QRotaryCosDramBlockWindow,
typename QRotarySinDramBlockWindow,
typename KnewRotaryCosDramBlockWindow,
typename KnewRotarySinDramBlockWindow>
CK_TILE_HOST_DEVICE auto
operator()(QDramBlockWindow& q_dram_block_window,
KDramBlockWindow& k_dram_block_window,
index_t i_page_block_k,
const KPageBlockNavigator& k_page_block_navigator,
const KnewDramBlockWindow& knew_dram_block_window,
VDramBlockWindow& v_dram_block_window,
index_t i_page_block_v,
const VPageBlockNavigator& v_page_block_navigator,
const VnewDramBlockWindow& vnew_dram_block_window,
const QRotaryCosDramBlockWindow& q_rotary_cos_dram_block_window,
const QRotarySinDramBlockWindow& q_rotary_sin_dram_block_window,
const KnewRotaryCosDramBlockWindow& knew_rotary_cos_dram_block_window,
const KnewRotarySinDramBlockWindow& knew_rotary_sin_dram_block_window,
index_t rotary_dim,
bool skip_rotate_q,
bool skip_rotate_append_kv) const
{
return operator()(q_dram_block_window,
identity{},
k_dram_block_window,
i_page_block_k,
k_page_block_navigator,
knew_dram_block_window,
identity{},
v_dram_block_window,
i_page_block_v,
v_page_block_navigator,
vnew_dram_block_window,
identity{},
q_rotary_cos_dram_block_window,
q_rotary_sin_dram_block_window,
knew_rotary_cos_dram_block_window,
knew_rotary_sin_dram_block_window,
rotary_dim,
skip_rotate_q,
skip_rotate_append_kv);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
struct BlockFmhaFwdAppendKVPipelineDefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kN1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct!
if constexpr(total_pixels > 4)
return 4;
else
return 2;
}
else
{
return 16 / sizeof(VDataType);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQNumElemsPerRead()
{
using DataType = typename Problem::QDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetQThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
static_assert(Problem::kK0 % KPerThread == 0);
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = GetQNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t MThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t MPerThread = kMPerBlock / (NumWarps * MThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<MPerThread, NumWarps, MThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetKnewNumElemsPerRead()
{
using DataType = typename Problem::KDataType;
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 16 / sizeof(DataType);
}
}
template <typename Problem>
CK_TILE_DEVICE static auto GetKnewThreadRangeAlongK()
{
static_assert(Problem::RotaryEnum != RotaryEmbeddingEnum::NONE);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::INTERLEAVED)
{
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
else
{
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = Problem::kK0 / KPerThread;
index_t start_pos = (get_thread_id() % KThreadPerBlock) * KPerThread;
return make_tuple(start_pos, start_pos + KPerThread);
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKnewDramTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN0;
constexpr index_t kKPerBlock = Problem::kK0;
constexpr index_t KPerThread = GetKnewNumElemsPerRead<Problem>();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{
// TODO: this is for 3d layout
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVnewDramTileDistribution()
{
using VLayout = remove_cvref_t<typename Problem::VLayout>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::kN1;
constexpr index_t kKPerBlock = Problem::kN0;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t NPerThread = 16 / sizeof(VDataType);
constexpr index_t NThreadPerBlock = kNPerBlock / NPerThread;
constexpr index_t KThreadPerWarp = get_warp_size() / NThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t KPerThread = kKPerBlock / (NumWarps * KThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NThreadPerBlock, NPerThread>,
sequence<KPerThread, NumWarps, KThreadPerWarp>>,
tuple<sequence<2>, sequence<1, 2>>,
tuple<sequence<1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 0>>{});
}
else
{
constexpr index_t KPerThread = 16 / sizeof(VDataType);
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_HOST_DEVICE static constexpr auto GetRotaryCosSinTileSize()
{
constexpr index_t height = (IsRotaryCosSinForQ ? Problem::kM0 : Problem::kN0);
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
return make_tuple(number<height>{}, number<Problem::kK0>{});
}
else
{
return make_tuple(number<height>{}, number<Problem::kK0 / 2>{});
}
}
template <typename Problem, bool IsRotaryCosSinForQ>
CK_TILE_HOST_DEVICE static constexpr auto MakeRotaryCosSinTileDistribution()
{
using DataType = std::conditional_t<IsRotaryCosSinForQ,
typename Problem::QDataType,
typename Problem::KDataType>;
constexpr auto TileSize = GetRotaryCosSinTileSize<Problem, IsRotaryCosSinForQ>();
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = TileSize[number<0>{}];
constexpr index_t kKPerBlock = TileSize[number<1>{}];
constexpr index_t KPerThread = []() {
if constexpr(Problem::RotaryEnum == RotaryEmbeddingEnum::HALF_ROTATED)
{
/// NOTICE: we might need to lower down this to support smaller rotary_dim
return 16 / sizeof(DataType);
}
else
{
return 8 / sizeof(DataType);
}
}();
constexpr index_t KThreadPerBlock = kKPerBlock / KPerThread;
constexpr index_t NThreadPerWarp = get_warp_size() / KThreadPerBlock;
constexpr index_t NumWarps = kBlockSize / get_warp_size();
constexpr index_t NPerThread = kNPerBlock / (NumWarps * NThreadPerWarp);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<NPerThread, NumWarps, NThreadPerWarp>,
sequence<KThreadPerBlock, KPerThread>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
};
} // namespace ck_tile
...@@ -6,7 +6,6 @@ ...@@ -6,7 +6,6 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -15,19 +14,18 @@ namespace ck_tile { ...@@ -15,19 +14,18 @@ namespace ck_tile {
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy> template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineQRKSVS struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
using Problem = remove_cvref_t<Problem_>; using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>; using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>; using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>; using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>; using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -49,8 +47,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag static constexpr bool kIsPagedKV = Problem::kIsPagedKV;
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -106,10 +104,11 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename QElementFunction, typename QElementFunction,
typename KElementFunction, typename KElementFunction,
...@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -123,13 +122,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func, const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const KPageBlockNavigator& k_page_block_navigator,
const KElementFunction& k_element_func, const KElementFunction& k_element_func,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const VElementFunction& v_element_func, const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func, const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func, const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func, const SAccElementFunction& s_acc_element_func,
...@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -140,20 +140,19 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
BlockDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KPageBlockNavigator::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>, std::is_same_v<VDataType, remove_cvref_t<typename VPageBlockNavigator::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowLengths{}[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK0 == KDramBlockWindowLengths{}[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN1 == VDramBlockWindowLengths{}[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kK1 == VDramBlockWindowLengths{}[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
...@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -213,12 +212,12 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX( const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split); q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do. // check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kHasUnevenSplits) if constexpr(FmhaMask::IsMasking || kHasUnevenSplits)
{ {
if(num_total_loop <= 0) const index_t original_num_total_loop =
integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
if(original_num_total_loop <= 0)
{ {
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -237,26 +236,34 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
} }
auto k_dram_block_window = // make sure the first tile is completely located in page-block
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(), const index_t adjusted_seqlen_k_start = [&, seqlen_k_start_ = seqlen_k_start] {
k_dram_block_window_tmp.get_window_lengths(), if constexpr(kIsPagedKV)
{seqlen_k_start, 0}); {
return kN0 * integer_divide_floor(seqlen_k_start_, kN0);
}
else
{
return seqlen_k_start_;
}
}();
const index_t num_total_loop =
integer_divide_ceil(seqlen_k_end - adjusted_seqlen_k_start, kN0);
auto [i_page_block_k, k_dram_block_window] = k_page_block_navigator.make_tile_window(
k_dram_block_window_lengths, {adjusted_seqlen_k_start, 0});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window( auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(), bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), adjusted_seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto [i_page_block_v, v_dram_window] = v_page_block_navigator.make_tile_window(
randval_dram_block_window_tmp, seqlen_k_start); v_dram_block_window_lengths,
{0, adjusted_seqlen_k_start}, // TODO: hdim split?
auto v_dram_window = Policy::template MakeVDramTileDistribution<Problem>());
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
auto q_tile = tile_elementwise_in(q_element_func, q); auto q_tile = tile_elementwise_in(q_element_func, q);
...@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -271,14 +278,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{ {
// STAGE 1, QK gemm // STAGE 1, QK gemm
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(), k_dram_block_window,
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load // load
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
{ {
// moving k_dram_window is an in-page-block operation, so there is
// no need to invoke k_page_block_navigator.move_tile_window() here.
move_tile_window(k_dram_window, {0, kK0}); move_tile_window(k_dram_window, {0, kK0});
clear_tile(s_acc); // initialize C clear_tile(s_acc); // initialize C
store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile)); store_tile(k_lds_window, tile_elementwise_in(k_element_func, k_block_tile));
...@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -355,7 +362,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc); s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
...@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -381,22 +389,32 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
move_tile_window(bias_dram_window, {0, kN0}); move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size /// TODO: only check in first/last iteration without increasing code size
if constexpr(kHasUnevenSplits) if constexpr(kHasUnevenSplits)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
set_tile_if(s_acc, set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(), -numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) { [&, seqlen_k_start_ = seqlen_k_start, seqlen_k_end_ = seqlen_k_end](
auto tile_idx) {
const auto col = const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col; if constexpr(kIsPagedKV)
{
return col < seqlen_k_start_ || seqlen_k_end_ <= col;
}
else
{
return seqlen_k_end_ <= col;
}
}); });
} }
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto k_origin = k_dram_block_window.get_window_origin(); const auto k_origin = k_page_block_navigator.to_global_window_origin(
i_page_block_k, k_dram_block_window.get_window_origin());
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}), k_origin.at(number<0>{}),
number<kM0>{}, number<kM0>{},
...@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -501,12 +519,6 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window);
}
block_sync_lds(); block_sync_lds();
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
...@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -522,7 +534,8 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch tile_elementwise_in(v_element_func, v_prefetch)); // store the prefetch
} }
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v =
v_page_block_navigator.move_tile_window(i_page_block_v, v_dram_window, {0, kK1});
const auto p = const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute)); cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
...@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -530,8 +543,10 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
// STAGE 3, KV gemm // STAGE 3, KV gemm
if constexpr(k1_loops > 1) if constexpr(k1_loops > 1)
{ {
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { static_for<0, k1_loops - 1, 1>{}([&,
const auto v = load_tile(v_dram_window); // load next v &i_page_block_v_ = i_page_block_v,
&v_dram_window_ = v_dram_window](auto i_k1) {
const auto v = load_tile(v_dram_window_); // load next v
block_sync_lds(); block_sync_lds();
gemm_1(o_acc, gemm_1(o_acc,
get_slice_tile( get_slice_tile(
...@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -552,11 +567,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
store_tile(v_lds_window, store_tile(v_lds_window,
tile_elementwise_in(v_element_func, v)); // store next v tile_elementwise_in(v_element_func, v)); // store next v
} }
move_tile_window(v_dram_window, {0, kK1}); i_page_block_v_ = v_page_block_navigator.move_tile_window(
i_page_block_v_, v_dram_window_, {0, kK1});
}); });
} }
// move K tile windows // move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0}); i_page_block_k = k_page_block_navigator.move_tile_window(
i_page_block_k, k_dram_block_window, {kN0, 0});
// tail // tail
{ {
block_sync_lds(); block_sync_lds();
...@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -618,36 +635,38 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowLengths,
typename VDramBlockWindowTmp, typename KPageBlockNavigator,
typename VDramBlockWindowLengths,
typename VPageBlockNavigator,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp, typename LSEaccDramBlockWindowTmp,
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile const KDramBlockWindowLengths& k_dram_block_window_lengths, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile const KPageBlockNavigator& k_page_block_navigator,
const VDramBlockWindowLengths& v_dram_block_window_lengths, // N1*K1 tile
const VPageBlockNavigator& v_page_block_navigator,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits, index_t num_splits,
index_t i_split, index_t i_split,
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr) const
BlockDropout& dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
k_dram_block_window_tmp, k_dram_block_window_lengths,
k_page_block_navigator,
identity{}, identity{},
v_dram_block_window_tmp, v_dram_block_window_lengths,
v_page_block_navigator,
identity{}, identity{},
bias_dram_block_window_tmp, bias_dram_block_window_tmp,
identity{}, identity{},
randval_dram_block_window_tmp,
lse_acc_dram_block_window_tmp, lse_acc_dram_block_window_tmp,
identity{}, identity{},
identity{}, identity{},
...@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -658,8 +677,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
mask, mask,
position_encoding, position_encoding,
scale_s, scale_s,
smem_ptr, smem_ptr);
dropout);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_fwd_splitkv_pipeline_qr_ks_vs_async_default_policy.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
// a variation of qr/ks/vs, where we use async copy to load k (potentially v in the future)
template <typename Problem_, typename Policy_ = BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy>
struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using SaccDataType = remove_cvref_t<typename Problem::SaccDataType>;
using SMPLComputeDataType = remove_cvref_t<typename Problem::SMPLComputeDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
static constexpr bool kQLoadOnce = true; // if q_tile load whole block length (hdim) at once
static_assert(kQLoadOnce == Policy::QLoadOnce);
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kN1 = BlockFmhaShape::kN1;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK0BlockLength = BlockFmhaShape::kK0BlockLength;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
// TODO: seq_q always support padding, hdim_q/v support multiple of vector(like 8x)
// only need special care about seq_k padding (oob need set -INF of p instead of zero)
static_assert(Problem::kPadSeqLenQ == true && Problem::kPadHeadDimQ == true &&
Problem::kPadHeadDimV == true);
static constexpr bool kPadSeqLenQ = true;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ = Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK = Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = []() {
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
return Policy::template GetAlignmentV<Problem>();
else
return kPadSeqLenK ? 1 : Policy::template GetAlignmentV<Problem>();
}();
static constexpr index_t kAlignmentO = Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetAlignmentBias<Problem>();
#if CK_TILE_FMHA_FWD_FAST_EXP2
static constexpr auto R_LOG2E = 1.0 / log2e_v<SaccDataType>;
#endif
static constexpr index_t kBlockPerCu = []() {
if constexpr(Problem::kBlockPerCu != -1)
return Problem::kBlockPerCu;
else
{
if constexpr(kK0BlockLength <= 32)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS &&
FmhaMask::IsMasking)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 64)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 2;
else
return 3;
}
else if constexpr(kK0BlockLength <= 128)
{
if constexpr(kPadSeqLenK && BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return 1;
else
return 2;
}
else if constexpr(kK0BlockLength <= 256)
{
return 1;
}
}
}();
static constexpr const char* name = "qr_async";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename QElementFunction,
typename KElementFunction,
typename VElementFunction,
typename BiasElementFunction,
typename LSEaccElementFunction,
typename SAccElementFunction,
typename PComputeElementFunction,
typename OAccElementFunction,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const QElementFunction& q_element_func,
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const KElementFunction& /*k_element_func*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const VElementFunction& v_element_func,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
const BiasElementFunction& bias_element_func,
RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
LSEaccDramBlockWindowTmp& lse_acc_dram_window_tmp, // M0*1 tile
const LSEaccElementFunction& lse_acc_element_func,
const SAccElementFunction& s_acc_element_func,
const PComputeElementFunction& p_compute_element_func,
const OAccElementFunction& o_acc_element_func,
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK0 == KDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kN1 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kK1 == VDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
constexpr auto LdsSeq = Policy::template GetLdsBufferSequence<Problem>();
// K tile in LDS
auto k_lds_ptr = reinterpret_cast<KDataType*>(smem_ptr);
auto k_lds_store = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsStoreBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0, 0});
},
number<Policy::NumPrefetchK>{});
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
auto k_lds_load = generate_tuple(
[&](auto i_buf) {
return make_tile_window(
make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf)),
Policy::template MakeKLdsLoadBlockDescriptor<Problem>(i_buf).get_lengths(),
{0, 0});
},
number<Policy::NumPrefetchK>{});
#else
auto k_lds_Load_view = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsLoadBlockDescriptor<Problem>());
auto k_lds_load =
make_tile_window(k_lds_Load_view,
Policy::template MakeKLdsLoadBlockDescriptor<Problem>().get_lengths(),
{0, 0});
#endif
// V tile in LDS
auto v_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<VDataType*>(smem_ptr),
Policy::template MakeVLdsBlockDescriptor<Problem>());
auto v_lds_window = make_tile_window(
v_lds, Policy::template MakeVLdsBlockDescriptor<Problem>().get_lengths(), {0, 0});
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetKVBlockGemm<Problem>();
auto q_dram_window = make_tile_window(
q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
q_dram_block_window_tmp.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem, decltype(gemm_0)>());
// TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){};
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0);
using SaccBlockTileType = decltype(gemm_0.MakeCBlockTile());
auto s_acc = SaccBlockTileType{};
// reduction function for softmax
const auto f_max = [](auto e0, auto e1) { return max(e0, e1); };
const auto f_sum = [](auto e0, auto e1) { return e0 + e1; };
// infer Sacc, S, P, M, L, Oacc type
using SBlockTileType = decltype(cast_tile<SMPLComputeDataType>(s_acc));
using MLBlockTileType = decltype(block_tile_reduce<SMPLComputeDataType>(
SBlockTileType{}, sequence<1>{}, f_max, SMPLComputeDataType{0}));
using OaccBlockTileType = decltype(gemm_1.MakeCBlockTile());
// init Oacc, M, L
auto o_acc = OaccBlockTileType{};
auto m = MLBlockTileType{};
auto l = MLBlockTileType{};
clear_tile(o_acc);
set_tile(m, -numeric<SMPLComputeDataType>::infinity());
clear_tile(l);
__builtin_amdgcn_sched_barrier(0);
const auto q_origin = q_dram_window.get_window_origin();
const auto [seqlen_k_start, seqlen_k_end] = mask.GetTileRangeAlongX(
q_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}, num_splits, i_split);
const auto num_total_loop = integer_divide_ceil(seqlen_k_end - seqlen_k_start, kN0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking || kPadSeqLenK || kHasUnevenSplits)
{
if(num_total_loop <= 0)
{
if constexpr(kStoreLSE)
{
auto lse_acc =
make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
set_tile(lse_acc, -numeric<SMPLComputeDataType>::infinity());
store_tile(lse_acc_dram_window_tmp,
tile_elementwise_in(lse_acc_element_func, lse_acc));
}
buffer_load_fence(0); // rocm-6.1, if whole tile is masked out, need to fence(0)
// otherwise will have compute error(maybe compiler bug?)
// Note: here occ are all cleard, return it
return o_acc;
}
__builtin_amdgcn_sched_barrier(0); // make sure sched_barrier(0) for this check
}
auto k_dram_block_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
{seqlen_k_start, 0});
auto k_dram_window = make_tile_window(
k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_window = make_tile_window(
bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_k_start}, // TODO: hdim split?
Policy::template MakeVDramTileDistribution<Problem>());
// prefetch K tile
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
__builtin_amdgcn_sched_barrier(0);
buffer_load_fence(k_dram_window.get_num_access(), q.get_thread_buffer());
(void)q_element_func; // ??? rocm-6.x if use q element func will have scratch on hdim=64/32
// auto q_tile = q; // tile_elementwise_in(q_element_func, q);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kK0BlockLength / kK0;
constexpr index_t k1_loops = kN0 / kK1;
static_assert(1 <= k0_loops);
static_assert(1 <= k1_loops);
// main loop
do
{
// STAGE 1, QK gemm
clear_tile(s_acc); // initialize C
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
async_load_tile_raw(k_lds_store(number<LdsSeq.at(number<i_k0 + 1>{})>{}),
k_dram_window);
if constexpr(i_k0 < k0_loops - 1)
move_tile_window(k_dram_window, {0, kK0});
async_load_fence(k_dram_window.get_num_access());
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, i_k0 * kK0>{}, sequence<kM0, (i_k0 + 1) * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<i_k0>{})>{}]);
#else
get_slice_tile(k_lds_load,
sequence<(LdsSeq.at(number<i_k0>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<i_k0>{}) + 1) * kN0, kK0>{}));
#endif
});
}
// TODO: this to fix a bug when loop smaller than 2,
// the following fence/barrier will be scheduled inside 1st loop
if constexpr(k0_loops <= 2)
__builtin_amdgcn_sched_barrier(0);
async_load_fence();
__builtin_amdgcn_s_barrier();
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
auto v_buf = load_tile(v_dram_window, bool_constant<false>{});
__builtin_amdgcn_sched_barrier(0);
{ // tail
gemm_0(s_acc,
get_slice_tile(
q, sequence<0, (k0_loops - 1) * kK0>{}, sequence<kM0, k0_loops * kK0>{}),
#if K_LDS_LOAD_USE_OFFSET_TRANSFORM
k_lds_load[number<LdsSeq.at(number<k0_loops - 1>{})>{}]);
#else
get_slice_tile(
k_lds_load,
sequence<(LdsSeq.at(number<k0_loops - 1>{})) * kN0, 0>{},
sequence<(LdsSeq.at(number<k0_loops - 1>{}) + 1) * kN0, kK0>{}));
#endif
}
__builtin_amdgcn_sched_barrier(1);
// STAGE 2, scale_s, add bias, mask, softmax
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x += type_convert<SaccDataType>(bias_element_func(y));
#else
x += log2e_v<SaccDataType> *
type_convert<SaccDataType>(bias_element_func(y));
#endif
},
s_acc,
bias_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto k_origin = k_dram_block_window.get_window_origin();
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
s_acc(i_j_idx) *= scale_s;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
else
{
s_acc = tile_elementwise_in(s_acc_element_func, s_acc);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&scale_s](auto& x) { x = x * scale_s; }, s_acc);
#endif
}
move_tile_window(bias_dram_window, {0, kN0});
/// TODO: only check in last iteration without increasing code size
if constexpr(kHasUnevenSplits)
{
const auto k_origin = k_dram_block_window.get_window_origin();
set_tile_if(s_acc,
-numeric<SMPLComputeDataType>::infinity(),
[&, seqlen_k_end_ = seqlen_k_end](auto tile_idx) {
const auto col =
k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return seqlen_k_end_ <= col;
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto k_origin = k_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(
s_acc, -numeric<SMPLComputeDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto s = cast_tile<SMPLComputeDataType>(s_acc); // S{j}
auto m_local = block_tile_reduce<SMPLComputeDataType>(
s,
sequence<1>{},
f_max,
-numeric<SMPLComputeDataType>::infinity()); // m_local = rowmax(S{j})
block_tile_reduce_sync(m_local, f_max, bool_constant<false>{});
const auto m_old = m; // m{j-1}
tile_elementwise_inout(
[](auto& e0, auto e1, auto e2) { e0 = max(e1, e2); }, m, m_old, m_local); // m{j}
auto p_compute = make_static_distributed_tensor<SMPLComputeDataType>(
s.get_tile_distribution()); // Pcompute{j}
__builtin_amdgcn_sched_barrier(0x7F);
// store & prefetch next v, after the max reduction
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(
v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp =
get_slice_tile(v_lds_window,
sequence<(LdsSeq.at(number<k0_loops>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store the prefetch
}
if constexpr(k1_loops > 1)
{
move_tile_window(
v_dram_window,
{0, kK1}); // will have scratch if move this right after load_tile(v_dram)...
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
__builtin_amdgcn_sched_barrier(0);
static const auto get_validated_m = [](SMPLComputeDataType raw_m) {
/// NOTICE: bias might be materialized mask including -inf values, need
/// consideration. alibi does not have this problem
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_m == -numeric<SMPLComputeDataType>::infinity()
? type_convert<SMPLComputeDataType>(0.f)
: raw_m;
}
else
{
return raw_m;
}
};
constexpr auto p_spans = decltype(p_compute)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_max = scale_s * get_validated_m(m[i_idx]);
#endif
sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p_compute(i_j_idx) = exp2(s[i_j_idx] - get_validated_m(m[i_idx]));
}
else
{
p_compute(i_j_idx) = exp2(scale_s * s[i_j_idx] - row_max);
}
#else
p_compute(i_j_idx) = exp(s[i_j_idx] - get_validated_m(m[i_idx]));
#endif
});
});
auto rowsum_p = block_tile_reduce<SMPLComputeDataType>(
p_compute, sequence<1>{}, f_sum, SMPLComputeDataType{0}); // rowsum(Pcompute{j})
block_tile_reduce_sync(rowsum_p, f_sum, bool_constant<false>{});
// l{j}, Oacc{j}
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
const auto tmp = [&]() {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
return exp2(m_old[i_idx] - get_validated_m(m[i_idx]));
}
else
{
auto row_max = scale_s * get_validated_m(m[i_idx]);
return exp2(scale_s * m_old[i_idx] - row_max);
}
}();
#else
const auto tmp = exp(m_old[i_idx] - get_validated_m(m[i_idx]));
#endif
l(i_idx) = tmp * l[i_idx] + rowsum_p[i_idx];
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
// FIXME: this use different equation from FA v2 paper,
// but produce correc result.
// Is the equation wrong?
o_acc(i_j_idx) *= tmp;
});
});
if constexpr(kHasDropout)
{
auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr,
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
}
const auto p =
cast_tile<PDataType>(tile_elementwise_in(p_compute_element_func, p_compute));
// STAGE 3, KV gemm
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
if constexpr(i_k1 != 0 && i_k1 < k1_loops - 1)
{
v_buf = load_tile(v_dram_window, bool_constant<false>{}); // load next v_buf
}
block_sync_lds();
gemm_1(o_acc,
get_slice_tile(
p, sequence<0, i_k1 * kK1>{}, sequence<kM0, (i_k1 + 1) * kK1>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1>{}) + 1) * kN1, kK1>{}));
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
auto v_shuffle_tmp = make_static_distributed_tensor<VDataType>(
Policy::template MakeShuffledVRegBlockDescriptor<Problem>());
shuffle_tile(v_shuffle_tmp, v_buf);
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func,
v_shuffle_tmp)); // store the prefetch
}
else
{
auto v_lds_window_tmp = get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + i_k1 + 1>{}) + 1) * kN1, kK1>{});
store_tile(v_lds_window_tmp,
tile_elementwise_in(v_element_func, v_buf)); // store next v_buf
}
if constexpr(i_k1 < k1_loops - 1)
move_tile_window(v_dram_window, {0, kK1});
});
}
i_total_loops++;
if(i_total_loops < num_total_loop)
{
// move K tile windows
move_tile_window(k_dram_block_window, {kN0, 0});
k_dram_window =
make_tile_window(k_dram_block_window.get_bottom_tensor_view(),
k_dram_block_window.get_window_lengths(),
k_dram_block_window.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
if constexpr(k1_loops >= 2 &&
LdsSeq.at(number<0>{}) == LdsSeq.at(number<k0_loops + k1_loops - 2>{}))
__builtin_amdgcn_s_barrier();
async_load_tile_raw(k_lds_store(LdsSeq.at(number<0>{})), k_dram_window);
move_tile_window(k_dram_window, {0, kK0});
}
// tail
{
block_sync_lds();
gemm_1(
o_acc,
get_slice_tile(p, sequence<0, (k1_loops - 1) * kK1>{}, sequence<kM0, kN0>{}),
get_slice_tile(
v_lds_window,
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{})) * kN1, 0>{},
sequence<(LdsSeq.at(number<k0_loops + k1_loops - 1>{}) + 1) * kN1, kK1>{}));
}
} while(i_total_loops < num_total_loop);
// store lse acc
if constexpr(kStoreLSE)
{
auto lse_acc = make_static_distributed_tensor<LSEDataType>(m.get_tile_distribution());
constexpr auto lse_acc_spans = decltype(lse_acc)::get_distributed_spans();
sweep_tile_span(lse_acc_spans[number<0>{}], [&, m_ = m, l_ = l](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
lse_acc(i_idx) = m_[i_idx] * R_LOG2E + log(l_[i_idx]);
}
else
{
lse_acc(i_idx) = m_[i_idx] * scale_s * R_LOG2E + log(l_[i_idx]);
}
#else
lse_acc(i_idx) = m_[i_idx] + log(l_[i_idx]);
#endif
});
store_tile(lse_acc_dram_window_tmp, tile_elementwise_in(lse_acc_element_func, lse_acc));
}
// finally, O
constexpr auto o_spans = decltype(o_acc)::get_distributed_spans();
sweep_tile_span(o_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
const auto tmp = [&]() {
if constexpr(FmhaMask::IsMasking)
{
return l[i_idx] == 0.f ? 0.f : 1 / l[i_idx];
}
else
return 1 / l[i_idx];
}();
sweep_tile_span(o_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
o_acc(i_j_idx) *= tmp;
});
});
o_acc = tile_elementwise_in(o_acc_element_func, o_acc);
return o_acc;
}
template <typename QDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename LSEaccDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, // M0*K0 tile
const KDramBlockWindowTmp& k_dram_block_window_tmp, // N0*K0 tile
const VDramBlockWindowTmp& v_dram_block_window_tmp, // N1*K1 tile
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, // M0*N0 tile
RandValDramBlockWindowTmp& randval_dram_block_window_tmp, // M0*N0 tile
LSEaccDramBlockWindowTmp& lse_acc_dram_block_window_tmp, // M0*1 tile
index_t num_splits,
index_t i_split,
FmhaMask mask,
PositionEncoding position_encoding,
float scale_s,
void* smem_ptr,
BlockDropout& dropout) const
{
return operator()(q_dram_block_window_tmp,
identity{},
k_dram_block_window_tmp,
identity{},
v_dram_block_window_tmp,
identity{},
bias_dram_block_window_tmp,
identity{},
randval_dram_block_window_tmp,
lse_acc_dram_block_window_tmp,
identity{},
identity{},
identity{},
identity{},
num_splits,
i_split,
mask,
position_encoding,
scale_s,
smem_ptr,
dropout);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_pipeline_qx_ks_vs_custom_policy.hpp"
namespace ck_tile {
// This pipeline is qkv all located in LDS
using BlockFmhaFwdSplitKVPipelineQRKSVSAsyncDefaultPolicy =
BlockFmhaPipelineQXKSVSCustomPolicy</* QLoadOnce = */ true,
/* AsyncCopyK = */ true,
/* AsyncCopyV = */ false,
/* NumPrefetchK = */ 3,
/* NumPrefetchV = */ 3>;
} // namespace ck_tile
...@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem ...@@ -54,38 +54,50 @@ struct BlockFmhaPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename QDataType, template <typename QDataType_,
typename KDataType, typename KDataType_,
typename VDataType, typename VDataType_,
typename SaccDataType, typename SaccDataType_,
typename SMPLComputeDataType, typename SMPLComputeDataType_,
typename BiasDataType, typename BiasDataType_,
typename RandValOutputDataType, typename LSEDataType_,
typename LSEDataType, typename PDataType_,
typename PDataType, typename OaccDataType_,
typename OaccDataType, typename ODataType_,
typename ODataType, typename BlockFmhaShape_,
typename BlockFmhaShape, bool kIsGroupMode_,
bool kIsGroupMode, typename FmhaMask_,
typename FmhaMask, typename Traits_>
typename Traits> struct BlockFmhaFwdSplitKVPipelineProblem
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType,
VDataType,
SaccDataType,
SMPLComputeDataType,
BiasDataType,
RandValOutputDataType,
LSEDataType,
PDataType,
OaccDataType,
ODataType,
BlockFmhaShape,
kIsGroupMode,
FmhaMask,
Traits>
{ {
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using SaccDataType = remove_cvref_t<SaccDataType_>;
using SMPLComputeDataType = remove_cvref_t<SMPLComputeDataType_>;
using BiasDataType = remove_cvref_t<BiasDataType_>;
using LSEDataType = remove_cvref_t<LSEDataType_>;
using PDataType = remove_cvref_t<PDataType_>;
using OaccDataType = remove_cvref_t<OaccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr bool kIsPagedKV = Traits::kIsPagedKV;
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename LSEDataType_, template <typename LSEDataType_,
...@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem ...@@ -119,4 +131,44 @@ struct BlockFmhaSplitKVCombinePipelineProblem
static constexpr index_t kMaxSplits = Traits::kMaxSplits; static constexpr index_t kMaxSplits = Traits::kMaxSplits;
}; };
template <typename QDataType_,
typename KDataType_,
typename VDataType_,
index_t kM0_,
index_t kN0_,
index_t kK0_,
index_t kN1_,
bool kIsVLayoutRowMajor_,
RotaryEmbeddingEnum RotaryEnum_,
bool kIsPagedKV_,
typename Traits_>
struct BlockFmhaFwdAppendKVPipelineProblem
{
using QDataType = remove_cvref_t<QDataType_>;
using KDataType = remove_cvref_t<KDataType_>;
using VDataType = remove_cvref_t<VDataType_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = 256;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kK0 = kK0_;
static constexpr index_t kN1 = kN1_;
using VLayout = std::conditional_t<kIsVLayoutRowMajor_,
ck_tile::tensor_layout::gemm::RowMajor,
ck_tile::tensor_layout::gemm::ColumnMajor>;
static constexpr auto RotaryEnum = RotaryEnum_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Traits::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -707,16 +707,19 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
{ {
if constexpr(AsyncCopyK) if constexpr(AsyncCopyK)
{ {
return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(); return GetSmemSizeKV<Problem>() + GetSmemSizeDropout<Problem>(0);
} }
else else
{ {
return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>()); return ck_tile::max(GetSmemSizeKV<Problem>(), GetSmemSizeDropout<Problem>(0));
} }
} }
// this method is only available when Problem::kHasDropout is present
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout() CK_TILE_HOST_DEVICE static constexpr std::
enable_if_t<std::is_convertible_v<decltype(Problem::kHasDropout), bool>, ck_tile::index_t>
GetSmemSizeDropout(int)
{ {
if constexpr(Problem::kHasDropout) if constexpr(Problem::kHasDropout)
{ {
...@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -736,6 +739,13 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
} }
} }
// fallback version if Problem::kHasDropout is not exist
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeDropout(...)
{
return 0;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_rotary_embedding.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -32,30 +33,31 @@ struct TileFmhaTraits ...@@ -32,30 +33,31 @@ struct TileFmhaTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK /* padding for seqlen_k */, bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ /* paddding for hdim_q */, bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
BlockAttentionBiasEnum BiasEnum, BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad, bool kHasBiasGrad_,
bool kStoreLSE, bool kStoreLSE_,
bool kHasDropout, bool kDoFp8StaticQuant_,
bool kDoFp8StaticQuant, bool kIsPagedKV_,
bool kHasUnevenSplits_ = true, bool kHasUnevenSplits_,
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ, struct TileFmhaFwdSplitKVTraits
kPadSeqLenK,
kPadHeadDimQ,
kPadHeadDimV,
BiasEnum,
kHasBiasGrad,
kStoreLSE,
kHasDropout,
kDoFp8StaticQuant,
kBlockPerCu>
{ {
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr bool kIsPagedKV = kIsPagedKV_;
// determine if some split (length) is not divisible by tile size // determine if some split (length) is not divisible by tile size
static constexpr bool kHasUnevenSplits = kHasUnevenSplits_; static constexpr bool kHasUnevenSplits = kHasUnevenSplits_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
...@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits ...@@ -76,6 +78,20 @@ struct TileFmhaFwdSplitKVCombineTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadSeqLenK_ /* padding for seqlen_k */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaFwdAppendKVTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadSeqLenK = kPadSeqLenK_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr bool kPadHeadDimV = kPadHeadDimV_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
template <bool kPadSeqLenQ_ /* padding for seqlen_q */, template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimV_ /* paddding for hdim_v */, bool kPadHeadDimV_ /* paddding for hdim_v */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */> index_t kBlockPerCu_ = 2 /* hint to occupancy */>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment