Commit bd689f40 authored by illsilin's avatar illsilin
Browse files

merge from public repo

parents c160c6cf a94113a9
...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16))); ...@@ -117,6 +117,15 @@ using int32x16_t = int32_t __attribute__((ext_vector_type(16)));
using int32x32_t = int32_t __attribute__((ext_vector_type(32))); using int32x32_t = int32_t __attribute__((ext_vector_type(32)));
using int32x64_t = int32_t __attribute__((ext_vector_type(64))); using int32x64_t = int32_t __attribute__((ext_vector_type(64)));
// u32
// using uint32_t = ...
using uint32x2_t = uint32_t __attribute__((ext_vector_type(2)));
using uint32x4_t = uint32_t __attribute__((ext_vector_type(4)));
using uint32x8_t = uint32_t __attribute__((ext_vector_type(8)));
using uint32x16_t = uint32_t __attribute__((ext_vector_type(16)));
using uint32x32_t = uint32_t __attribute__((ext_vector_type(32)));
using uint32x64_t = uint32_t __attribute__((ext_vector_type(64)));
// i16 // i16
// using int16_t = ... // using int16_t = ...
using int16x2_t = int16_t __attribute__((ext_vector_type(2))); using int16x2_t = int16_t __attribute__((ext_vector_type(2)));
......
...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x( ...@@ -746,8 +746,9 @@ CK_TILE_HOST_DEVICE constexpr auto slice_distribution_from_x(
return make_tuple( return make_tuple(
make_static_tile_distribution( make_static_tile_distribution(
tile_distribution_encoding<typename Encoding::RsLengths, tile_distribution_encoding<typename Encoding::RsLengths,
decltype(sliced_h_lengths), // only need to change the remove_cvref_t<decltype(sliced_h_lengths)>, // only need to
// h_lengths type // change the
// h_lengths type
typename Encoding::Ps2RHssMajor, typename Encoding::Ps2RHssMajor,
typename Encoding::Ps2RHssMinor, typename Encoding::Ps2RHssMinor,
typename Encoding::Ys2RHsMajor, typename Encoding::Ys2RHsMajor,
......
...@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution ...@@ -393,7 +393,10 @@ struct tile_window_with_static_distribution
bottom_tensor_thread_coord, bottom_tensor_thread_coord,
bool_constant<oob_conditional_check>{}, bool_constant<oob_conditional_check>{},
pre_nop_); pre_nop_);
#if CK_TILE_WORKAROUND_ROCM_6_1_SCRATCH_MEMORY_ISSUE
asm volatile(
""); // this is starting from rocm-6.2, but same sympton, reuse this flag
#endif
// move thread coordinate // move thread coordinate
if constexpr(iCoordAccess != (NumAccessPerCoord - 1)) if constexpr(iCoordAccess != (NumAccessPerCoord - 1))
{ {
......
...@@ -53,6 +53,39 @@ class philox ...@@ -53,6 +53,39 @@ class philox
out_tmp[3] = tmp_ph.w; out_tmp[3] = tmp_ph.w;
} }
CK_TILE_HOST_DEVICE void get_random_8x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
out_tmp[1] = tmp[start_idx + 2];
}
CK_TILE_HOST_DEVICE void get_random_4x8(uint8_t* out,
const unsigned long long subsequence,
const index_t start_idx) const
{
uint4 tmp_ph;
tmp_ph = get_philox_4x32(subsequence);
uint32x4_t tmp;
tmp[0] = tmp_ph.x;
tmp[1] = tmp_ph.y;
tmp[2] = tmp_ph.z;
tmp[3] = tmp_ph.w;
uint32_t* out_tmp = reinterpret_cast<uint32_t*>(&out[0]);
out_tmp[0] = tmp[start_idx];
}
private: private:
struct ull2 struct ull2
{ {
......
...@@ -8,21 +8,16 @@ ...@@ -8,21 +8,16 @@
#include "ck_tile/ops/fmha/block/block_masking.hpp" #include "ck_tile/ops/fmha/block/block_masking.hpp"
#include "ck_tile/ops/fmha/block/block_position_encoding.hpp" #include "ck_tile/ops/fmha/block/block_position_encoding.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_bwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_bwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_combine_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_kernel.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_splitkv_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp" #include "ck_tile/ops/fmha/kernel/fmha_fwd_tile_partitioner.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_convert_dq.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_kr_ktr_vr_iglp.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_enum.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_problem.hpp"
......
...@@ -286,11 +286,226 @@ struct BlockDropout ...@@ -286,11 +286,226 @@ struct BlockDropout
}); });
} }
ck_tile::philox ph;
const float rp_undrop;
const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
};
template <bool IsDropout_, bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd;
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<false, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = false;
static constexpr bool IsStoreRandval = IsStoreRandval_;
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
__host__ __device__ static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
(void)randval_dram_block_window_tmp;
(void)seqlen_qk_start;
return make_null_tile_window(make_tuple(number<0>{}, number<0>{}));
}
};
template <bool IsWG32_, bool IsStoreRandval_>
struct BlockDropoutBwd<true, IsWG32_, IsStoreRandval_>
{
static constexpr bool IsDropout = true;
// true: 32*32 warp gemm
// false: 16*16 warp gemm
static constexpr bool IsWG32 = IsWG32_;
static constexpr bool IsStoreRandval = IsStoreRandval_;
CK_TILE_HOST_DEVICE BlockDropoutBwd(index_t i_batch,
index_t i_head,
index_t nheads,
unsigned long long seed,
unsigned long long offset,
float rp_undrop_,
uint8_t p_undrop_in_uint8_t_)
: ph(seed,
offset + (i_batch * nheads + i_head) * get_warp_size() +
(IsWG32 ? get_lane_id() : ((get_lane_id() & 47) + ((get_warp_id() & 1) << 4)))),
rp_undrop(rp_undrop_),
p_undrop_in_uint8_t(p_undrop_in_uint8_t_)
{
}
template <typename BlockGemm, bool IsFwd = true, typename RandValDramBlockWindowTmp>
CK_TILE_HOST_DEVICE static constexpr auto
MakeRandvalDramWindow(RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
index_t seqlen_qk_start)
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
const auto block_origin = randval_dram_block_window_tmp.get_window_origin();
auto randval_dram_window = [&]() {
if constexpr(IsFwd)
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{block_origin.at(number<0>{}), seqlen_qk_start}); // M/N
}
else
{
return make_tile_window(
randval_dram_block_window_tmp.get_bottom_tensor_view(),
ck_tile::make_tuple(number<kMPerStep>{}, number<kNPerStep>{}),
{seqlen_qk_start, block_origin.at(number<1>{})}); // M/N
}
}();
return randval_dram_window;
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsBlockDescriptor()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = WG::kN;
constexpr index_t kN1 = 8;
constexpr index_t kN0 = kNPerStep / kN1;
constexpr auto randval_lds_block_desc_0 = make_naive_tensor_descriptor(
ck_tile::make_tuple(number<kN0>{}, number<kMPerStep>{}, number<kN1>{}),
ck_tile::make_tuple(number<(kMPerStep + 1) * kN1>{}, number<kN1>{}, number<1>{}),
number<kN1>{},
number<1>{});
constexpr auto randval_lds_block_desc = transform_tensor_descriptor(
randval_lds_block_desc_0,
ck_tile::make_tuple(
make_pass_through_transform(number<kMPerStep>{}),
make_merge_transform(ck_tile::make_tuple(number<kN0>{}, number<kN1>{}))),
ck_tile::make_tuple(sequence<1>{}, sequence<0, 2>{}),
ck_tile::make_tuple(sequence<0>{}, sequence<1>{}));
return randval_lds_block_desc;
}
template <typename BlockGemm, bool IsFwd = true>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr bool MBwdWG16MultiIterCheck = (!IsFwd) && (!IsWG32) && (kMPerBlock > 16);
constexpr index_t MIterPerWarp = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return 2;
}
else
{
return 1;
}
}();
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
// Use Bwd WarpGemm to ensure that Fwd's random values ​​are consistent with Bwd.
// except headdim256.
constexpr auto randval_block_inner_part_dstr_encoding = []() {
if constexpr(std::is_same_v<typename BlockGemm::ADataType, half_t> &&
std::is_same_v<typename BlockGemm::BDataType, half_t> &&
std::is_same_v<typename BlockGemm::CDataType, float>)
{
if constexpr(IsWG32)
return typename WarpGemmMfmaF16F16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaF16F16F32M16N16K16::CWarpDstrEncoding{};
}
else
{
if constexpr(IsWG32)
return typename WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA::CWarpDstrEncoding{};
else
return typename WarpGemmMfmaBf16Bf16F32M16N16K16::CWarpDstrEncoding{};
}
}();
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
randval_block_inner_part_dstr_encoding);
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeRandValLdsShuffleTileDistribution()
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = 1;
constexpr index_t NIterPerWarp = 1;
constexpr auto randval_block_outer_part_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto randval_block_part_dstr_encode =
detail::make_embed_tile_distribution_encoding(randval_block_outer_part_dstr_encoding,
typename WG::CWarpDstrEncoding{});
return make_static_tile_distribution(randval_block_part_dstr_encode);
}
template <typename BlockGemm, template <typename BlockGemm,
typename PComputeDataType,
typename RandValOutputDataType, typename RandValOutputDataType,
typename PComputeWindow, typename PComputeWindow,
typename RandValDramWindow> typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx, CK_TILE_HOST_DEVICE void Run(void* randval_ptr,
const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute, PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const RandValDramWindow& randval_dram_window) const
{ {
...@@ -305,30 +520,177 @@ struct BlockDropout ...@@ -305,30 +520,177 @@ struct BlockDropout
constexpr index_t kMPerStep = MWarp * WG::kM; constexpr index_t kMPerStep = MWarp * WG::kM;
constexpr index_t kNPerStep = NWarp * WG::kN; constexpr index_t kNPerStep = NWarp * WG::kN;
// randval tile in LDS
auto randval_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<uint8_t*>(randval_ptr), MakeRandValLdsBlockDescriptor<BlockGemm>());
auto randval_lds_window = make_tile_window(
randval_lds, MakeRandValLdsBlockDescriptor<BlockGemm>().get_lengths(), {0, 0});
// register distribute // register distribute
auto randval = auto randval_dist_generated =
make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>()); make_static_distributed_tensor<uint8_t>(MakeRandValTileDistribution<BlockGemm>());
static_assert(randval.kThreadElementSpaceSize == 16); static_assert(randval_dist_generated.kThreadElementSpaceSize == 16);
const int start_n0_idx = randval_dram_window.get_window_origin().at(number<1>{}); auto randval_lds_read_window =
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) { make_tile_window(randval_lds_window.get_bottom_tensor_view(),
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) { randval_lds_window.get_window_lengths(),
int block_row_start = (start_m0_idx / WG::kM) + i_m0; randval_lds_window.get_window_origin(),
int block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id(); MakeRandValLdsShuffleTileDistribution<BlockGemm>());
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
int block_row_start = (start_m0_idx / WG::kM) + (i_m0 * MWarp) + get_warp_id();
int block_col_start = (start_n0_idx / WG::kN) + i_n0;
uint2 rowcol = make_uint2(block_row_start, block_col_start); uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number // generate random number
uint8_t random_uint8_t[16]; uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol)); ph.get_random_16x8(random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol));
constexpr auto randval_dist_generated_spans =
decltype(randval_dist_generated)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_dist_generated_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_dist_generated_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = ck_tile::make_tuple(idx0, idx1);
randval_dist_generated(i_j_idx) = random_uint8_t[i_random_idx++];
});
});
// save to LDS
store_tile(randval_lds_window, randval_dist_generated);
block_sync_lds();
// read from LDS to register
auto randval = load_tile(randval_lds_read_window);
constexpr auto randval_spans = decltype(randval)::get_distributed_spans(); constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto p_idx0 = tile_distributed_index<i_m0>{};
constexpr auto p_idx1 =
tile_distributed_index<i_n0, idx1.impl_.at(1), idx1.impl_.at(2)>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1); constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t[i_random_idx++]; p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
constexpr auto p_idx0 = ? p_compute[p_idx] * rp_undrop
tile_distributed_index<i_m0, idx0.impl_.at(1), idx0.impl_.at(2)>{}; : PComputeDataType(0);
});
});
// save to Global
if constexpr(IsStoreRandval)
{
const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {0, kNPerStep});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {kMPerStep, -kNPerBlock});
}
});
if constexpr(IsStoreRandval)
{
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerBlock});
}
}
template <typename BlockGemm,
typename RandValOutputDataType,
typename PComputeWindow,
typename RandValDramWindow>
CK_TILE_HOST_DEVICE void Run(const index_t start_m0_idx,
const index_t start_n0_idx,
PComputeWindow& p_compute,
RandValDramWindow& randval_dram_window) const
{
constexpr auto config =
BlockGemm::Policy::template GetWarpGemmMWarpNWarp<typename BlockGemm::Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
using BlockGemmShape = remove_cvref_t<typename BlockGemm::BlockGemmShape>;
constexpr index_t kMPerBlock = BlockGemmShape::kM;
constexpr index_t kNPerBlock = BlockGemmShape::kN;
constexpr bool MBwdWG16MultiIterCheck = (!IsWG32) && (kMPerBlock > 16);
constexpr bool MBwdWG16SingleIterCheck = (!IsWG32) && (kMPerBlock == 16);
constexpr index_t kMPerStep = [&]() {
if constexpr(MBwdWG16MultiIterCheck)
{
return MWarp * WG::kM * 2;
}
else
{
return MWarp * WG::kM;
}
}();
constexpr index_t kNPerStep = NWarp * WG::kN;
// register distribute
auto randval = make_static_distributed_tensor<uint8_t>(
MakeRandValTileDistribution<BlockGemm, false>());
if constexpr(IsWG32)
static_assert(randval.kThreadElementSpaceSize == 16);
else
static_assert(randval.kThreadElementSpaceSize == 4 ||
randval.kThreadElementSpaceSize == 8);
static_for<0, kNPerBlock / kNPerStep, 1>{}([&](auto i_n0) {
static_for<0, kMPerBlock / kMPerStep, 1>{}([&](auto i_m0) {
int block_row_start, block_col_start;
if constexpr(IsWG32)
{
block_row_start = (start_m0_idx / WG::kM) + i_m0;
block_col_start = (start_n0_idx / WG::kN) + (i_n0 * NWarp) + get_warp_id();
}
else
{
block_row_start = start_m0_idx / 32 + i_m0;
block_col_start = (start_n0_idx / 32) + get_warp_id() / 2 + i_n0 * 2;
}
uint2 rowcol = make_uint2(block_row_start, block_col_start);
// generate random number
uint8_t* random_uint8_t_;
if constexpr(MBwdWG16SingleIterCheck)
{
uint8_t random_uint8_t[4];
// m0t0 ~m0t15/m0t32~m0t47: 0
// m0t16~m0t31/m0t48~m0t63: 1
// m1t0 ~m1t15/m1t32~m1t47: 2
// m1t16~m1t31/m1t48~m1t63: 3
const index_t start_idx =
((get_lane_id() >> 4) & 1) + (((start_m0_idx >> 4) & 1) << 1);
ph.get_random_4x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else if constexpr(MBwdWG16MultiIterCheck)
{
uint8_t random_uint8_t[8];
// t0 ~t15/t32~t47: 0
// t16~t31/t48~t63: 1
const index_t start_idx = (get_lane_id() >> 4) & 1;
ph.get_random_8x8(
random_uint8_t, reinterpret_cast<unsigned long long&>(rowcol), start_idx);
random_uint8_t_ = random_uint8_t;
}
else
{
uint8_t random_uint8_t[16];
ph.get_random_16x8(random_uint8_t,
reinterpret_cast<unsigned long long&>(rowcol));
random_uint8_t_ = random_uint8_t;
}
constexpr auto randval_spans = decltype(randval)::get_distributed_spans();
int i_random_idx = 0;
sweep_tile_span(randval_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(randval_spans[number<1>{}], [&](auto idx1) {
constexpr auto r_idx = ck_tile::make_tuple(idx0, idx1);
randval(r_idx) = random_uint8_t_[i_random_idx++];
constexpr auto p_idx0 = tile_distributed_index<i_m0 + idx0.impl_.at(0),
idx0.impl_.at(1),
idx0.impl_.at(2)>{};
constexpr auto p_idx1 = tile_distributed_index<i_n0>{}; constexpr auto p_idx1 = tile_distributed_index<i_n0>{};
constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1); constexpr auto p_idx = ck_tile::make_tuple(p_idx0, p_idx1);
p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t p_compute(p_idx) = randval[r_idx] <= p_undrop_in_uint8_t
...@@ -337,19 +699,19 @@ struct BlockDropout ...@@ -337,19 +699,19 @@ struct BlockDropout
}); });
}); });
// save to Global // save to Global
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
const auto randval_store = cast_tile<RandValOutputDataType>(randval); const auto randval_store = cast_tile<RandValOutputDataType>(randval);
store_tile(randval_dram_window, randval_store); store_tile(randval_dram_window, randval_store);
move_tile_window(randval_dram_window, {kMPerStep, 0}); move_tile_window(randval_dram_window, {kMPerStep, 0});
} }
}); });
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep}); move_tile_window(randval_dram_window, {-kMPerBlock, kNPerStep});
} }
}); });
if(is_store_randval) if constexpr(IsStoreRandval)
{ {
move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock}); move_tile_window(randval_dram_window, {kMPerBlock, -kNPerBlock});
} }
...@@ -358,7 +720,6 @@ struct BlockDropout ...@@ -358,7 +720,6 @@ struct BlockDropout
ck_tile::philox ph; ck_tile::philox ph;
const float rp_undrop; const float rp_undrop;
const uint8_t p_undrop_in_uint8_t; const uint8_t p_undrop_in_uint8_t;
const bool is_store_randval;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -23,13 +23,9 @@ ...@@ -23,13 +23,9 @@
namespace ck_tile { namespace ck_tile {
template <typename TilePartitioner_, template <typename FmhaPipeline_, typename KGradEpiloguePipeline_, typename VGradEpiloguePipeline_>
typename FmhaPipeline_,
typename KGradEpiloguePipeline_,
typename VGradEpiloguePipeline_>
struct FmhaBwdDQDKDVKernel struct FmhaBwdDQDKDVKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>; using FmhaPipeline = ck_tile::remove_cvref_t<FmhaPipeline_>;
using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>; using KGradEpiloguePipeline = ck_tile::remove_cvref_t<KGradEpiloguePipeline_>;
using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>; using VGradEpiloguePipeline = ck_tile::remove_cvref_t<VGradEpiloguePipeline_>;
...@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -59,9 +55,12 @@ struct FmhaBwdDQDKDVKernel
static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV; static constexpr bool kPadHeadDimV = FmhaPipeline::kPadHeadDimV;
static constexpr auto BiasEnum = FmhaPipeline::BiasEnum; static constexpr auto BiasEnum = FmhaPipeline::BiasEnum;
static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad; static constexpr bool kHasBiasGrad = FmhaPipeline::kHasBiasGrad;
static constexpr bool kHasDropout = FmhaPipeline::kHasDropout;
using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>; using FmhaMask = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaMask>;
static constexpr bool kHasMask = FmhaMask::IsMasking; using FmhaDropout = ck_tile::remove_cvref_t<typename FmhaPipeline::FmhaDropout>;
static constexpr bool kHasMask = FmhaMask::IsMasking;
static constexpr bool kHasDropout = FmhaDropout::IsDropout;
static constexpr bool kIsStoreRandval = FmhaDropout::IsStoreRandval;
static constexpr bool kIsDeterministic = FmhaPipeline::kIsDeterministic;
// clang-format off // clang-format off
template <typename T> struct t2s; template <typename T> struct t2s;
...@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -73,9 +72,12 @@ struct FmhaBwdDQDKDVKernel
{ {
// sync with generate.py // sync with generate.py
// clang-format off // clang-format off
using bfs = typename FmhaPipeline::BlockFmhaShape; using bfs = typename FmhaPipeline::BlockFmhaShape;
using gbr = typename bfs::Gemm0BlockWarps; using gbr0 = typename bfs::Gemm0BlockWarps;
using gwt = typename bfs::Gemm0WarpTile; using gbr1 = typename bfs::Gemm1BlockWarps;
using gbr4 = typename bfs::Gemm4BlockWarps;
using gwt0 = typename bfs::Gemm0WarpTile;
using gwt1 = typename bfs::Gemm1WarpTile;
#define _SS_ std::string #define _SS_ std::string
#define _TS_ std::to_string #define _TS_ std::to_string
auto pn = [&] () { auto pn = [&] () {
...@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel ...@@ -88,13 +90,17 @@ struct FmhaBwdDQDKDVKernel
return return
_SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) + _SS_("fmha_bwd_d") + _TS_(bfs::kQKHeaddim) + "_" + _SS_(t2s<QDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + "_" + "_" + (kIsGroupMode ? "group" : "batch") + "_" +
"b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + "b" + _TS_(bfs::kM0) + "x" + _TS_(bfs::kN0) + "x" + _TS_(bfs::kK0) + "x" + _TS_(bfs::kK1) + "x" + _TS_(bfs::kK2) + "x" + _TS_(bfs::kK3) + "x" +
_TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" + _TS_(bfs::kK4) + "x" + _TS_(bfs::kQKHeaddim) + "x" + _TS_(bfs::kVHeaddim) + "_" +
"r" + _TS_(gbr::at(ck_tile::number<0>{})) + "x" + _TS_(gbr::at(ck_tile::number<1>{})) + "x" + _TS_(gbr::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr0::at(ck_tile::number<0>{})) + "x" + _TS_(gbr0::at(ck_tile::number<1>{})) + "x" + _TS_(gbr0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "r" + _TS_(gbr1::at(ck_tile::number<0>{})) + "x" + _TS_(gbr1::at(ck_tile::number<1>{})) + "x" + _TS_(gbr1::at(ck_tile::number<2>{})) + "_" +
"r" + _TS_(gbr4::at(ck_tile::number<0>{})) + "x" + _TS_(gbr4::at(ck_tile::number<1>{})) + "x" + _TS_(gbr4::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt0::at(ck_tile::number<0>{})) + "x" + _TS_(gwt0::at(ck_tile::number<1>{})) + "x" + _TS_(gwt0::at(ck_tile::number<2>{})) + "_" +
"w" + _TS_(gwt1::at(ck_tile::number<0>{})) + "x" + _TS_(gwt1::at(ck_tile::number<1>{})) + "x" + _TS_(gwt1::at(ck_tile::number<2>{})) + "_" +
("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) + ("o" + _TS_(kBlockPerCu) + "_") + _SS_(FmhaPipeline::name) + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ); (kHasBiasGrad ? "_dbias" : "") + (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) +
(kIsStoreRandval ? "_storerandval" : "" ) + (kIsDeterministic ? "_deterministic" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
// clang-format on // clang-format on
...@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -117,7 +123,7 @@ struct FmhaBwdDQDKDVKernel
const void* lse_ptr; const void* lse_ptr;
const void* do_ptr; const void* do_ptr;
const void* d_ptr; const void* d_ptr;
void* dq_ptr; void* dq_acc_ptr;
void* dk_ptr; void* dk_ptr;
void* dv_ptr; void* dv_ptr;
...@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel ...@@ -131,14 +137,13 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t num_head_q; ck_tile::index_t num_head_q;
ck_tile::index_t nhead_ratio_qk; ck_tile::index_t nhead_ratio_qk;
float raw_scale; float raw_scale;
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale; float scale;
#endif
ck_tile::index_t stride_q; ck_tile::index_t stride_q;
ck_tile::index_t stride_k; ck_tile::index_t stride_k;
ck_tile::index_t stride_v; ck_tile::index_t stride_v;
ck_tile::index_t stride_do; ck_tile::index_t stride_do;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t stride_dk; ck_tile::index_t stride_dk;
ck_tile::index_t stride_dv; ck_tile::index_t stride_dv;
...@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -147,8 +152,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_v; ck_tile::index_t nhead_stride_v;
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_lsed; ck_tile::index_t nhead_stride_lsed;
ck_tile::index_t nhead_stride_dq_acc;
ck_tile::index_t batch_stride_lsed; ck_tile::index_t nhead_stride_dk;
ck_tile::index_t nhead_stride_dv;
}; };
struct FmhaBwdCommonBiasKargs struct FmhaBwdCommonBiasKargs
...@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel ...@@ -206,7 +212,6 @@ struct FmhaBwdDQDKDVKernel
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max(); uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
bool is_store_randval = false;
uint64_t drop_seed = 1; uint64_t drop_seed = 1;
uint64_t drop_offset = 0; uint64_t drop_offset = 0;
void* rand_val_ptr = nullptr; void* rand_val_ptr = nullptr;
...@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -218,6 +223,10 @@ struct FmhaBwdDQDKDVKernel
{ {
ck_tile::index_t batch_stride_randval = 0; ck_tile::index_t batch_stride_randval = 0;
}; };
struct FmhaBwdDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdBatchModeKargs struct FmhaBwdBatchModeKargs
: FmhaBwdCommonKargs, : FmhaBwdCommonKargs,
...@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel ...@@ -228,12 +237,15 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs<0>>>, FmhaBwdEmptyKargs<0>>>,
std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>, std::conditional_t<kHasBiasGrad, FmhaBwdBatchModeBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>, std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>> std::conditional_t<kHasDropout, FmhaBwdBatchModeDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
{ {
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_lsed;
ck_tile::index_t batch_stride_dq_acc;
ck_tile::index_t batch_stride_dk; ck_tile::index_t batch_stride_dk;
ck_tile::index_t batch_stride_dv; ck_tile::index_t batch_stride_dv;
}; };
...@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel ...@@ -247,7 +259,8 @@ struct FmhaBwdDQDKDVKernel
FmhaBwdEmptyKargs<0>>>, FmhaBwdEmptyKargs<0>>>,
std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>, std::conditional_t<kHasBiasGrad, FmhaBwdCommonBiasGradKargs, FmhaBwdEmptyKargs<1>>,
std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>, std::conditional_t<kHasMask, FmhaBwdMaskKargs, FmhaBwdEmptyKargs<2>>,
std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>> std::conditional_t<kHasDropout, FmhaBwdCommonDropoutKargs, FmhaBwdEmptyKargs<3>>,
std::conditional_t<kIsDeterministic, FmhaBwdDeterministicKargs, FmhaBwdEmptyKargs<4>>
{ {
const int32_t* seqstart_q_ptr; const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr; const int32_t* seqstart_k_ptr;
...@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -266,10 +279,10 @@ struct FmhaBwdDQDKDVKernel
const void* do_ptr, const void* do_ptr,
const void* d_ptr, const void* d_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* dq_ptr,
void* dk_ptr, void* dk_ptr,
void* dv_ptr, void* dv_ptr,
void* dbias_ptr, void* dbias_ptr,
void* dq_acc_ptr,
ck_tile::index_t seqlen_q, ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k, ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q, ck_tile::index_t hdim_q,
...@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -283,6 +296,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_do, ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk, ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv, ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias, ck_tile::index_t stride_dbias,
...@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -293,6 +307,9 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_q, ck_tile::index_t batch_stride_q,
ck_tile::index_t batch_stride_k, ck_tile::index_t batch_stride_k,
...@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel ...@@ -301,14 +318,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t batch_stride_randval, ck_tile::index_t batch_stride_randval,
ck_tile::index_t batch_stride_do, ck_tile::index_t batch_stride_do,
ck_tile::index_t batch_stride_lsed, ck_tile::index_t batch_stride_lsed,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t batch_stride_dk, ck_tile::index_t batch_stride_dk,
ck_tile::index_t batch_stride_dv, ck_tile::index_t batch_stride_dv,
ck_tile::index_t batch_stride_dbias, ck_tile::index_t batch_stride_dbias,
ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -317,7 +335,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr, lse_ptr,
do_ptr, do_ptr,
d_ptr, d_ptr,
dq_ptr, dq_acc_ptr,
dk_ptr, dk_ptr,
dv_ptr, dv_ptr,
seqlen_q, seqlen_q,
...@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -327,13 +345,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q, num_head_q,
nhead_ratio_qk, nhead_ratio_qk,
scale, scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>), static_cast<float>(scale * ck_tile::log2e_v<>),
#endif
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
stride_do, stride_do,
stride_dq_acc,
stride_dk, stride_dk,
stride_dv, stride_dv,
nhead_stride_q, nhead_stride_q,
...@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel ...@@ -341,15 +358,20 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
batch_stride_lsed}, // args for common karg nhead_stride_dq_acc,
{}, // placeholder for bias nhead_stride_dk,
{}, // placeholder for dbias nhead_stride_dv}, // args for common karg
{}, // placeholder for mask {}, // placeholder for bias
{}, // placeholder for dropout {}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v, batch_stride_v,
batch_stride_do, batch_stride_do,
batch_stride_lsed,
batch_stride_dq_acc,
batch_stride_dk, batch_stride_dk,
batch_stride_dv}; batch_stride_dv};
...@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel ...@@ -384,11 +406,18 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); kargs.init_dropout(p_drop, drop_seed_offset, scale);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.batch_stride_randval = batch_stride_randval; kargs.stride_randval = stride_randval;
kargs.is_store_randval = s_randval; kargs.nhead_stride_randval = nhead_stride_randval;
kargs.batch_stride_randval = batch_stride_randval;
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
} }
return kargs; return kargs;
...@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -404,10 +433,10 @@ struct FmhaBwdDQDKDVKernel
const void* do_ptr, const void* do_ptr,
const void* d_ptr, const void* d_ptr,
void* rand_val_ptr, void* rand_val_ptr,
void* dq_ptr,
void* dk_ptr, void* dk_ptr,
void* dv_ptr, void* dv_ptr,
void* dbias_ptr, void* dbias_ptr,
void* dq_acc_ptr,
const void* seqstart_q_ptr, const void* seqstart_q_ptr,
const void* seqstart_k_ptr, const void* seqstart_k_ptr,
const void* seqlen_k_ptr, const void* seqlen_k_ptr,
...@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -422,6 +451,7 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t stride_bias, ck_tile::index_t stride_bias,
ck_tile::index_t stride_randval, ck_tile::index_t stride_randval,
ck_tile::index_t stride_do, ck_tile::index_t stride_do,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t stride_dk, ck_tile::index_t stride_dk,
ck_tile::index_t stride_dv, ck_tile::index_t stride_dv,
ck_tile::index_t stride_dbias, ck_tile::index_t stride_dbias,
...@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel ...@@ -432,13 +462,15 @@ struct FmhaBwdDQDKDVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_lsed, ck_tile::index_t nhead_stride_lsed,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t nhead_stride_dk,
ck_tile::index_t nhead_stride_dv,
ck_tile::index_t nhead_stride_dbias, ck_tile::index_t nhead_stride_dbias,
ck_tile::index_t batch_stride_lsed, ck_tile::index_t split_stride_dq_acc,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
float p_drop, float p_drop,
bool s_randval,
const std::tuple<uint64_t, uint64_t>& drop_seed_offset) const std::tuple<uint64_t, uint64_t>& drop_seed_offset)
{ {
Kargs kargs{{q_ptr, Kargs kargs{{q_ptr,
...@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -447,7 +479,7 @@ struct FmhaBwdDQDKDVKernel
lse_ptr, lse_ptr,
do_ptr, do_ptr,
d_ptr, d_ptr,
dq_ptr, dq_acc_ptr,
dk_ptr, dk_ptr,
dv_ptr, dv_ptr,
-1, // seqlen will be updated by another pointer -1, // seqlen will be updated by another pointer
...@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel ...@@ -457,13 +489,12 @@ struct FmhaBwdDQDKDVKernel
num_head_q, num_head_q,
nhead_ratio_qk, nhead_ratio_qk,
scale, scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
static_cast<float>(scale * ck_tile::log2e_v<>), static_cast<float>(scale * ck_tile::log2e_v<>),
#endif
stride_q, stride_q,
stride_k, stride_k,
stride_v, stride_v,
stride_do, stride_do,
stride_dq_acc,
stride_dk, stride_dk,
stride_dv, stride_dv,
nhead_stride_q, nhead_stride_q,
...@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel ...@@ -471,11 +502,14 @@ struct FmhaBwdDQDKDVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_do, nhead_stride_do,
nhead_stride_lsed, nhead_stride_lsed,
batch_stride_lsed}, // args for common karg nhead_stride_dq_acc,
{}, // placeholder for bias nhead_stride_dk,
{}, // placeholder for dbias nhead_stride_dv}, // args for common karg
{}, // placeholder for mask {}, // placeholder for bias
{}, // placeholder for dropout {}, // placeholder for dbias
{}, // placeholder for mask
{}, // placeholder for dropout
{}, // placeholder for deterministic
reinterpret_cast<const int32_t*>(seqstart_q_ptr), reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr), reinterpret_cast<const int32_t*>(seqstart_k_ptr),
reinterpret_cast<const int32_t*>(seqlen_k_ptr)}; reinterpret_cast<const int32_t*>(seqlen_k_ptr)};
...@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel ...@@ -506,10 +540,16 @@ struct FmhaBwdDQDKDVKernel
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
kargs.init_dropout(p_drop, drop_seed_offset, scale); kargs.init_dropout(p_drop, drop_seed_offset, scale);
kargs.rand_val_ptr = rand_val_ptr; if constexpr(kIsStoreRandval)
kargs.stride_randval = stride_randval; {
kargs.nhead_stride_randval = nhead_stride_randval; kargs.rand_val_ptr = rand_val_ptr;
kargs.is_store_randval = s_randval; kargs.stride_randval = stride_randval;
kargs.nhead_stride_randval = nhead_stride_randval;
}
}
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
} }
return kargs; return kargs;
...@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel ...@@ -518,7 +558,17 @@ struct FmhaBwdDQDKDVKernel
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_k_); return dim3(
ck_tile::integer_divide_ceil(seqlen_k_, FmhaPipeline::kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -536,7 +586,7 @@ struct FmhaBwdDQDKDVKernel
__shared__ char smem_ptr[GetSmemSize()]; __shared__ char smem_ptr[GetSmemSize()];
// divide problem // divide problem
const auto [i_tile_n, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_k); const auto [i_tile_n, i_nhead, i_batch] = GetTileIndex();
const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0); const index_t i_n0 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN0);
...@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -547,6 +597,7 @@ struct FmhaBwdDQDKDVKernel
long_index_t batch_offset_randval = 0; long_index_t batch_offset_randval = 0;
long_index_t batch_offset_do = 0; long_index_t batch_offset_do = 0;
long_index_t batch_offset_lsed = 0; long_index_t batch_offset_lsed = 0;
long_index_t batch_offset_dq_acc = 0;
long_index_t batch_offset_dk = 0; long_index_t batch_offset_dk = 0;
long_index_t batch_offset_dv = 0; long_index_t batch_offset_dv = 0;
long_index_t batch_offset_dbias = 0; long_index_t batch_offset_dbias = 0;
...@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel ...@@ -557,13 +608,14 @@ struct FmhaBwdDQDKDVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
batch_offset_do = query_start * kargs.stride_do; batch_offset_do = query_start * kargs.stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed; batch_offset_lsed = query_start;
batch_offset_dk = key_start * kargs.stride_dk; batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
batch_offset_dv = key_start * kargs.stride_dv; batch_offset_dk = key_start * kargs.stride_dk;
batch_offset_dv = key_start * kargs.stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = query_start * kargs.stride_bias; batch_offset_bias = query_start * kargs.stride_bias;
...@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -576,7 +628,7 @@ struct FmhaBwdDQDKDVKernel
{ {
batch_offset_dbias = key_start; batch_offset_dbias = key_start;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = query_start * kargs.stride_randval; batch_offset_randval = query_start * kargs.stride_randval;
} }
...@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel ...@@ -603,13 +655,14 @@ struct FmhaBwdDQDKDVKernel
} }
else else
{ {
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do; batch_offset_do = static_cast<long_index_t>(i_batch) * kargs.batch_stride_do;
batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed; batch_offset_lsed = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lsed;
batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk; batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv; batch_offset_dk = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dk;
batch_offset_dv = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dv;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
...@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -618,7 +671,7 @@ struct FmhaBwdDQDKDVKernel
{ {
batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias; batch_offset_dbias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dbias;
} }
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
batch_offset_randval = batch_offset_randval =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval; static_cast<long_index_t>(i_batch) * kargs.batch_stride_randval;
...@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel ...@@ -646,14 +699,11 @@ struct FmhaBwdDQDKDVKernel
const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) + const OGradDataType* do_ptr = reinterpret_cast<const OGradDataType*>(kargs.do_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_do +
batch_offset_do; batch_offset_do;
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_q +
batch_offset_q;
KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) + KGradDataType* dk_ptr = reinterpret_cast<KGradDataType*>(kargs.dk_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_k + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dk +
batch_offset_dk; batch_offset_dk;
VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) + VGradDataType* dv_ptr = reinterpret_cast<VGradDataType*>(kargs.dv_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_v + static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dv +
batch_offset_dv; batch_offset_dv;
// Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window // Q/K/V/LSE/D/dO/dQ/dK/dV DRAM and DRAM window
...@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -663,45 +713,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_q, 1), make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQ>{}, number<FmhaPipeline::kAlignmentQ>{},
number<1>{}); number<1>{});
const auto q_dram = [&]() { const auto q_dram = pad_tensor_view(
if constexpr(FmhaPipeline::kQLoadOnce) q_dram_naive,
{ make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
return pad_tensor_view( sequence<kPadSeqLenQ, kPadHeadDimQ>{});
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
q_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
const auto qt_dram_naive =
transform_tensor_view(q_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_q),
make_pass_through_transform(kargs.seqlen_q)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto qt_dram = [&]() {
if constexpr(FmhaPipeline::kQTLoadOnce)
{
return pad_tensor_view(
qt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kM0>{}),
sequence<kPadHeadDimQ, kPadSeqLenQ>{});
}
else
{
return pad_tensor_view(
qt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kK3>{}),
sequence<kPadHeadDimQ, kPadSeqLenQ>{});
}
}();
const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto k_dram_naive = make_naive_tensor_view<address_space_enum::global>(
k_ptr, k_ptr,
...@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -709,45 +724,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_k, 1), make_tuple(kargs.stride_k, 1),
number<FmhaPipeline::kAlignmentK>{}, number<FmhaPipeline::kAlignmentK>{},
number<1>{}); number<1>{});
const auto k_dram = [&]() { const auto k_dram = pad_tensor_view(
if constexpr(FmhaPipeline::kKLoadOnce) k_dram_naive,
{ make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
return pad_tensor_view( sequence<kPadSeqLenK, kPadHeadDimQ>{});
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
else
{
return pad_tensor_view(
k_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{}),
sequence<kPadSeqLenK, kPadHeadDimQ>{});
}
}();
const auto kt_dram_naive =
transform_tensor_view(k_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_q),
make_pass_through_transform(kargs.seqlen_k)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto kt_dram = [&]() {
if constexpr(FmhaPipeline::kKTLoadOnce)
{
return pad_tensor_view(
kt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kN0>{}),
sequence<kPadHeadDimQ, kPadSeqLenK>{});
}
else
{
return pad_tensor_view(
kt_dram_naive,
make_tuple(number<FmhaPipeline::kQKHeaddim>{}, number<FmhaPipeline::kK4>{}),
sequence<kPadHeadDimQ, kPadSeqLenK>{});
}
}();
const auto v_dram = [&]() { const auto v_dram = [&]() {
const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>( const auto v_dram_naive = make_naive_tensor_view<address_space_enum::global>(
...@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel ...@@ -756,20 +736,10 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_v, 1), make_tuple(kargs.stride_v, 1),
number<FmhaPipeline::kAlignmentV>{}, number<FmhaPipeline::kAlignmentV>{},
number<1>{}); number<1>{});
if constexpr(FmhaPipeline::kVLoadOnce) return pad_tensor_view(
{ v_dram_naive,
return pad_tensor_view( make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
v_dram_naive, sequence<kPadSeqLenK, kPadHeadDimV>{});
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}
else
{
return pad_tensor_view(
v_dram_naive,
make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK2>{}),
sequence<kPadSeqLenK, kPadHeadDimV>{});
}
}(); }();
const auto lse_dram = [&]() { const auto lse_dram = [&]() {
...@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel ...@@ -792,145 +762,89 @@ struct FmhaBwdDQDKDVKernel
make_tuple(kargs.stride_do, 1), make_tuple(kargs.stride_do, 1),
number<FmhaPipeline::kAlignmentOGrad>{}, number<FmhaPipeline::kAlignmentOGrad>{},
number<1>{}); number<1>{});
const auto do_dram = [&]() { const auto do_dram = pad_tensor_view(
if constexpr(FmhaPipeline::kOGradLoadOnce) do_dram_naive,
{ make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
return pad_tensor_view( sequence<kPadSeqLenQ, kPadHeadDimV>{});
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}
else
{
return pad_tensor_view(
do_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK2>{}),
sequence<kPadSeqLenQ, kPadHeadDimV>{});
}
}();
const auto dot_dram_naive =
transform_tensor_view(do_dram_naive,
make_tuple(make_pass_through_transform(kargs.hdim_v),
make_pass_through_transform(kargs.seqlen_q)),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
const auto dot_dram = [&]() {
if constexpr(FmhaPipeline::kOGradTLoadOnce)
{
return pad_tensor_view(
dot_dram_naive,
make_tuple(number<FmhaPipeline::kVHeaddim>{}, number<FmhaPipeline::kM0>{}),
sequence<kPadHeadDimV, kPadSeqLenQ>{});
}
else
{
return pad_tensor_view(
dot_dram_naive,
make_tuple(number<FmhaPipeline::kVHeaddim>{}, number<FmhaPipeline::kK1>{}),
sequence<kPadHeadDimV, kPadSeqLenQ>{});
}
}();
auto dq_dram = [&]() {
const auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_q, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
auto q_dram_window = make_tile_window( auto q_dram_window = make_tile_window(
q_dram, q_dram,
[&]() { make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
if constexpr(FmhaPipeline::kQLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK0>{});
}(),
{0, 0}); {0, 0});
auto qt_dram_window =
make_tile_window(qt_dram,
[&]() {
if constexpr(FmhaPipeline::kQTLoadOnce)
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kM0>{});
else
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kK3>{});
}(),
{0, 0});
auto k_dram_window = make_tile_window( auto k_dram_window = make_tile_window(
k_dram, k_dram,
[&]() { make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kQKHeaddim>{}),
if constexpr(FmhaPipeline::kKLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kQKHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK0>{});
}(),
{i_n0, 0}); {i_n0, 0});
auto kt_dram_window =
make_tile_window(kt_dram,
[&]() {
if constexpr(FmhaPipeline::kKTLoadOnce)
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kN0>{});
else
return make_tuple(number<FmhaPipeline::kQKHeaddim>{},
number<FmhaPipeline::kK4>{});
}(),
{0, i_n0});
auto v_dram_window = make_tile_window( auto v_dram_window = make_tile_window(
v_dram, v_dram,
[&]() { make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kVHeaddim>{}),
if constexpr(FmhaPipeline::kVLoadOnce)
return make_tuple(number<FmhaPipeline::kN0>{},
number<FmhaPipeline::kVHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kN0>{}, number<FmhaPipeline::kK2>{});
}(),
{i_n0, 0}); {i_n0, 0});
auto do_dram_window = make_tile_window( auto do_dram_window = make_tile_window(
do_dram, do_dram,
[&]() { make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kVHeaddim>{}),
if constexpr(FmhaPipeline::kOGradLoadOnce)
return make_tuple(number<FmhaPipeline::kM0>{},
number<FmhaPipeline::kVHeaddim>{});
else
return make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kK2>{});
}(),
{0, 0}); {0, 0});
auto dot_dram_window = auto dq_dram_window = [&, i_tile_n_ = i_tile_n, i_nhead_ = i_nhead]() {
make_tile_window(dot_dram, if constexpr(kIsDeterministic)
[&]() { {
if constexpr(FmhaPipeline::kOGradTLoadOnce) AccDataType* dq_acc_ptr =
return make_tuple(number<FmhaPipeline::kVHeaddim>{}, reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
number<FmhaPipeline::kM0>{}); static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
else static_cast<long_index_t>(i_tile_n_) * kargs.split_stride_dq_acc +
return make_tuple(number<FmhaPipeline::kVHeaddim>{}, batch_offset_dq_acc;
number<FmhaPipeline::kK1>{});
}(), auto dq_acc_dram = [&]() {
{0, 0}); const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global>(
auto dq_dram_window = make_tile_window( dq_acc_ptr,
dq_dram, make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}), make_tuple(kargs.stride_dq_acc, 1),
{0, 0}); number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
else
{
AccDataType* dq_acc_ptr =
reinterpret_cast<AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * kargs.nhead_stride_dq_acc +
batch_offset_dq_acc;
auto dq_acc_dram = [&]() {
const auto dq_acc_dram_naive =
make_naive_tensor_view<address_space_enum::global,
memory_operation_enum::atomic_add>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaPipeline::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(
dq_acc_dram_naive,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
return make_tile_window(
dq_acc_dram,
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kQKHeaddim>{}),
{0, 0});
}
}();
auto lse_dram_window = auto lse_dram_window =
make_tile_window(lse_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0}); make_tile_window(lse_dram, make_tuple(number<FmhaPipeline::kM0>{}), {0});
...@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -1008,9 +922,7 @@ struct FmhaBwdDQDKDVKernel
// TODO: how to use s_read? // TODO: how to use s_read?
AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) + AccDataType slope = *(reinterpret_cast<const AccDataType*>(kargs.alibi_slope_ptr) +
i_batch_ * kargs.alibi_slope_stride + i_nhead_); i_batch_ * kargs.alibi_slope_stride + i_nhead_);
#if CK_TILE_FMHA_FWD_FAST_EXP2
slope *= ck_tile::log2e_v<>; slope *= ck_tile::log2e_v<>;
#endif
if constexpr(kHasMask) if constexpr(kHasMask)
{ {
return make_alibi_from_lr_mask<AccDataType, false>(slope, return make_alibi_from_lr_mask<AccDataType, false>(slope,
...@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel ...@@ -1033,35 +945,34 @@ struct FmhaBwdDQDKDVKernel
}(); }();
// dropout // dropout
float rp_undrop = 1; float rp_undrop = 1;
float scale_rp_undrop = 1; float scale_rp_undrop = 1;
uint8_t p_undrop_in_uint8_t = std::numeric_limits<uint8_t>::max();
uint64_t drop_seed = 0;
uint64_t drop_offset = 0;
bool is_store_randval = false;
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
rp_undrop = kargs.rp_undrop; rp_undrop = kargs.rp_undrop;
scale_rp_undrop = kargs.scale_rp_undrop; scale_rp_undrop = kargs.scale_rp_undrop;
p_undrop_in_uint8_t = kargs.p_undrop_in_uint8_t;
drop_seed = kargs.drop_seed;
drop_offset = kargs.drop_offset;
is_store_randval = kargs.is_store_randval;
} }
BlockDropout dropout(i_batch, auto dropout = [&, i_nhead_ = i_nhead, i_batch_ = i_batch]() {
i_nhead, if constexpr(kHasDropout)
kargs.num_head_q, {
drop_seed, return FmhaDropout{i_batch_,
drop_offset, i_nhead_,
rp_undrop, kargs.num_head_q,
p_undrop_in_uint8_t, kargs.drop_seed,
is_store_randval); kargs.drop_offset,
kargs.rp_undrop,
kargs.p_undrop_in_uint8_t};
}
else
{
return FmhaDropout{};
};
}();
auto randval_dram_window = [&, i_nhead_ = i_nhead]() { auto randval_dram_window = [&, i_nhead_ = i_nhead]() {
constexpr auto randval_dram_window_lengths = constexpr auto randval_dram_window_lengths =
make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{}); make_tuple(number<FmhaPipeline::kM0>{}, number<FmhaPipeline::kN0>{});
if constexpr(kHasDropout) if constexpr(kIsStoreRandval)
{ {
RandValOutputDataType* rand_val_ptr = RandValOutputDataType* rand_val_ptr =
reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) + reinterpret_cast<RandValOutputDataType*>(kargs.rand_val_ptr) +
...@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel ...@@ -1103,14 +1014,11 @@ struct FmhaBwdDQDKDVKernel
}(); }();
auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window, auto [dk_acc_tile, dv_acc_tile] = FmhaPipeline{}(q_dram_window,
qt_dram_window,
k_dram_window, k_dram_window,
kt_dram_window,
v_dram_window, v_dram_window,
bias_dram_window, bias_dram_window,
randval_dram_window, randval_dram_window,
do_dram_window, do_dram_window,
dot_dram_window,
lse_dram_window, lse_dram_window,
d_dram_window, d_dram_window,
dq_dram_window, dq_dram_window,
...@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel ...@@ -1118,9 +1026,7 @@ struct FmhaBwdDQDKDVKernel
mask, mask,
position_encoding, position_encoding,
kargs.raw_scale, kargs.raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
kargs.scale, kargs.scale,
#endif
rp_undrop, rp_undrop,
scale_rp_undrop, scale_rp_undrop,
smem_ptr, smem_ptr,
...@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel ...@@ -1169,10 +1075,9 @@ struct FmhaBwdDQDKDVKernel
} }
}; };
template <typename TilePartitioner_, typename FmhaBwdOGradDotO_> template <typename FmhaBwdOGradDotO_>
struct FmhaBwdOGradDotOKernel struct FmhaBwdOGradDotOKernel
{ {
using TilePartitioner = ck_tile::remove_cvref_t<TilePartitioner_>;
using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>; using FmhaBwdOGradDotO = ck_tile::remove_cvref_t<FmhaBwdOGradDotO_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize; static constexpr ck_tile::index_t kBlockSize = FmhaBwdOGradDotO::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu; static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdOGradDotO::kBlockPerCu;
...@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel ...@@ -1234,13 +1139,13 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t nhead_stride_do; ck_tile::index_t nhead_stride_do;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t nhead_stride_d; ck_tile::index_t nhead_stride_d;
ck_tile::index_t batch_stride_d;
}; };
struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs struct FmhaBwdOGradDotOBatchModeKargs : FmhaBwdOGradDotOCommonKargs
{ {
ck_tile::index_t batch_stride_do; ck_tile::index_t batch_stride_do;
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_d;
}; };
struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs struct FmhaBwdOGradDotOGroupModeKargs : FmhaBwdOGradDotOCommonKargs
...@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel ...@@ -1278,10 +1183,10 @@ struct FmhaBwdOGradDotOKernel
stride_o, stride_o,
nhead_stride_do, nhead_stride_do,
nhead_stride_o, nhead_stride_o,
nhead_stride_d, nhead_stride_d},
batch_stride_d},
batch_stride_do, batch_stride_do,
batch_stride_o}; batch_stride_o,
batch_stride_d};
return kargs; return kargs;
} }
...@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1298,8 +1203,7 @@ struct FmhaBwdOGradDotOKernel
ck_tile::index_t stride_o, ck_tile::index_t stride_o,
ck_tile::index_t nhead_stride_do, ck_tile::index_t nhead_stride_do,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t nhead_stride_d, ck_tile::index_t nhead_stride_d)
ck_tile::index_t batch_stride_d)
{ {
Kargs kargs{{o_ptr, Kargs kargs{{o_ptr,
do_ptr, do_ptr,
...@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1311,8 +1215,7 @@ struct FmhaBwdOGradDotOKernel
stride_o, stride_o,
nhead_stride_do, nhead_stride_do,
nhead_stride_o, nhead_stride_o,
nhead_stride_d, nhead_stride_d},
batch_stride_d},
reinterpret_cast<const int32_t*>(seqstart_q_ptr)}; reinterpret_cast<const int32_t*>(seqstart_q_ptr)};
return kargs; return kargs;
...@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel ...@@ -1321,7 +1224,16 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_HOST static constexpr auto CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_) GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{ {
return TilePartitioner::GridSize(batch_size_, nhead_, seqlen_q_); return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
} }
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); } CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
...@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1331,7 +1243,7 @@ struct FmhaBwdOGradDotOKernel
CK_TILE_DEVICE void operator()(Kargs kargs) const CK_TILE_DEVICE void operator()(Kargs kargs) const
{ {
// divide problem // divide problem
const auto [i_tile_m, i_nhead, i_batch] = TilePartitioner{}(kargs.seqlen_q); const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
...@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel ...@@ -1346,7 +1258,7 @@ struct FmhaBwdOGradDotOKernel
batch_offset_o = query_start * kargs.stride_o; batch_offset_o = query_start * kargs.stride_o;
batch_offset_do = query_start * kargs.stride_do; batch_offset_do = query_start * kargs.stride_do;
batch_offset_d = static_cast<long_index_t>(i_batch) * kargs.batch_stride_d; batch_offset_d = query_start;
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel ...@@ -1418,4 +1330,315 @@ struct FmhaBwdOGradDotOKernel
} }
}; };
template <typename FmhaBwdConvertQGrad_>
struct FmhaBwdConvertQGradKernel
{
using FmhaBwdConvertQGrad = ck_tile::remove_cvref_t<FmhaBwdConvertQGrad_>;
static constexpr ck_tile::index_t kBlockSize = FmhaBwdConvertQGrad::kBlockSize;
static constexpr ck_tile::index_t kBlockPerCu = FmhaBwdConvertQGrad::kBlockPerCu;
static constexpr ck_tile::index_t kM0 = FmhaBwdConvertQGrad::kM0;
static constexpr ck_tile::index_t kN0 = FmhaBwdConvertQGrad::kN0;
static constexpr ck_tile::index_t kQKHeaddim = FmhaBwdConvertQGrad::kQKHeaddim;
using AccDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::AccDataType>;
using QGradDataType = ck_tile::remove_cvref_t<typename FmhaBwdConvertQGrad::QGradDataType>;
static constexpr bool kIsGroupMode = FmhaBwdConvertQGrad::kIsGroupMode;
static constexpr bool kPadSeqLenQ = FmhaBwdConvertQGrad::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = FmhaBwdConvertQGrad::kPadHeadDimQ;
static constexpr bool kIsDeterministic = FmhaBwdConvertQGrad::kIsDeterministic;
// clang-format off
template <typename T> struct t2s;
template <> struct t2s<ck_tile::fp16_t> { static constexpr const char * name = "fp16"; };
template <> struct t2s<ck_tile::bf16_t> { static constexpr const char * name = "bf16"; };
// clang-format on
CK_TILE_HOST static std::string GetName()
{
// sync with generate.py
// clang-format off
#define _SS_ std::string
#define _TS_ std::to_string
auto pn = [&] () {
std::string n;
if (kPadSeqLenQ) n += "s";
if (kPadHeadDimQ) n += "d";
return n.empty() ? n : std::string("p") + n; }();
return
_SS_("fmha_bwd_convert_dq_d") + _TS_(kQKHeaddim) + "_" + _SS_(t2s<QGradDataType>::name) +
"_" + (kIsGroupMode ? "group" : "batch") + (kIsDeterministic ? "_deterministic" : "") + "_" +
("o" + _TS_(kBlockPerCu)) + (pn.empty() ? "" : "_" + pn);
#undef _SS_
#undef _TS_
// clang-format on
}
// to avoid duplicated base class prblem, introduce an template arg
template <ck_tile::index_t I>
struct FmhaBwdConvertQGradEmptyKargs
{
};
// kargs use aggregate initializer, so no constructor will provided
// use inheritance to minimize karg size
// user need to use MakeKargs() function to create kargs.
struct FmhaBwdConvertQGradCommonKargs
{
const void* dq_acc_ptr;
void* dq_ptr;
ck_tile::index_t seqlen_q;
ck_tile::index_t seqlen_k;
ck_tile::index_t hdim_q;
ck_tile::index_t stride_dq;
ck_tile::index_t stride_dq_acc;
ck_tile::index_t nhead_stride_dq;
ck_tile::index_t nhead_stride_dq_acc;
};
struct FmhaBwdConvertQGradDeterministicKargs
{
ck_tile::index_t split_stride_dq_acc = 0;
};
struct FmhaBwdConvertQGradBatchModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
{
ck_tile::index_t batch_stride_dq;
ck_tile::index_t batch_stride_dq_acc;
};
struct FmhaBwdConvertQGradGroupModeKargs
: FmhaBwdConvertQGradCommonKargs,
std::conditional_t<kIsDeterministic,
FmhaBwdConvertQGradDeterministicKargs,
FmhaBwdConvertQGradEmptyKargs<0>>
{
const int32_t* seqstart_q_ptr;
const int32_t* seqstart_k_ptr;
};
using Kargs = std::conditional_t<kIsGroupMode,
FmhaBwdConvertQGradGroupModeKargs,
FmhaBwdConvertQGradBatchModeKargs>;
template <bool Cond = !kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
ck_tile::index_t seqlen_q,
ck_tile::index_t seqlen_k,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t batch_stride_dq,
ck_tile::index_t batch_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
dq_ptr,
seqlen_q,
seqlen_k,
hdim_q,
stride_dq,
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
batch_stride_dq,
batch_stride_dq_acc};
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
}
template <bool Cond = kIsGroupMode>
CK_TILE_HOST static constexpr std::enable_if_t<Cond, Kargs>
MakeKargs(const void* dq_acc_ptr,
void* dq_ptr,
const void* seqstart_q_ptr,
const void* seqstart_k_ptr,
ck_tile::index_t hdim_q,
ck_tile::index_t stride_dq,
ck_tile::index_t stride_dq_acc,
ck_tile::index_t nhead_stride_dq,
ck_tile::index_t nhead_stride_dq_acc,
ck_tile::index_t split_stride_dq_acc)
{
Kargs kargs{{dq_acc_ptr,
dq_ptr,
-1, // seqlen will be updated by another pointer
-1, //
hdim_q,
stride_dq,
stride_dq_acc,
nhead_stride_dq,
nhead_stride_dq_acc},
{},
reinterpret_cast<const int32_t*>(seqstart_q_ptr),
reinterpret_cast<const int32_t*>(seqstart_k_ptr)};
if constexpr(kIsDeterministic)
{
kargs.split_stride_dq_acc = split_stride_dq_acc;
}
return kargs;
}
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kM0), nhead_, batch_size_);
}
CK_TILE_DEVICE static constexpr auto GetTileIndex()
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
CK_TILE_HOST static constexpr auto BlockSize() { return dim3(kBlockSize); }
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
CK_TILE_DEVICE void operator()(Kargs kargs) const
{
// divide problem
const auto [i_tile_m, i_nhead, i_batch] = GetTileIndex();
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * kM0);
long_index_t batch_offset_dq = 0;
long_index_t batch_offset_dq_acc = 0;
if constexpr(kIsGroupMode)
{
// get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_dq = query_start * kargs.stride_dq;
batch_offset_dq_acc = query_start * kargs.stride_dq_acc;
// get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
kargs.seqlen_q = adjusted_seqstart_q_ptr[1] - adjusted_seqstart_q_ptr[0];
if constexpr(kIsDeterministic)
{
const auto adjusted_seqstart_k_ptr = kargs.seqstart_k_ptr + i_batch;
kargs.seqlen_k = adjusted_seqstart_k_ptr[1] - adjusted_seqstart_k_ptr[0];
}
// # of required blocks is different in each groups, terminate unnecessary blocks
// earlier
if(kargs.seqlen_q <= i_m0)
{
return;
}
}
else
{
batch_offset_dq = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq;
batch_offset_dq_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_dq_acc;
}
// for simplicity, batch stride we just modify the pointer
QGradDataType* dq_ptr = reinterpret_cast<QGradDataType*>(kargs.dq_ptr) +
static_cast<long_index_t>(i_nhead) * kargs.nhead_stride_dq +
batch_offset_dq;
// dQAcc/dQ DRAM and DRAM window
const auto dq_acc_dram = [&, i_nhead_ = i_nhead]() {
if constexpr(kIsDeterministic)
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(nsplits, kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.split_stride_dq_acc, kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
sequence<false, kPadSeqLenQ, kPadHeadDimQ>{});
}
else
{
const AccDataType* dq_acc_ptr =
reinterpret_cast<const AccDataType*>(kargs.dq_acc_ptr) +
static_cast<long_index_t>(i_nhead_) * (kargs.nhead_stride_dq_acc) +
batch_offset_dq_acc;
auto dq_acc_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_acc_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq_acc, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGradAcc>{},
number<1>{});
return pad_tensor_view(dq_acc_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}
}();
auto dq_dram = [&]() {
auto dq_dram_naive = make_naive_tensor_view<address_space_enum::global>(
dq_ptr,
make_tuple(kargs.seqlen_q, kargs.hdim_q),
make_tuple(kargs.stride_dq, 1),
number<FmhaBwdConvertQGrad::kAlignmentQGrad>{},
number<1>{});
return pad_tensor_view(dq_dram_naive,
make_tuple(number<kM0>{}, number<kQKHeaddim>{}),
sequence<kPadSeqLenQ, kPadHeadDimQ>{});
}();
auto dq_acc_dram_window = [&]() {
if constexpr(kIsDeterministic)
{
return make_tile_window(
dq_acc_dram,
make_tuple(number<1>{}, number<kM0>{}, number<kQKHeaddim>{}),
{0, i_m0, 0});
}
else
{
return make_tile_window(
dq_acc_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
}
}();
auto dq_dram_window =
make_tile_window(dq_dram, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {i_m0, 0});
if constexpr(kIsDeterministic)
{
const index_t nsplits = ck_tile::integer_divide_ceil(kargs.seqlen_k, kN0);
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window, nsplits);
}
else
{
FmhaBwdConvertQGrad{}(dq_acc_dram_window, dq_dram_window);
}
}
};
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename BlockFmhaShape_>
struct FmhaBwdTilePartitioner
{
using BlockFmhaShape = ck_tile::remove_cvref_t<BlockFmhaShape_>;
static constexpr ck_tile::index_t kN0 = BlockFmhaShape::kN0;
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_k_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_k_, kN0), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_k*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
template <ck_tile::index_t kBlockSize>
struct FmhaBwdOGradDotOTilePartitioner
{
CK_TILE_HOST static constexpr auto
GridSize(ck_tile::index_t batch_size_, ck_tile::index_t nhead_, ck_tile::index_t seqlen_q_)
{
// TODO: this may need tuning
return dim3(ck_tile::integer_divide_ceil(seqlen_q_, kBlockSize), nhead_, batch_size_);
}
CK_TILE_DEVICE auto operator()(ck_tile::index_t /*seqlen_q*/)
{
const index_t i_block = blockIdx.x;
const index_t i_nhead = blockIdx.y;
const index_t i_batch = blockIdx.z;
return ck_tile::make_tuple(i_block, i_nhead, i_batch);
}
};
} // namespace ck_tile
...@@ -86,7 +86,7 @@ struct FmhaFwdKernel ...@@ -86,7 +86,7 @@ struct FmhaFwdKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kStoreLSE ? "_lse" : "" ) + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -387,7 +387,6 @@ struct FmhaFwdKernel ...@@ -387,7 +387,6 @@ struct FmhaFwdKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t window_size_left, ck_tile::index_t window_size_left,
ck_tile::index_t window_size_right, ck_tile::index_t window_size_right,
ck_tile::index_t mask_type, ck_tile::index_t mask_type,
...@@ -448,7 +447,6 @@ struct FmhaFwdKernel ...@@ -448,7 +447,6 @@ struct FmhaFwdKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -524,7 +522,7 @@ struct FmhaFwdKernel ...@@ -524,7 +522,7 @@ struct FmhaFwdKernel
} }
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; batch_offset_lse = query_start;
} }
if constexpr(kHasDropout) if constexpr(kHasDropout)
{ {
......
...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -55,7 +55,7 @@ struct FmhaFwdSplitKVCombineKernel
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) +
_SS_(FmhaPipeline::name) + _SS_(FmhaPipeline::name) +
(pn.empty() ? "" : "_" + pn) + (pn.empty() ? "" : "_" + pn) +
(kStoreLSE ? "_lse" : "" ) + (kStoreLSE ? "_lse" : "" ) +
(kDoFp8StaticQuant ? "_squant" : "" ); (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -91,7 +91,6 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t nhead_stride_o; ck_tile::index_t nhead_stride_o;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -116,6 +115,7 @@ struct FmhaFwdSplitKVCombineKernel
std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>> std::conditional_t<kDoFp8StaticQuant, Fp8StaticQuantKargs, EmptyKargs<1>>
{ {
ck_tile::index_t batch_stride_o; ck_tile::index_t batch_stride_o;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -166,13 +166,13 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
{}, // placeholder for lse {}, // placeholder for lse
{}, // placeholder for fp8_static_quant args {}, // placeholder for fp8_static_quant args
batch_stride_o}; batch_stride_o,
batch_stride_lse_acc};
if constexpr(kStoreLSE) if constexpr(kStoreLSE)
{ {
...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -206,9 +206,7 @@ struct FmhaFwdSplitKVCombineKernel
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t nhead_stride_lse, ck_tile::index_t nhead_stride_lse,
ck_tile::index_t nhead_stride_o, ck_tile::index_t nhead_stride_o,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t batch_stride_lse,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc) ck_tile::index_t split_stride_o_acc)
{ {
...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -225,7 +223,6 @@ struct FmhaFwdSplitKVCombineKernel
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
nhead_stride_o, nhead_stride_o,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -237,7 +234,6 @@ struct FmhaFwdSplitKVCombineKernel
{ {
kargs.lse_ptr = lse_ptr; kargs.lse_ptr = lse_ptr;
kargs.nhead_stride_lse = nhead_stride_lse; kargs.nhead_stride_lse = nhead_stride_lse;
kargs.batch_stride_lse = batch_stride_lse;
} }
if constexpr(kDoFp8StaticQuant) if constexpr(kDoFp8StaticQuant)
{ {
...@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -274,24 +270,25 @@ struct FmhaFwdSplitKVCombineKernel
const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0); const index_t i_m0 = __builtin_amdgcn_readfirstlane(i_tile_m * FmhaPipeline::kM0);
const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1); const index_t i_n1 = __builtin_amdgcn_readfirstlane(i_tile_n * FmhaPipeline::kN1);
const long_index_t batch_offset_lse_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
long_index_t batch_offset_lse = 0;
long_index_t batch_offset_o = 0;
if constexpr(kStoreLSE) long_index_t batch_offset_lse_acc = 0;
{ long_index_t batch_offset_lse = 0;
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse; long_index_t batch_offset_o = 0;
}
if constexpr(kIsGroupMode) if constexpr(kIsGroupMode)
{ {
// get starting offset for each batch // get starting offset for each batch
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
batch_offset_o = query_start * kargs.row_stride_o; batch_offset_o = query_start * kargs.row_stride_o;
batch_offset_lse_acc = query_start;
if constexpr(kStoreLSE)
{
batch_offset_lse = query_start;
}
// get real # queries & # keys under group mode // get real # queries & # keys under group mode
const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch; const auto adjusted_seqstart_q_ptr = kargs.seqstart_q_ptr + i_batch;
...@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel ...@@ -306,7 +303,13 @@ struct FmhaFwdSplitKVCombineKernel
} }
else else
{ {
batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o; batch_offset_o = static_cast<long_index_t>(i_batch) * kargs.batch_stride_o;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(kStoreLSE)
{
batch_offset_lse = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse;
}
} }
// for simplicity, batch stride we just modify the pointer // for simplicity, batch stride we just modify the pointer
......
...@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel ...@@ -85,7 +85,7 @@ struct FmhaFwdSplitKVKernel
"w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" + "w" + _TS_(gwt::at(ck_tile::number<0>{})) + "x" + _TS_(gwt::at(ck_tile::number<1>{})) + "x" + _TS_(gwt::at(ck_tile::number<2>{})) + "_" +
(kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" + (kBlockPerCuInput == -1 ? "" : ("o" + _TS_(kBlockPerCu) + "_")) + _SS_(FmhaPipeline::name) + "_" +
"v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) + "v" + (std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor> ? "r" : "c") + (pn.empty() ? "" : "_" + pn) +
(BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) + (BiasEnum == BlockAttentionBiasEnum::NO_BIAS ? _SS_("") : (_SS_("_") + BlockAttentionBiasEnumToStr<BiasEnum>::name)) +
(kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" ); (kHasMask ? "_" + _SS_(FmhaMask::name) : "") + (kHasDropout ? "_dropout" : "" ) + (kDoFp8StaticQuant ? "_squant" : "" );
#undef _SS_ #undef _SS_
#undef _TS_ #undef _TS_
...@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel ...@@ -136,7 +136,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_lse_acc; ck_tile::index_t nhead_stride_lse_acc;
ck_tile::index_t nhead_stride_o_acc; ck_tile::index_t nhead_stride_o_acc;
ck_tile::index_t batch_stride_lse_acc;
ck_tile::index_t batch_stride_o_acc; ck_tile::index_t batch_stride_o_acc;
ck_tile::index_t split_stride_lse_acc; ck_tile::index_t split_stride_lse_acc;
...@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel ...@@ -216,6 +215,7 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t batch_stride_q; ck_tile::index_t batch_stride_q;
ck_tile::index_t batch_stride_k; ck_tile::index_t batch_stride_k;
ck_tile::index_t batch_stride_v; ck_tile::index_t batch_stride_v;
ck_tile::index_t batch_stride_lse_acc;
}; };
struct GroupModeKargs struct GroupModeKargs
...@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel ...@@ -313,7 +313,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel ...@@ -323,7 +322,8 @@ struct FmhaFwdSplitKVKernel
{}, // placeholder for dropout {}, // placeholder for dropout
batch_stride_q, batch_stride_q,
batch_stride_k, batch_stride_k,
batch_stride_v}; batch_stride_v,
batch_stride_lse_acc};
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
...@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel ...@@ -394,7 +394,6 @@ struct FmhaFwdSplitKVKernel
ck_tile::index_t nhead_stride_randval, ck_tile::index_t nhead_stride_randval,
ck_tile::index_t nhead_stride_lse_acc, ck_tile::index_t nhead_stride_lse_acc,
ck_tile::index_t nhead_stride_o_acc, ck_tile::index_t nhead_stride_o_acc,
ck_tile::index_t batch_stride_lse_acc,
ck_tile::index_t batch_stride_o_acc, ck_tile::index_t batch_stride_o_acc,
ck_tile::index_t split_stride_lse_acc, ck_tile::index_t split_stride_lse_acc,
ck_tile::index_t split_stride_o_acc, ck_tile::index_t split_stride_o_acc,
...@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel ...@@ -433,7 +432,6 @@ struct FmhaFwdSplitKVKernel
nhead_stride_v, nhead_stride_v,
nhead_stride_lse_acc, nhead_stride_lse_acc,
nhead_stride_o_acc, nhead_stride_o_acc,
batch_stride_lse_acc,
batch_stride_o_acc, batch_stride_o_acc,
split_stride_lse_acc, split_stride_lse_acc,
split_stride_o_acc}, // args for common karg split_stride_o_acc}, // args for common karg
...@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel ...@@ -511,8 +509,7 @@ struct FmhaFwdSplitKVKernel
long_index_t batch_offset_v = 0; long_index_t batch_offset_v = 0;
long_index_t batch_offset_bias = 0; long_index_t batch_offset_bias = 0;
long_index_t batch_offset_randval = 0; long_index_t batch_offset_randval = 0;
const long_index_t batch_offset_lse_acc = long_index_t batch_offset_lse_acc = 0;
static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
const long_index_t batch_offset_o_acc = const long_index_t batch_offset_o_acc =
static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc; static_cast<long_index_t>(i_batch) * kargs.batch_stride_o_acc;
...@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel ...@@ -522,8 +519,9 @@ struct FmhaFwdSplitKVKernel
const long_index_t query_start = kargs.seqstart_q_ptr[i_batch]; const long_index_t query_start = kargs.seqstart_q_ptr[i_batch];
const long_index_t key_start = kargs.seqstart_k_ptr[i_batch]; const long_index_t key_start = kargs.seqstart_k_ptr[i_batch];
batch_offset_q = query_start * kargs.stride_q; batch_offset_q = query_start * kargs.stride_q;
batch_offset_k = key_start * kargs.stride_k; batch_offset_k = key_start * kargs.stride_k;
batch_offset_lse_acc = query_start;
if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<VLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
batch_offset_v = key_start * kargs.stride_v; batch_offset_v = key_start * kargs.stride_v;
...@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel ...@@ -564,9 +562,10 @@ struct FmhaFwdSplitKVKernel
} }
else else
{ {
batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q; batch_offset_q = static_cast<long_index_t>(i_batch) * kargs.batch_stride_q;
batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k; batch_offset_k = static_cast<long_index_t>(i_batch) * kargs.batch_stride_k;
batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v; batch_offset_v = static_cast<long_index_t>(i_batch) * kargs.batch_stride_v;
batch_offset_lse_acc = static_cast<long_index_t>(i_batch) * kargs.batch_stride_lse_acc;
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias; batch_offset_bias = static_cast<long_index_t>(i_batch) * kargs.batch_stride_bias;
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdConvertQGrad
{
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
static constexpr index_t kM0 = Problem::kM0;
static constexpr index_t kN0 = Problem::kN0;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kQKHeaddim = Problem::kQKHeaddim;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr index_t kAlignmentQGradAcc =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGradAcc<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentPostQGrad<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
// Convert only
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradDramTileDistribution<Problem>());
auto dq_acc = load_tile(dq_acc_dram_window);
const auto dq = cast_tile<QGradDataType>(dq_acc);
store_tile(dq_dram_block_window_tmp, dq);
}
// Reduce + Convert
template <typename QGradAccDramBlockWindowTmp, typename QGradDramBlockWindowTmp>
CK_TILE_HOST_DEVICE void
operator()(const QGradAccDramBlockWindowTmp& dq_acc_dram_block_window_tmp,
QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
index_t nsplits) const
{
static_assert(
std::is_same_v<AccDataType,
remove_cvref_t<typename QGradAccDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}], "wrong!");
auto dq_acc_dram_window =
make_tile_window(dq_acc_dram_block_window_tmp.get_bottom_tensor_view(),
dq_acc_dram_block_window_tmp.get_window_lengths(),
dq_acc_dram_block_window_tmp.get_window_origin(),
Policy::template MakePostQGradAccDramTileDistribution<Problem>());
auto dq_acc = decltype(load_tile(dq_acc_dram_window)){};
clear_tile(dq_acc);
constexpr auto dq_acc_spans = decltype(dq_acc)::get_distributed_spans();
index_t i_total_loops = 0;
auto dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
do
{
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
dq_acc_buf = load_tile(dq_acc_dram_window);
move_tile_window(dq_acc_dram_window, {1, 0, 0});
i_total_loops += 1;
} while(i_total_loops < (nsplits - 1));
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_acc(n_i_j_idx) += dq_acc_buf(n_i_j_idx);
});
});
});
// declare dq
constexpr auto dq_converted_dstr =
Policy::template MakePostQGradAccDramTileDistribution<Problem>();
auto dq_converted = make_static_distributed_tensor<QGradDataType>(dq_converted_dstr);
sweep_tile_span(dq_acc_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(dq_acc_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(dq_acc_spans[number<2>{}], [&](auto idx2) {
constexpr auto n_i_j_idx = make_tuple(idx0, idx1, idx2);
dq_converted(n_i_j_idx) = type_convert<QGradDataType>(dq_acc[n_i_j_idx]);
});
});
});
constexpr auto dq_dstr = Policy::template MakePostQGradDramTileDistribution<Problem>();
auto dq = make_static_distributed_tensor<QGradDataType>(dq_dstr);
dq.get_thread_buffer() = dq_converted.get_thread_buffer();
store_tile(dq_dram_block_window_tmp, dq);
}
};
} // namespace ck_tile
...@@ -4,11 +4,11 @@ ...@@ -4,11 +4,11 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dot_do_o_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdOGradDotODefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdOGradDotO struct BlockFmhaBwdOGradDotO
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO ...@@ -26,7 +26,7 @@ struct BlockFmhaBwdOGradDotO
static constexpr index_t kAlignmentO = static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; } CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() { return 0; }
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// These templates are not used here.
using BlockFmhaBwdOGradDotODefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ false,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ false,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS struct BlockFmhaBwdDQDKDVPipelineKRKTRVR
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>; using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>; using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = true; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kQTLoadOnce = false; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kKLoadOnce = true; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kKTLoadOnce = false; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kVLoadOnce = true; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kOGradLoadOnce = true; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kOGradTLoadOnce = false; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = static constexpr index_t kAlignmentQGrad = 1;
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad = static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad = static constexpr index_t kAlignmentVGrad =
...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "qs_ks_vr_dos"; static constexpr const char* name = "kr_ktr_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp, typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp, typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp, typename QGradDramBlockWindowTmp,
...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
const KDramBlockWindowTmp& k_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp, const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp, const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
...@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -122,13 +107,11 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float raw_scale, float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale, float scale,
#endif
float rp_undrop, float rp_undrop,
float scale_rp_undrop, float scale_rp_undrop,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -138,9 +121,7 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> && remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType, std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> && remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> && std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
...@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -156,77 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
// QT tile in LDS
auto qt_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
// OGradT tile in LDS
auto dot_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
...@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -234,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>(); constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>(); constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad // init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc); // K, HBM ->LDS ->Reg
clear_tile(dk_acc); auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
auto k_dram_window = make_tile_window( k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_bottom_tensor_view(), k_dram_block_window_tmp.get_window_origin(),
k_dram_block_window_tmp.get_window_lengths(), Policy::template MakeKDramTileDistribution<Problem>());
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin(); const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] = const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
...@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -274,217 +169,408 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{ {
// Note: here dk_acc&dv_acc are all cleard, return it // Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it. // Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
} }
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
VDataType* v_lds_ptr =
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_block_tile = load_tile(k_dram_window); auto k_block_tile = load_tile(k_dram_window);
auto v_block_tile = load_tile(v_dram_window);
store_tile(k_lds_write_window, k_block_tile);
shuffle_tile(shuffled_k_block_tile, k_block_tile);
store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
block_sync_lds();
k_reg_tensor = load_tile(k_lds_read_window);
block_sync_lds();
auto kt_reg_tensor = load_tile(kt_lds_read_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS store_tile(v_lds_write_window, v_block_tile);
auto q_dram_block_window = block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
block_sync_lds();
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakePTRegSliceBlockDescriptor<Problem>());
// QT: Reg -> Reg-> LDS
auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
auto do_dram_block_window = QDataType* qt_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->Reg ->LDS
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(), do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
// dOT: Reg ->Reg ->LDS
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto dq_dram_block_window = auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window = auto shuffled_do_lds_write_window = make_tile_window(
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window = auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto ds_lds_read_window =
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK4>{}),
ds_lds_window.get_window_origin(),
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N {seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
auto dbias_dram_block_window = static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template GetSmemSizeOGrad<Problem>() +
dbias_dram_block_window_tmp.get_window_lengths(), Policy::template GetSmemSizeOGradT<Problem>() +
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto bias_lds_write_window =
make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto bias_s_lds_read_window =
make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
bias_lds_write_window.get_window_lengths(),
bias_lds_write_window.get_window_origin(),
Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window( auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(), lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(), lse_dram_block_window_tmp.get_window_lengths(),
lse_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window( auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(), d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(), d_dram_block_window_tmp.get_window_lengths(),
d_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window = DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
bias_dram_block_window.get_window_lengths(), Policy::template GetSmemSizeOGrad<Problem>() +
bias_dram_block_window.get_window_origin(), Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
auto d_lds = make_tensor_view<address_space_enum::lds>(
d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto biast_lds_window = auto d_lds_read_window = make_tile_window(
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), d_lds,
biast_lds_shuffle_window.get_window_lengths(), make_tuple(number<kM0>{}),
biast_lds_shuffle_window.get_window_origin(), {0},
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>()); Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>( // RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start); randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0; // BiasGrad
constexpr index_t k0_loops = kQKHeaddim / kK0; // Reg ->LDS ->Reg ->HBM
constexpr index_t k1_loops = kM0 / kK1; const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3; auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0 clear_tile(dv_acc);
auto st_acc = SPTBlockTileType{}; clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
// Hot loop
while(i_total_loops < num_total_loop)
{
auto q_block_tile = load_tile(q_dram_window); auto q_block_tile = load_tile(q_dram_window);
clear_tile(st_acc); // Initialize S^T move_tile_window(q_dram_window, {kM0, 0});
store_tile(q_lds_window, q_block_tile); // LDS write
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) auto lse_block_tile = load_tile(lse_dram_window);
{ move_tile_window(lse_dram_window, {kM0});
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 1) store_tile(q_lds_window, q_block_tile);
{ shuffle_tile(shuffled_q_block_tile, q_block_tile);
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) { store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile store_tile(lse_lds_write_window, lse_block_tile);
{ // tail
block_sync_lds(); block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window, auto q_reg_tensor = load_tile(q_lds_read_window);
sequence<0, (k0_loops - 1) * kK0>{}, auto lse = load_tile(lse_lds_read_window);
sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_window, block_sync_lds();
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{})); // STAGE 1, Q@K Gemm0
block_sync_lds(); auto s_acc = SPBlockTileType{};
}
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
block_sync_lds(); const auto bias_tile = load_tile(bias_dram_window);
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile); shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds(); block_sync_lds();
auto biast_tile = load_tile(biast_lds_window); auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y); x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
}, },
st_acc, s_acc,
biast_tile); bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0}); move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2 s_acc(i_j_idx) *= scale;
st_acc(i_j_idx) *= raw_scale; position_encoding.update(s_acc(i_j_idx), row, col);
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
}); });
}); });
} }
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); bool need_perpixel_check = mask.IsEdgeTile(
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) { set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col);
}); });
} }
} }
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) { static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -499,157 +585,162 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
} }
}; };
auto pt = SPTBlockTileType{}; auto p = SPBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
} }
else else
{ {
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
} }
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
dropout.Run<decltype(gemm_0), RandValOutputDataType>( dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
} }
const auto p_gemm = [&]() {
// STAGE 3, P^T@OGrad^T Gemm1 if constexpr(FmhaDropout::IsDropout)
block_sync_lds();
store_tile(do_lds_window, do_block_tile); // store the prefetch
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt); p);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return cast_tile<GemmDataType>(p);
} }
}(); }();
static_for<0, k1_loops, 1>{}([&](auto i_k1) { // STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds(); auto do_block_tile = load_tile(do_dram_window);
gemm_1(dv_acc, move_tile_window(do_dram_window, {kM0, 0});
get_slice_tile(
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}), auto d_block_tile = load_tile(d_dram_window);
get_slice_tile(dot_lds_window, move_tile_window(d_dram_window, {kM0});
sequence<0, i_k1 * kK1>{},
sequence<kVHeaddim, (i_k1 + 1) * kK1>{})); store_tile(do_lds_window, do_block_tile);
block_sync_lds(); shuffle_tile(shuffled_do_block_tile, do_block_tile);
}); store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
store_tile(d_lds_write_window, d_block_tile);
block_sync_lds();
auto dot_reg_tensor = load_tile(dot_lds_read_window);
block_sync_lds();
Policy::template PTFromGemm0CToGemm1A<Problem,
decltype(pt_reg_tensor),
decltype(p_gemm)>(pt_reg_tensor, p_gemm);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
// STAGE 4, OGrad@V Gemm2 // STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{}; auto do_reg_tensor = load_tile(do_lds_read_window);
clear_tile(dpt_acc); // Initialize PGrad^T auto d = load_tile(d_lds_read_window);
block_sync_lds();
static_for<0, k2_loops, 1>{}([&](auto i_k2) { auto dp_acc = SPGradBlockTileType{};
block_sync_lds();
gemm_2(dpt_acc,
get_slice_tile(do_lds_window,
sequence<0, i_k2 * kK2>{},
sequence<kM0, (i_k2 + 1) * kK2>{}),
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
});
// STAGE 5, P^T(PGrad^T - D) dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{}; // STAGE 5, P^T(PGrad^T - D)
constexpr auto dst_spans = decltype(dst)::get_distributed_spans(); auto ds = SPGradBlockTileType{};
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) { constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0; bool undrop_flag = p[i_j_idx] >= 0;
dst(i_j_idx) = ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
pt[i_j_idx] * ? (dp_acc[i_j_idx] - d[i_idx])
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); : d[i_idx]);
}); });
}); });
if constexpr(kHasBiasGrad) if constexpr(kHasBiasGrad)
{ {
const auto dbiast = [&]() { const auto dbias = [&]() {
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[&rp_undrop](const auto& x) { [&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop); return type_convert<BiasGradDataType>(x * rp_undrop);
}, },
dst); ds);
} }
else else
{ {
return cast_tile<BiasGradDataType>(dst); return cast_tile<BiasGradDataType>(ds);
} }
}(); }();
store_tile(biast_lds_shuffle_window, dbiast); store_tile(bias_lds_write_window, dbias);
block_sync_lds(); block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>( auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_block_window, {kM0, 0}); move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
auto qt_reg_tensor = load_tile(qt_lds_read_window);
block_sync_lds(); block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst);
static_for<0, k3_loops, 1>{}([&](auto i_k3) { const auto ds_gemm = cast_tile<GemmDataType>(ds);
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
get_slice_tile(qt_lds_window,
sequence<0, i_k3 * kK3>{},
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
block_sync_lds();
});
// STAGE 7, SGrad@K^T Gemm4 Policy::template SGradTFromGemm2CToGemm3A<Problem,
store_tile(ds_lds_window, dst_gemm); decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
auto dq_acc = QGradBlockTileType{}; gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
clear_tile(dq_acc); // Initialize QGrad
store_tile(ds_lds_window, ds_gemm);
block_sync_lds(); block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
// STAGE7 SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) { static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc, if constexpr(i_k4 < k4_loops - 1)
get_slice_tile(ds_lds_window, {
sequence<0, i_k4 * kK4>{}, ds_reg_tensor_next = load_tile(ds_lds_read_window);
sequence<kM0, (i_k4 + 1) * kK4>{}), move_tile_window(ds_lds_read_window, {0, kK4});
get_slice_tile(kt_lds_window, }
sequence<0, i_k4 * kK4>{}, auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{})); sequence<0, i_k4 * kK4>{},
}); sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {0, -kN0});
// QGrad Scale // QGrad Scale
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc); dq_acc);
...@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS ...@@ -658,34 +749,33 @@ struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
} }
const auto dq = cast_tile<QGradDataType>(dq_acc); if constexpr(kIsDeterministic)
update_tile(dq_dram_block_window, dq); {
store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
}
move_tile_window(dq_dram_window, {kM0, 0});
// move tile windows i_total_loops += 1;
move_tile_window(q_dram_block_window, {kM0, 0}); seqlen_q_step += kM0;
move_tile_window(dq_dram_block_window, {kM0, 0}); }
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale // Results Scale
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc); dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
} }
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
}; };
......
...@@ -6,13 +6,13 @@ ...@@ -6,13 +6,13 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp" #include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp" #include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_kts_vr_default_policy.hpp" #include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp" #include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile { namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy> template <typename Problem, typename Policy = BlockFmhaBwdPipelineDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSKTSVR struct BlockFmhaBwdDQDKDVPipelineKRKTRVRIGLP
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -30,6 +30,8 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>; using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>; using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using HotLoopScheduler = typename Policy::template HotLoopScheduler<Problem>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -46,22 +48,14 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false; static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kQTLoadOnce = false; static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kKLoadOnce = true; static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kKTLoadOnce = true; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kVLoadOnce = true; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr bool kOGradLoadOnce = false; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kOGradTLoadOnce = false; static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kIsDeterministic = Problem::kIsDeterministic;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -71,12 +65,9 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV = static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad = static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>(); kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad = static constexpr index_t kAlignmentQGrad = 1;
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad = static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>(); kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad = static constexpr index_t kAlignmentVGrad =
...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -84,7 +75,7 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
static constexpr index_t kAlignmentBias = static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>(); kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_kts_vr"; static constexpr const char* name = "kr_ktr_vr_iglp";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -92,14 +83,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
template <typename QDramBlockWindowTmp, template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp, typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp, typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp, typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp, typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp, typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp, typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp, typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp, typename QGradDramBlockWindowTmp,
...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -107,14 +95,11 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
typename PositionEncoding> typename PositionEncoding>
CK_TILE_HOST_DEVICE auto CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp, operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp, const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& kt_dram_block_window_tmp,
const VDramBlockWindowTmp& v_dram_block_window_tmp, const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp, const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp, const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp, const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp, const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp, const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp, const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
...@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -122,43 +107,29 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
FmhaMask mask, FmhaMask mask,
PositionEncoding position_encoding, PositionEncoding position_encoding,
float raw_scale, float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale, float scale,
#endif
float rp_undrop, float rp_undrop,
float scale_rp_undrop, float scale_rp_undrop,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout& dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> && std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType,
remove_cvref_t<typename KTDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> && std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType, std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> && remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType, std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> && remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> && std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>>,
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == KTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] && kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
...@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -166,83 +137,6 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}], kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); "wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto kt_lds = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsBlockDescriptor<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeKT<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM // Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>(); constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>(); constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
...@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -250,34 +144,19 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>(); constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>(); constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad // init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){}; auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){}; auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc); // K, HBM ->LDS ->Reg
clear_tile(dk_acc); auto k_dram_window =
make_tile_window(k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>());
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin(); const auto k_origin = k_dram_window.get_window_origin();
// Early termination
const auto [seqlen_q_start, seqlen_q_end] = const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{}); mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
...@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -290,272 +169,444 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
{ {
// Note: here dk_acc&dv_acc are all cleard, return it // Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it. // Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
} }
KDataType* k_lds_ptr =
static_cast<KDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto k_lds = make_tensor_view<address_space_enum::lds>(
k_lds_ptr, Policy::template MakeKLdsWriteBlockDescriptor<Problem>());
auto k_block_tile = load_tile(k_dram_window); auto k_lds_write_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto k_lds_read_window =
make_tile_window(k_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK0>{}),
k_lds_write_window.get_window_origin(),
Policy::template MakeKRegSliceBlockDescriptor<Problem>());
auto k_reg_tensor = make_static_distributed_tensor<KDataType>(
Policy::template MakeKRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// V, HBM ->LDS ->Reg
auto v_dram_window =
make_tile_window(v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVDramTileDistribution<Problem>());
VDataType* v_lds_ptr =
static_cast<VDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto v_lds = make_tensor_view<address_space_enum::lds>(
v_lds_ptr, Policy::template MakeVLdsWriteBlockDescriptor<Problem>());
auto v_lds_write_window =
make_tile_window(v_lds, make_tuple(number<kN0>{}, number<kK2>{}), {0, 0});
auto v_lds_read_window =
make_tile_window(v_lds_write_window.get_bottom_tensor_view(),
make_tuple(number<kN0>{}, number<kK2>{}),
v_lds_write_window.get_window_origin(),
Policy::template MakeVRegSliceBlockDescriptor<Problem>());
auto v_reg_tensor = make_static_distributed_tensor<VDataType>(
Policy::template MakeVRegBlockDescriptor<Problem>());
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS //------------------------------------------------------------------
// KT, Reg ->LDS ->Reg
auto shuffled_k_block_tile = make_static_distributed_tensor<KDataType>(
Policy::template MakeShuffledKRegWriteBlockDescriptor<Problem>());
auto kt_dram_block_window = kt_dram_block_window_tmp; KDataType* kt_lds_ptr = static_cast<KDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto shuffled_k_lds_write = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeShuffledKLdsWriteBlockDescriptor<Problem>());
auto shuffled_k_lds_write_window = make_tile_window(
shuffled_k_lds_write, make_tuple(number<kN0>{}, number<kK0>{}), {0, 0});
auto kt_lds_read = make_tensor_view<address_space_enum::lds>(
kt_lds_ptr, Policy::template MakeKTLdsReadBlockDescriptor<Problem>());
auto kt_lds_read_window =
make_tile_window(kt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kN0>{}),
{0, 0},
Policy::template MakeKTRegBlockDescriptor<Problem>());
//------------------------------------------------------------------
// Pre-Load KV into Registers
auto k_block_tile = load_tile(k_dram_window);
auto v_block_tile = load_tile(v_dram_window);
auto kt_dram_window = make_tile_window( store_tile(k_lds_write_window, k_block_tile);
kt_dram_block_window.get_bottom_tensor_view(), shuffle_tile(shuffled_k_block_tile, k_block_tile);
kt_dram_block_window.get_window_lengths(), store_tile(shuffled_k_lds_write_window, shuffled_k_block_tile);
kt_dram_block_window.get_window_origin(),
Policy::template MakeKTDramTileDistribution<Problem>()); // K^T DRAM tile window for
// load
auto kt_block_tile = load_tile(kt_dram_window); block_sync_lds();
k_reg_tensor = load_tile(k_lds_read_window);
block_sync_lds();
auto kt_shuffle_tmp = make_static_distributed_tensor<KDataType>( auto kt_reg_tensor = load_tile(kt_lds_read_window);
Policy::template MakeShuffledKTRegBlockDescriptor<Problem>());
shuffle_tile(kt_shuffle_tmp, kt_block_tile);
store_tile(kt_lds_window, kt_shuffle_tmp); // persistent K^T in LDS store_tile(v_lds_write_window, v_block_tile);
auto q_dram_block_window = block_sync_lds();
v_reg_tensor = load_tile(v_lds_read_window);
//---------------------------- Loop Load in ----------------------------//
// Q: HBM ->Reg ->LDS
auto q_dram_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(), q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeQDramTileDistribution<Problem>());
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto q_lds_read_window =
make_tile_window(q_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK0>{}),
q_lds_window.get_window_origin(),
Policy::template MakeQRegSliceBlockDescriptor<Problem>());
auto qt_dram_block_window = auto pt_reg_tensor = make_static_distributed_tensor<GemmDataType>(
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template MakePTRegSliceBlockDescriptor<Problem>());
qt_dram_block_window_tmp.get_window_lengths(), // QT: Reg -> Reg-> LDS
{0, seqlen_q_start}); auto shuffled_q_block_tile = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQRegWriteBlockDescriptor<Problem>());
auto do_dram_block_window = QDataType* qt_lds_ptr =
static_cast<QDataType*>(static_cast<void*>(static_cast<char*>(smem_ptr)));
auto shuffled_q_lds_write = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeShuffledQLdsWriteBlockDescriptor<Problem>());
auto shuffled_q_lds_write_window = make_tile_window(
shuffled_q_lds_write, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
auto qt_lds_read = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsReadBlockDescriptor<Problem>());
auto qt_lds_read_window =
make_tile_window(qt_lds_read,
make_tuple(number<kQKHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeQTRegSliceBlockDescriptor<Problem>());
// dO: HBM ->Reg ->LDS
auto do_dram_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(), do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0}); {seqlen_q_start, 0},
Policy::template MakeOGradDramTileDistribution<Problem>());
auto dot_dram_block_window = OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>()));
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window = auto do_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(), do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window = auto do_lds_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start}); auto do_lds_read_window =
make_tile_window(do_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK2>{}),
do_lds_window.get_window_origin(),
Policy::template MakeOGradRegSliceBlockDescriptor<Problem>());
// dOT: Reg ->Reg ->LDS
auto shuffled_do_block_tile = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradRegWriteBlockDescriptor<Problem>());
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto shuffled_do_lds_write = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeShuffledOGradLdsWriteBlockDescriptor<Problem>());
auto d_dram_block_window = auto shuffled_do_lds_write_window = make_tile_window(
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(), shuffled_do_lds_write, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start}); auto dot_read_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsReadBlockDescriptor<Problem>());
auto dot_lds_read_window =
make_tile_window(dot_read_lds,
make_tuple(number<kVHeaddim>{}, number<kM0>{}),
{0, 0},
Policy::template MakeOGradTRegSliceBlockDescriptor<Problem>());
// dS: Reg -> Reg -> LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto ds_lds_read_window =
make_tile_window(ds_lds_window.get_bottom_tensor_view(),
make_tuple(number<kM0>{}, number<kK4>{}),
ds_lds_window.get_window_origin(),
Policy::template MakeSGradRegSliceBlockDescriptor<Problem>());
auto dst_reg_tensor = make_static_distributed_tensor<GemmDataType>(
Policy::template MakeSGradTRegSliceBlockDescriptor<Problem>());
// Bias: HBM ->Reg ->Reg ->LDS
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin(); const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
auto bias_dram_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(), make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(), bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N {seqlen_q_start, bias_origin.at(number<1>{})},
Policy::template MakeBiasTileDistribution<Problem>());
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin(); BiasDataType* bias_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
auto dbias_dram_block_window = static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(), Policy::template GetSmemSizeOGrad<Problem>() +
dbias_dram_block_window_tmp.get_window_lengths(), Policy::template GetSmemSizeOGradT<Problem>() +
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>() +
Policy::template GetSmemSizeD<Problem>()));
auto bias_lds = make_tensor_view<address_space_enum::lds>(
bias_lds_ptr, Policy::template MakeBiasLdsBlockDescriptor<Problem>());
auto qt_dram_window = auto bias_lds_write_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(), make_tile_window(bias_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window = auto bias_s_lds_read_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(), make_tile_window(bias_lds_write_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(), bias_lds_write_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(), bias_lds_write_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>()); Policy::template MakeBiasSTileDistribution<decltype(gemm_0)>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// LSE: HBM -> LDS ->Reg
auto lse_dram_window = make_tile_window( auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(), lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(), lse_dram_block_window_tmp.get_window_lengths(),
lse_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
LSEDataType* lse_lds_ptr = static_cast<LSEDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>() +
Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto lse_lds = make_tensor_view<address_space_enum::lds>(
lse_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
auto lse_lds_write_window = make_tile_window(lse_lds, make_tuple(number<kM0>{}), {0});
auto lse_lds_read_window = make_tile_window(
lse_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// D: HBM ->Reg
auto d_dram_window = make_tile_window( auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(), d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(), d_dram_block_window_tmp.get_window_lengths(),
d_dram_block_window.get_window_origin(), {seqlen_q_start},
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window = DDataType* d_lds_ptr = static_cast<DDataType*>(static_cast<void*>(
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(), static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeQT<Problem>() +
bias_dram_block_window.get_window_lengths(), Policy::template GetSmemSizeOGrad<Problem>() +
bias_dram_block_window.get_window_origin(), Policy::template GetSmemSizeOGradT<Problem>() +
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template GetSmemSizeQ<Problem>() + Policy::template GetSmemSizeLSE<Problem>()));
auto biast_lds_window = auto d_lds = make_tensor_view<address_space_enum::lds>(
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(), d_lds_ptr, Policy::template MakeLSEDLdsWriteBlockDescriptor<Problem>());
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>( auto d_lds_write_window = make_tile_window(d_lds, make_tuple(number<kM0>{}), {0});
auto d_lds_read_window = make_tile_window(
d_lds,
make_tuple(number<kM0>{}),
{0},
Policy::template MakeLSEDLdsReadBlockDescriptor<Problem, decltype(gemm_0)>());
// RandVal: HBM ->Reg
auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start); randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0; // BiasGrad
constexpr index_t k0_loops = kQKHeaddim / kK0; // Reg ->LDS ->Reg ->HBM
constexpr index_t k1_loops = kM0 / kK1; const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3; auto dbias_dram_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto dbias_lds_read_window =
make_tile_window(bias_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
// ----------------------------Loop write out------------------------------//
auto dq_dram_window = make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
using SPBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
index_t i_total_loops = 0;
index_t seqlen_q_step = seqlen_q_start;
static_assert(kQKHeaddim == kK0, "kQKHeaddim should equal to kK0");
static_assert(kM0 == kK1, "kM0 should equal to kK1");
static_assert(kVHeaddim == kK2, "kVHeaddim should equal to kK2");
static_assert(kM0 == kK3, "kM0 should equal to kK3");
constexpr index_t k4_loops = kN0 / kK4; constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0 /*
auto st_acc = SPTBlockTileType{}; * Prefetch Q, LSE, dO, D
*/
auto q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
auto lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
auto q_block_tile = load_tile(q_dram_window); auto do_block_tile = load_tile(do_dram_window);
{ move_tile_window(do_dram_window, {kM0, 0});
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T auto d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
store_tile(q_lds_window, q_block_tile); // LDS write 0 /*
q_block_tile = load_tile(q_dram_window); // global read 1 * Store prefetched data into LDS
} */
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
shuffle_tile(shuffled_q_block_tile, q_block_tile);
store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) store_tile(lse_lds_write_window, lse_block_tile);
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2) store_tile(do_lds_window, do_block_tile);
{ shuffle_tile(shuffled_do_block_tile, do_block_tile);
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) { store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile store_tile(d_lds_write_window, d_block_tile);
{ // tail block_sync_lds();
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile); /*
block_sync_lds(); * Prefetch LDS data into Reg to Asynchronous Data Movement and MFMA pipeline
*/
gemm_0(st_acc, auto q_reg_tensor = load_tile(q_lds_read_window);
q_lds_window, auto lse = load_tile(lse_lds_read_window);
get_slice_tile(k_lds_window, auto do_reg_tensor = load_tile(do_lds_read_window);
sequence<0, (k0_loops - 1) * kK0>{}, auto d = load_tile(d_lds_read_window);
sequence<kN0, k0_loops * kK0>{}));
} clear_tile(dv_acc);
clear_tile(dk_acc);
__builtin_amdgcn_sched_barrier(0);
// Hot loop
while(i_total_loops < (num_total_loop - 1))
{
// STAGE 1, Q@K Gemm0
auto s_acc = SPBlockTileType{};
q_block_tile = load_tile(q_dram_window);
move_tile_window(q_dram_window, {kM0, 0});
lse_block_tile = load_tile(lse_dram_window);
move_tile_window(lse_dram_window, {kM0});
do_block_tile = load_tile(do_dram_window);
move_tile_window(do_dram_window, {kM0, 0});
d_block_tile = load_tile(d_dram_window);
move_tile_window(d_dram_window, {kM0});
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<0>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{ {
block_sync_lds(); const auto bias_tile = load_tile(bias_dram_window);
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>( auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>()); Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile); shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp); store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds(); block_sync_lds();
auto biast_tile = load_tile(biast_lds_window); auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout( tile_elementwise_inout(
[&](auto& x, const auto& y) { [&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y); x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
}, },
st_acc, s_acc,
biast_tile); bias_s_tile);
move_tile_window(bias_dram_window, {kM0, 0}); move_tile_window(bias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI) else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans(); sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices( const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1)); s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2 s_acc(i_j_idx) *= scale;
st_acc(i_j_idx) *= raw_scale; position_encoding.update(s_acc(i_j_idx), row, col);
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
}); });
}); });
} }
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking) if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{ {
const auto q_origin = q_dram_block_window.get_window_origin(); bool need_perpixel_check = mask.IsEdgeTile(
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}), seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check) if(need_perpixel_check)
{ {
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) { set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{}); const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{}); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col); return mask.IsOutOfBound(row, col);
}); });
} }
} }
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) { static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking) FmhaMask::IsMasking)
...@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR ...@@ -570,278 +621,416 @@ struct BlockFmhaBwdDQDKDVPipelineKSKTSVR
} }
}; };
auto pt = SPTBlockTileType{}; auto p = SPBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans(); constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) { sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2 auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS || if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI) BiasEnum == BlockAttentionBiasEnum::ALIBI)
{ {
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
} }
else else
{ {
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse); p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
} }
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
}); });
}); });
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>( if constexpr(FmhaDropout::IsDropout)
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{ {
dropout.Run<decltype(gemm_0), RandValOutputDataType>( dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window); seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
} }
const auto p_gemm = [&]() {
// STAGE 3, P^T@OGrad^T Gemm1 if constexpr(FmhaDropout::IsDropout)
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt); p);
} }
else else
{ {
return cast_tile<GemmDataType>(pt); return cast_tile<GemmDataType>(p);
} }
}(); }();
if constexpr(k1_loops > 1) // STAGE 3, P^T@OGrad^T Gemm1
{ Policy::template PTFromGemm0CToGemm1A<Problem,
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) { decltype(pt_reg_tensor),
const auto dot = load_tile(dot_dram_window); // load next OGrad^T decltype(p_gemm)>(pt_reg_tensor, p_gemm);
block_sync_lds(); gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2 auto qt_reg_tensor = load_tile(qt_lds_read_window);
auto dpt_acc = SPGradTBlockTileType{};
{ HotLoopScheduler::template GemmStagedScheduler<1>();
move_tile_window(do_dram_window, {0, kK2}); __builtin_amdgcn_sched_barrier(0);
// STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{};
clear_tile(dpt_acc); // Initialize PGrad^T dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
store_tile(do_lds_window, do_block_tile); // LDS write 0 block_sync_lds();
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2) store_tile(q_lds_window, q_block_tile);
{ shuffle_tile(shuffled_q_block_tile, q_block_tile);
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) { store_tile(shuffled_q_lds_write_window, shuffled_q_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile store_tile(lse_lds_write_window, lse_block_tile);
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile); store_tile(do_lds_window, do_block_tile);
block_sync_lds(); shuffle_tile(shuffled_do_block_tile, do_block_tile);
store_tile(shuffled_do_lds_write_window, shuffled_do_block_tile);
gemm_2(dpt_acc, store_tile(d_lds_write_window, d_block_tile);
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
HotLoopScheduler::template GemmStagedScheduler<2>();
__builtin_amdgcn_sched_barrier(0);
// STAGE 5, P^T(PGrad^T - D) // STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window); auto ds = SPGradBlockTileType{};
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
auto dst = SPGradTBlockTileType{}; sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0); constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) { sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1); constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0; bool undrop_flag = p[i_j_idx] >= 0;
dst(i_j_idx) = ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
pt[i_j_idx] * ? (dp_acc[i_j_idx] - d[i_idx])
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]); : d[i_idx]);
}); });
}); });
if constexpr(kHasBiasGrad) if constexpr(kHasBiasGrad)
{ {
const auto dbiast = [&]() { const auto dbias = [&]() {
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
return tile_elementwise_in( return tile_elementwise_in(
[&rp_undrop](const auto& x) { [&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop); return type_convert<BiasGradDataType>(x * rp_undrop);
}, },
dst); ds);
} }
else else
{ {
return cast_tile<BiasGradDataType>(dst); return cast_tile<BiasGradDataType>(ds);
} }
}(); }();
store_tile(biast_lds_shuffle_window, dbiast); store_tile(bias_lds_write_window, dbias);
block_sync_lds(); block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window); auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>( auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>()); Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile); shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp); store_tile(dbias_dram_window, dbias_tile);
move_tile_window(dbias_dram_block_window, {kM0, 0}); move_tile_window(dbias_dram_window, {kM0, 0});
__builtin_amdgcn_sched_barrier(0);
} }
// STAGE 6, SGrad^T@Q^T Gemm3 // STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>( const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
block_sync_lds(); block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
q_reg_tensor = load_tile(q_lds_read_window);
lse = load_tile(lse_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<3>();
__builtin_amdgcn_sched_barrier(0);
// STAGE7 SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor_next = load_tile(ds_lds_read_window);
move_tile_window(ds_lds_read_window, {0, kK4});
}
auto kt_reg_tensor_slice = get_slice_tile(kt_reg_tensor,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{
ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
}
});
move_tile_window(ds_lds_read_window, {0, -kN0});
do_reg_tensor = load_tile(do_lds_read_window);
d = load_tile(d_lds_read_window);
HotLoopScheduler::template GemmStagedScheduler<4>();
// QGrad Scale
if constexpr(FmhaDropout::IsDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
if constexpr(kIsDeterministic)
{ {
shuffle_tile(qt_shuffle_tmp, qt_prefetch); store_tile(dq_dram_window, dq_acc);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
} }
move_tile_window(qt_dram_window, {0, kK3}); else
{
update_tile(dq_dram_window, dq_acc);
}
move_tile_window(dq_dram_window, {kM0, 0});
i_total_loops += 1;
seqlen_q_step += kM0;
}
__builtin_amdgcn_sched_barrier(0);
// Tail
auto s_acc = SPBlockTileType{};
// STAGE 1, Q@K Gemm0
s_acc = gemm_0(q_reg_tensor, k_reg_tensor);
const auto dst_gemm = cast_tile<GemmDataType>(dst); // STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
const auto bias_tile = load_tile(bias_dram_window);
auto shuffled_bias_tile = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(shuffled_bias_tile, bias_tile);
store_tile(bias_lds_write_window, shuffled_bias_tile);
block_sync_lds();
auto bias_s_tile = load_tile(bias_s_lds_read_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
},
s_acc,
bias_s_tile);
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
constexpr auto s_spans = decltype(s_acc)::get_distributed_spans();
sweep_tile_span(s_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(s_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
s_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = seqlen_q_step + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(k3_loops > 1) s_acc(i_j_idx) *= scale;
position_encoding.update(s_acc(i_j_idx), row, col);
});
});
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
bool need_perpixel_check = mask.IsEdgeTile(
seqlen_q_step, k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
if(need_perpixel_check)
{ {
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) { set_tile_if(s_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto qt = load_tile(qt_dram_window); // load next Q^T const auto row = seqlen_q_step + tile_idx.at(number<0>{});
block_sync_lds(); const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
gemm_3(dk_acc, return mask.IsOutOfBound(row, col);
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
}); });
} }
// tail }
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{ {
block_sync_lds(); return raw_lse == -numeric<LSEDataType>::infinity() ? type_convert<LSEDataType>(0.f)
gemm_3(dk_acc, : raw_lse;
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
} }
else
{
return raw_lse;
}
};
// STAGE 7, SGrad@K^T Gemm4 auto p = SPBlockTileType{};
store_tile(ds_lds_window, dst_gemm); constexpr auto p_spans = decltype(p)::get_distributed_spans();
sweep_tile_span(p_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
auto dq_acc = QGradBlockTileType{}; sweep_tile_span(p_spans[number<1>{}], [&](auto idx1) {
clear_tile(dq_acc); // Initialize QGrad constexpr auto i_j_idx = make_tuple(idx0, idx1);
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
p(i_j_idx) = exp2(s_acc[i_j_idx] - row_lse);
}
else
{
p(i_j_idx) = exp2(scale * s_acc[i_j_idx] - row_lse);
}
});
});
block_sync_lds(); if constexpr(FmhaDropout::IsDropout)
{
dropout.template Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_step, k_origin.at(number<0>{}), p, randval_dram_window);
}
static_for<0, k4_loops, 1>{}([&](auto i_k4) { // STAGE 3, P^T@OGrad^T Gemm1
gemm_4(dq_acc, const auto p_gemm = [&]() {
get_slice_tile(ds_lds_window, if constexpr(FmhaDropout::IsDropout)
sequence<0, i_k4 * kK4>{}, {
sequence<kM0, (i_k4 + 1) * kK4>{}), return tile_elementwise_in(
get_slice_tile(kt_lds_window, [](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); }, p);
sequence<0, i_k4 * kK4>{}, }
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{})); else
{
return cast_tile<GemmDataType>(p);
}
}();
Policy::template PTFromGemm0CToGemm1A<Problem, decltype(pt_reg_tensor), decltype(p_gemm)>(
pt_reg_tensor, p_gemm);
auto dot_reg_tensor = load_tile(dot_lds_read_window);
gemm_1(dv_acc, pt_reg_tensor, dot_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<1>();
// STAGE 4, OGrad@V Gemm2
auto dp_acc = SPGradBlockTileType{};
auto qt_reg_tensor = load_tile(qt_lds_read_window);
dp_acc = gemm_2(do_reg_tensor, v_reg_tensor);
HotLoopScheduler::template GemmStagedScheduler<2>();
// STAGE 5, P^T(PGrad^T - D)
auto ds = SPGradBlockTileType{};
constexpr auto ds_spans = decltype(ds)::get_distributed_spans();
sweep_tile_span(ds_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(ds_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = p[i_j_idx] >= 0;
ds(i_j_idx) = p[i_j_idx] * (!FmhaDropout::IsDropout || undrop_flag
? (dp_acc[i_j_idx] - d[i_idx])
: d[i_idx]);
}); });
});
// QGrad Scale if constexpr(kHasBiasGrad)
if constexpr(kHasDropout) {
const auto dbias = [&]() {
if constexpr(FmhaDropout::IsDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
ds);
}
else
{
return cast_tile<BiasGradDataType>(ds);
}
}();
store_tile(bias_lds_write_window, dbias);
block_sync_lds();
auto shuffled_dbias_tile = load_tile(dbias_lds_read_window);
auto dbias_tile = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbias_tile, shuffled_dbias_tile);
store_tile(dbias_dram_window, dbias_tile);
}
// STAGE 6, SGrad^T@Q^T Gemm3
const auto ds_gemm = cast_tile<GemmDataType>(ds);
Policy::template SGradTFromGemm2CToGemm3A<Problem,
decltype(dst_reg_tensor),
decltype(ds_gemm)>(dst_reg_tensor, ds_gemm);
gemm_3(dk_acc, dst_reg_tensor, qt_reg_tensor);
store_tile(ds_lds_window, ds_gemm);
block_sync_lds();
auto ds_reg_tensor = load_tile(ds_lds_read_window);
auto ds_reg_tensor_next = decltype(ds_reg_tensor){};
move_tile_window(ds_lds_read_window, {0, kK4});
HotLoopScheduler::template GemmStagedScheduler<3>();
// STAGE 7, SGrad@K^T Gemm4
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc);
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
if constexpr(i_k4 < k4_loops - 1)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, ds_reg_tensor_next = load_tile(ds_lds_read_window);
dq_acc); move_tile_window(ds_lds_read_window, {0, kK4});
} }
else auto kt_reg_tensor_slice = get_slice_tile(
kt_reg_tensor, sequence<0, i_k4 * kK4>{}, sequence<kQKHeaddim, (i_k4 + 1) * kK4>{});
gemm_4(dq_acc, ds_reg_tensor, kt_reg_tensor_slice);
if constexpr(i_k4 < k4_loops - 1)
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc); ds_reg_tensor.get_thread_buffer() = ds_reg_tensor_next.get_thread_buffer();
} }
const auto dq = cast_tile<QGradDataType>(dq_acc); });
update_tile(dq_dram_block_window, dq);
// move tile windows HotLoopScheduler::template GemmStagedScheduler<4>();
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale // Results Scale
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; }, tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc); dk_acc);
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
} }
else else
{ {
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc); tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
} }
// VGrad Scale
if constexpr(kHasDropout) if constexpr(kIsDeterministic)
{ {
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc); store_tile(dq_dram_window, dq_acc);
}
else
{
update_tile(dq_dram_window, dq_acc);
} }
return ck_tile::make_tuple(dk_acc, dv_acc); return make_tuple(dk_acc, dv_acc);
} }
}; };
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment