Commit 74f1516c authored by danyao12's avatar danyao12
Browse files

tmp save

parent 497ccb87
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k & k^t located in lds.
using BlockFmhaBwdDQDKDVPipelineKSKTSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ true,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_ks_vr_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineKSVR
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = false;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = false;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "ks_vr";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& qt_dram_block_window_tmp,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& dot_dram_block_window_tmp,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<QDataType,
remove_cvref_t<typename QTDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradTDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kQKHeaddim == QTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kVHeaddim ==
OGradTDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kK0>{}), {0, 0});
// QT tile in LDS
QDataType* qt_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto qt_lds = make_tensor_view<address_space_enum::lds>(
qt_lds_ptr, Policy::template MakeQTLdsBlockDescriptor<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kK3>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kK2>{}), {0, 0});
// OGradT tile in LDS
OGradDataType* dot_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto dot_lds = make_tensor_view<address_space_enum::lds>(
dot_lds_ptr, Policy::template MakeOGradTLdsBlockDescriptor<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kK1>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto qt_dram_block_window =
make_tile_window(qt_dram_block_window_tmp.get_bottom_tensor_view(),
qt_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dot_dram_block_window =
make_tile_window(dot_dram_block_window_tmp.get_bottom_tensor_view(),
dot_dram_block_window_tmp.get_window_lengths(),
{0, seqlen_q_start});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto qt_dram_window =
make_tile_window(qt_dram_block_window.get_bottom_tensor_view(),
qt_dram_block_window.get_window_lengths(),
qt_dram_block_window.get_window_origin(),
Policy::template MakeQTDramTileDistribution<Problem>());
auto dot_dram_window =
make_tile_window(dot_dram_block_window.get_bottom_tensor_view(),
dot_dram_block_window.get_window_lengths(),
dot_dram_block_window.get_window_origin(),
Policy::template MakeOGradTDramTileDistribution<Problem>());
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
{
move_tile_window(q_dram_window, {0, kK0});
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write 0
q_block_tile = load_tile(q_dram_window); // global read 1
}
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 2)
{
static_for<0, k0_loops - 2, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
move_tile_window(q_dram_window, {0, kK0});
store_tile(q_lds_window,
q_block_tile); // LDS write i + 1
q_block_tile = load_tile(q_dram_window); // global read i + 2
});
}
const auto dot_prefetch = load_tile(dot_dram_window); // prefetch load OGrad^T tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 2) * kK0>{},
sequence<kN0, (k0_loops - 1) * kK0>{}));
block_sync_lds();
store_tile(q_lds_window, q_block_tile);
block_sync_lds();
gemm_0(st_acc,
q_lds_window,
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
auto dot_shuffle_tmp = make_static_distributed_tensor<OGradDataType>(
Policy::template MakeShuffledOGradTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(dot_shuffle_tmp, dot_prefetch);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
}
move_tile_window(dot_dram_window, {0, kK1});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
if constexpr(k1_loops > 1)
{
static_for<0, k1_loops - 1, 1>{}([&](auto i_k1) {
const auto dot = load_tile(dot_dram_window); // load next OGrad^T
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(pt_gemm,
sequence<i_k1 * kK1, 0>{},
sequence<(i_k1 + 1) * kK1, kN0>{}),
dot_lds_window);
block_sync_lds();
shuffle_tile(dot_shuffle_tmp, dot);
store_tile(dot_lds_window,
dot_shuffle_tmp); // store the prefetch
move_tile_window(dot_dram_window, {0, kK1});
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
// tail
{
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<(k1_loops - 1) * kK1, 0>{}, sequence<kM0, kN0>{}),
dot_lds_window);
block_sync_lds();
}
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
{
move_tile_window(do_dram_window, {0, kK2});
clear_tile(dpt_acc); // Initialize PGrad^T
store_tile(do_lds_window, do_block_tile); // LDS write 0
do_block_tile = load_tile(do_dram_window); // global read 1
}
if constexpr(k2_loops > 2)
{
static_for<0, k2_loops - 2, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
move_tile_window(do_dram_window, {0, kK2});
store_tile(do_lds_window,
do_block_tile); // LDS write i + 1
do_block_tile = load_tile(do_dram_window); // global read i + 2
});
}
const auto qt_prefetch = load_tile(qt_dram_window); // prefetch load Q^T tile
{ // tail
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 2) * kK2>{},
sequence<kN0, (k2_loops - 1) * kK2>{}));
block_sync_lds();
store_tile(do_lds_window, do_block_tile);
block_sync_lds();
gemm_2(dpt_acc,
do_lds_window,
get_slice_tile(v,
sequence<0, (k2_loops - 1) * kK2>{},
sequence<kN0, k2_loops * kK2>{}));
}
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
auto qt_shuffle_tmp = make_static_distributed_tensor<QDataType>(
Policy::template MakeShuffledQTRegBlockDescriptor<Problem>());
block_sync_lds();
{
shuffle_tile(qt_shuffle_tmp, qt_prefetch);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
}
move_tile_window(qt_dram_window, {0, kK3});
const auto dst_gemm = cast_tile<GemmDataType>(dst);
if constexpr(k3_loops > 1)
{
static_for<0, k3_loops - 1, 1>{}([&](auto i_k3) {
const auto qt = load_tile(qt_dram_window); // load next Q^T
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(dst_gemm,
sequence<i_k3 * kK3, 0>{},
sequence<(i_k3 + 1) * kK3, kN0>{}),
qt_lds_window);
block_sync_lds();
shuffle_tile(qt_shuffle_tmp, qt);
store_tile(qt_lds_window,
qt_shuffle_tmp); // store the prefetch
move_tile_window(qt_dram_window, {0, kK3});
});
}
// tail
{
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<(k3_loops - 1) * kK3, 0>{}, sequence<kM0, kN0>{}),
qt_lds_window);
block_sync_lds();
}
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, k located in lds.
using BlockFmhaBwdDQDKDVPipelineKSVRDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ false,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ false,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/fmha/block/block_attention_bias_enum.hpp"
#include "ck_tile/ops/fmha/block/block_dropout.hpp"
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_dq_dk_dv_pipeline_qs_ks_vr_dos_default_policy.hpp"
#include "ck_tile/ops/reduce/block/block_reduce.hpp"
namespace ck_tile {
template <typename Problem, typename Policy = BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy>
struct BlockFmhaBwdDQDKDVPipelineQSKSVROGradS
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
using LSEDataType = remove_cvref_t<typename Problem::LSEDataType>;
using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using DDataType = remove_cvref_t<typename Problem::DDataType>;
using RandValOutputDataType = remove_cvref_t<typename Problem::RandValOutputDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
using QGradDataType = remove_cvref_t<typename Problem::QGradDataType>;
using KGradDataType = remove_cvref_t<typename Problem::KGradDataType>;
using VGradDataType = remove_cvref_t<typename Problem::VGradDataType>;
using BiasGradDataType = remove_cvref_t<typename Problem::BiasGradDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
static constexpr index_t kBlockPerCu = Problem::kBlockPerCu;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = BlockFmhaShape::kM0;
static constexpr index_t kN0 = BlockFmhaShape::kN0;
static constexpr index_t kK0 = BlockFmhaShape::kK0;
static constexpr index_t kK1 = BlockFmhaShape::kK1;
static constexpr index_t kK2 = BlockFmhaShape::kK2;
static constexpr index_t kK3 = BlockFmhaShape::kK3;
static constexpr index_t kK4 = BlockFmhaShape::kK4;
static constexpr index_t kQKHeaddim = BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = BlockFmhaShape::kVHeaddim;
static constexpr bool kQLoadOnce = true;
static constexpr bool kQTLoadOnce = false;
static constexpr bool kKLoadOnce = true;
static constexpr bool kKTLoadOnce = false;
static constexpr bool kVLoadOnce = true;
static constexpr bool kOGradLoadOnce = true;
static constexpr bool kOGradTLoadOnce = false;
static constexpr bool kIsGroupMode = Problem::kIsGroupMode;
static constexpr bool kPadSeqLenQ = Problem::kPadSeqLenQ;
static constexpr bool kPadSeqLenK = Problem::kPadSeqLenK;
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kHasBiasGrad = Problem::kHasBiasGrad;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this
static constexpr index_t kAlignmentQ =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentQ<Problem>();
static constexpr index_t kAlignmentK =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentK<Problem>();
static constexpr index_t kAlignmentV =
kPadHeadDimV ? 1 : Policy::template GetAlignmentV<Problem>();
static constexpr index_t kAlignmentO =
kPadHeadDimV ? 1 : Policy::template GetAlignmentO<Problem>();
static constexpr index_t kAlignmentOGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentOGrad<Problem>();
static constexpr index_t kAlignmentQGrad =
kPadHeadDimQ ? 2 : Policy::template GetAlignmentQGrad<Problem>();
static constexpr index_t kAlignmentKGrad =
kPadHeadDimQ ? 1 : Policy::template GetAlignmentKGrad<Problem>();
static constexpr index_t kAlignmentVGrad =
kPadHeadDimV ? 1 : Policy::template GetAlignmentVGrad<Problem>();
static constexpr index_t kAlignmentBias =
kPadSeqLenK ? 1 : Policy::template GetTransposedAlignmentBias<Problem>();
static constexpr const char* name = "qs_ks_vr_dos";
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
}
template <typename QDramBlockWindowTmp,
typename QTDramBlockWindowTmp,
typename KDramBlockWindowTmp,
typename KTDramBlockWindowTmp,
typename VDramBlockWindowTmp,
typename BiasDramBlockWindowTmp,
typename RandValDramBlockWindowTmp,
typename OGradDramBlockWindowTmp,
typename OGradTDramBlockWindowTmp,
typename LSEDramBlockWindowTmp,
typename DDramBlockWindowTmp,
typename QGradDramBlockWindowTmp,
typename BiasGradDramBlockWindowTmp,
typename PositionEncoding>
CK_TILE_HOST_DEVICE auto
operator()(const QDramBlockWindowTmp& q_dram_block_window_tmp,
const QTDramBlockWindowTmp& /*qt_dram_block_window_tmp*/,
const KDramBlockWindowTmp& k_dram_block_window_tmp,
const KTDramBlockWindowTmp& /*kt_dram_block_window_tmp*/,
const VDramBlockWindowTmp& v_dram_block_window_tmp,
const BiasDramBlockWindowTmp& bias_dram_block_window_tmp,
const RandValDramBlockWindowTmp& randval_dram_block_window_tmp,
const OGradDramBlockWindowTmp& do_dram_block_window_tmp,
const OGradTDramBlockWindowTmp& /*dot_dram_block_window_tmp*/,
const LSEDramBlockWindowTmp& lse_dram_block_window_tmp,
const DDramBlockWindowTmp& d_dram_block_window_tmp,
const QGradDramBlockWindowTmp& dq_dram_block_window_tmp,
const BiasGradDramBlockWindowTmp& dbias_dram_block_window_tmp,
FmhaMask mask,
PositionEncoding position_encoding,
float raw_scale,
#if CK_TILE_FMHA_FWD_FAST_EXP2
float scale,
#endif
float rp_undrop,
float scale_rp_undrop,
void* smem_ptr,
BlockDropout& dropout) const
{
static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
std::is_same_v<KDataType, remove_cvref_t<typename KDramBlockWindowTmp::DataType>> &&
std::is_same_v<VDataType, remove_cvref_t<typename VDramBlockWindowTmp::DataType>> &&
std::is_same_v<OGradDataType,
remove_cvref_t<typename OGradDramBlockWindowTmp::DataType>> &&
std::is_same_v<LSEDataType,
remove_cvref_t<typename LSEDramBlockWindowTmp::DataType>> &&
std::is_same_v<DDataType, remove_cvref_t<typename DDramBlockWindowTmp::DataType>> &&
std::is_same_v<QGradDataType,
remove_cvref_t<typename QGradDramBlockWindowTmp::DataType>>,
"wrong!");
static_assert(kM0 == QDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == KDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == VDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasDramBlockWindowTmp{}.get_window_lengths()[number<1>{}] &&
kM0 == OGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == LSEDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == DDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == QGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kM0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] &&
kN0 == BiasGradDramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!");
// Q tile in LDS
QDataType* q_lds_ptr = static_cast<QDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>()));
auto q_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptor<Problem>());
auto q_lds_window =
make_tile_window(q_lds, make_tuple(number<kM0>{}, number<kQKHeaddim>{}), {0, 0});
// QT tile in LDS
auto qt_lds = make_tensor_view<address_space_enum::lds>(
q_lds_ptr, Policy::template MakeQLdsBlockDescriptorAsQT<Problem>());
auto qt_lds_window =
make_tile_window(qt_lds, make_tuple(number<kQKHeaddim>{}, number<kM0>{}), {0, 0});
// K tile in LDS
auto k_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptor<Problem>());
auto k_lds_window =
make_tile_window(k_lds, make_tuple(number<kN0>{}, number<kQKHeaddim>{}), {0, 0});
// KT tile in LDS
auto kt_lds = make_tensor_view<address_space_enum::lds>(
reinterpret_cast<KDataType*>(smem_ptr),
Policy::template MakeKLdsBlockDescriptorAsKT<Problem>());
auto kt_lds_window =
make_tile_window(kt_lds, make_tuple(number<kQKHeaddim>{}, number<kN0>{}), {0, 0});
// OGrad tile in LDS
OGradDataType* do_lds_ptr = static_cast<OGradDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>()));
auto do_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptor<Problem>());
auto do_lds_window =
make_tile_window(do_lds, make_tuple(number<kM0>{}, number<kVHeaddim>{}), {0, 0});
// OGradT tile in LDS
auto dot_lds = make_tensor_view<address_space_enum::lds>(
do_lds_ptr, Policy::template MakeOGradLdsBlockDescriptorAsOGradT<Problem>());
auto dot_lds_window =
make_tile_window(dot_lds, make_tuple(number<kVHeaddim>{}, number<kM0>{}), {0, 0});
// SGrad tile in LDS
GemmDataType* ds_lds_ptr = static_cast<GemmDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto ds_lds = make_tensor_view<address_space_enum::lds>(
ds_lds_ptr, Policy::template MakeSGradLdsBlockDescriptor<Problem>());
auto ds_lds_window =
make_tile_window(ds_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
// BiasT/BiasGradT tile in LDS, use the same size and layout
BiasDataType* biast_lds_ptr = static_cast<BiasDataType*>(static_cast<void*>(
static_cast<char*>(smem_ptr) + Policy::template GetSmemSizeK<Problem>() +
Policy::template GetSmemSizeQ<Problem>() +
Policy::template GetSmemSizeOGrad<Problem>()));
auto biast_lds = make_tensor_view<address_space_enum::lds>(
biast_lds_ptr, Policy::template MakeBiasTLdsBlockDescriptor<Problem>());
auto biast_lds_shuffle_window =
make_tile_window(biast_lds, make_tuple(number<kM0>{}, number<kN0>{}), {0, 0});
auto dbiast_lds_shuffle_window =
make_tile_window(biast_lds,
make_tuple(number<kM0>{}, number<kN0>{}),
{0, 0},
Policy::template MakeShuffledBiasTileDistribution<Problem>());
static_assert(std::is_same_v<BiasDataType, BiasGradDataType>,
"BiasDataType and BiasGradDataType should be the same!");
// Block GEMM
constexpr auto gemm_0 = Policy::template GetQKBlockGemm<Problem>();
constexpr auto gemm_1 = Policy::template GetPTOGradTBlockGemm<Problem>();
constexpr auto gemm_2 = Policy::template GetOGradVBlockGemm<Problem>();
constexpr auto gemm_3 = Policy::template GetSGradTQTBlockGemm<Problem>();
constexpr auto gemm_4 = Policy::template GetSGradKTBlockGemm<Problem>();
auto v_dram_window = make_tile_window(
v_dram_block_window_tmp.get_bottom_tensor_view(),
v_dram_block_window_tmp.get_window_lengths(),
v_dram_block_window_tmp.get_window_origin(),
Policy::template MakeVInRegDramTileDistribution<Problem, decltype(gemm_2)>());
auto v = load_tile(v_dram_window); // persistent V register tile
using SPTBlockTileType = decltype(gemm_0.MakeCBlockTile());
using SPGradTBlockTileType = decltype(gemm_2.MakeCBlockTile());
using QGradBlockTileType = decltype(gemm_4.MakeCBlockTile());
// init VGrad & KGrad
auto dv_acc = decltype(gemm_1.MakeCBlockTile()){};
auto dk_acc = decltype(gemm_3.MakeCBlockTile()){};
clear_tile(dv_acc);
clear_tile(dk_acc);
auto k_dram_window = make_tile_window(
k_dram_block_window_tmp.get_bottom_tensor_view(),
k_dram_block_window_tmp.get_window_lengths(),
k_dram_block_window_tmp.get_window_origin(),
Policy::template MakeKDramTileDistribution<Problem>()); // K DRAM tile window for
// load
__builtin_amdgcn_sched_barrier(0);
const auto k_origin = k_dram_window.get_window_origin();
const auto [seqlen_q_start, seqlen_q_end] =
mask.GetTileRangeAlongY(k_origin.at(number<0>{}), number<kM0>{}, number<kN0>{});
const auto num_total_loop = integer_divide_ceil(seqlen_q_end - seqlen_q_start, kM0);
// check early exit if masked and no work to do.
if constexpr(FmhaMask::IsMasking)
{
if(num_total_loop <= 0)
{
// Note: here dk_acc&dv_acc are all cleard, return it
// Note: v loaded but no fence, ignore it.
return ck_tile::make_tuple(dk_acc, dv_acc);
}
}
auto k_block_tile = load_tile(k_dram_window);
store_tile(k_lds_window, k_block_tile); // // persistent K in LDS
auto q_dram_block_window =
make_tile_window(q_dram_block_window_tmp.get_bottom_tensor_view(),
q_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto do_dram_block_window =
make_tile_window(do_dram_block_window_tmp.get_bottom_tensor_view(),
do_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto dq_dram_block_window =
make_tile_window(dq_dram_block_window_tmp.get_bottom_tensor_view(),
dq_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, 0});
auto lse_dram_block_window =
make_tile_window(lse_dram_block_window_tmp.get_bottom_tensor_view(),
lse_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
auto d_dram_block_window =
make_tile_window(d_dram_block_window_tmp.get_bottom_tensor_view(),
d_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start});
const auto bias_origin = bias_dram_block_window_tmp.get_window_origin();
auto bias_dram_block_window =
make_tile_window(bias_dram_block_window_tmp.get_bottom_tensor_view(),
bias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, bias_origin.at(number<1>{})}); // M/N
const auto dbias_origin = dbias_dram_block_window_tmp.get_window_origin();
auto dbias_dram_block_window =
make_tile_window(dbias_dram_block_window_tmp.get_bottom_tensor_view(),
dbias_dram_block_window_tmp.get_window_lengths(),
{seqlen_q_start, dbias_origin.at(number<1>{})}); // M/N
auto lse_dram_window = make_tile_window(
lse_dram_block_window.get_bottom_tensor_view(),
lse_dram_block_window.get_window_lengths(),
lse_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto d_dram_window = make_tile_window(
d_dram_block_window.get_bottom_tensor_view(),
d_dram_block_window.get_window_lengths(),
d_dram_block_window.get_window_origin(),
Policy::template MakeLSEDDramTileDistribution<Problem, decltype(gemm_0)>());
auto bias_dram_window =
make_tile_window(bias_dram_block_window.get_bottom_tensor_view(),
bias_dram_block_window.get_window_lengths(),
bias_dram_block_window.get_window_origin(),
Policy::template MakeBiasTileDistribution<Problem>());
auto biast_lds_window =
make_tile_window(biast_lds_shuffle_window.get_bottom_tensor_view(),
biast_lds_shuffle_window.get_window_lengths(),
biast_lds_shuffle_window.get_window_origin(),
Policy::template MakeBiasTTileDistribution<decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0), false>(
randval_dram_block_window_tmp, seqlen_q_start);
index_t i_total_loops = 0;
constexpr index_t k0_loops = kQKHeaddim / kK0;
constexpr index_t k1_loops = kM0 / kK1;
constexpr index_t k2_loops = kVHeaddim / kK2;
constexpr index_t k3_loops = kM0 / kK3;
constexpr index_t k4_loops = kN0 / kK4;
do
{
auto q_dram_window = make_tile_window(
q_dram_block_window.get_bottom_tensor_view(),
q_dram_block_window.get_window_lengths(),
q_dram_block_window.get_window_origin(),
Policy::template MakeQDramTileDistribution<Problem>()); // Q DRAM tile window for
// load
auto do_dram_window = make_tile_window(
do_dram_block_window.get_bottom_tensor_view(),
do_dram_block_window.get_window_lengths(),
do_dram_block_window.get_window_origin(),
Policy::template MakeOGradDramTileDistribution<Problem>()); // OGrad DRAM tile
// window for load
// STAGE 1, Q@K Gemm0
auto st_acc = SPTBlockTileType{};
auto q_block_tile = load_tile(q_dram_window);
clear_tile(st_acc); // Initialize S^T
store_tile(q_lds_window, q_block_tile); // LDS write
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
const auto bias_tile = load_tile(bias_dram_window); // load bias tile
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
__builtin_amdgcn_sched_barrier(
0); // prevent from messing up the order of global loads
}
if constexpr(k0_loops > 1)
{
static_for<0, k0_loops - 1, 1>{}([&](auto i_k0) {
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kM0, (i_k0 + 1) * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, i_k0 * kK0>{},
sequence<kN0, (i_k0 + 1) * kK0>{}));
block_sync_lds();
});
}
auto do_block_tile = load_tile(do_dram_window); // prefetch load OGrad tile
{ // tail
block_sync_lds();
gemm_0(st_acc,
get_slice_tile(q_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kM0, k0_loops * kK0>{}),
get_slice_tile(k_lds_window,
sequence<0, (k0_loops - 1) * kK0>{},
sequence<kN0, k0_loops * kK0>{}));
block_sync_lds();
}
// STAGE 2, Scale, Add bias, Mask, Softmax, Dropout
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
{
block_sync_lds();
auto bias_shuffle_tmp = make_static_distributed_tensor<BiasDataType>(
Policy::template MakeShuffledBiasTileDistribution<Problem>());
shuffle_tile(bias_shuffle_tmp, bias_tile);
store_tile(biast_lds_shuffle_window, bias_shuffle_tmp);
block_sync_lds();
auto biast_tile = load_tile(biast_lds_window);
tile_elementwise_inout(
[&](auto& x, const auto& y) {
#if !CK_TILE_FMHA_FWD_FAST_EXP2
x = raw_scale * x + type_convert<AccDataType>(y);
#else
x = scale * x + log2e_v<AccDataType> * type_convert<AccDataType>(y);
#endif
},
st_acc,
biast_tile);
move_tile_window(bias_dram_window, {kM0, 0});
}
else if constexpr(BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
const auto q_origin = q_dram_block_window.get_window_origin();
constexpr auto st_spans = decltype(st_acc)::get_distributed_spans();
sweep_tile_span(st_spans[number<0>{}], [&](auto idx0) {
sweep_tile_span(st_spans[number<1>{}], [&](auto idx1) {
const auto tile_idx = get_x_indices_from_distributed_indices(
st_acc.get_tile_distribution(), make_tuple(idx0, idx1));
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if !CK_TILE_FMHA_FWD_FAST_EXP2
st_acc(i_j_idx) *= raw_scale;
#else
st_acc(i_j_idx) *= scale;
#endif
position_encoding.update(st_acc(i_j_idx), row, col);
});
});
}
else
{
#if !CK_TILE_FMHA_FWD_FAST_EXP2
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, st_acc);
#endif
}
if constexpr(kPadSeqLenK || FmhaMask::IsMasking)
{
const auto q_origin = q_dram_block_window.get_window_origin();
bool need_perpixel_check = mask.IsEdgeTile(q_origin.at(number<0>{}),
k_origin.at(number<0>{}),
number<kM0>{},
number<kN0>{});
if(need_perpixel_check)
{
set_tile_if(st_acc, -numeric<AccDataType>::infinity(), [&](auto tile_idx) {
const auto row = q_origin.at(number<0>{}) + tile_idx.at(number<0>{});
const auto col = k_origin.at(number<0>{}) + tile_idx.at(number<1>{});
return mask.IsOutOfBound(row, col);
});
}
}
const auto lse = load_tile(lse_dram_window);
static const auto get_validated_lse = [](LSEDataType raw_lse) {
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
FmhaMask::IsMasking)
{
return raw_lse == -numeric<LSEDataType>::infinity()
? type_convert<LSEDataType>(0.f)
: raw_lse;
}
else
{
return raw_lse;
}
};
auto pt = SPTBlockTileType{};
constexpr auto pt_spans = decltype(pt)::get_distributed_spans();
sweep_tile_span(pt_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
#if CK_TILE_FMHA_FWD_FAST_EXP2
auto row_lse = log2e_v<LSEDataType> * get_validated_lse(lse[i_idx]);
#endif
sweep_tile_span(pt_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
#if CK_TILE_FMHA_FWD_FAST_EXP2
if constexpr(BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS ||
BiasEnum == BlockAttentionBiasEnum::ALIBI)
{
pt(i_j_idx) = exp2(st_acc[i_j_idx] - row_lse);
}
else
{
pt(i_j_idx) = exp2(scale * st_acc[i_j_idx] - row_lse);
}
#else
pt(i_j_idx) = exp(st_acc[i_j_idx] - get_validated_lse(lse[i_idx]));
#endif
});
});
if constexpr(kHasDropout)
{
dropout.Run<decltype(gemm_0), RandValOutputDataType>(
seqlen_q_start + i_total_loops * kM0, pt, randval_dram_window);
}
// STAGE 3, P^T@OGrad^T Gemm1
block_sync_lds();
store_tile(do_lds_window, do_block_tile); // store the prefetch
const auto pt_gemm = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[](const auto& x) { return type_convert<GemmDataType>(x > 0.f ? x : 0.f); },
pt);
}
else
{
return cast_tile<GemmDataType>(pt);
}
}();
static_for<0, k1_loops, 1>{}([&](auto i_k1) {
block_sync_lds();
gemm_1(dv_acc,
get_slice_tile(
pt_gemm, sequence<i_k1 * kK1, 0>{}, sequence<(i_k1 + 1) * kK1, kN0>{}),
get_slice_tile(dot_lds_window,
sequence<0, i_k1 * kK1>{},
sequence<kVHeaddim, (i_k1 + 1) * kK1>{}));
block_sync_lds();
});
// STAGE 4, OGrad@V Gemm2
auto dpt_acc = SPGradTBlockTileType{};
clear_tile(dpt_acc); // Initialize PGrad^T
static_for<0, k2_loops, 1>{}([&](auto i_k2) {
block_sync_lds();
gemm_2(dpt_acc,
get_slice_tile(do_lds_window,
sequence<0, i_k2 * kK2>{},
sequence<kM0, (i_k2 + 1) * kK2>{}),
get_slice_tile(
v, sequence<0, i_k2 * kK2>{}, sequence<kN0, (i_k2 + 1) * kK2>{}));
block_sync_lds();
});
// STAGE 5, P^T(PGrad^T - D)
const auto d = load_tile(d_dram_window);
auto dst = SPGradTBlockTileType{};
constexpr auto dst_spans = decltype(dst)::get_distributed_spans();
sweep_tile_span(dst_spans[number<0>{}], [&](auto idx0) {
constexpr auto i_idx = make_tuple(idx0);
sweep_tile_span(dst_spans[number<1>{}], [&](auto idx1) {
constexpr auto i_j_idx = make_tuple(idx0, idx1);
bool undrop_flag = pt[i_j_idx] >= 0;
dst(i_j_idx) =
pt[i_j_idx] *
(!kHasDropout || undrop_flag ? (dpt_acc[i_j_idx] - d[i_idx]) : d[i_idx]);
});
});
if constexpr(kHasBiasGrad)
{
const auto dbiast = [&]() {
if constexpr(kHasDropout)
{
return tile_elementwise_in(
[&rp_undrop](const auto& x) {
return type_convert<BiasGradDataType>(x * rp_undrop);
},
dst);
}
else
{
return cast_tile<BiasGradDataType>(dst);
}
}();
store_tile(biast_lds_shuffle_window, dbiast);
block_sync_lds();
auto dbiast_tile = load_tile(dbiast_lds_shuffle_window);
auto dbiast_shuffle_tmp = make_static_distributed_tensor<BiasGradDataType>(
Policy::template MakeBiasTileDistribution<Problem>());
shuffle_tile(dbiast_shuffle_tmp, dbiast_tile);
store_tile(dbias_dram_block_window, dbiast_shuffle_tmp);
move_tile_window(dbias_dram_block_window, {kM0, 0});
}
// STAGE 6, SGrad^T@Q^T Gemm3
block_sync_lds();
const auto dst_gemm = cast_tile<GemmDataType>(dst);
static_for<0, k3_loops, 1>{}([&](auto i_k3) {
block_sync_lds();
gemm_3(dk_acc,
get_slice_tile(
dst_gemm, sequence<i_k3 * kK3, 0>{}, sequence<(i_k3 + 1) * kK3, kN0>{}),
get_slice_tile(qt_lds_window,
sequence<0, i_k3 * kK3>{},
sequence<kQKHeaddim, (i_k3 + 1) * kK3>{}));
block_sync_lds();
});
// STAGE 7, SGrad@K^T Gemm4
store_tile(ds_lds_window, dst_gemm);
auto dq_acc = QGradBlockTileType{};
clear_tile(dq_acc); // Initialize QGrad
block_sync_lds();
static_for<0, k4_loops, 1>{}([&](auto i_k4) {
gemm_4(dq_acc,
get_slice_tile(ds_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kM0, (i_k4 + 1) * kK4>{}),
get_slice_tile(kt_lds_window,
sequence<0, i_k4 * kK4>{},
sequence<kQKHeaddim, (i_k4 + 1) * kK4>{}));
});
// QGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dq_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dq_acc);
}
const auto dq = cast_tile<QGradDataType>(dq_acc);
update_tile(dq_dram_block_window, dq);
// move tile windows
move_tile_window(q_dram_block_window, {kM0, 0});
move_tile_window(dq_dram_block_window, {kM0, 0});
move_tile_window(do_dram_block_window, {kM0, 0});
move_tile_window(lse_dram_window, {kM0});
move_tile_window(d_dram_window, {kM0});
} while(++i_total_loops < num_total_loop);
// KGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&scale_rp_undrop](auto& x) { x = x * scale_rp_undrop; },
dk_acc);
}
else
{
tile_elementwise_inout([&raw_scale](auto& x) { x = x * raw_scale; }, dk_acc);
}
// VGrad Scale
if constexpr(kHasDropout)
{
tile_elementwise_inout([&rp_undrop](auto& x) { x = x * rp_undrop; }, dv_acc);
}
return ck_tile::make_tuple(dk_acc, dv_acc);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/ops/fmha/pipeline/block_fmha_bwd_pipeline_default_policy.hpp"
namespace ck_tile {
// This pipeline is v located in regs, q & k & do located in lds.
using BlockFmhaBwdDQDKDVPipelineQSKSVROGradSDefaultPolicy =
BlockFmhaBwdPipelineDefaultPolicy</* QLoadOnce_ = */ true,
/* QTLoadOnce_ = */ false,
/* KLoadOnce_ = */ true,
/* KTLoadOnce_ = */ false,
/* VLoadOnce_ = */ true,
/* OGradLoadOnce_ = */ true,
/* OGradTLoadOnce_ = */ false>;
} // namespace ck_tile
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -18,60 +20,215 @@ ...@@ -18,60 +20,215 @@
namespace ck_tile { namespace ck_tile {
template <bool QLoadOnce_,
bool QTLoadOnce_,
bool KLoadOnce_,
bool KTLoadOnce_,
bool VLoadOnce_,
bool OGradLoadOnce_,
bool OGradTLoadOnce_>
struct BlockFmhaBwdPipelineDefaultPolicy struct BlockFmhaBwdPipelineDefaultPolicy
{ {
static constexpr bool QLoadOnce = template <typename Problem>
QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
static constexpr bool QTLoadOnce = {
QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once using BlockGemmProblem =
static constexpr bool KLoadOnce = BlockGemmPipelineProblem<typename Problem::QDataType,
KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once typename Problem::KDataType,
static constexpr bool KTLoadOnce = typename Problem::AccDataType,
KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once Problem::kBlockSize,
static constexpr bool VLoadOnce = TileGemmShape<Problem::BlockFmhaShape::kM0,
VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once Problem::BlockFmhaShape::kN0,
static constexpr bool OGradLoadOnce = Problem::BlockFmhaShape::kK0>>;
OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once
static constexpr bool OGradTLoadOnce = using WarpGemm = WarpGemmMfmaDispatcher<
OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
// these are for global load // these are for global load
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{ {
if constexpr(VLoadOnce) using VDataType = remove_cvref_t<typename Problem::VDataType>;
{ constexpr index_t kBlockSize = Problem::kBlockSize;
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
}
else return kTotalPixels > kMaxVecLoad ? kMaxVecLoad : kTotalPixels;
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
} }
template <typename Problem> template <typename Problem>
...@@ -84,8 +241,20 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -84,8 +241,20 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
constexpr index_t kTotalPixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (kTotalPixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
...@@ -128,958 +297,1723 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -128,958 +297,1723 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentQ<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentK<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentOGrad<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentBias()
{ {
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
// TODO: not correct! constexpr index_t kVecLoad = ((kTotalPixels / kMaxVecLoad) >= kMinVecLoad)
if constexpr(total_pixels > 32) ? kMaxVecLoad
return 8; : (kTotalPixels / kMinVecLoad);
else return kVecLoad;
return 4;
} }
// these are for lds
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using QDataType = remove_cvref_t<typename Problem::QDataType>; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
return 16 / sizeof(QDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kTotalPixels = kMPerBlock * kNPerBlock / kBlockSize;
return kTotalPixels / GetTransposedAlignmentBias<Problem>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGradAcc()
{ {
// TODO: this is for 3d layout using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using KDataType = remove_cvref_t<typename Problem::KDataType>; return 16 / sizeof(AccDataType);
return 16 / sizeof(KDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGrad()
{ {
// TODO: this is for 3d layout return GetAlignmentPostQGradAcc<Problem>();
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
return 16 / sizeof(BiasDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N2 = kNPerBlock / (N1 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
return 16 / sizeof(GemmDataType); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <typename Problem, typename BlockGemm> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>(); constexpr index_t NWarp = config.template at<2>();
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // Duplicate dimension
v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); constexpr index_t N0 = NWarp;
constexpr index_t N1 =
(get_warp_size() / kMPerBlock) > 1 ? (get_warp_size() / kMPerBlock) : 1;
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); constexpr index_t M0 = MWarp;
constexpr index_t M1 = (get_warp_size() / kMPerBlock) > 1 ? kMPerBlock : get_warp_size();
constexpr index_t M2 =
(get_warp_size() / kMPerBlock) > 1 ? 1 : (kMPerBlock / get_warp_size());
return v_block_dstr; return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1>,
sequence<2>>{});
} }
// 3d + padding template <typename DataType, index_t MPerBlock, index_t KPerBlock>
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack> CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr index_t K1 = 16 / sizeof(DataType);
make_tuple(number<KPerBlock / KPack>{}, number<MNPerBlock>{}, number<KPack>{}), constexpr index_t K0 = KPerBlock / K1;
make_tuple(number<(MNPerBlock + 1) * KPack>{}, number<KPack>{}, number<1>{}), constexpr index_t M2 = 1;
number<8>{}, constexpr index_t M1 = get_warp_size();
number<1>{}); constexpr index_t M0 = MPerBlock / M1;
constexpr auto x_lds_block_desc = transform_tensor_descriptor( return make_static_tile_distribution(
x_lds_block_desc_0, tile_distribution_encoding<sequence<>,
make_tuple(make_pass_through_transform(MNPerBlock), tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), tuple<sequence<1>, sequence<1>>,
make_tuple(sequence<1>{}, sequence<0, 2>{}), tuple<sequence<0>, sequence<1>>,
make_tuple(sequence<0>{}, sequence<1>{})); sequence<1, 2, 2>,
sequence<2, 0, 1>>{});
return x_lds_block_desc;
}
// 3d + padding
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT()
{
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MNPerBlock>{}, number<KPack>{}),
make_tuple(number<(MNPerBlock + 1) * KPack>{}, number<KPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_pass_through_transform(MNPerBlock),
make_merge_transform(make_tuple(KPerBlock / KPack, KPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return xt_lds_block_desc;
} }
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, index_t PixelsPerRow> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution()
{ {
static_assert(PixelsPerRow % KPack == 0); using ODataType = remove_cvref_t<typename Problem::ODataType>;
constexpr index_t NPerRow = PixelsPerRow / KPack;
static_assert(MNPerBlock % NPerRow == 0);
static_assert(KPerBlock % KPack == 0);
constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{},
number<MNPerBlock / NPerRow>{},
number<NPerRow>{},
number<KPack>{}),
make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{},
number<PixelsPerRow + KPack>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto xt_lds_block_desc = transform_tensor_descriptor( constexpr index_t kBlockSize = Problem::kBlockSize;
xt_lds_block_desc_0, constexpr index_t kKPerBlock = Problem::kVHeaddim;
make_tuple(
make_merge_transform(make_tuple(number<MNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return xt_lds_block_desc; return MakePreXDramTileDistribution<ODataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem> constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() constexpr index_t kKPerBlock = Problem::kVHeaddim;
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptorAsXT<kMPerBlock, kKPerBlock, kKPack>(); return MakePreXDramTileDistribution<OGradDataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDeterministicDramTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::QGradDataType,
constexpr index_t kKPerBlock = [&]() { typename Problem::QGradDataType,
if constexpr(KLoadOnce) typename Problem::AccDataType,
return Problem::BlockFmhaShape::kQKHeaddim; Problem::Shape::WarpTile::at(number<0>{}),
else Problem::Shape::WarpTile::at(number<1>{}),
return Problem::BlockFmhaShape::kK0; Problem::Shape::WarpTile::at(number<2>{}),
}(); true>;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); using WarpGemmAttrImpl = typename WarpGemm::WarpGemmAttribute::Impl;
}
template <typename Problem> constexpr index_t MWarp = Problem::Shape::BlockWarps::at(number<0>{});
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() constexpr index_t NWarp = Problem::Shape::BlockWarps::at(number<1>{});
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptorAsXT<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t kMPerBlock = Problem::Shape::kM0;
} constexpr index_t kNPerBlock = Problem::Shape::kQKHeaddim;
template <typename Problem> constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
} sequence<>,
tuple<sequence<1>, sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<2, 3>>,
tuple<sequence<1, 1>>,
sequence<1, 2, 3>,
sequence<0, 0, 0>>{};
template <typename Problem> constexpr auto dq_block_inner_dstr_encoding = tile_distribution_encoding<
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() sequence<>,
{ tuple<sequence<1>,
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; sequence<WarpGemmAttrImpl::kCM0PerLane, WarpGemmAttrImpl::kCMLane>,
constexpr index_t kKPerBlock = [&]() { sequence<WarpGemmAttrImpl::kCNLane, WarpGemmAttrImpl::kCM1PerLane>>,
if constexpr(OGradLoadOnce) tuple<sequence<2, 3>>,
return Problem::BlockFmhaShape::kVHeaddim; tuple<sequence<1, 0>>,
else sequence<1, 2, 3>,
return Problem::BlockFmhaShape::kK2; sequence<0, 0, 1>>{};
}();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
} dq_block_outer_dstr_encoding, dq_block_inner_dstr_encoding);
template <typename Problem> constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradLoadOnce)
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptorAsXT<kMPerBlock, kKPerBlock, kKPack>(); return dq_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDeterministicDramTileDistribution()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::QGradDataType,
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0; typename Problem::QGradDataType,
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>(); typename Problem::AccDataType,
Problem::Shape::WarpTile::at(number<0>{}),
Problem::Shape::WarpTile::at(number<1>{}),
Problem::Shape::WarpTile::at(number<2>{}),
true>;
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); constexpr index_t MWarp = Problem::Shape::BlockWarps::at(number<0>{});
} constexpr index_t NWarp = Problem::Shape::BlockWarps::at(number<1>{});
template <typename Problem> constexpr index_t kMPerBlock = Problem::Shape::kM0;
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() constexpr index_t kNPerBlock = Problem::Shape::kQKHeaddim;
{
using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType);
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>(); constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
} constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
template <typename Problem> constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() sequence<>,
{ tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
using KDataType = remove_cvref_t<typename Problem::KDataType>; tuple<sequence<1, 2>>,
constexpr index_t Banks = 32; // TODO: need change based on arch tuple<sequence<1, 1>>,
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType); sequence<1, 2>,
constexpr index_t kKPack = GetSmemKPackK<Problem>(); sequence<0, 0>>{};
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>(); constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
} dq_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
template <typename Problem> constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor()
{
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType);
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>(); return dq_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
{ {
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType);
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(PixelsPerRow % kKPack == 0); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t NPerRow = PixelsPerRow / kKPack; constexpr index_t kMPerBlock = Problem::Shape::kM0;
static_assert(kNPerBlock % NPerRow == 0); constexpr index_t kKPerBlock = Problem::Shape::kQKHeaddim;
static_assert(kMPerBlock % kKPack == 0);
constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr index_t K1 = 16 / sizeof(AccDataType);
make_tuple(number<kMPerBlock / kKPack>{}, constexpr index_t K0 = kKPerBlock / K1;
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto biast_lds_block_desc = transform_tensor_descriptor( constexpr index_t M2 = get_warp_size() / K0;
biast_lds_block_desc_0, constexpr index_t M1 = kBlockSize / get_warp_size();
make_tuple( constexpr index_t M0 = kMPerBlock / (M1 * M2);
make_merge_transform(make_tuple(number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return biast_lds_block_desc; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
// these are for lds
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{ {
constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * return GetAlignmentQ<Problem>();
MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_q;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQT()
{ {
constexpr index_t smem_size_qt = [&]() { return GetTransposedAlignmentQ<Problem>();
if constexpr(QLoadOnce && !QTLoadOnce)
return 0;
else
return sizeof(typename Problem::QDataType) *
MakeQTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_qt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{ {
constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * return GetAlignmentK<Problem>();
MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_k;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackKT()
{ {
constexpr index_t smem_size_kt = [&]() { return GetTransposedAlignmentK<Problem>();
if constexpr(KLoadOnce && !KTLoadOnce)
return 0;
else
return sizeof(typename Problem::KDataType) *
MakeKTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_kt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{ {
constexpr index_t smem_size_v = [&]() { return GetAlignmentV<Problem>();
if constexpr(VLoadOnce)
return 0;
else
return sizeof(typename Problem::VDataType) *
MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_v;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{ {
constexpr index_t smem_size_do = return GetAlignmentBias<Problem>();
sizeof(typename Problem::OGradDataType) *
MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_do;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad()
{ {
constexpr index_t smem_size_dot = [&]() { return GetAlignmentOGrad<Problem>();
if constexpr(OGradLoadOnce && !OGradTLoadOnce)
return 0;
else
return sizeof(typename Problem::OGradDataType) *
MakeOGradTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_dot;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGradT()
{ {
constexpr index_t smem_size_ds = return GetTransposedAlignmentOGrad<Problem>();
sizeof(typename Problem::GemmDataType) *
MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_ds;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad()
{ {
constexpr index_t smem_size_bias = [&]() { // TODO: this is for 3d layout
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
return sizeof(typename Problem::BiasDataType) * return 16 / sizeof(GemmDataType);
MakeBiasTLdsBlockDescriptor<Problem>().get_element_space_size();
else
return 0;
}();
return smem_size_bias;
} }
template <typename Problem> template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr index_t smem_size_q = GetSmemSizeQ<Problem>(); constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>(); constexpr auto MNLdsLayer =
constexpr index_t smem_size_k = GetSmemSizeK<Problem>(); (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
constexpr index_t smem_size_v = GetSmemSizeV<Problem>(); constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>(); make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>(); number<MNPerBlock / MNLdsLayer>{},
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>(); number<KPack>{}),
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>(); make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias); number<KPack>{},
number<1>{});
index_t smem_size = 0;
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
if constexpr(QLoadOnce && OGradLoadOnce) x_lds_block_desc_0,
smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot + make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
smem_size_transpose; // 1~4 & 10 number<KPerBlock / KPack * MNLdsLayer>{})),
else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) make_pass_through_transform(number<KPack>{})),
smem_size += smem_size_q + smem_size_qt + make_tuple(sequence<1, 0>{}, sequence<2>{}),
max(smem_size_do, make_tuple(sequence<1, 0>{}, sequence<2>{}));
smem_size_dot,
smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce) x_lds_block_desc_permuted,
smem_size += smem_size_do + smem_size_dot + make_tuple(make_unmerge_transform(
max(smem_size_q, make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
smem_size_qt, make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy make_pass_through_transform(number<KPack>{})),
else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce) make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
smem_size += max(smem_size_q, make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
smem_size_qt,
smem_size_do, constexpr auto x_lds_block_desc = transform_tensor_descriptor(
smem_size_dot, x_lds_block_desc_xk0_mnldslayer_mn_xk1,
smem_size_transpose); // 9/13 TODO: Multiple buffers strategy make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
// 14/15 needs to be adjusted make_merge_transform_v3_division_mod(
if constexpr(KLoadOnce) make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
smem_size += (smem_size_k + smem_size_kt); // 1~13 make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
else make_tuple(sequence<0>{}, sequence<1>{}));
smem_size =
max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy return x_lds_block_desc;
return max(smem_size, smem_size_v); // 15
} }
template <typename Problem, typename BlockGemm> template <typename Problem,
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() index_t MNPerBlock,
index_t KPerBlock,
index_t KPack,
index_t KPackT>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
{ {
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); // kfold and mpair dimension is not always required.
using WG = remove_cvref_t<decltype(config.template at<0>())>; // more dimension in merge_transform increase the difficulty of generating immarg offset
constexpr index_t MWarp = config.template at<1>(); // for compiler.
constexpr index_t NWarp = config.template at<2>(); constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
constexpr auto kBlockSize = Problem::kBlockSize;
constexpr auto MN0 = MNPerBlock / KPack;
constexpr auto MN1 = KPack;
constexpr auto KThreadWrite = kBlockSize / MN0;
constexpr auto K0Number = KPerBlock / KPackT;
constexpr auto K0PerThreadWrite = K0Number / KThreadWrite;
constexpr auto KThreadRead = get_warp_size() / MNPerXDL; // assume 32x32x8 mfma
constexpr auto K0PerThreadRead = K0Number / KThreadRead;
constexpr auto kfold = (KPackT * MN0 * 2 > 128) ? 1 : 128 / (KPackT * MN0 * 2);
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mnpair<=n0
constexpr auto mnpair =
(KPackT * MNPerXDL * 2 > 128)
? 1
: ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * MN1>{},
number<kfold * MN0 / mnpair>{},
number<mnpair>{},
KPackT),
make_tuple(number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1>{},
number<KPackT * kfold * MN0>{},
number<KPackT * mnpair>{},
number<KPackT>{},
number<1>{}),
number<KPackT>{},
number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
xt_lds_block_desc_raw,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
constexpr index_t N0 = NWarp; xt_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<MN1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane; xt_lds_block_desc_unmerged,
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2; make_tuple(make_merge_transform_v3_division_mod(
constexpr index_t M1 = MWarp; make_tuple(number<KThreadReadPerm>{},
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM); number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
number<KPackT>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return make_static_tile_distribution( return xt_lds_block_desc;
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{ {
using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t K1 = 16 / sizeof(VDataType); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
constexpr index_t K0 = kKPerBlock / K1; }
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( template <typename Problem>
tile_distribution_encoding<sequence<>, CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, using WarpGemm = WarpGemmMfmaDispatcher<
tuple<sequence<1>, sequence<2, 0>>, typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t K1 = GetAlignmentQ<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
// coalesce reading for each blocks constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsWriteBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor()
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
auto kt_lds_block_desc = MakeKTLdsWriteBlockDescriptor<Problem>();
return transform_tensor_descriptor(
kt_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto kt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
return kt_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto q_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
return q_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
// coalesce reading for each blocks constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsWriteBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
auto qt_lds_block_desc = MakeQTLdsWriteBlockDescriptor<Problem>();
return transform_tensor_descriptor(
qt_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto qt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
return qt_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto dst_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
return dst_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsWriteBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { using LSEDType = remove_cvref_t<typename Problem::DDataType>;
if constexpr(OGradLoadOnce) constexpr index_t kMPack = 16 / sizeof(LSEDType);
return Problem::BlockFmhaShape::kVHeaddim;
else constexpr auto lsed_lds_block_desc =
return Problem::BlockFmhaShape::kK2; make_naive_tensor_descriptor(make_tuple(number<kMPerBlock>{}),
}(); make_tuple(number<1>{}),
number<kMPack>{},
number<1>{});
return lsed_lds_block_desc;
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
constexpr index_t N0 = NWarp;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
// constexpr index_t SwizzleConfig = 1;
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDRegBlockDescriptor()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
constexpr index_t N0 = NWarp;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor()
{
// Hold full block data
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}) == 16 ? false : true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto do_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
return do_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegWriteBlockDescriptor()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
// coalesce reading for each blocks constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsWriteBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
auto dot_lds_block_desc = MakeOGradTLdsWriteBlockDescriptor<Problem>();
return transform_tensor_descriptor(
dot_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
// constexpr index_t kNPerBlock = 32;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto dot_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
return dot_block_dstr;
} }
template <typename DataType, index_t MPerBlock, index_t KPerBlock> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor()
{ {
constexpr index_t K1 = 16 / sizeof(DataType); using WarpGemm =
constexpr index_t K0 = KPerBlock / K1; WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
constexpr index_t M2 = 1; typename Problem::OGradDataType,
constexpr index_t M1 = get_warp_size(); typename Problem::AccDataType,
constexpr index_t M0 = MPerBlock / M1; Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
return make_static_tile_distribution( constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
tile_distribution_encoding<sequence<>, constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1>>, constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
tuple<sequence<0>, sequence<1>>, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
sequence<1, 2, 2>,
sequence<2, 0, 1>>{}); constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto pt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
return pt_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
Problem::kIsDeterministic ? true : false>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto ds_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
return ds_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQGradWriteBlockDescriptor()
{
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using WarpGemmAttrImpl = typename WarpGemm::WarpGemmAttribute::Impl;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr auto dq_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dq_block_inner_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<WarpGemmAttrImpl::kCM0PerLane, WarpGemmAttrImpl::kCMLane>,
sequence<WarpGemmAttrImpl::kCNLane, WarpGemmAttrImpl::kCM1PerLane>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 1>>{};
constexpr auto dq_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dq_block_outer_dstr_encoding, dq_block_inner_dstr_encoding);
constexpr auto dq_block_dstr = make_static_tile_distribution(dq_block_dstr_encode);
return dq_block_dstr;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor()
{
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType);
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kMPerBlock % kKPack == 0);
constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto biast_lds_block_desc = transform_tensor_descriptor(
biast_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return biast_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQ()
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) *
MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t kBlockSize = Problem::kBlockSize; return smem_size_q;
constexpr index_t kKPerBlock = Problem::kVHeaddim;
return MakePreXDramTileDistribution<ODataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQT()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; constexpr index_t smem_size_qt =
sizeof(typename Problem::QDataType) *
constexpr index_t kBlockSize = Problem::kBlockSize; MakeQTLdsWriteBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t kKPerBlock = Problem::kVHeaddim;
return MakePreXDramTileDistribution<OGradDataType, kBlockSize, kKPerBlock>(); return smem_size_qt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeK()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t smem_size_k =
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; sizeof(typename Problem::KDataType) *
constexpr index_t kKPerBlock = [&]() { MakeKLdsWriteBlockDescriptor<Problem>().get_element_space_size();
if constexpr(QTLoadOnce) return smem_size_k;
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t N1 = GetTransposedAlignmentQ<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeKT()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t smem_size_kt =
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; sizeof(typename Problem::KDataType) *
constexpr index_t kKPerBlock = [&]() { MakeKTLdsReadBlockDescriptor<Problem>().get_element_space_size();
if constexpr(QTLoadOnce) return smem_size_kt;
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t N1 = GetTransposedAlignmentQ<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t smem_size_lse =
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; sizeof(typename Problem::LSEDataType) *
constexpr index_t kKPerBlock = [&]() { MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
if constexpr(KTLoadOnce) return smem_size_lse;
return Problem::BlockFmhaShape::kN0; }
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t N1 = GetTransposedAlignmentK<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; template <typename Problem>
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
constexpr index_t K3 = total_pixels / N1; {
constexpr index_t kKPack = GetSmemKPackK<Problem>(); constexpr index_t smem_size_d =
static_assert(kKPack % K3 == 0); sizeof(typename Problem::DDataType) *
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
constexpr index_t K1 = get_warp_size() / (K2 * N0); return smem_size_d;
constexpr index_t K0 = kBlockSize / get_warp_size(); }
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( template <typename Problem>
tile_distribution_encoding<sequence<>, CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeV()
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, {
tuple<sequence<2>, sequence<2, 1, 2>>, constexpr index_t smem_size_v =
tuple<sequence<0>, sequence<1, 0, 2>>, sizeof(typename Problem::VDataType) *
sequence<2, 1>, MakeVLdsWriteBlockDescriptor<Problem>().get_element_space_size();
sequence<3, 1>>{}); return smem_size_v;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGrad()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t smem_size_do =
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; sizeof(typename Problem::OGradDataType) *
constexpr index_t kKPerBlock = [&]() { MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size();
if constexpr(KTLoadOnce) return smem_size_do;
return Problem::BlockFmhaShape::kN0; }
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t N1 = GetTransposedAlignmentK<Problem>(); template <typename Problem>
constexpr index_t N0 = kNPerBlock / N1; CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGradT()
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; {
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? constexpr index_t smem_size_dot =
constexpr index_t K3 = total_pixels / N1; sizeof(typename Problem::OGradDataType) *
constexpr index_t kKPack = GetSmemKPackK<Problem>(); MakeOGradTLdsWriteBlockDescriptor<Problem>().get_element_space_size();
static_assert(kKPack % K3 == 0); return smem_size_dot;
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave }
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( template <typename Problem>
tile_distribution_encoding<sequence<>, CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeSGrad()
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, {
tuple<sequence<2>, sequence<2, 1, 2>>, constexpr index_t smem_size_ds =
tuple<sequence<0>, sequence<1, 0, 2>>, sizeof(typename Problem::GemmDataType) *
sequence<1, 2>, MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size();
sequence<1, 3>>{}); return smem_size_ds;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeBias()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t smem_size_bias = [&]() {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
constexpr index_t kKPerBlock = [&]() { return sizeof(typename Problem::BiasDataType) *
if constexpr(OGradTLoadOnce) MakeBiasTLdsBlockDescriptor<Problem>().get_element_space_size();
return Problem::BlockFmhaShape::kM0;
else else
return Problem::BlockFmhaShape::kK1; return 0;
}(); }();
return smem_size_bias;
}
constexpr index_t N1 = GetTransposedAlignmentOGrad<Problem>(); template <typename Problem>
constexpr index_t N0 = kNPerBlock / N1; // P CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>();
constexpr index_t smem_size_lse = GetSmemSizeLSE<Problem>();
constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>();
constexpr index_t smem_size_d = GetSmemSizeD<Problem>();
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
smem_size_ds;
constexpr index_t smem_size_stage2 = smem_size_qt + smem_size_bias;
constexpr index_t smem_size_stage3 = smem_size_qt;
constexpr index_t smem_size_stage4 = smem_size_qt + smem_size_do + smem_size_d;
constexpr index_t smem_size_stage5 = smem_size_qt;
constexpr index_t smem_size_stage6 = smem_size_qt + smem_size_ds;
return max(smem_size_stage0_0,
smem_size_stage0_1,
smem_size_stage1,
smem_size_stage2,
smem_size_stage3,
smem_size_stage4,
smem_size_stage5,
smem_size_stage6);
}
template <typename Problem_>
struct HotLoopScheduler
{
using Problem = Problem_;
template <index_t GemmStage>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler()
{
}
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; template <>
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
constexpr index_t K3 = total_pixels / N1; {
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); // Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
static_assert(kKPack % K3 == 0); // Comp: Q x K
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t VMEM_READ_INST =
constexpr index_t K1 = get_warp_size() / (K2 * N0); Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t K0 = kBlockSize / get_warp_size(); constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
static_assert(kKPerBlock == K0 * K1 * K2 * K3); constexpr index_t MFMA_INST = Gemm0MFMA;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr index_t VMEM_READ__MFMA_Rate = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder =
MFMA_INST - VMEM_READ__MFMA_Rate * VMEM_READ__MFMA_Rate;
// To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = LDS_READ_INST / MFMA_INST;
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, VMEM_READ__MFMA_Rate, 1>{}([&](auto j) {
ignore = j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read
});
});
static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read
});
}
return make_static_tile_distribution( template <>
tile_distribution_encoding<sequence<>, CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, {
tuple<sequence<2>, sequence<2, 1, 2>>, // Mem: Q^T LDS load
tuple<sequence<0>, sequence<1, 0, 2>>, // Comp: OGrad x V
sequence<2, 1>, constexpr index_t LDS_READ_INST = QT_LDS_READ;
sequence<3, 1>>{}); constexpr index_t MFMA_INST = Gemm1MFMA;
}
// To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate = LDS_READ_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS read
});
}
template <typename Problem> template <>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; // Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; // Comp: PT x OGrad
constexpr index_t kKPerBlock = [&]() { constexpr index_t LDS_WRITE_INST = Q_LDS_WRITE + QT_LDS_WRITE + OGrad_LDS_WRITE +
if constexpr(OGradTLoadOnce) OGradT_LDS_WRITE + LSE_LDS_WRITE + D_LDS_WRITE;
return Problem::BlockFmhaShape::kM0; constexpr index_t MFMA_INST = Gemm2MFMA;
else
return Problem::BlockFmhaShape::kK1; // To hide instruction issue latency
}(); constexpr index_t MFMA__LDS_WRITE_Rate = LDS_WRITE_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, MFMA__LDS_WRITE_Rate, 0); // DS write
});
}
constexpr index_t N1 = GetTransposedAlignmentOGrad<Problem>(); template <>
constexpr index_t N0 = kNPerBlock / N1; CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; {
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? // Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
constexpr index_t K3 = total_pixels / N1; // Comp: SGradT x QT
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
static_assert(kKPack % K3 == 0); constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave constexpr index_t MFMA_INST = Gemm3MFMA;
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size(); // To hide instruction issue latency
constexpr index_t MFMA__LDS_WRITE_Rate =
LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / MFMA__LDS_WRITE_Rate;
constexpr index_t MFMA__LDS_READ_Rate =
(MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
: 1
: 0;
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, MFMA__LDS_WRITE_Rate, 0); // DS Write
});
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS Read
});
}
return make_static_tile_distribution( template <>
tile_distribution_encoding<sequence<>, CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, {
tuple<sequence<2>, sequence<2, 1, 2>>, // Mem: SGrad, OGrad, D LDS load.
tuple<sequence<0>, sequence<1, 0, 2>>, // Comp: SGrad x KT
sequence<1, 2>, constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
sequence<1, 3>>{}); constexpr index_t MFMA_INST = Gemm4MFMA;
}
// To hide instruction issue latency
constexpr index_t MFMA__LDS_READ_Rate =
LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, MFMA__LDS_READ_Rate, 0); // DS Read
});
}
private:
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static constexpr index_t WarpGemmN =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
static constexpr index_t Gemm4MWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
static constexpr index_t Gemm4NWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
// Compute
static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm4MFMA =
kM0 * kQKHeaddim * kN0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
// VMEM
static constexpr index_t Q_VMEM_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t OGrad_VMEM_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write
static constexpr index_t Q_LDS_WRITE =
kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t QT_LDS_WRITE =
kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
static constexpr index_t OGrad_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
};
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
...@@ -1140,204 +2074,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1140,204 +2074,6 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution(); return c_block_tensor_type::get_tile_distribution();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
}
}();
using BlockGemmPolicy =
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
std::is_same_v<typename Problem::VDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
}
else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
std::is_same_v<typename Problem::VDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
}
}();
using BlockGemmPolicy =
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>;
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -8,9 +8,7 @@ namespace ck_tile { ...@@ -8,9 +8,7 @@ namespace ck_tile {
// This class is used for codegen pattern matching // This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum enum class BlockFmhaBwdPipelineEnum
{ {
KSKTSVR = 0, KRKTRVR = 0,
QSKSVROGradS,
KSVR,
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -24,7 +24,9 @@ template <typename QDataType_, ...@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_, typename BiasGradDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaBwdPipelineProblem struct BlockFmhaBwdPipelineProblem
{ {
...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem ...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>; using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem ...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem ...@@ -88,4 +91,30 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename AccDataType_,
typename QGradDataType_,
typename Shape_,
typename Traits_,
bool kIsGroupMode_,
bool kIsDeterministic_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Shape = remove_cvref_t<Shape_>;
using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = Shape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
static_assert(0 < kBlockSize && kBlockSize % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -28,6 +28,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
using PDataType = remove_cvref_t<typename Problem::PDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -49,8 +50,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -49,8 +50,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ; static constexpr bool kPadHeadDimQ = Problem::kPadHeadDimQ;
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -141,7 +141,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -249,7 +249,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start); randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
...@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -501,10 +501,14 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
} }
block_sync_lds(); block_sync_lds();
...@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS ...@@ -637,7 +641,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -29,6 +29,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
using PDataType = remove_cvref_t<typename Problem::PDataType>; using PDataType = remove_cvref_t<typename Problem::PDataType>;
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -54,8 +55,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -54,8 +55,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimQ = true; // support multiple of vector(like 8x)
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = true; // always store LSE (acc) static constexpr bool kStoreLSE = true; // always store LSE (acc)
static constexpr bool kHasDropout = false; // ignore this flag
static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = Problem::kHasUnevenSplits;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
...@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -153,7 +153,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -301,7 +301,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
{bias_origin.at(number<0>{}), seqlen_k_start}, // M/N {bias_origin.at(number<0>{}), seqlen_k_start}, // M/N
Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>()); Policy::template MakeBiasDramTileDistribution<Problem, decltype(gemm_0)>());
auto randval_dram_window = dropout.MakeRandvalDramWindow<decltype(gemm_0)>( auto randval_dram_window = dropout.template MakeRandvalDramWindow<decltype(gemm_0)>(
randval_dram_block_window_tmp, seqlen_k_start); randval_dram_block_window_tmp, seqlen_k_start);
auto v_dram_window = auto v_dram_window =
...@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -584,12 +584,13 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
auto randval_ptr = auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>(); reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, randval_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0, seqlen_k_start + i_total_loops * kN0,
p_compute, p_compute,
randval_dram_window); randval_dram_window);
...@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync ...@@ -741,7 +742,7 @@ struct BlockFmhaFwdSplitKVPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
BlockDropout& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -21,6 +21,7 @@ template <typename QDataType_, ...@@ -21,6 +21,7 @@ template <typename QDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaPipelineProblem struct BlockFmhaPipelineProblem
{ {
...@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem ...@@ -37,6 +38,7 @@ struct BlockFmhaPipelineProblem
using ODataType = remove_cvref_t<ODataType_>; using ODataType = remove_cvref_t<ODataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
...@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem ...@@ -49,7 +51,6 @@ struct BlockFmhaPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kStoreLSE = Traits::kStoreLSE; static constexpr bool kStoreLSE = Traits::kStoreLSE;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant; static constexpr bool kDoFp8StaticQuant = Traits::kDoFp8StaticQuant;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -68,6 +69,7 @@ template <typename QDataType, ...@@ -68,6 +69,7 @@ template <typename QDataType,
typename BlockFmhaShape, typename BlockFmhaShape,
bool kIsGroupMode, bool kIsGroupMode,
typename FmhaMask, typename FmhaMask,
typename FmhaDropout,
typename Traits> typename Traits>
struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType, struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
KDataType, KDataType,
...@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType, ...@@ -83,6 +85,7 @@ struct BlockFmhaFwdSplitKVPipelineProblem : BlockFmhaPipelineProblem<QDataType,
BlockFmhaShape, BlockFmhaShape,
kIsGroupMode, kIsGroupMode,
FmhaMask, FmhaMask,
FmhaDropout,
Traits> Traits>
{ {
static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits; static constexpr bool kHasUnevenSplits = kIsGroupMode || Traits::kHasUnevenSplits;
......
...@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -29,6 +29,7 @@ struct BlockFmhaPipelineQRKSVS
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -51,7 +52,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV; static constexpr bool kPadHeadDimV = Problem::kPadHeadDimV;
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS ...@@ -100,8 +100,6 @@ struct BlockFmhaPipelineQRKSVS
static constexpr const char* name = "qr"; static constexpr const char* name = "qr";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -141,7 +139,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS ...@@ -486,10 +484,14 @@ struct BlockFmhaPipelineQRKSVS
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
smem_ptr, seqlen_k_start + i_total_loops * kN0, p_compute, randval_dram_window); smem_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0,
p_compute,
randval_dram_window);
} }
block_sync_lds(); block_sync_lds();
...@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS ...@@ -620,7 +622,7 @@ struct BlockFmhaPipelineQRKSVS
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -30,6 +30,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -30,6 +30,7 @@ struct BlockFmhaPipelineQRKSVSAsync
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -56,7 +57,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -56,7 +57,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x) static constexpr bool kPadHeadDimV = true; // support multiple of vector(like 8x)
static constexpr auto BiasEnum = Problem::BiasEnum; static constexpr auto BiasEnum = Problem::BiasEnum;
static constexpr bool kStoreLSE = Problem::kStoreLSE; static constexpr bool kStoreLSE = Problem::kStoreLSE;
static constexpr bool kHasDropout = Problem::kHasDropout;
// last dimension vector length used to create tensor view(and decide buffer_load vector length) // last dimension vector length used to create tensor view(and decide buffer_load vector length)
// ... together with tensor distribution. tensor dist should able to overwrite this // ... together with tensor distribution. tensor dist should able to overwrite this
...@@ -112,8 +112,6 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -112,8 +112,6 @@ struct BlockFmhaPipelineQRKSVSAsync
static constexpr const char* name = "qr_async"; static constexpr const char* name = "qr_async";
using DropoutType = std::conditional_t<kHasDropout, BlockDropout, NullBlockDropout>;
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -153,7 +151,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -153,7 +151,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const FmhaDropout dropout) const
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
...@@ -569,12 +567,13 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -569,12 +567,13 @@ struct BlockFmhaPipelineQRKSVSAsync
}); });
}); });
if constexpr(kHasDropout) if constexpr(FmhaDropout::IsDropout)
{ {
auto randval_ptr = auto randval_ptr =
reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>(); reinterpret_cast<char*>(smem_ptr) + Policy::template GetSmemSizeKV<Problem>();
dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>( dropout.template Run<decltype(gemm_0), SMPLComputeDataType, RandValOutputDataType>(
randval_ptr, randval_ptr,
q_origin.at(number<0>{}),
seqlen_k_start + i_total_loops * kN0, seqlen_k_start + i_total_loops * kN0,
p_compute, p_compute,
randval_dram_window); randval_dram_window);
...@@ -730,7 +729,7 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -730,7 +729,7 @@ struct BlockFmhaPipelineQRKSVSAsync
PositionEncoding position_encoding, PositionEncoding position_encoding,
float scale_s, float scale_s,
void* smem_ptr, void* smem_ptr,
DropoutType& dropout) const FmhaDropout dropout) const
{ {
return operator()(q_dram_block_window_tmp, return operator()(q_dram_block_window_tmp,
identity{}, identity{},
......
...@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -28,6 +28,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>; using OaccDataType = remove_cvref_t<typename Problem::OaccDataType>;
using ODataType = remove_cvref_t<typename Problem::ODataType>; using ODataType = remove_cvref_t<typename Problem::ODataType>;
using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>; using FmhaMask = remove_cvref_t<typename Problem::FmhaMask>;
using FmhaDropout = remove_cvref_t<typename Problem::FmhaDropout>;
using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>; using BlockFmhaShape = remove_cvref_t<typename Problem::BlockFmhaShape>;
using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>; using VLayout = remove_cvref_t<typename BlockFmhaShape::VLayout>;
...@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8 ...@@ -124,7 +125,7 @@ struct [[deprecated]] BlockFmhaPipelineQRKSVSFp8
float descale_qk, float descale_qk,
float descale_sv, float descale_sv,
void* smem_ptr, void* smem_ptr,
BlockDropout& /*dropout*/) const // not supported FmhaDropout& /*dropout*/) const // not supported
{ {
static_assert( static_assert(
std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> && std::is_same_v<QDataType, remove_cvref_t<typename QDramBlockWindowTmp::DataType>> &&
......
...@@ -92,4 +92,20 @@ struct TileFmhaBwdShape ...@@ -92,4 +92,20 @@ struct TileFmhaBwdShape
// that need load V at once // that need load V at once
}; };
template <typename BlockTile_, // sequence<...
typename BlockWarps_,
typename WarpTile_>
struct TileFmhaBwdConvertQGradShape
{
using BlockTile = remove_cvref_t<BlockTile_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
using WarpTile = remove_cvref_t<WarpTile_>;
static constexpr index_t NumWarps = reduce_on_sequence(BlockWarps{}, multiplies{}, number<1>{});
static constexpr index_t kM0 = BlockTile::at(number<0>{}); // tile size along q seqlen
static constexpr index_t kN0 = BlockTile::at(number<1>{}); // tile size along k seqlen
static constexpr index_t kQKHeaddim = BlockTile::at(number<2>{}); // Q & K headdim
};
} // namespace ck_tile } // namespace ck_tile
...@@ -15,7 +15,6 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */, ...@@ -15,7 +15,6 @@ template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
BlockAttentionBiasEnum BiasEnum_, BlockAttentionBiasEnum BiasEnum_,
bool kHasBiasGrad_, bool kHasBiasGrad_,
bool kStoreLSE_, bool kStoreLSE_,
bool kHasDropout_,
bool kDoFp8StaticQuant_, bool kDoFp8StaticQuant_,
index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu_ = -1 /* overwrite occupancy if not -1 */>
struct TileFmhaTraits struct TileFmhaTraits
...@@ -27,7 +26,6 @@ struct TileFmhaTraits ...@@ -27,7 +26,6 @@ struct TileFmhaTraits
static constexpr auto BiasEnum = BiasEnum_; static constexpr auto BiasEnum = BiasEnum_;
static constexpr bool kHasBiasGrad = kHasBiasGrad_; static constexpr bool kHasBiasGrad = kHasBiasGrad_;
static constexpr bool kStoreLSE = kStoreLSE_; static constexpr bool kStoreLSE = kStoreLSE_;
static constexpr bool kHasDropout = kHasDropout_;
static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_; static constexpr bool kDoFp8StaticQuant = kDoFp8StaticQuant_;
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
...@@ -39,7 +37,6 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */, ...@@ -39,7 +37,6 @@ template <bool kPadSeqLenQ /* padding for seqlen_q */,
BlockAttentionBiasEnum BiasEnum, BlockAttentionBiasEnum BiasEnum,
bool kHasBiasGrad, bool kHasBiasGrad,
bool kStoreLSE, bool kStoreLSE,
bool kHasDropout,
bool kDoFp8StaticQuant, bool kDoFp8StaticQuant,
bool kHasUnevenSplits_ = true, bool kHasUnevenSplits_ = true,
index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */> index_t kBlockPerCu = -1 /* overwrite occupancy if not -1 */>
...@@ -50,7 +47,6 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ, ...@@ -50,7 +47,6 @@ struct TileFmhaFwdSplitKVTraits : TileFmhaTraits<kPadSeqLenQ,
BiasEnum, BiasEnum,
kHasBiasGrad, kHasBiasGrad,
kStoreLSE, kStoreLSE,
kHasDropout,
kDoFp8StaticQuant, kDoFp8StaticQuant,
kBlockPerCu> kBlockPerCu>
{ {
...@@ -86,4 +82,14 @@ struct TileFmhaBwdOGradDotOTraits ...@@ -86,4 +82,14 @@ struct TileFmhaBwdOGradDotOTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdConvertQGradTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmARegBRegCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmARegBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmARegBRegCRegV1DefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
}
};
} // namespace ck_tile
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment