Unverified Commit 3d61f89a authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge pull request #134 from ROCm/merge_from_public

Merge from public
parents c160c6cf 4558a3f8
...@@ -11,6 +11,8 @@ ...@@ -11,6 +11,8 @@
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -18,60 +20,215 @@ ...@@ -18,60 +20,215 @@
namespace ck_tile { namespace ck_tile {
template <bool QLoadOnce_,
bool QTLoadOnce_,
bool KLoadOnce_,
bool KTLoadOnce_,
bool VLoadOnce_,
bool OGradLoadOnce_,
bool OGradTLoadOnce_>
struct BlockFmhaBwdPipelineDefaultPolicy struct BlockFmhaBwdPipelineDefaultPolicy
{ {
static constexpr bool QLoadOnce = template <typename Problem>
QLoadOnce_; // if q load whole block length (qkhdim) to LDS at once CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
static constexpr bool QTLoadOnce = {
QTLoadOnce_; // if q^t load whole block length (qkhdim) to LDS at once using BlockGemmProblem =
static constexpr bool KLoadOnce = BlockGemmPipelineProblem<typename Problem::QDataType,
KLoadOnce_; // if k load whole block length (qkhdim) to LDS at once typename Problem::KDataType,
static constexpr bool KTLoadOnce = typename Problem::AccDataType,
KTLoadOnce_; // if k^t load whole block length (qkhdim) to LDS at once Problem::kBlockSize,
static constexpr bool VLoadOnce = TileGemmShape<Problem::BlockFmhaShape::kM0,
VLoadOnce_; // if v load whole block length (vhdim) to Vgprs at once Problem::BlockFmhaShape::kN0,
static constexpr bool OGradLoadOnce = Problem::BlockFmhaShape::kK0>>;
OGradLoadOnce_; // if do load whole block length (vhdim) to LDS at once
static constexpr bool OGradTLoadOnce = using WarpGemm = WarpGemmMfmaDispatcher<
OGradTLoadOnce_; // if do^t load whole block length (vhdim) to LDS at once typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::QDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm0BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm1BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm2WarpTile::at(number<2>{}),
false,
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{}) == 16 ? false : true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::OGradDataType,
typename Problem::VDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::QDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm3BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{
using BlockGemmProblem =
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}),
false>;
using BlockGemmPolicy =
BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::KDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm4BlockWarps,
WarpGemm>;
return BlockGemmARegBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
// these are for global load // these are for global load
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQ()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
return 16 / sizeof(QDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentK()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentV()
{ {
if constexpr(VLoadOnce) using VDataType = remove_cvref_t<typename Problem::VDataType>;
{ constexpr index_t kBlockSize = Problem::kBlockSize;
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
return WG::kK / WG::WarpGemmAttribute::Impl::kABKLane; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
}
else return total_pixels > kMaxVecLoad ? kMaxVecLoad : total_pixels;
{
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType);
}
} }
template <typename Problem> template <typename Problem>
...@@ -84,20 +241,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -84,20 +241,39 @@ struct BlockFmhaBwdPipelineDefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentOGrad()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentQGrad() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentBias()
{ {
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>; using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t kBlockSize = Problem::kBlockSize;
using WG = remove_cvref_t<decltype(config.template at<0>())>; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
using CWarpDstr = typename WG::CWarpDstr; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr auto vec = constexpr index_t kMaxVecLoad = 16 / sizeof(BiasDataType);
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths().at(number<CWarpDstr::NDimY - 1>{}); constexpr index_t kMinVecLoad = 4 / sizeof(BiasDataType);
return vec;
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
constexpr index_t kVecLoad = ((total_pixels / kMaxVecLoad) >= kMinVecLoad)
? kMaxVecLoad
: (total_pixels / kMinVecLoad);
return kVecLoad;
} }
template <typename Problem> template <typename Problem>
...@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -128,60 +304,35 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentQ()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentQ<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentK()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentK<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad() CK_TILE_HOST_DEVICE static constexpr auto GetTransposedAlignmentOGrad()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentOGrad<Problem>();
if constexpr(total_pixels > 4)
return 4;
else
return 2;
} }
template <typename Problem> template <typename Problem>
...@@ -193,1151 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -193,1151 +344,1577 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize;
// TODO: not correct! return total_pixels / GetAlignmentBias<Problem>();
if constexpr(total_pixels > 32)
return 8;
else
return 4;
} }
// these are for lds
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGradAcc()
{ {
// TODO: this is for 3d layout using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
using QDataType = remove_cvref_t<typename Problem::QDataType>; return 16 / sizeof(AccDataType);
return 16 / sizeof(QDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentPostQGrad()
{ {
// TODO: this is for 3d layout return GetAlignmentPostQGradAcc<Problem>();
using KDataType = remove_cvref_t<typename Problem::KDataType>;
return 16 / sizeof(KDataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV() CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using VDataType = remove_cvref_t<typename Problem::VDataType>;
return 16 / sizeof(VDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t N2 = kNPerBlock / (N1 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias() CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>;
return 16 / sizeof(BiasDataType); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
return 16 / sizeof(OGradDataType); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution()
{ {
// TODO: this is for 3d layout constexpr index_t kBlockSize = Problem::kBlockSize;
using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
return 16 / sizeof(GemmDataType);
}
template <typename Problem, typename BlockGemm> constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
CK_TILE_HOST_DEVICE static constexpr auto MakeVInRegDramTileDistribution() constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M1 = get_warp_size() / K0;
constexpr index_t M0 = kBlockSize / get_warp_size();
constexpr index_t M2 = kMPerBlock / (M1 * M0);
using WG = remove_cvref_t<decltype(config.template at<0>())>; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
}
template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution()
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t MWarp = config.template at<1>(); constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>(); constexpr index_t NWarp = config.template at<2>();
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WG::kN); constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t KIterPerWarp = kKPerBlock / WG::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding( // Duplicate dimension
v_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); constexpr index_t N0 = NWarp;
constexpr index_t N1 =
(get_warp_size() / kMPerBlock) > 1 ? (get_warp_size() / kMPerBlock) : 1;
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode); constexpr index_t M0 = MWarp;
constexpr index_t M1 = (get_warp_size() / kMPerBlock) > 1 ? kMPerBlock : get_warp_size();
constexpr index_t M2 =
(get_warp_size() / kMPerBlock) > 1 ? 1 : (kMPerBlock / get_warp_size());
return v_block_dstr; return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2>>,
tuple<sequence<0, 1>, sequence<0, 1>>,
tuple<sequence<0, 0>, sequence<1, 1>>,
sequence<1>,
sequence<2>>{});
} }
// 3d + padding template <typename Problem>
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack> CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTileDistribution()
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr index_t kBlockSize = Problem::kBlockSize;
make_tuple(number<KPerBlock / KPack>{}, number<MNPerBlock>{}, number<KPack>{}),
make_tuple(number<(MNPerBlock + 1) * KPack>{}, number<KPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_pass_through_transform(MNPerBlock),
make_merge_transform(make_tuple(KPerBlock / KPack, KPack))),
make_tuple(sequence<1>{}, sequence<0, 2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return x_lds_block_desc;
}
// 3d + padding constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack> constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptorAsXT()
{
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{}, number<MNPerBlock>{}, number<KPack>{}),
make_tuple(number<(MNPerBlock + 1) * KPack>{}, number<KPack>{}, number<1>{}),
number<8>{},
number<1>{});
constexpr auto xt_lds_block_desc = transform_tensor_descriptor( constexpr index_t N1 = GetAlignmentBias<Problem>();
x_lds_block_desc_0, constexpr index_t N0 = kNPerBlock / N1;
make_tuple(make_pass_through_transform(MNPerBlock), constexpr index_t M1 = get_warp_size() / N0;
make_merge_transform(make_tuple(KPerBlock / KPack, KPack))), constexpr index_t M0 = kBlockSize / get_warp_size();
make_tuple(sequence<1>{}, sequence<0, 2>{}), constexpr index_t M2 = kMPerBlock / (M1 * M0);
make_tuple(sequence<1>{}, sequence<0>{}));
return xt_lds_block_desc; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>,
sequence<2, 1>>{});
} }
template <index_t MNPerBlock, index_t KPerBlock, index_t KPack, index_t PixelsPerRow> template <typename DataType, index_t MPerBlock, index_t KPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution()
{ {
static_assert(PixelsPerRow % KPack == 0); constexpr index_t K1 = 16 / sizeof(DataType);
constexpr index_t NPerRow = PixelsPerRow / KPack; constexpr index_t K0 = KPerBlock / K1;
static_assert(MNPerBlock % NPerRow == 0); constexpr index_t M2 = 1;
static_assert(KPerBlock % KPack == 0); constexpr index_t M1 = get_warp_size();
constexpr index_t M0 = MPerBlock / M1;
constexpr auto xt_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack>{},
number<MNPerBlock / NPerRow>{},
number<NPerRow>{},
number<KPack>{}),
make_tuple(number<(MNPerBlock / NPerRow) * (PixelsPerRow + KPack)>{},
number<PixelsPerRow + KPack>{},
number<KPack>{},
number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
xt_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<MNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return xt_lds_block_desc; return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1>>,
tuple<sequence<0>, sequence<1>>,
sequence<1, 2, 2>,
sequence<2, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; using ODataType = remove_cvref_t<typename Problem::ODataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem> constexpr index_t kBlockSize = Problem::kBlockSize;
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptorAsQT() constexpr index_t kKPerBlock = Problem::kVHeaddim;
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
return MakeXLdsBlockDescriptorAsXT<kMPerBlock, kKPerBlock, kKPack>(); return MakePreXDramTileDistribution<ODataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::kVHeaddim;
return MakePreXDramTileDistribution<OGradDataType, kBlockSize, kKPerBlock>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsBlockDescriptorAsKT() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradAccDramTileDistribution()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptorAsXT<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t kBlockSize = Problem::kBlockSize;
} constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
template <typename Problem> constexpr index_t K1 = 16 / sizeof(AccDataType);
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsBlockDescriptor() constexpr index_t K0 = kKPerBlock / K1;
{
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kKPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M1 * M2);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<1>, sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 3>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2, 3>,
sequence<0, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakePostQGradDramTileDistribution()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; using AccDataType = remove_cvref_t<typename Problem::AccDataType>;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradLoadOnce)
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>(); constexpr index_t kBlockSize = Problem::kBlockSize;
} constexpr index_t kMPerBlock = Problem::kM0;
constexpr index_t kKPerBlock = Problem::kQKHeaddim;
template <typename Problem> constexpr index_t K1 = 16 / sizeof(AccDataType);
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptorAsOGradT() constexpr index_t K0 = kKPerBlock / K1;
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t M1 = kBlockSize / get_warp_size();
if constexpr(OGradLoadOnce) constexpr index_t M0 = kMPerBlock / (M1 * M2);
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
return MakeXLdsBlockDescriptorAsXT<kMPerBlock, kKPerBlock, kKPack>(); return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
} }
// these are for lds
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQ()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; return GetAlignmentQ<Problem>();
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackQT()
{ {
using QDataType = remove_cvref_t<typename Problem::QDataType>; return GetTransposedAlignmentQ<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(QDataType);
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackK()
{ {
using KDataType = remove_cvref_t<typename Problem::KDataType>; return GetAlignmentK<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(KDataType);
constexpr index_t kKPack = GetSmemKPackK<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackKT()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; return GetTransposedAlignmentK<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(OGradDataType);
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
return MakeXTLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack, PixelsPerRow>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackV()
{ {
using BiasDataType = remove_cvref_t<typename Problem::BiasDataType>; return GetAlignmentV<Problem>();
constexpr index_t Banks = 32; // TODO: need change based on arch
constexpr index_t PixelsPerRow = Banks * 4 / sizeof(BiasDataType);
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(PixelsPerRow % kKPack == 0);
constexpr index_t NPerRow = PixelsPerRow / kKPack;
static_assert(kNPerBlock % NPerRow == 0);
static_assert(kMPerBlock % kKPack == 0);
constexpr auto biast_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kMPerBlock / kKPack>{},
number<kNPerBlock / NPerRow>{},
number<NPerRow>{},
number<kKPack>{}),
make_tuple(number<(kNPerBlock / NPerRow) * (PixelsPerRow + kKPack)>{},
number<PixelsPerRow + kKPack>{},
number<kKPack>{},
number<1>{}),
number<kKPack>{},
number<1>{});
constexpr auto biast_lds_block_desc = transform_tensor_descriptor(
biast_lds_block_desc_0,
make_tuple(
make_merge_transform(make_tuple(number<kNPerBlock / NPerRow>{}, number<NPerRow>{})),
make_merge_transform(make_tuple(number<kMPerBlock / kKPack>{}, number<kKPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return biast_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQ() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBias()
{ {
constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) * return GetAlignmentBias<Problem>();
MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_q;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeQT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackBiasT()
{ {
constexpr index_t smem_size_qt = [&]() { return GetTransposedAlignmentBias<Problem>();
if constexpr(QLoadOnce && !QTLoadOnce)
return 0;
else
return sizeof(typename Problem::QDataType) *
MakeQTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_qt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeK() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGrad()
{ {
constexpr index_t smem_size_k = sizeof(typename Problem::KDataType) * return GetAlignmentOGrad<Problem>();
MakeKLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_k;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeKT() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackOGradT()
{ {
constexpr index_t smem_size_kt = [&]() { return GetTransposedAlignmentOGrad<Problem>();
if constexpr(KLoadOnce && !KTLoadOnce)
return 0;
else
return sizeof(typename Problem::KDataType) *
MakeKTLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_kt;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeV() CK_TILE_HOST_DEVICE static constexpr auto GetSmemKPackSGrad()
{ {
constexpr index_t smem_size_v = [&]() { // TODO: this is for 3d layout
if constexpr(VLoadOnce) using GemmDataType = remove_cvref_t<typename Problem::GemmDataType>;
return 0; return 16 / sizeof(GemmDataType);
else
return sizeof(typename Problem::VDataType) *
MakeVLdsBlockDescriptor<Problem>().get_element_space_size();
}();
return smem_size_v;
} }
template <typename Problem> template <index_t MNPerBlock, index_t KPerBlock, index_t KPack>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeXLdsBlockDescriptor()
{ {
constexpr index_t smem_size_do = constexpr auto DataTypeSize = 2; // sizeof(F16/BF16)
sizeof(typename Problem::OGradDataType) * constexpr auto MNLdsLayer =
MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size(); (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
return smem_size_do;
constexpr auto x_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<KPerBlock / KPack * MNLdsLayer>{},
number<MNPerBlock / MNLdsLayer>{},
number<KPack>{}),
make_tuple(number<KPack>{}, number<KPerBlock * MNLdsLayer>{}, number<1>{}),
number<KPack>{},
number<1>{});
constexpr auto x_lds_block_desc_permuted = transform_tensor_descriptor(
x_lds_block_desc_0,
make_tuple(make_xor_transform(make_tuple(number<MNPerBlock / MNLdsLayer>{},
number<KPerBlock / KPack * MNLdsLayer>{})),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto x_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
x_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(
make_tuple(number<KPerBlock / KPack>{}, number<MNLdsLayer>{})),
make_pass_through_transform(number<MNPerBlock / MNLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto x_lds_block_desc = transform_tensor_descriptor(
x_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<MNPerBlock / MNLdsLayer>{}, number<MNLdsLayer>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return x_lds_block_desc;
} }
template <typename Problem> template <typename Problem,
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeOGradT() index_t MNPerBlock,
index_t KPerBlock,
index_t KPack,
index_t KPackT>
CK_TILE_HOST_DEVICE static constexpr auto MakeXTLdsBlockDescriptor()
{ {
constexpr index_t smem_size_dot = [&]() { // kfold and mpair dimension is not always required.
if constexpr(OGradLoadOnce && !OGradTLoadOnce) // more dimension in merge_transform increase the difficulty of generating immarg offset
return 0; // for compiler.
else constexpr auto MNPerXDL = Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
return sizeof(typename Problem::OGradDataType) * constexpr auto kBlockSize = Problem::kBlockSize;
MakeOGradTLdsBlockDescriptor<Problem>().get_element_space_size();
}(); constexpr auto MN0 = MNPerBlock / KPack;
return smem_size_dot; constexpr auto MN1 = KPack;
constexpr auto KThreadWrite = kBlockSize / MN0;
constexpr auto K0Number = KPerBlock / KPackT;
constexpr auto K0PerThreadWrite = K0Number / KThreadWrite;
constexpr auto KThreadRead = get_warp_size() / MNPerXDL; // assume 32x32x8 mfma
constexpr auto K0PerThreadRead = K0Number / KThreadRead;
constexpr auto kfold = (KPackT * MN0 * 2 > 128) ? 1 : 128 / (KPackT * MN0 * 2);
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mnpair<=n0
constexpr auto mnpair =
(KPackT * MNPerXDL * 2 > 128)
? 1
: ((128 / (KPackT * MNPerXDL * 2)) > MN0 ? MN0 : 128 / (KPackT * MNPerXDL * 2));
constexpr auto xt_lds_block_desc_raw = make_naive_tensor_descriptor(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * MN1>{},
number<kfold * MN0 / mnpair>{},
number<mnpair>{},
KPackT),
make_tuple(number<KPackT * kfold * MN0 * KThreadReadPerm * MN1 * K0PerThreadWrite>{},
number<KPackT * kfold * MN0 * KThreadReadPerm * MN1>{},
number<KPackT * kfold * MN0>{},
number<KPackT * mnpair>{},
number<KPackT>{},
number<1>{}),
number<KPackT>{},
number<1>{});
constexpr auto xt_lds_block_desc_permuted = transform_tensor_descriptor(
xt_lds_block_desc_raw,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * MN1>{}, number<kfold * MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto xt_lds_block_desc_unmerged = transform_tensor_descriptor(
xt_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<MN1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<MN0 / mnpair>{})),
make_pass_through_transform(number<mnpair>{}),
make_pass_through_transform(KPackT)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto xt_lds_block_desc = transform_tensor_descriptor(
xt_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
number<KPackT>{})),
make_merge_transform_v3_division_mod(
make_tuple(number<MN0 / mnpair>{}, number<mnpair>{}, number<MN1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return xt_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeSGrad() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{ {
constexpr index_t smem_size_ds = constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
sizeof(typename Problem::GemmDataType) * constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size(); constexpr index_t kKPack = GetSmemKPackK<Problem>();
return smem_size_ds;
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSizeBias() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{ {
constexpr index_t smem_size_bias = [&]() { using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS) constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
return sizeof(typename Problem::BiasDataType) * using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
MakeBiasTLdsBlockDescriptor<Problem>().get_element_space_size();
else
return 0;
}();
return smem_size_bias;
}
template <typename Problem> constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
CK_TILE_HOST_DEVICE static constexpr ck_tile::index_t GetSmemSize() constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
{
constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>();
constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>();
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_transpose = max(smem_size_ds, smem_size_bias);
index_t smem_size = 0;
if constexpr(QLoadOnce && OGradLoadOnce)
smem_size += smem_size_q + smem_size_qt + smem_size_do + smem_size_dot +
smem_size_transpose; // 1~4 & 10
else if(QLoadOnce && !OGradLoadOnce && !OGradTLoadOnce)
smem_size += smem_size_q + smem_size_qt +
max(smem_size_do,
smem_size_dot,
smem_size_transpose); // 5/7/11 TODO: Multiple buffers strategy
else if(!QLoadOnce && !QTLoadOnce && OGradLoadOnce)
smem_size += smem_size_do + smem_size_dot +
max(smem_size_q,
smem_size_qt,
smem_size_transpose); // 6/8/12 TODO: Multiple buffers strategy
else if(!QLoadOnce && !QTLoadOnce && !OGradLoadOnce && !OGradTLoadOnce)
smem_size += max(smem_size_q,
smem_size_qt,
smem_size_do,
smem_size_dot,
smem_size_transpose); // 9/13 TODO: Multiple buffers strategy
// 14/15 needs to be adjusted
if constexpr(KLoadOnce)
smem_size += (smem_size_k + smem_size_kt); // 1~13
else
smem_size =
max(smem_size_k, smem_size_kt, smem_size); // 14/15 TODO: Multiple buffers strategy
return max(smem_size, smem_size_v); // 15 constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
} constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
template <typename Problem, typename BlockGemm> constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDDramTileDistribution() constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
{
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane; constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
constexpr index_t N0 = NWarp; k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * 2; constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / 2;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution( return k_block_dstr;
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{ {
using VDataType = remove_cvref_t<typename Problem::VDataType>; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = 16 / sizeof(VDataType); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeQDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kVPack = GetSmemKPackV<Problem>();
constexpr index_t kKPerBlock = [&]() {
if constexpr(QLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t K1 = GetAlignmentQ<Problem>(); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
constexpr index_t K0 = kKPerBlock / K1; }
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( template <typename Problem>
tile_distribution_encoding<sequence<>, CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
tuple<sequence<1>, sequence<2, 0>>, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
if constexpr(KLoadOnce)
return Problem::BlockFmhaShape::kQKHeaddim;
else
return Problem::BlockFmhaShape::kK0;
}();
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N0 = kNPerBlock / (N2 * N1);
return make_static_tile_distribution( constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<0, 1>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<0, 1>>{}); sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKRegWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(OGradLoadOnce)
return Problem::BlockFmhaShape::kVHeaddim;
else
return Problem::BlockFmhaShape::kK2;
}();
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t N2 = GetTransposedAlignmentK<Problem>();
// coalesce reading for each blocks constexpr index_t N1 = get_warp_size() / K0;
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t M0 = kMPerBlock / (M2 * M1);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<1, 2>, sequence<2, 1>,
sequence<0, 1>>{}); sequence<1, 2>>{});
} }
template <typename DataType, index_t MPerBlock, index_t KPerBlock> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreXDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKLdsWriteBlockDescriptor()
{ {
constexpr index_t K1 = 16 / sizeof(DataType); // Hold all data
constexpr index_t K0 = KPerBlock / K1; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t M2 = 1; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t M1 = get_warp_size();
constexpr index_t M0 = MPerBlock / M1;
return make_static_tile_distribution( constexpr index_t kKPack = GetSmemKPackK<Problem>();
tile_distribution_encoding<sequence<>, constexpr index_t kKPackT = GetSmemKPackKT<Problem>();
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1>>, return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
tuple<sequence<0>, sequence<1>>,
sequence<1, 2, 2>,
sequence<2, 0, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreODramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKTLdsReadBlockDescriptor()
{ {
using ODataType = remove_cvref_t<typename Problem::ODataType>; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kBlockSize = Problem::kBlockSize; auto shuffled_k_lds_block_desc = MakeShuffledKLdsWriteBlockDescriptor<Problem>();
constexpr index_t kKPerBlock = Problem::kVHeaddim;
return MakePreXDramTileDistribution<ODataType, kBlockSize, kKPerBlock>(); return transform_tensor_descriptor(
shuffled_k_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakePreOGradDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeKTRegBlockDescriptor()
{ {
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t kKPerBlock = Problem::kVHeaddim; constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
return MakePreXDramTileDistribution<OGradDataType, kBlockSize, kKPerBlock>(); constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto kt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto kt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
kt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto kt_block_dstr = make_static_tile_distribution(kt_block_dstr_encode);
return kt_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeQTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t kKPerBlock = [&]() {
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t N1 = GetTransposedAlignmentQ<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackQ<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t kKPerBlock = [&]() { using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
if constexpr(QTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK3;
}();
constexpr index_t N1 = GetTransposedAlignmentQ<Problem>(); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackQ<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
tile_distribution_encoding<sequence<>, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>, constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
tuple<sequence<0>, sequence<1, 0, 2>>, constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto q_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<1, 3>>{}); sequence<0, 0>>{};
constexpr auto q_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
q_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto q_block_dstr = make_static_tile_distribution(q_block_dstr_encode);
return q_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeKTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQRegWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() {
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t N1 = GetTransposedAlignmentK<Problem>(); constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t K1 = GetAlignmentQ<Problem>();
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t K3 = total_pixels / N1; constexpr index_t N2 = GetTransposedAlignmentQ<Problem>();
constexpr index_t kKPack = GetSmemKPackK<Problem>(); constexpr index_t N1 = get_warp_size() / K0;
static_assert(kKPack % K3 == 0); constexpr index_t N0 = kBlockSize / get_warp_size();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<>, tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<2>, sequence<2, 1, 2>>, tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>, tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>, sequence<2, 1>,
sequence<3, 1>>{}); sequence<1, 2>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledKTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledQLdsWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; // Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
if constexpr(KTLoadOnce)
return Problem::BlockFmhaShape::kN0;
else
return Problem::BlockFmhaShape::kK4;
}();
constexpr index_t N1 = GetTransposedAlignmentK<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t kKPackT = GetSmemKPackQT<Problem>();
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackK<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
tile_distribution_encoding<sequence<>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeOGradTDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeQTLdsReadBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; // Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPerBlock = [&]() { constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t N1 = GetTransposedAlignmentOGrad<Problem>();
constexpr index_t N0 = kNPerBlock / N1; // P
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; auto shuffled_q_lds_block_desc = MakeShuffledQLdsWriteBlockDescriptor<Problem>();
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
static_assert(kKPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution( return transform_tensor_descriptor(
tile_distribution_encoding<sequence<>, shuffled_q_lds_block_desc,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>, make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
tuple<sequence<2>, sequence<2, 1, 2>>, make_pass_through_transform(number<kKPerBlock>{})),
tuple<sequence<0>, sequence<1, 0, 2>>, make_tuple(sequence<1>{}, sequence<0>{}),
sequence<2, 1>, make_tuple(sequence<0>{}, sequence<1>{}));
sequence<3, 1>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradTRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQTRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t kKPerBlock = [&]() { using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
if constexpr(OGradTLoadOnce)
return Problem::BlockFmhaShape::kM0;
else
return Problem::BlockFmhaShape::kK1;
}();
constexpr index_t N1 = GetTransposedAlignmentOGrad<Problem>(); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t N0 = kNPerBlock / N1; constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true?
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
tile_distribution_encoding<sequence<>, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>, constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
tuple<sequence<0>, sequence<1, 0, 2>>, constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto qt_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>, sequence<1, 2>,
sequence<1, 3>>{}); sequence<0, 0>>{};
constexpr auto qt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
qt_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto qt_block_dstr = make_static_tile_distribution(qt_block_dstr_encode);
return qt_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_DEVICE static constexpr auto MakeBiasTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeSGradTRegSliceBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>(); constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t N0 = kNPerBlock / N1; // P constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<1>{});
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t M3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
static_assert(kMPerBlock == M0 * M1 * M2 * M3);
return make_static_tile_distribution( constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
tile_distribution_encoding<sequence<>, constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>, constexpr auto dst_block_outer_dstr_encoding =
tuple<sequence<0>, sequence<1, 0, 2>>, tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>, sequence<1, 2>,
sequence<3, 1>>{}); sequence<0, 0>>{};
constexpr auto dst_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dst_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto dst_block_dstr = make_static_tile_distribution(dst_block_dstr_encode);
return dst_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsWriteBlockDescriptor()
{ {
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; using LSEDType = remove_cvref_t<typename Problem::DDataType>;
constexpr index_t kMPack = 16 / sizeof(LSEDType);
constexpr index_t N1 = GetTransposedAlignmentBias<Problem>(); constexpr auto lsed_lds_block_desc =
constexpr index_t N0 = kNPerBlock / N1; make_naive_tensor_descriptor(make_tuple(number<kMPerBlock>{}),
constexpr index_t total_pixels = kMPerBlock * kNPerBlock / kBlockSize; make_tuple(number<1>{}),
static_assert(total_pixels % N1 == 0); // TODO: this is not always true? number<kMPack>{},
constexpr index_t M3 = total_pixels / N1; number<1>{});
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
static_assert(kKPack % M3 == 0);
constexpr index_t M2 = kKPack / M3; // TODO: this dimention could be outside single wave
constexpr index_t M1 = get_warp_size() / (M2 * N0);
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution( return lsed_lds_block_desc;
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2, M3>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2, 1>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<1, 3>>{});
} }
template <typename BlockGemm> template <typename Problem, typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasTTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeLSEDLdsReadBlockDescriptor()
{ {
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile()); constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
return c_block_tensor_type::get_tile_distribution(); using WG = remove_cvref_t<decltype(config.template at<0>())>;
} constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t N1 = WG::WarpGemmAttribute::Impl::kCNLane;
constexpr index_t N0 = NWarp;
// M4 *2 and M2 /2 when swizzle mode enabled
constexpr index_t SwizzleConfig = WG::kM == 16 ? 1 : 2;
// constexpr index_t SwizzleConfig = 1;
constexpr index_t M4 = WG::WarpGemmAttribute::Impl::kCM1PerLane * SwizzleConfig;
constexpr index_t M3 = WG::WarpGemmAttribute::Impl::kCMLane;
constexpr index_t M2 = WG::WarpGemmAttribute::Impl::kCM0PerLane / SwizzleConfig;
constexpr index_t M1 = MWarp;
constexpr index_t M0 = kMPerBlock / (M1 * WG::WarpGemmAttribute::Impl::kM);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<N0, N1>,
tuple<sequence<M0, M1, M2, M3, M4>>,
tuple<sequence<1, 0>, sequence<1, 0>>,
tuple<sequence<1, 0>, sequence<3, 1>>,
sequence<1, 1, 1>,
sequence<0, 2, 4>>{});
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradLdsBlockDescriptor()
{ {
using BlockGemmProblem = // Hold full block data
BlockGemmPipelineProblem<typename Problem::QDataType, constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
typename Problem::KDataType, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK0>>;
constexpr auto warp_gemm = []() { constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
std::is_same_v<typename Problem::KDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
}
else if constexpr(std::is_same_v<typename Problem::QDataType, bf16_t> &&
std::is_same_v<typename Problem::KDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
}
}();
using BlockGemmPolicy = return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::QDataType, }
typename Problem::KDataType,
typename Problem::AccDataType, template <typename Problem>
typename Problem::BlockFmhaShape::Gemm0BlockWarps, CK_TILE_HOST_DEVICE static constexpr auto MakeOGradRegSliceBlockDescriptor()
decltype(warp_gemm)>; {
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto do_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto do_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
do_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto do_block_dstr = make_static_tile_distribution(do_block_dstr_encode);
return do_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradRegWriteBlockDescriptor()
{ {
using BlockGemmProblem = constexpr index_t kBlockSize = Problem::kBlockSize;
BlockGemmPipelineProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType,
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kVHeaddim,
Problem::BlockFmhaShape::kK1>>;
using WarpGemm = constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
typename Problem::OGradDataType, constexpr index_t K1 = GetAlignmentOGrad<Problem>();
typename Problem::AccDataType, constexpr index_t K0 = kKPerBlock / K1;
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}), constexpr index_t N2 = GetTransposedAlignmentOGrad<Problem>();
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<1>{}), constexpr index_t N1 = get_warp_size() / K0;
Problem::BlockFmhaShape::Gemm1WarpTile::at(number<2>{}), constexpr index_t N0 = kBlockSize / get_warp_size();
true>;
using BlockGemmPolicy = return make_static_tile_distribution(
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::GemmDataType, tile_distribution_encoding<sequence<>,
typename Problem::OGradDataType, tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
typename Problem::AccDataType, tuple<sequence<1>, sequence<1, 2>>,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, tuple<sequence<0>, sequence<1, 0>>,
WarpGemm>; sequence<2, 1>,
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; sequence<1, 2>>{});
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledOGradLdsWriteBlockDescriptor()
{ {
using BlockGemmProblem = // Hold all data
BlockGemmPipelineProblem<typename Problem::OGradDataType, constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
typename Problem::VDataType, constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK2>>;
constexpr auto warp_gemm = []() { constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> && constexpr index_t kKPackT = GetSmemKPackOGradT<Problem>();
std::is_same_v<typename Problem::VDataType, half_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
}
else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
std::is_same_v<typename Problem::VDataType, bf16_t> &&
std::is_same_v<typename Problem::AccDataType, float>)
{
return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
}
}();
using BlockGemmPolicy = return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kKPerBlock, kKPack, kKPackT>();
BlockGemmASmemBRegCRegV1CustomPolicy<typename Problem::OGradDataType, }
typename Problem::VDataType,
typename Problem::AccDataType,
typename Problem::BlockFmhaShape::Gemm2BlockWarps,
decltype(warp_gemm)>;
return BlockGemmASmemBRegCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
}
// template <typename Problem>
// CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
// {
// using BlockGemmProblem =
// BlockGemmPipelineProblem<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// Problem::kBlockSize,
// TileGemmShape<Problem::BlockFmhaShape::kM0,
// Problem::BlockFmhaShape::kN0,
// Problem::BlockFmhaShape::kK2>>;
// constexpr auto warp_gemm = []() {
// if constexpr(std::is_same_v<typename Problem::OGradDataType, half_t> &&
// std::is_same_v<typename Problem::VDataType, half_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaF16F16F32M32N32K16SwizzleA{};
// }
// else if constexpr(std::is_same_v<typename Problem::OGradDataType, bf16_t> &&
// std::is_same_v<typename Problem::VDataType, bf16_t> &&
// std::is_same_v<typename Problem::AccDataType, float>)
// {
// return WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA{};
// }
// }();
// using BlockGemmPolicy =
// BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::OGradDataType,
// typename Problem::VDataType,
// typename Problem::AccDataType,
// typename
// Problem::BlockFmhaShape::Gemm2BlockWarps,
// decltype(warp_gemm)>;
// return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{};
// }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTLdsReadBlockDescriptor()
{
// Hold all data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kM0;
auto shuffled_do_lds_block_desc = MakeShuffledOGradLdsWriteBlockDescriptor<Problem>();
return transform_tensor_descriptor(
shuffled_do_lds_block_desc,
make_tuple(make_pass_through_transform(number<kNPerBlock>{}),
make_pass_through_transform(number<kKPerBlock>{})),
make_tuple(sequence<1>{}, sequence<0>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeOGradTRegSliceBlockDescriptor()
{ {
using BlockGemmProblem = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
BlockGemmPipelineProblem<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::QDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK3>>;
using WarpGemm = constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
typename Problem::QDataType,
typename Problem::AccDataType, constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kVHeaddim;
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}), // constexpr index_t kNPerBlock = 32;
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<1>{}), constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
Problem::BlockFmhaShape::Gemm3WarpTile::at(number<2>{}),
true>; constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
using BlockGemmPolicy = constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
BlockGemmARegBSmemCRegV1CustomPolicy<typename Problem::GemmDataType,
typename Problem::QDataType, constexpr auto dot_block_outer_dstr_encoding =
typename Problem::AccDataType, tile_distribution_encoding<sequence<MWarp>,
typename Problem::BlockFmhaShape::Gemm3BlockWarps, tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
WarpGemm>; tuple<sequence<0, 1>>,
return BlockGemmARegBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto dot_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
dot_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto dot_block_dstr = make_static_tile_distribution(dot_block_dstr_encode);
return dot_block_dstr;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakePTRegSliceBlockDescriptor()
{ {
using BlockGemmProblem = using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
BlockGemmPipelineProblem<typename Problem::GemmDataType, constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
typename Problem::KDataType, using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
typename Problem::AccDataType,
Problem::kBlockSize,
TileGemmShape<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kQKHeaddim,
Problem::BlockFmhaShape::kK4>>;
using WarpGemm = constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<1>{});
typename Problem::KDataType,
typename Problem::AccDataType, constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<0>{}), constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<1>{}),
Problem::BlockFmhaShape::Gemm4WarpTile::at(number<2>{}), constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
true>; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using BlockGemmPolicy =
BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::GemmDataType, constexpr auto pt_block_outer_dstr_encoding =
typename Problem::KDataType, tile_distribution_encoding<sequence<NWarp>,
typename Problem::AccDataType, tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
typename Problem::BlockFmhaShape::Gemm4BlockWarps, tuple<sequence<1, 0>>,
WarpGemm>; tuple<sequence<1, 0>>,
return BlockGemmASmemBSmemCRegV1<BlockGemmProblem, BlockGemmPolicy>{}; sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto pt_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
pt_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto pt_block_dstr = make_static_tile_distribution(pt_block_dstr_encode);
return pt_block_dstr;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradLdsBlockDescriptor()
{
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPack = GetSmemKPackSGrad<Problem>();
return MakeXLdsBlockDescriptor<kMPerBlock, kKPerBlock, kKPack>();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeSGradRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetSGradKTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK4;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto ds_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto ds_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
ds_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
constexpr auto ds_block_dstr = make_static_tile_distribution(ds_block_dstr_encode);
return ds_block_dstr;
}
template <typename Problem, typename PTOutTensor, typename PInTensor>
CK_TILE_DEVICE static constexpr void PTFromGemm0CToGemm1A(PTOutTensor& pt_out,
const PInTensor& p_in)
{
if constexpr(Problem::BlockFmhaShape::Gemm1WarpTile::at(number<0>{}) == 16)
{
using BlockGemm = remove_cvref_t<decltype(GetPTOGradTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm1BlockWarps::at(number<0>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK1;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto pt_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
pt_warp_tensor.get_thread_buffer() = p_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
pt_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
pt_warp_tensor.get_thread_buffer());
});
});
}
else
{
pt_out.get_thread_buffer() = p_in.get_thread_buffer();
}
}
template <typename Problem, typename SGradTOutTensor, typename SGradInTensor>
CK_TILE_DEVICE static constexpr void SGradTFromGemm2CToGemm3A(SGradTOutTensor& dst_out,
const SGradInTensor& ds_in)
{
if constexpr(Problem::BlockFmhaShape::Gemm3WarpTile::at(number<0>{}) == 16)
{
using BlockGemm = remove_cvref_t<decltype(GetSGradTQTBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm3BlockWarps::at(number<0>{});
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK3;
constexpr index_t MIterPerWarp = kMPerBlock / (MWarp * WarpGemm::kM);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
using AWarpDstr = typename WarpGemm::AWarpDstr;
using CWarpDstr = typename WarpGemm::CWarpDstr;
auto dst_warp_tensor =
make_static_distributed_tensor<typename Problem::GemmDataType>(CWarpDstr{});
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
dst_warp_tensor.get_thread_buffer() = ds_in.get_y_sliced_thread_data(
merge_sequences(sequence<kIter, mIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
dst_out.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths),
dst_warp_tensor.get_thread_buffer());
});
});
}
else
{
dst_out.get_thread_buffer() = ds_in.get_thread_buffer();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBiasTileDistribution()
{
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t N1 = GetAlignmentBias<Problem>();
constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t M2 = GetTransposedAlignmentBias<Problem>();
constexpr index_t M1 = get_warp_size() / N0;
constexpr index_t M0 = kBlockSize / get_warp_size();
return make_static_tile_distribution(
tile_distribution_encoding<sequence<>,
tuple<sequence<M0, M1, M2>, sequence<N0, N1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<1, 0>>,
sequence<2, 1>,
sequence<1, 2>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasLdsBlockDescriptor()
{
// Hold full block data
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPack = GetSmemKPackBias<Problem>();
constexpr index_t kKPackT = GetSmemKPackBiasT<Problem>();
return MakeXTLdsBlockDescriptor<Problem, kNPerBlock, kMPerBlock, kKPack, kKPackT>();
}
template <typename BlockGemm>
CK_TILE_HOST_DEVICE static constexpr auto MakeBiasSTileDistribution()
{
using c_block_tensor_type = decltype(BlockGemm{}.MakeCBlockTile());
return c_block_tensor_type::get_tile_distribution();
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQ()
{
constexpr index_t smem_size_q = sizeof(typename Problem::QDataType) *
MakeQLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_q;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeQT()
{
constexpr index_t smem_size_qt =
sizeof(typename Problem::QDataType) *
MakeShuffledQLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_qt;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeK()
{
constexpr index_t smem_size_k =
sizeof(typename Problem::KDataType) *
MakeKLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_k;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeKT()
{
constexpr index_t smem_size_kt =
sizeof(typename Problem::KDataType) *
MakeKTLdsReadBlockDescriptor<Problem>().get_element_space_size();
return smem_size_kt;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeLSE()
{
constexpr index_t smem_size_lse =
sizeof(typename Problem::LSEDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_lse;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeD()
{
constexpr index_t smem_size_d =
sizeof(typename Problem::DDataType) *
MakeLSEDLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_d;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeV()
{
constexpr index_t smem_size_v =
sizeof(typename Problem::VDataType) *
MakeVLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_v;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGrad()
{
constexpr index_t smem_size_do =
sizeof(typename Problem::OGradDataType) *
MakeOGradLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_do;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeOGradT()
{
constexpr index_t smem_size_dot =
sizeof(typename Problem::OGradDataType) *
MakeShuffledOGradLdsWriteBlockDescriptor<Problem>().get_element_space_size();
return smem_size_dot;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeSGrad()
{
constexpr index_t smem_size_ds =
sizeof(typename Problem::GemmDataType) *
MakeSGradLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_ds;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeBias()
{
constexpr index_t smem_size_bias = [&]() {
if constexpr(Problem::BiasEnum == BlockAttentionBiasEnum::ELEMENTWISE_BIAS)
return sizeof(typename Problem::BiasDataType) *
MakeBiasLdsBlockDescriptor<Problem>().get_element_space_size();
else
return 0;
}();
return smem_size_bias;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_q = GetSmemSizeQ<Problem>();
constexpr index_t smem_size_qt = GetSmemSizeQT<Problem>();
constexpr index_t smem_size_lse = GetSmemSizeLSE<Problem>();
constexpr index_t smem_size_k = GetSmemSizeK<Problem>();
constexpr index_t smem_size_kt = GetSmemSizeKT<Problem>();
constexpr index_t smem_size_v = GetSmemSizeV<Problem>();
constexpr index_t smem_size_do = GetSmemSizeOGrad<Problem>();
constexpr index_t smem_size_dot = GetSmemSizeOGradT<Problem>();
constexpr index_t smem_size_d = GetSmemSizeD<Problem>();
constexpr index_t smem_size_ds = GetSmemSizeSGrad<Problem>();
constexpr index_t smem_size_bias = GetSmemSizeBias<Problem>();
constexpr index_t smem_size_stage0_0 = smem_size_k + smem_size_kt;
constexpr index_t smem_size_stage0_1 = smem_size_v;
constexpr index_t smem_size_stage1 = smem_size_qt + smem_size_q + +smem_size_dot +
smem_size_do + smem_size_lse + smem_size_d +
max(smem_size_bias, smem_size_ds);
return max(smem_size_stage0_0, smem_size_stage0_1, smem_size_stage1);
}
template <typename Problem_>
struct HotLoopScheduler
{
using Problem = Problem_;
template <index_t GemmStage>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler()
{
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<0>()
{
// Mem: Q, LSE, OGrad, D global load, OGrad^T LDS load
// Comp: Q x K
constexpr index_t VMEM_READ_INST =
Q_VMEM_READ + OGrad_VMEM_READ + LSE_VMEM_READ + D_VMEM_READ;
constexpr index_t LDS_READ_INST = OGradT_LDS_READ;
constexpr index_t MFMA_INST = Gemm0MFMA;
// Evenly distributed to relieve SQ->TA FIFO pressure
constexpr index_t MFMA_PER_VMEM_READ = MFMA_INST / VMEM_READ_INST;
constexpr index_t MFMA_Remainder = MFMA_INST - MFMA_PER_VMEM_READ * VMEM_READ_INST;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, VMEM_READ_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read
static_for<0, MFMA_PER_VMEM_READ, 1>{}([&](auto j) {
ignore = j;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
});
static_for<0, MFMA_Remainder, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<1>()
{
// Mem: Q^T LDS load
// Comp: OGrad x V
constexpr index_t LDS_READ_INST = QT_LDS_READ;
constexpr index_t MFMA_INST = Gemm1MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA = LDS_READ_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS read
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<2>()
{
// Mem: Q, QT, LSE, OGrad, OGradT, D, LDS store
// Comp: PT x OGrad
constexpr index_t LDS_WRITE_INST = Q_LDS_WRITE + QT_LDS_WRITE + OGrad_LDS_WRITE +
OGradT_LDS_WRITE + LSE_LDS_WRITE + D_LDS_WRITE;
constexpr index_t MFMA_INST = Gemm2MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA = LDS_WRITE_INST / MFMA_INST;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS write
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<3>()
{
// Mem: SGradT LDS store, SGrad, Q, LSE LDS load.
// Comp: SGradT x QT
constexpr index_t LDS_WRITE_INST = SGradT_LDS_WRITE;
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P1 + Q_LDS_READ + LSE_LDS_READ;
constexpr index_t MFMA_INST = Gemm3MFMA;
// To hide instruction issue latency
constexpr index_t LDS_WRITE_PER_MFMA =
LDS_WRITE_INST / MFMA_INST >= 1 ? LDS_WRITE_INST / MFMA_INST : 1;
constexpr index_t MFMA_INST_LDS_WRITE = LDS_WRITE_INST / LDS_WRITE_PER_MFMA;
constexpr index_t LDS_READ_PER_MFMA =
(MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE) > 0
? LDS_READ_INST / (MFMA_INST - MFMA_INST_LDS_WRITE)
: 1
: 0;
static_for<0, MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x200, LDS_WRITE_PER_MFMA, 0); // DS Write
});
static_for<0, MFMA_INST - MFMA_INST_LDS_WRITE, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
});
}
template <>
CK_TILE_DEVICE static constexpr void GemmStagedScheduler<4>()
{
// Mem: SGrad, OGrad, D LDS load.
// Comp: SGrad x KT
constexpr index_t LDS_READ_INST = SGradT_LDS_READ_P2 + OGrad_LDS_READ + D_LDS_READ;
constexpr index_t MFMA_INST = Gemm4MFMA;
// To hide instruction issue latency
constexpr index_t LDS_READ_PER_MFMA =
LDS_READ_INST / MFMA_INST > 0 ? LDS_READ_INST / MFMA_INST : 1;
static_for<0, MFMA_INST, 1>{}([&](auto i) {
ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA
__builtin_amdgcn_sched_group_barrier(0x100, LDS_READ_PER_MFMA, 0); // DS Read
});
}
private:
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr index_t kM0 = Problem::BlockFmhaShape::kM0;
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<0>{});
static constexpr index_t WarpGemmN =
Problem::BlockFmhaShape::Gemm0WarpTile::at(number<1>{});
static constexpr index_t WarpGemmK = WarpGemmM == 16 ? 16 : 8;
static constexpr index_t Gemm4MWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<0>{});
static constexpr index_t Gemm4NWarp =
Problem::BlockFmhaShape::Gemm4BlockWarps::at(number<1>{});
// Compute
static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm4MFMA =
kM0 * kQKHeaddim * kN0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
// VMEM
static constexpr index_t Q_VMEM_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t OGrad_VMEM_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t LSE_VMEM_READ = 1;
static constexpr index_t D_VMEM_READ = 1;
// LDS Read
static constexpr index_t OGradT_LDS_READ =
kM0 * kVHeaddim / get_warp_size() / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t QT_LDS_READ =
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ =
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write
static constexpr index_t Q_LDS_WRITE =
kM0 * kQKHeaddim / Problem::kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t QT_LDS_WRITE =
kM0 * kQKHeaddim / kBlockSize / GetTransposedAlignmentQ<Problem>();
static constexpr index_t OGrad_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t OGradT_LDS_WRITE =
kM0 * kVHeaddim / kBlockSize / GetTransposedAlignmentOGrad<Problem>();
static constexpr index_t LSE_LDS_WRITE = 1;
static constexpr index_t D_LDS_WRITE = 1;
static constexpr index_t SGradT_LDS_WRITE = kM0 * kN0 / kBlockSize;
};
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -8,9 +8,8 @@ namespace ck_tile { ...@@ -8,9 +8,8 @@ namespace ck_tile {
// This class is used for codegen pattern matching // This class is used for codegen pattern matching
enum class BlockFmhaBwdPipelineEnum enum class BlockFmhaBwdPipelineEnum
{ {
KSKTSVR = 0, KRKTRVR_IGLP = 0,
QSKSVROGradS, KRKTRVR,
KSVR,
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -24,7 +24,9 @@ template <typename QDataType_, ...@@ -24,7 +24,9 @@ template <typename QDataType_,
typename BiasGradDataType_, typename BiasGradDataType_,
typename BlockFmhaShape_, typename BlockFmhaShape_,
bool kIsGroupMode_, bool kIsGroupMode_,
bool kIsDeterministic_,
typename FmhaMask_, typename FmhaMask_,
typename FmhaDropout_,
typename Traits_> typename Traits_>
struct BlockFmhaBwdPipelineProblem struct BlockFmhaBwdPipelineProblem
{ {
...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem ...@@ -45,10 +47,12 @@ struct BlockFmhaBwdPipelineProblem
using BiasGradDataType = remove_cvref_t<BiasGradDataType_>; using BiasGradDataType = remove_cvref_t<BiasGradDataType_>;
using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>; using BlockFmhaShape = remove_cvref_t<BlockFmhaShape_>;
using FmhaMask = remove_cvref_t<FmhaMask_>; using FmhaMask = remove_cvref_t<FmhaMask_>;
using FmhaDropout = remove_cvref_t<FmhaDropout_>;
using Traits = remove_cvref_t<Traits_>; using Traits = remove_cvref_t<Traits_>;
static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size(); static constexpr index_t kBlockSize = BlockFmhaShape::NumWarps * get_warp_size();
static constexpr bool kIsGroupMode = kIsGroupMode_; static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits // attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ; static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem ...@@ -57,7 +61,6 @@ struct BlockFmhaBwdPipelineProblem
static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV; static constexpr bool kPadHeadDimV = Traits::kPadHeadDimV;
static constexpr auto BiasEnum = Traits::BiasEnum; static constexpr auto BiasEnum = Traits::BiasEnum;
static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad; static constexpr bool kHasBiasGrad = Traits::kHasBiasGrad;
static constexpr bool kHasDropout = Traits::kHasDropout;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
...@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem ...@@ -88,4 +91,35 @@ struct BlockFmhaBwdOGradDotOPipelineProblem
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu; static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
}; };
template <typename AccDataType_,
typename QGradDataType_,
index_t kBlockSize_,
index_t kM0_,
index_t kN0_,
index_t kQKHeaddim_,
bool kIsGroupMode_,
bool kIsDeterministic_,
typename Traits_>
struct BlockFmhaBwdConvertQGradPipelineProblem
{
using AccDataType = remove_cvref_t<AccDataType_>;
using QGradDataType = remove_cvref_t<QGradDataType_>;
using Traits = remove_cvref_t<Traits_>;
static_assert(0 < kBlockSize_ && kBlockSize_ % get_warp_size() == 0,
"kBlockSize should be divisible by get_warp_size()");
static constexpr index_t kBlockSize = kBlockSize_;
static constexpr index_t kM0 = kM0_;
static constexpr index_t kN0 = kN0_;
static constexpr index_t kQKHeaddim = kQKHeaddim_;
static constexpr bool kIsGroupMode = kIsGroupMode_;
static constexpr bool kIsDeterministic = kIsDeterministic_;
// attributes from traits
static constexpr bool kPadSeqLenQ = Traits::kPadSeqLenQ;
static constexpr bool kPadHeadDimQ = Traits::kPadHeadDimQ;
static constexpr index_t kBlockPerCu = Traits::kBlockPerCu;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -231,7 +231,9 @@ struct BlockFmhaPipelineQRKSVSAsync ...@@ -231,7 +231,9 @@ struct BlockFmhaPipelineQRKSVSAsync
// TODO: we use async Copy for K, which is inline asm // TODO: we use async Copy for K, which is inline asm
// a side effect is we have to use inline asm for q as well // a side effect is we have to use inline asm for q as well
auto q = decltype(load_tile(q_dram_window)){}; auto q = decltype(load_tile(q_dram_window)){};
set_tile(q, number<0>{}); // use per-dword clear to avoid scratch // TODO: start from rocm-6.2, compiler will have problem if manually set clear of q.
// however, q would be cleared in the constructor of static distributed tensor
// set_tile(q, number<0>{}); // use per-dword clear to avoid scratch
load_tile_raw(q, q_dram_window); load_tile_raw(q, q_dram_window);
__builtin_amdgcn_sched_barrier(0); __builtin_amdgcn_sched_barrier(0);
......
...@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits ...@@ -86,4 +86,14 @@ struct TileFmhaBwdOGradDotOTraits
static constexpr index_t kBlockPerCu = kBlockPerCu_; static constexpr index_t kBlockPerCu = kBlockPerCu_;
}; };
template <bool kPadSeqLenQ_ /* padding for seqlen_q */,
bool kPadHeadDimQ_ /* paddding for hdim_q */,
index_t kBlockPerCu_ = 2 /* hint to occupancy */>
struct TileFmhaBwdConvertQGradTraits
{
static constexpr bool kPadSeqLenQ = kPadSeqLenQ_;
static constexpr bool kPadHeadDimQ = kPadHeadDimQ_;
static constexpr index_t kBlockPerCu = kBlockPerCu_;
};
} // namespace ck_tile } // namespace ck_tile
...@@ -5,6 +5,9 @@ ...@@ -5,6 +5,9 @@
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_default_policy.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_breg_creg_v1_default_policy.hpp"
namespace ck_tile {
// A is block distributed tensor
// B is block distributed tensor
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmARegBRegCRegV1DefaultPolicy>
struct BlockGemmARegBRegCRegV1
{
using Problem = remove_cvref_t<Problem_>;
using Policy = remove_cvref_t<Policy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B
template <typename CBlockTensor, typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
static_assert(std::is_same_v<ADataType, remove_cv_t<typename ABlockTensor::DataType>> &&
std::is_same_v<BDataType, remove_cv_t<typename BBlockTensor::DataType>> &&
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!");
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
// M->N Warp
constexpr auto a_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<NWarp>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<KIterPerWarp>>,
tuple<sequence<1, 0>>,
tuple<sequence<1, 0>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto b_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{});
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{});
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
// check ABC-block-distribution
static_assert(
std::is_same_v<remove_cvref_t<decltype(a_block_dstr_encode)>,
remove_cvref_t<decltype(ABlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"A distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(b_block_dstr_encode)>,
remove_cvref_t<decltype(BBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"B distribution is wrong!");
static_assert(
std::is_same_v<remove_cvref_t<decltype(c_block_dstr_encode)>,
remove_cvref_t<decltype(CBlockTensor::get_tile_distribution()
.get_static_tile_distribution_encoding())>>,
"C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor;
constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto b_warp_y_lengths =
to_sequence(BWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto a_warp_y_index_zeros = uniform_sequence_gen_t<AWarpDstr::NDimY, 0>{};
constexpr auto b_warp_y_index_zeros = uniform_sequence_gen_t<BWarpDstr::NDimY, 0>{};
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A Block window
AWarpTensor a_warp_tensor;
a_warp_tensor.get_thread_buffer() = a_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, kIter>{}, a_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, a_warp_y_lengths));
static_for<0, NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B block tensor
BWarpTensor b_warp_tensor;
b_warp_tensor.get_thread_buffer() = b_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<nIter, kIter>{}, b_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, b_warp_y_lengths));
// read C warp tensor from C block tensor
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
CK_TILE_DEVICE constexpr auto MakeCBlockTile() const
{
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = config.template at<1>();
constexpr index_t NWarp = config.template at<2>();
constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM);
constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN);
// constexpr index_t KIterPerWarp = KPerBlock / WG::kK;
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
// C = A * B
template <typename ABlockTensor, typename BBlockTensor>
CK_TILE_DEVICE auto operator()(const ABlockTensor& a_block_tensor,
const BBlockTensor& b_block_tensor) const
{
auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor, b_block_tensor);
return c_block_tensor;
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
template <typename AType_,
typename BType_,
typename CType_,
typename BlockWarps_,
typename WarpGemm_>
struct BlockGemmARegBRegCRegV1CustomPolicy
{
using AType = remove_cvref_t<AType_>;
using BType = remove_cvref_t<BType_>;
using CType = remove_cvref_t<CType_>;
using BlockWarps = remove_cvref_t<BlockWarps_>;
static constexpr index_t kMWarps = BlockWarps::at(number<0>{});
static constexpr index_t kNWarps = BlockWarps::at(number<1>{});
static constexpr index_t kKWarps = BlockWarps::at(number<2>{});
using WarpGemm = remove_cvref_t<WarpGemm_>;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
return make_tuple(WarpGemm{}, kMWarps, kNWarps);
}
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
namespace ck_tile {
// Default policy for BlockGemmARegBRegCRegV1
// Default policy class should not be templated, put template on member functions instead
struct BlockGemmARegBRegCRegV1DefaultPolicy
{
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetWarpGemmMWarpNWarp()
{
if constexpr(std::is_same_v<typename Problem::ADataType, half_t> &&
std::is_same_v<typename Problem::BDataType, half_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
else if constexpr(std::is_same_v<typename Problem::ADataType, bf16_t> &&
std::is_same_v<typename Problem::BDataType, bf16_t> &&
std::is_same_v<typename Problem::CDataType, float>)
{
return make_tuple(WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution{}, 4, 1);
}
}
};
} // namespace ck_tile
...@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1 ...@@ -35,16 +35,13 @@ struct BlockGemmARegBSmemCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
// constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockTensorTmp{}.get_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
// "wrong!"); "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
......
...@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1 ...@@ -35,16 +35,13 @@ struct BlockGemmASmemBRegCRegV1
std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>, std::is_same_v<CDataType, remove_cv_t<typename CBlockTensor::DataType>>,
"wrong!"); "wrong!");
// constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}];
// constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockTensorTmp{}.get_lengths()[number<0>{}];
// constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}];
constexpr index_t MPerBlock = BlockGemmShape::kM;
constexpr index_t NPerBlock = BlockGemmShape::kN;
constexpr index_t KPerBlock = BlockGemmShape::kK;
// static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
// KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
// "wrong!"); "wrong!");
constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
......
...@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 = ...@@ -22,6 +22,9 @@ using WarpGemmMfmaF16F16F32M32N32K16 =
using WarpGemmMfmaF16F16F32M16N16K32 = using WarpGemmMfmaF16F16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>;
...@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 = ...@@ -59,6 +62,9 @@ using WarpGemmMfmaBf16Bf16F32M32N32K16 =
using WarpGemmMfmaBf16Bf16F32M16N16K32 = using WarpGemmMfmaBf16Bf16F32M16N16K32 =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>;
......
...@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -119,9 +119,9 @@ struct WarpGemmAtrributeMfmaIterateK
static_for<0, kKIter, 1>{}([&](auto iKIter) { static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec, Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter]);
}); });
} }
...@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -135,15 +135,15 @@ struct WarpGemmAtrributeMfmaIterateK
// c = a * b // c = a * b
auto c_vec = Impl{}( auto c_vec = Impl{}(
reinterpret_cast<const buf_a>(a_vec).template get_as<typename Impl::AVecType>()[I0], reinterpret_cast<const buf_a&>(a_vec).template get_as<typename Impl::AVecType>()[I0],
reinterpret_cast<const buf_b>(b_vec).template get_as<typename Impl::BVecType>()[I0]); reinterpret_cast<const buf_b&>(b_vec).template get_as<typename Impl::BVecType>()[I0]);
// c += a * b // c += a * b
static_for<1, kKIter, 1>{}([&](auto iKIter) { static_for<1, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec, Impl{}(c_vec,
reinterpret_cast<const buf_a>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter]);
}); });
......
...@@ -15,7 +15,8 @@ template <typename AType, ...@@ -15,7 +15,8 @@ template <typename AType,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave,
bool TransposeC> bool TransposeC,
bool SwizzleA = false>
struct WarpGemmMfmaDispatcher; struct WarpGemmMfmaDispatcher;
// clang-format off // clang-format off
...@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float ...@@ -29,6 +30,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16 // bf16
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
...@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float ...@@ -39,6 +43,9 @@ template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8 // fp8
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
...@@ -58,8 +65,15 @@ template <typename AType, ...@@ -58,8 +65,15 @@ template <typename AType,
index_t MPerWave, index_t MPerWave,
index_t NPerWave, index_t NPerWave,
index_t KPerWave, index_t KPerWave,
bool TransposeC> bool TransposeC,
using WarpGemmMfmaDispatcher = typename impl:: bool SwizzleA = false>
WarpGemmMfmaDispatcher<AType, BType, CType, MPerWave, NPerWave, KPerWave, TransposeC>::Type; using WarpGemmMfmaDispatcher = typename impl::WarpGemmMfmaDispatcher<AType,
BType,
CType,
MPerWave,
NPerWave,
KPerWave,
TransposeC,
SwizzleA>::Type;
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2023-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -39,11 +39,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
public: public:
Argument(const Tensor<InDataType>& input, Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
: input_{input}, : input_{input},
output_{output}, output_{output},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
...@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -58,24 +58,25 @@ struct ReferenceColumnToImage : public device::BaseOperator
const Tensor<InDataType>& input_; const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_; std::vector<long_index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_; std::vector<long_index_t> output_spatial_lengths_;
private: private:
void initOutputSpatialLengths() void initOutputSpatialLengths()
{ {
constexpr auto input_offset_to_spatial = 3; constexpr auto input_offset_to_spatial = 3;
for(ck::index_t i = 0; i < NDimSpatial; ++i) for(ck::long_index_t i = 0; i < NDimSpatial; ++i)
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back( output_spatial_lengths_.push_back(
(output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + (output_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
...@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -98,26 +99,26 @@ struct ReferenceColumnToImage : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.output_.GetLengths()[0]; const long_index_t G = arg.output_.GetLengths()[0];
const index_t N = arg.output_.GetLengths()[1]; const long_index_t N = arg.output_.GetLengths()[1];
const index_t C = arg.output_.GetLengths()[2]; const long_index_t C = arg.output_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Wo + wo; long_index_t row = n * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.output_.GetLengths()[3])
...@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -140,32 +141,32 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
const index_t Ho = arg.output_spatial_lengths_[0]; const long_index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t ho = 0; ho < Ho; ++ho) for(long_index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Ho * Wo + ho * Wo + wo; long_index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(hi >= 0 && if(hi >= 0 &&
...@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -196,27 +197,27 @@ struct ReferenceColumnToImage : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
const index_t Do = arg.output_spatial_lengths_[0]; const long_index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1]; const long_index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n) { auto func = [&](auto g, auto n) {
for(index_t d_o = 0; d_o < Do; ++d_o) for(long_index_t d_o = 0; d_o < Do; ++d_o)
{ {
for(index_t ho = 0; ho < Ho; ++ho) for(long_index_t ho = 0; ho < Ho; ++ho)
{ {
for(index_t wo = 0; wo < Wo; ++wo) for(long_index_t wo = 0; wo < Wo; ++wo)
{ {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
auto di = auto di =
static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{ {
auto hi = auto hi =
static_cast<ck::long_index_t>(ho * static_cast<ck::long_index_t>(ho *
...@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -224,7 +225,8 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>(y * static_cast<ck::long_index_t>(y *
arg.conv_dilations_[1]) - arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2];
++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
...@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -232,7 +234,7 @@ struct ReferenceColumnToImage : public device::BaseOperator
static_cast<ck::long_index_t>( static_cast<ck::long_index_t>(
x * arg.conv_dilations_[2]) - x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
...@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -294,15 +296,15 @@ struct ReferenceColumnToImage : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
const ck::index_t G = arg.output_.GetLengths()[0]; const ck::long_index_t G = arg.output_.GetLengths()[0];
const ck::index_t N = arg.output_.GetLengths()[1]; const ck::long_index_t N = arg.output_.GetLengths()[1];
const ck::index_t C = arg.output_.GetLengths()[2]; const ck::long_index_t C = arg.output_.GetLengths()[2];
const index_t NDoHoWo = const long_index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX = const long_index_t CZYX =
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) && if(!(arg.input_.GetLengths()[0] == static_cast<std::size_t>(G) &&
...@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator ...@@ -326,11 +328,11 @@ struct ReferenceColumnToImage : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
{ {
return Argument{input, return Argument{input,
output, output,
......
...@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -38,10 +38,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input, Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -72,10 +72,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -447,10 +447,10 @@ struct ReferenceConvBwdData : public device::BaseOperator
Tensor<InDataType>& input, Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
const Tensor<OutDataType>& output, const Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -40,10 +40,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi, const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x, Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -74,10 +74,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<InDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<WeiDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator ...@@ -402,10 +402,10 @@ struct ReferenceConvBwdWeight : public device::BaseOperator
const Tensor<InDataType>& in_n_c_hi_wi, const Tensor<InDataType>& in_n_c_hi_wi,
Tensor<WeiDataType>& wei_k_c_y_x, Tensor<WeiDataType>& wei_k_c_y_x,
const Tensor<OutDataType>& out_n_k_ho_wo, const Tensor<OutDataType>& out_n_k_ho_wo,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -69,10 +69,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<InDataType>& input, const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
...@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -103,10 +103,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_; const std::array<Tensor<WeiDataType>, NumBElementwiseTensor>& elementwise_b_tensors_;
const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_; const std::array<Tensor<OutDataType>, NumDElementwiseTensor>& elementwise_d_tensors_;
std::vector<index_t> conv_strides_; std::vector<ck::long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<ck::long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<ck::long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<ck::long_index_t> in_right_pads_;
InElementwiseOperation in_element_op_; InElementwiseOperation in_element_op_;
WeiElementwiseOperation wei_element_op_; WeiElementwiseOperation wei_element_op_;
...@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator ...@@ -416,10 +416,10 @@ struct ReferenceConvFwd : public device::BaseOperator
const Tensor<InDataType>& input, const Tensor<InDataType>& input,
const Tensor<WeiDataType>& weight, const Tensor<WeiDataType>& weight,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::long_index_t> input_right_pads,
InElementwiseOperation in_element_op, InElementwiseOperation in_element_op,
WeiElementwiseOperation wei_element_op, WeiElementwiseOperation wei_element_op,
OutElementwiseOperation out_element_op, OutElementwiseOperation out_element_op,
......
...@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -40,11 +40,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
public: public:
Argument(const Tensor<InDataType>& input, Argument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
: input_{input}, : input_{input},
output_{output}, output_{output},
conv_strides_{conv_filter_strides}, conv_strides_{conv_filter_strides},
...@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -59,13 +59,13 @@ struct ReferenceImageToColumn : public device::BaseOperator
const Tensor<InDataType>& input_; const Tensor<InDataType>& input_;
Tensor<OutDataType>& output_; Tensor<OutDataType>& output_;
std::vector<index_t> conv_strides_; std::vector<long_index_t> conv_strides_;
std::vector<index_t> conv_dilations_; std::vector<long_index_t> conv_dilations_;
std::vector<index_t> in_left_pads_; std::vector<long_index_t> in_left_pads_;
std::vector<index_t> in_right_pads_; std::vector<long_index_t> in_right_pads_;
std::vector<index_t> filter_spatial_lengths_; std::vector<long_index_t> filter_spatial_lengths_;
std::vector<index_t> output_spatial_lengths_; std::vector<long_index_t> output_spatial_lengths_;
private: private:
void initOutputSpatialLengths() void initOutputSpatialLengths()
...@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -76,7 +76,8 @@ struct ReferenceImageToColumn : public device::BaseOperator
{ {
// XEff = (X - 1) * conv_dilation_w + 1; // XEff = (X - 1) * conv_dilation_w + 1;
// Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1; // Wo = (Wi + in_left_pad_w + in_right_pad_w - XEff) / conv_stride_w + 1;
const ck::index_t x_eff = (filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1; const ck::long_index_t x_eff =
(filter_spatial_lengths_[i] - 1) * conv_dilations_[i] + 1;
output_spatial_lengths_.push_back( output_spatial_lengths_.push_back(
(input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] + (input_.GetLengths()[i + input_offset_to_spatial] + in_left_pads_[i] +
...@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -99,24 +100,24 @@ struct ReferenceImageToColumn : public device::BaseOperator
throw std::runtime_error("wrong! inconsistent dimension"); throw std::runtime_error("wrong! inconsistent dimension");
} }
const index_t G = arg.input_.GetLengths()[0]; const long_index_t G = arg.input_.GetLengths()[0];
const index_t N = arg.input_.GetLengths()[1]; const long_index_t N = arg.input_.GetLengths()[1];
const index_t C = arg.input_.GetLengths()[2]; const long_index_t C = arg.input_.GetLengths()[2];
if constexpr(NDimSpatial == 1) if constexpr(NDimSpatial == 1)
{ {
const index_t Wo = arg.output_spatial_lengths_[0]; const long_index_t Wo = arg.output_spatial_lengths_[0];
auto func = [&](auto g, auto n, auto wo) { auto func = [&](auto g, auto n, auto wo) {
index_t row = n * Wo + wo; long_index_t row = n * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[0]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(wi >= 0 && if(wi >= 0 &&
ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3]) ck::type_convert<std::size_t>(wi) < arg.input_.GetLengths()[3])
...@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -135,26 +136,26 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 2) else if constexpr(NDimSpatial == 2)
{ {
const index_t Ho = arg.output_spatial_lengths_[0]; const long_index_t Ho = arg.output_spatial_lengths_[0];
const index_t Wo = arg.output_spatial_lengths_[1]; const long_index_t Wo = arg.output_spatial_lengths_[1];
auto func = [&](auto g, auto n, auto ho, auto wo) { auto func = [&](auto g, auto n, auto ho, auto wo) {
index_t row = n * Ho * Wo + ho * Wo + wo; long_index_t row = n * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[0]; ++y)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[1]; ++x)
{ {
auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) + auto wi = static_cast<ck::long_index_t>(wo * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(hi >= 0 && if(hi >= 0 &&
...@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -178,31 +179,31 @@ struct ReferenceImageToColumn : public device::BaseOperator
} }
else if constexpr(NDimSpatial == 3) else if constexpr(NDimSpatial == 3)
{ {
const index_t Do = arg.output_spatial_lengths_[0]; const long_index_t Do = arg.output_spatial_lengths_[0];
const index_t Ho = arg.output_spatial_lengths_[1]; const long_index_t Ho = arg.output_spatial_lengths_[1];
const index_t Wo = arg.output_spatial_lengths_[2]; const long_index_t Wo = arg.output_spatial_lengths_[2];
auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) { auto func = [&](auto g, auto n, auto d_o, auto ho, auto wo) {
index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo; long_index_t row = n * Do * Ho * Wo + d_o * Ho * Wo + ho * Wo + wo;
index_t column = 0; long_index_t column = 0;
for(index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z) for(long_index_t z = 0; z < arg.filter_spatial_lengths_[0]; ++z)
{ {
auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) + auto di = static_cast<ck::long_index_t>(d_o * arg.conv_strides_[0]) +
static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) - static_cast<ck::long_index_t>(z * arg.conv_dilations_[0]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[0]); static_cast<ck::long_index_t>(arg.in_left_pads_[0]);
for(index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y) for(long_index_t y = 0; y < arg.filter_spatial_lengths_[1]; ++y)
{ {
auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) + auto hi = static_cast<ck::long_index_t>(ho * arg.conv_strides_[1]) +
static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) - static_cast<ck::long_index_t>(y * arg.conv_dilations_[1]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[1]); static_cast<ck::long_index_t>(arg.in_left_pads_[1]);
for(index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x) for(long_index_t x = 0; x < arg.filter_spatial_lengths_[2]; ++x)
{ {
auto wi = auto wi =
static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) + static_cast<ck::long_index_t>(wo * arg.conv_strides_[2]) +
static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) - static_cast<ck::long_index_t>(x * arg.conv_dilations_[2]) -
static_cast<ck::long_index_t>(arg.in_left_pads_[2]); static_cast<ck::long_index_t>(arg.in_left_pads_[2]);
for(index_t c = 0; c < C; ++c) for(long_index_t c = 0; c < C; ++c)
{ {
if(di >= 0 && if(di >= 0 &&
ck::type_convert<std::size_t>(di) < ck::type_convert<std::size_t>(di) <
...@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -259,15 +260,15 @@ struct ReferenceImageToColumn : public device::BaseOperator
bool IsSupportedArgument(const Argument& arg) bool IsSupportedArgument(const Argument& arg)
{ {
const ck::index_t G = arg.input_.GetLengths()[0]; const ck::long_index_t G = arg.input_.GetLengths()[0];
const ck::index_t N = arg.input_.GetLengths()[1]; const ck::long_index_t N = arg.input_.GetLengths()[1];
const ck::index_t C = arg.input_.GetLengths()[2]; const ck::long_index_t C = arg.input_.GetLengths()[2];
const index_t NDoHoWo = const long_index_t NDoHoWo =
N * ck::accumulate_n<index_t>( N * ck::accumulate_n<long_index_t>(
arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.output_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
const index_t CZYX = const long_index_t CZYX =
C * ck::accumulate_n<index_t>( C * ck::accumulate_n<long_index_t>(
arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>()); arg.filter_spatial_lengths_.begin(), NDimSpatial, 1, std::multiplies<>());
if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) && if(!(arg.output_.GetLengths()[0] == static_cast<std::size_t>(G) &&
...@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator ...@@ -291,11 +292,11 @@ struct ReferenceImageToColumn : public device::BaseOperator
static auto MakeArgument(const Tensor<InDataType>& input, static auto MakeArgument(const Tensor<InDataType>& input,
Tensor<OutDataType>& output, Tensor<OutDataType>& output,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<ck::long_index_t> filter_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<ck::long_index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<ck::long_index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<ck::long_index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) std::vector<ck::long_index_t> input_right_pads)
{ {
return Argument{input, return Argument{input,
output, output,
......
...@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu; ...@@ -108,6 +108,7 @@ using FastGelu = ck::tensor_operation::element_wise::FastGelu;
using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu; using MultiplyFastGelu = ck::tensor_operation::element_wise::MultiplyFastGelu;
using AddMultiply = ck::tensor_operation::element_wise::AddMultiply; using AddMultiply = ck::tensor_operation::element_wise::AddMultiply;
using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd; using MultiplyAdd = ck::tensor_operation::element_wise::MultiplyAdd;
using MultiplyMultiply = ck::tensor_operation::element_wise::MultiplyMultiply;
using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd; using ScaleAdd = ck::tensor_operation::element_wise::ScaleAdd;
using Gelu = ck::tensor_operation::element_wise::Gelu; using Gelu = ck::tensor_operation::element_wise::Gelu;
using Swish = ck::tensor_operation::element_wise::Swish; using Swish = ck::tensor_operation::element_wise::Swish;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment