"vscode:/vscode.git/clone" did not exist on "33b3c0f85ffb647a1fc831c59c112bcfca5c06b8"
Unverified Commit c87aa6c8 authored by Illia Silin's avatar Illia Silin Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents 60afb522 b70f367f
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
// TODO: alow 2 gemm have different type
template <typename ADataType_,
typename GDataType_,
typename DDataType_,
typename AccDataType_,
typename ODataType_,
typename AScaleDataType_,
typename GScaleDataType_,
typename DScaleDataType_,
typename YSmoothScaleDataType_,
typename TopkWeightDataType_,
typename IndexDataType_, // data type for all indexing
typename GateActivation_, // = ck_tile::element_wise::Silu,
typename BlockShape_, // shoule be FusedMoeGemmShape
typename Traits_>
struct FusedMoeGemmPipelineProblem
{
using ADataType = remove_cvref_t<ADataType_>;
using GDataType = remove_cvref_t<GDataType_>;
using DDataType = remove_cvref_t<DDataType_>;
using AccDataType = remove_cvref_t<AccDataType_>;
using ODataType = remove_cvref_t<ODataType_>;
using AScaleDataType = remove_cvref_t<AScaleDataType_>;
using GScaleDataType = remove_cvref_t<GScaleDataType_>;
using DScaleDataType = remove_cvref_t<DScaleDataType_>;
using YSmoothScaleDataType = remove_cvref_t<YSmoothScaleDataType_>;
using TopkWeightDataType = remove_cvref_t<TopkWeightDataType_>;
using IndexDataType = remove_cvref_t<IndexDataType_>;
// the input for next gemm should have same time as
using YDataType = ADataType;
using GateActivation = remove_cvref_t<GateActivation_>;
using BlockShape = remove_cvref_t<BlockShape_>;
using Traits = remove_cvref_t<Traits_>;
};
} // namespace ck_tile
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
namespace ck_tile {
enum class FusedMoeGemmWeightPermuteEnum
{
// permute_b_n0_k0_n1_k1_n2_k2 = 0, // 0,1,4,2,5,3,6
// permute_b_n0_n1_k0_k1_n2_k2 = 1, // 0,1,2,4,5,3,6
no_permute = 0,
b_nr_kr_kw_nw_kv = 1, // 0,1,3,4,2,5
b_nr_kr_waveflatten = b_nr_kr_kw_nw_kv,
};
template <bool IsGateOnly_,
bool UseSmoothQuant_,
index_t OAtomic_, // 0-no atomic, 1-atomic-pk-f16/bf16, 2-atomic-f32
FusedMoeGemmWeightPermuteEnum PermuteEnum_ =
FusedMoeGemmWeightPermuteEnum::b_nr_kr_waveflatten,
bool PadHiddenSize_ = false,
bool PadIntermediateSize_ = false>
struct FusedMoeGemmTraits
{
// Gate+Up or Gate only
static constexpr bool IsGateOnly = IsGateOnly_;
static constexpr bool UseSmoothQuant = UseSmoothQuant_;
static constexpr index_t OAtomic = OAtomic_;
static constexpr FusedMoeGemmWeightPermuteEnum PermuteEnum = PermuteEnum_;
static constexpr bool PadHiddenSize = PadHiddenSize_;
static constexpr bool PadIntermediateSize = PadIntermediateSize_;
};
// Note: this need to be a bit mask
enum class FusedMoeGemmPipelineSequencerEnum
{
SLD_A = 1 << 0, // shared load a
SLD_B = 1 << 1,
GLD_A = 1 << 2, // global load a
GLD_B = 1 << 3,
SST_A = 1 << 4, // shared store a
SST_B = 1 << 5,
GST_O = 1 << 6, // global store out
};
} // namespace ck_tile
...@@ -22,6 +22,7 @@ ...@@ -22,6 +22,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/block/block_universal_gemm_as_bs_cr.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
// A is block window on shared memory
// B is block window on shared memory
// C is block distributed tensor
template <typename Problem_, typename Policy_ = BlockGemmASmemBSmemCRegV1DefaultPolicy>
struct BlockUniversalGemmAsBsCr
{
private:
// TODO: This should be in Policy - UniversalGemmPolicyBase ?
template <typename PipelineProblem_, typename GemmPolicy_>
struct GemmTraits_
{
using Problem = remove_cvref_t<PipelineProblem_>;
using Policy = remove_cvref_t<GemmPolicy_>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
using CDataType = remove_cvref_t<typename Problem::CDataType>;
using BlockGemmShape = remove_cvref_t<typename Problem::BlockGemmShape>;
static constexpr index_t kBlockSize = Problem::kBlockSize;
static constexpr auto Scheduler = Problem::Scheduler;
static constexpr index_t MPerBlock = BlockGemmShape::kM;
static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>();
static_assert(MWarp == BlockGemmShape::BlockWarps::at(number<0>{}),
"Error! WarpGemm's MWarp is not consisten with BlockGemmShape!");
static_assert(NWarp == BlockGemmShape::BlockWarps::at(number<1>{}),
"Error! WarpGemm's NWarp is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kM == BlockGemmShape::WarpTile::at(number<0>{}),
"Error! WarpGemm's M is not consisten with BlockGemmShape!");
static_assert(WarpGemm::kN == BlockGemmShape::WarpTile::at(number<1>{}),
"Error! WarpGemm's N is not consisten with BlockGemmShape!");
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static_assert(MIterPerWarp * MWarp * WarpGemm::kM == MPerBlock,
"Error! Warps should cover all Block tile!");
static_assert(NIterPerWarp * NWarp * WarpGemm::kN == NPerBlock,
"Error! Warps should cover all Block tile!");
static constexpr index_t MPerBlockPerIter = MWarp * WarpGemm::kM;
static constexpr index_t NPerBlockPerIter = NWarp * WarpGemm::kN;
static constexpr index_t KPerBlockPerIter = WarpGemm::kK;
using AWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::AWarpDstrEncoding{}))>;
using BWarpTileDistr = remove_cvref_t<decltype(make_static_tile_distribution(
typename WarpGemm::BWarpDstrEncoding{}))>;
using AWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<ADataType>(AWarpTileDistr{}))>;
using BWarpTile =
remove_cvref_t<decltype(make_static_distributed_tensor<BDataType>(BWarpTileDistr{}))>;
// TODO: Should we have two policies? Interwave & Intrawave ??
static constexpr index_t InterWaveSchedulingMacClusters = 1;
static constexpr index_t KPack = WarpGemm::kKPerThread;
static constexpr index_t KPerThread = KPerBlock / WarpGemm::kK * KPack;
static constexpr index_t KRepeat = KPerThread / KPack;
};
public:
using Traits = GemmTraits_<Problem_, Policy_>;
using ADataType = remove_cvref_t<typename Traits::ADataType>;
using BDataType = remove_cvref_t<typename Traits::BDataType>;
using CDataType = remove_cvref_t<typename Traits::CDataType>;
using WarpGemm = remove_cvref_t<typename Traits::WarpGemm>;
static constexpr index_t KIterPerWarp = Traits::KIterPerWarp;
static constexpr index_t MIterPerWarp = Traits::MIterPerWarp;
static constexpr index_t NIterPerWarp = Traits::NIterPerWarp;
static constexpr index_t MWarp = Traits::MWarp;
static constexpr index_t NWarp = Traits::NWarp;
static constexpr auto Scheduler = Traits::Scheduler;
private:
template <GemmPipelineScheduler Scheduler, typename GemmTraits>
struct BlockGemmImpl
{
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Default, GemmTraits>
{
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
static_assert(std::is_same_v<typename GemmTraits::ADataType,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
AWarpWindow::get_num_of_dimension(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!");
static_assert(GemmTraits::AWarpTile::get_lengths() ==
AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp>
a_warp_windows;
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
BWarpWindow::get_num_of_dimension(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!");
static_assert(GemmTraits::BWarpTile::get_lengths() ==
BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp>
b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// TODO: I don't have to move 0,0 window!
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * GemmTraits::MPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * GemmTraits::NPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
const auto a_warp_tile = load_tile(a_warp_windows(mIter)(kIter));
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
const auto b_warp_tile = load_tile(b_warp_windows(nIter)(kIter));
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor, a_warp_tile, b_warp_tile);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Intrawave, GemmTraits>
{
statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp>
a_warp_tiles_;
statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp>
b_warp_tiles_;
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
AWarpWindow::get_num_of_dimension(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!");
static_assert(GemmTraits::AWarpTile::get_lengths() ==
AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<AWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::MIterPerWarp>
a_warp_windows;
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, 0},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
BWarpWindow::get_num_of_dimension(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!");
static_assert(GemmTraits::BWarpTile::get_lengths() ==
BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<
statically_indexed_array<BWarpWindow, GemmTraits::KIterPerWarp>,
GemmTraits::NIterPerWarp>
b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
// TODO: I don't have to move 0,0 window!
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * GemmTraits::MPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * GemmTraits::NPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
});
});
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
[[maybe_unused]] const ASmemBlockWindow& a_block_window,
[[maybe_unused]] const BSmemBlockWindow& b_block_window)
{
static_assert(
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, GemmTraits::KIterPerWarp, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() = c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kIter],
b_warp_tiles_[nIter][kIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
});
});
});
}
};
template <typename GemmTraits>
struct BlockGemmImpl<GemmPipelineScheduler::Interwave, GemmTraits>
{
static constexpr index_t KPerThread = GemmTraits::KPerThread;
static constexpr index_t NumMacClusters = GemmTraits::InterWaveSchedulingMacClusters;
static constexpr index_t KPerInnerLoop =
ck_tile::max(KPerThread / NumMacClusters, GemmTraits::KPack);
// TODO: do we really need this?? Are there any cases when this would be >=1 ??
// Would we need InterWaveSchedulingMacClusters > 1 ???
static constexpr index_t KRepeat = KPerThread / KPerInnerLoop;
static constexpr index_t KInnerLoopIter = KPerInnerLoop / GemmTraits::KPack;
statically_indexed_array<
statically_indexed_array<typename GemmTraits::AWarpTile, KInnerLoopIter>,
GemmTraits::MIterPerWarp>
a_warp_tiles_;
statically_indexed_array<
statically_indexed_array<typename GemmTraits::BWarpTile, KInnerLoopIter>,
GemmTraits::NIterPerWarp>
b_warp_tiles_;
template <index_t KIdx, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
GemmTraits::MPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::NPerBlock == BSmemBlockWindow{}.get_window_lengths()[number<0>{}] &&
GemmTraits::KPerBlock == ASmemBlockWindow{}.get_window_lengths()[number<1>{}],
"MPerBlock, NPerBlock, KPerBlock defined in "
" BlockGemmShape are different from A/B block smem windows apropriate dims!");
static_assert(std::is_same_v<typename GemmTraits::ADataType,
typename ASmemBlockWindow::DataType> &&
std::is_same_v<typename GemmTraits::BDataType,
typename BSmemBlockWindow::DataType>,
"The ADataType and BDataType as defined in "
"traits should be the same as correspoinding block window data type!");
const index_t iMWarp = get_warp_id() / GemmTraits::NWarp;
const index_t iNWarp = get_warp_id() - (iMWarp * GemmTraits::NWarp);
// TODO: refactor warp_window tile type to class member as it should be
// compile-time known information.
auto a_warp_window_tmp = make_tile_window(
a_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kM>{}, number<GemmTraits::WarpGemm::kK>{}),
a_block_window.get_window_origin() +
multi_index<2>{iMWarp * GemmTraits::WarpGemm::kM, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::AWarpDstrEncoding{}));
using AWarpWindow = remove_cvref_t<decltype(a_warp_window_tmp)>;
static_assert(GemmTraits::AWarpTile::get_num_of_dimension() ==
AWarpWindow::get_num_of_dimension(),
"AWarpWindow number of dimensions must be equal to "
"AWarpTile number of dimensions!");
static_assert(GemmTraits::AWarpTile::get_lengths() ==
AWarpWindow{}.get_window_lengths(),
"AWarpWindow lengths must be equal to AWarpTile lengths!");
statically_indexed_array<statically_indexed_array<AWarpWindow, KInnerLoopIter>,
GemmTraits::MIterPerWarp>
a_warp_windows;
// construct B-warp-window
auto b_warp_window_tmp = make_tile_window(
b_block_window.get_bottom_tensor_view(),
make_tuple(number<GemmTraits::WarpGemm::kN>{}, number<GemmTraits::WarpGemm::kK>{}),
b_block_window.get_window_origin() +
multi_index<2>{iNWarp * GemmTraits::WarpGemm::kN, KIdx * KPerInnerLoop},
make_static_tile_distribution(typename GemmTraits::WarpGemm::BWarpDstrEncoding{}));
using BWarpWindow = remove_cvref_t<decltype(b_warp_window_tmp)>;
static_assert(GemmTraits::BWarpTile::get_num_of_dimension() ==
BWarpWindow::get_num_of_dimension(),
"BWarpWindow number of dimensions must be equal to "
"BWarpTile number of dimensions!");
static_assert(GemmTraits::BWarpTile::get_lengths() ==
BWarpWindow{}.get_window_lengths(),
"BWarpWindow lengths must be equal to BWarpTile lengths!");
statically_indexed_array<statically_indexed_array<BWarpWindow, KInnerLoopIter>,
GemmTraits::NIterPerWarp>
b_warp_windows;
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
a_warp_windows(mIter)(kIter) = a_warp_window_tmp;
move_tile_window(a_warp_windows(mIter)(kIter),
{mIter * GemmTraits::MPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
b_warp_windows(nIter)(kIter) = b_warp_window_tmp;
move_tile_window(b_warp_windows(nIter)(kIter),
{nIter * GemmTraits::NPerBlockPerIter,
kIter * GemmTraits::KPerBlockPerIter});
});
});
// TODO check if a_warp_tiles has same desc as a_warp_window
static_for<0, KInnerLoopIter, 1>{}([&](auto kIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
// read A warp tensor from A block window
load_tile(a_warp_tiles_(mIter)(kIter), a_warp_windows(mIter)(kIter));
});
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read B warp tensor from B Block window
load_tile(b_warp_tiles_(nIter)(kIter), b_warp_windows(nIter)(kIter));
});
});
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
static_assert(
std::is_same_v<typename GemmTraits::CDataType, typename CBlockTensor::DataType>,
"The CDataType as defined in traits should be the same as correspoinding "
"C block tensor data type!");
using CWarpDstr = typename GemmTraits::WarpGemm::CWarpDstr;
using CWarpTensor = typename GemmTraits::WarpGemm::CWarpTensor;
constexpr auto c_warp_y_lengths =
to_sequence(CWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
constexpr auto c_warp_y_index_zeros = uniform_sequence_gen_t<CWarpDstr::NDimY, 0>{};
// hot loop:
static_for<0, KRepeat, 1>{}([&](auto kIter) {
LocalPrefetch<kIter.value>(a_block_window, b_block_window);
__builtin_amdgcn_sched_barrier(0);
// NOTE: Synchronize threads in a workgroup at the start of each MAC
// cluster, but except the first, as we can shorten non-MAC cluster a bit
// and there's no observable negative impact. The desired effect is waves in
// a workgroup executing MAC in sync. This avoids some out-of-sync waves
// hijacking MAC resource from other workgroups and reducing the chance of
// latency hiding by waiting for the rest of the workgroup at the eventual
// sync point.
if constexpr(kIter.value != 0 || KRepeat == 1)
{
__builtin_amdgcn_s_barrier();
__builtin_amdgcn_sched_barrier(0);
}
static_for<0, KInnerLoopIter, 1>{}([&](auto kInnerIter) {
static_for<0, GemmTraits::MIterPerWarp, 1>{}([&](auto mIter) {
static_for<0, GemmTraits::NIterPerWarp, 1>{}([&](auto nIter) {
// read C warp tensor from C block tensor-
CWarpTensor c_warp_tensor;
c_warp_tensor.get_thread_buffer() =
c_block_tensor.get_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// The block_sync_lds() here performs double duty:
// A) safeguard against data hazard because barrier from
// blockwise_gemm is moved here B) reduce VMEM FIFO congestion
// by applying small delays to different wavefronts It is
// performed near the end of MAC cluster to minimize lgkmcnt
// penalty
if constexpr(kIter.value == KRepeat - 1 &&
kInnerIter.value == KInnerLoopIter - 1 &&
mIter.value == GemmTraits::MIterPerWarp - 1 &&
nIter.value == GemmTraits::NIterPerWarp - 1)
{
__builtin_amdgcn_sched_barrier(0);
block_sync_lds();
__builtin_amdgcn_sched_barrier(0);
}
// warp GEMM
typename GemmTraits::WarpGemm{}(c_warp_tensor,
a_warp_tiles_[mIter][kInnerIter],
b_warp_tiles_[nIter][kInnerIter]);
// write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data(
merge_sequences(sequence<mIter, nIter>{}, c_warp_y_index_zeros),
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths),
c_warp_tensor.get_thread_buffer());
if constexpr(kInnerIter.value == 0 && mIter.value == 0 &&
nIter.value == 0)
{
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(1);
__builtin_amdgcn_sched_barrier(0);
}
});
});
});
__builtin_amdgcn_sched_barrier(0);
__builtin_amdgcn_s_setprio(0);
__builtin_amdgcn_sched_barrier(0);
});
}
};
public:
CK_TILE_DEVICE static constexpr auto MakeCBlockTile()
{
constexpr auto c_block_outer_dstr_encoding = tile_distribution_encoding<
sequence<>,
tuple<sequence<MIterPerWarp, MWarp>, sequence<NIterPerWarp, NWarp>>,
tuple<sequence<1, 2>>,
tuple<sequence<1, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor;
}
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void LocalPrefetch(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.template LocalPrefetch(a_block_window, b_block_window);
}
// C += A * B
template <typename CBlockTensor, typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
}
// C = A * B
template <typename ASmemBlockWindow, typename BSmemBlockWindow>
CK_TILE_DEVICE auto operator()(const ASmemBlockWindow& a_block_window,
const BSmemBlockWindow& b_block_window)
{
auto c_block_tensor = MakeCBlockTile();
block_gemm_impl_.template operator()(c_block_tensor, a_block_window, b_block_window);
return c_block_tensor;
}
private:
BlockGemmImpl<Scheduler, Traits> block_gemm_impl_{};
};
} // namespace ck_tile
...@@ -115,12 +115,22 @@ struct GemmKernel ...@@ -115,12 +115,22 @@ struct GemmKernel
} }
}(); }();
auto a_pad_view = pad_tensor_view( auto a_pad_view = [&]() {
a_tensor_view, if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), {
// somehow clang-format is splitting below line into multiple. return pad_tensor_view(
// clang-format off a_tensor_view,
sequence<false, GemmPipeline::kPadA>{}); make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on // clang-format on
auto a_block_window = make_tile_window( auto a_block_window = make_tile_window(
...@@ -128,12 +138,22 @@ struct GemmKernel ...@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto b_pad_view = pad_tensor_view( auto b_pad_view = [&]() {
b_tensor_view, if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), {
// clang-format off return pad_tensor_view(
sequence<false, GemmPipeline::kPadB>{}); b_tensor_view,
// clang-format on make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
auto b_block_window = make_tile_window( auto b_block_window = make_tile_window(
b_pad_view, b_pad_view,
...@@ -171,18 +191,28 @@ struct GemmKernel ...@@ -171,18 +191,28 @@ struct GemmKernel
} }
}(); }();
auto c_pad_view = pad_tensor_view( auto c_pad_view = [&]() {
c_tensor_view, if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), {
// clang-format off return pad_tensor_view(
sequence<false, GemmPipeline::kPadC>{}); c_tensor_view,
// clang-format on make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
auto c_block_window = make_tile_window( sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
} }
}; };
......
...@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ??? // Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
...@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -247,8 +247,8 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0}); b_lds_block, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
// Block GEMM // Block GEMM
constexpr auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile(); auto c_block_tile = block_gemm.MakeCBlockTile();
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
...@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -290,7 +290,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
{ {
static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) { static_for<0, PrefetchStages, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds(); block_sync_lds();
...@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -318,7 +318,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static_for<1, tail_num, 1>{}([&](auto prefetch_idx) { static_for<1, tail_num, 1>{}([&](auto prefetch_idx) {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
block_sync_lds(); block_sync_lds();
...@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -331,14 +331,14 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
}); });
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
}; };
if constexpr(TailNum == TailNumber::One) if constexpr(TailNum == TailNumber::One)
{ {
block_sync_lds(); block_sync_lds();
// block_gemm.LocalPrefetch(); block_gemm.LocalPrefetch(a_lds_gemm_window, b_lds_gemm_window);
block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window); block_gemm(c_block_tile, a_lds_gemm_window, b_lds_gemm_window);
} }
else if constexpr(TailNum == TailNumber::Two) else if constexpr(TailNum == TailNumber::Two)
......
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
enum struct GemmPipelineScheduler enum struct GemmPipelineScheduler
{ {
Default,
Intrawave, Intrawave,
Interwave, Interwave,
}; };
...@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch ...@@ -43,6 +44,7 @@ inline std::ostream& operator<<(std::ostream& os, const ck_tile::GemmPipelineSch
{ {
switch(s) switch(s)
{ {
case ck_tile::GemmPipelineScheduler::Default: os << "Default"; break;
case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break; case ck_tile::GemmPipelineScheduler::Intrawave: os << "Intrawave"; break;
case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break; case ck_tile::GemmPipelineScheduler::Interwave: os << "Interwave"; break;
default: os << ""; default: os << "";
......
...@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window = make_tile_window(
make_tile_window(a_lds_block, a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load // B DRAM tile window for load
auto b_copy_dram_window = auto b_copy_dram_window =
...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window = make_tile_window(
make_tile_window(b_lds_block, b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM // A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( auto a_lds_gemm_window = make_tile_window(
...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
store_tile(a_copy_lds_window, a_block_tile_tmp); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0 // LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, a_block_tile_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1 // LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--; iCounter--;
} }
......
...@@ -4,6 +4,7 @@ ...@@ -4,6 +4,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -11,6 +12,7 @@ namespace ck_tile { ...@@ -11,6 +12,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
#if 0 #if 0
// 2d // 2d
template <typename Problem> template <typename Problem>
...@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -51,6 +53,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
...@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -116,6 +119,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return smem_size; return smem_size;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
#elif 1 #elif 1
// fake XOR // fake XOR
template <typename Problem> template <typename Problem>
...@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -192,88 +209,307 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t M2 = get_warp_size() / K0; {
#if 1 // coalesce reading for each blocks constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = MPerBlock / M1;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_static_tile_distribution( static_assert(KPack % K3 == 0);
tile_distribution_encoding<sequence<1>, constexpr index_t K2 = KPack / K3;
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, if constexpr(get_warp_size() % (K2 * M0))
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
sequence<1, 2>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<0, 1>>{}); static_assert(KPerBlock == K0 * K1 * K2 * K3);
#else // coalesce reading for each warps return make_static_tile_distribution(
constexpr index_t M0 = kBlockSize / get_warp_size(); tile_distribution_encoding<sequence<1>,
constexpr index_t M1 = kMPerBlock / (M2 * M0); tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
return make_static_tile_distribution( tuple<sequence<0>, sequence<1, 0, 2>>,
tile_distribution_encoding<sequence<1>, sequence<2, 1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, sequence<3, 1>>{});
tuple<sequence<1>, sequence<1, 2>>, }
tuple<sequence<0>, sequence<2, 0>>, else
sequence<1, 2>, {
sequence<1, 1>>{}); constexpr index_t K1 = (K2 * M0) / get_warp_size();
#endif constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType); constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
#if 1 // coalesce reading for each blocks static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = kNPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = kBlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
#else // coalesce reading for each warps tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
constexpr index_t N0 = kBlockSize / get_warp_size(); tuple<sequence<2>, sequence<2, 1, 2>>,
constexpr index_t N1 = kNPerBlock / (N2 * N0); tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
return make_static_tile_distribution( sequence<1, 3>>{});
tile_distribution_encoding<sequence<1>, }
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, else
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<0>, sequence<2, 0>>, constexpr index_t K1 = (K2 * N0) / get_warp_size();
sequence<1, 2>, constexpr index_t K2_m = K2 / K1;
sequence<1, 1>>{}); constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
#endif static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{ {
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1DefaultPolicy; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{}; template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
constexpr bool TransposeC = false;
constexpr auto I0 = number<0>{};
constexpr auto I1 = number<1>{};
constexpr auto I2 = number<2>{};
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -3,40 +3,135 @@ ...@@ -3,40 +3,135 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
static constexpr int _VectorSize = 16;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename TileGemmTraits_>
struct GemmPipelineProblem struct GemmPipelineProblemBase
{ {
using ADataType = remove_cvref_t<ADataType_>; using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>; using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>; using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>; using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
static constexpr bool kPadA = GemmTraits::kPadA; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC; static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: VectorLoadSize / sizeof(BDataType);
}
else
{
return VectorLoadSize / sizeof(BDataType);
}
}
static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType); CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType); {
static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType); if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
}; };
// Alias for GemmPipelineProblem
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
...@@ -45,30 +140,15 @@ template <typename ADataType_, ...@@ -45,30 +140,15 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
{ {
using ADataType = remove_cvref_t<ADataType_>; static constexpr auto Scheduler = Scheduler_;
using BDataType = remove_cvref_t<BDataType_>; static constexpr auto HasHotLoop = HasHotLoop_;
using CDataType = remove_cvref_t<CDataType_>; static constexpr auto TailNum = TailNum_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = GemmTraits::kPadA;
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC;
static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1;
static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1;
static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -9,12 +9,8 @@ ...@@ -9,12 +9,8 @@
namespace ck_tile { namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -22,286 +18,136 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr bool TransposeC = true; static constexpr bool TransposeC = true;
template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 4)
{
return (4 / sizeof(DataType));
}
else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 2)
{
return (2 / sizeof(DataType));
}
else
{
return 1;
}
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
constexpr index_t K0 = KPerBlock / K1;
constexpr auto DataTypeSize = sizeof(ADataType);
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value) constexpr auto MLdsLayer =
{ (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1 constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
: 32 * 4 / KPerBlock / sizeof(ADataType); make_tuple(number<KPerBlock / KPack * MLdsLayer>{},
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor( number<MPerBlock / MLdsLayer>{},
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1), number<KPack>{}),
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1)); make_tuple(number<KPack>{}, number<KPerBlock * MLdsLayer>{}, number<1>{}),
number<KPack>{},
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor( number<1>{});
a_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{}, constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
number<K0 * MLdsLayer>{})), a_lds_block_desc_0,
make_pass_through_transform(K1)), make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
make_tuple(sequence<1, 0>{}, sequence<2>{}), number<KPerBlock / KPack * MLdsLayer>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{})); make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor( make_tuple(sequence<1, 0>{}, sequence<2>{}));
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})), constexpr auto a_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}), a_lds_block_desc_permuted,
make_pass_through_transform(K1)), make_tuple(make_unmerge_transform(
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(number<KPerBlock / KPack>{}, number<MLdsLayer>{})),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor( make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
a_lds_block_desc_ak0_kMLdsLayer_m_ak1, make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod( constexpr auto a_lds_block_desc = transform_tensor_descriptor(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))), a_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(make_merge_transform_v3_division_mod(
make_tuple(sequence<1>{}, sequence<0>{})); make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{})),
make_merge_transform_v3_division_mod(
return a_lds_block_desc_m_k; make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
} make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
else // ColumnMajor A make_tuple(sequence<0>{}, sequence<1>{}));
{
// kfold and mpair dimension is not always required. return a_lds_block_desc;
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kM;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=kN0
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
? M0
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
K1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1; constexpr auto DataTypeSize = sizeof(BDataType);
constexpr auto NLdsLayer =
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value) (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
{
// NLdsLayer * K0 as logical Bank constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 make_tuple(number<KPerBlock / KPack * NLdsLayer>{},
? 1 number<NPerBlock / NLdsLayer>{},
: 32 * 4 / KPerBlock / sizeof(BDataType); number<KPack>{}),
; make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor( number<KPack>{},
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1), number<1>{});
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( b_lds_block_desc_0,
b_lds_block_desc, make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{}, number<KPerBlock / KPack * NLdsLayer>{})),
number<K0 * NLdsLayer>{})), make_pass_through_transform(number<KPack>{})),
make_pass_through_transform(K1)), make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}), make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor(
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor( b_lds_block_desc_permuted,
b_lds_block_desc_permuted, make_tuple(make_unmerge_transform(
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})), make_tuple(number<KPerBlock / KPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}), make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(K1)), make_pass_through_transform(number<KPack>{})),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor( constexpr auto b_lds_block_desc = transform_tensor_descriptor(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1, b_lds_block_desc_xk0_mnldslayer_mn_xk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)), make_tuple(make_merge_transform_v3_division_mod(
make_merge_transform_v3_division_mod( make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))), make_merge_transform_v3_division_mod(
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}), make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))),
make_tuple(sequence<1>{}, sequence<0>{})); make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k; return b_lds_block_desc;
}
else // RowMajor B
{
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kN;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=kN0
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
? N0
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
K1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
} }
template <typename Problem> template <typename Problem>
...@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -334,69 +180,268 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using ADataType = remove_cvref_t<typename Problem::ADataType>;
typename Problem::BDataType, using ALayout = remove_cvref_t<typename Problem::ALayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t K0 = KPerBlock / K1; {
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t K3 = total_pixels / M1;
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(KPack % K3 == 0);
return make_static_tile_distribution( constexpr index_t K2 = KPack / K3;
tile_distribution_encoding<sequence<1>, if constexpr(get_warp_size() % (K2 * M0) == 0)
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<1, 2>, static_assert(KPerBlock == K0 * K1 * K2 * K3);
sequence<0, 1>>{}); return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using BDataType = remove_cvref_t<typename Problem::BDataType>;
typename Problem::BDataType, using BLayout = remove_cvref_t<typename Problem::BLayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0), constexpr index_t BlockSize = Problem::kBlockSize;
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2), constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
TransposeC>; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1; constexpr index_t N0 = NPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = BlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = BlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
...@@ -3,19 +3,23 @@ ...@@ -3,19 +3,23 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <bool kPadA_, template <bool kPadM_,
bool kPadB_, bool kPadN_,
bool kPadC_, bool kPadK_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_> typename CLayout_>
struct TileGemmTraits struct TileGemmTraits
{ {
static constexpr bool kPadA = kPadA_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
......
...@@ -10,114 +10,134 @@ ...@@ -10,114 +10,134 @@
namespace ck_tile { namespace ck_tile {
// fp16 // fp16
using WarpGemmMfmaF16F16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>;
using WarpGemmMfmaF16F16F32M16N16K16 = using WarpGemmMfmaF16F16F32M32N32K8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16 = using WarpGemmMfmaF16F16F32M16N16K16 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K32 = using WarpGemmMfmaF16F16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplF16F16F32M16N16K16, 2>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 1>>; WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplF16F16F32M32N32K8, 2>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K16SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M32N32K8>>; WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplF16F16F32M16N16K16>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution = using WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplF16F16F32M16N16K16, WarpGemmAttributeMfmaImplF16F16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaF16F16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplF16F16F32M32N32K8, WarpGemmAttributeMfmaImplF16F16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
// bf16 // bf16
using WarpGemmMfmaBf16Bf16F32M32N32K8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = using WarpGemmMfmaBf16Bf16F32M32N32K8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16 = using WarpGemmMfmaBf16Bf16F32M32N32K16 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32 = using WarpGemmMfmaBf16Bf16F32M16N16K32 = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<
WarpGemmImpl<WarpGemmAtrributeMfmaIterateK<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, 2>>; WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA = WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 1>>; WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
1>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA =
WarpGemmAtrributeMfmaIterateK_SwizzleA<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, 2>>; WarpGemmImpl<WarpGemmAtrributeMfmaIterateK_SwizzleA<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution = WarpGemmImpl< using WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution<
WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16, WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16<WGAttrCtlEnum::Default_>,
2>>; 2>>;
using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8, WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8<WGAttrCtlEnum::Default_>,
2>>; 2>>;
// fp8 // fp8
using WarpGemmMfma_f32_32x32x16_fp8_fp8 =
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = using WarpGemmMfma_f32_32x32x16_fp8_fp8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8 = WarpGemmImpl<
WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8 = using WarpGemmMfma_f32_32x32x16_bf8_fp8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8 = using WarpGemmMfma_f32_32x32x16_bf8_bf8 = WarpGemmImpl<
WarpGemmImpl<WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmAtrributeMfma<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8<WGAttrCtlEnum::Default_>>>;
using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed = WarpGemmImpl< using WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed =
WarpGemmAtrributeMfmaTransposedCDistribution<WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8>>; WarpGemmImpl<WarpGemmAtrributeMfmaTransposedCDistribution<
WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8<WGAttrCtlEnum::Default_>>>;
template <index_t swizzle_factor = 2> template <index_t swizzle_factor = 2>
using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution = using WarpGemmMfmaFp8Fp8F32M32N32K16SwizzleBTransposedCDistribution =
WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB< WarpGemmImpl<WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB<
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>, WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, WGAttrCtlEnum::Default_>,
2, 2,
swizzle_factor>>; swizzle_factor>>;
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,9 +21,12 @@ struct WarpGemmAtrributeMfma ...@@ -21,9 +21,12 @@ struct WarpGemmAtrributeMfma
using BVecType = typename Impl::BVecType; using BVecType = typename Impl::BVecType;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -51,10 +54,13 @@ struct WarpGemmAtrributeMfma ...@@ -51,10 +54,13 @@ struct WarpGemmAtrributeMfma
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
Impl{}(c_vec, a_vec, b_vec); Impl{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -81,9 +87,12 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -81,9 +87,12 @@ struct WarpGemmAtrributeMfmaIterateK
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -111,8 +120,11 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -111,8 +120,11 @@ struct WarpGemmAtrributeMfmaIterateK
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -122,10 +134,33 @@ struct WarpGemmAtrributeMfmaIterateK ...@@ -122,10 +134,33 @@ struct WarpGemmAtrributeMfmaIterateK
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
...@@ -164,9 +199,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -164,9 +199,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
using BVecType = typename Impl::AVecType; using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -194,11 +232,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution ...@@ -194,11 +232,14 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
// swap A and B // swap A and B
Impl{}(c_vec, b_vec, a_vec); Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -222,9 +263,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -222,9 +263,12 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
using BVecType = typename Impl::AVecType; using BVecType = typename Impl::AVecType;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK; static constexpr index_t kK = Impl::kK;
static constexpr index_t kKPerThread = Impl::kABKPerLane;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return 1; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -255,12 +299,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB ...@@ -255,12 +299,15 @@ struct WarpGemmAtrributeMfmaTransposedCDistribution_SwizzleB
sequence<2, 2>, sequence<2, 2>,
sequence<0, 2>>; sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void CK_TILE_DEVICE void operator()(CVecType& c_vec,
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
// swap A and B // swap A and B
Impl{}(c_vec, b_vec, a_vec); Impl{}(c_vec, b_vec, a_vec, bool_constant<post_nop_>{});
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -287,9 +334,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -287,9 +334,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -316,9 +366,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -316,9 +366,12 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
sequence<2, 2>, sequence<2, 2>,
sequence<0, 2>>; sequence<0, 2>>;
template <bool post_nop_ = false>
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void CK_TILE_DEVICE void operator()(CVecType& c_vec,
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -328,10 +381,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution ...@@ -328,10 +381,34 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter], .template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]); .template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
// c_vec += a_vec * b_vec
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
...@@ -372,10 +449,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -372,10 +449,13 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kN; static constexpr index_t kM = Impl::kN;
static constexpr index_t kN = Impl::kM; static constexpr index_t kN = Impl::kM;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -429,8 +509,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -429,8 +509,11 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
sequence<0, 2>>; sequence<0, 2>>;
#endif #endif
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -440,10 +523,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB ...@@ -440,10 +523,33 @@ struct WarpGemmAtrributeMfmaIterateKAndTransposedCDistribution_SwizzleB
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter], .template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter]); .template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// swap A and B, value and type
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
...@@ -483,10 +589,13 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -483,10 +589,13 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>; ext_vector_t<BDataType, vector_traits<typename Impl::BVecType>::vector_size * kKIter>;
using CVecType = typename Impl::CVecType; using CVecType = typename Impl::CVecType;
static constexpr index_t kM = Impl::kM; static constexpr index_t kM = Impl::kM;
static constexpr index_t kN = Impl::kN; static constexpr index_t kN = Impl::kN;
static constexpr index_t kK = Impl::kK * kKIter; static constexpr index_t kK = Impl::kK * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together static constexpr index_t kKPerThread = Impl::kABKPerLane * kKIter;
static constexpr index_t SFactor = SFactor_; // group how many CM1 together
CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access() { return kKIter; }
using AWarpDstrEncoding = tile_distribution_encoding< using AWarpDstrEncoding = tile_distribution_encoding<
sequence<>, sequence<>,
...@@ -518,8 +627,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -518,8 +627,11 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
sequence<0, 2>>; sequence<0, 2>>;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>; using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>; using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
...@@ -529,10 +641,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA ...@@ -529,10 +641,33 @@ struct WarpGemmAtrributeMfmaIterateK_SwizzleA
reinterpret_cast<const buf_a&>(a_vec) reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter], .template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec) reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter]); .template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
}); });
} }
template <index_t iKIter, bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
number<iKIter>,
bool_constant<post_nop_> = {}) const
{
using buf_a = thread_buffer<typename Impl::AVecType, kKIter>;
using buf_b = thread_buffer<typename Impl::BVecType, kKIter>;
static_assert(iKIter < kKIter);
// static_for<0, kKIter, 1>{}([&](auto iKIter) {
Impl{}(c_vec,
reinterpret_cast<const buf_a&>(a_vec)
.template get_as<typename Impl::AVecType>()[iKIter],
reinterpret_cast<const buf_b&>(b_vec)
.template get_as<typename Impl::BVecType>()[iKIter],
bool_constant<post_nop_>{});
//});
}
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{ {
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -7,12 +7,68 @@ ...@@ -7,12 +7,68 @@
namespace ck_tile { namespace ck_tile {
// TODO: refactor warp-gemm
// currently there is a discrepency for vav/vva if we need transpose C/D
// e.g. if we want A:agpr, B:vgpr, we have to use vva in WGAttrEnum
// because we swap the A/B pointer in _impl code (but not known this info here)
enum class WGAttrCtlEnum
{
Default_ = 0,
Raw_vvv = 1, // c-vgpr, a-vgpr, b-vgpr
Raw_vaa = 2, // c-vgpr, a-agpr, b-agpr
Raw_vav = 3, // c-vgpr, a-agpr, b-vgpr
Raw_vva = 4, // c-vgpr, a-vgpr, b-agpr
Raw_avv = 5, // c-agpr, a-vgpr, b-vgpr
// raw_a_a_a = 3, // c-agpr, a-agpr, b-agpr
};
#define DISPATCH_MFMA_(mfma_, dmod_, amod_, bmod_, cmod_) \
if constexpr(post_nop_) \
{ \
asm volatile(mfma_ " %0, %1, %2, %3 ; yyy\n" \
"s_nop 3" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
} \
else \
{ \
asm volatile(mfma_ " %0, %1, %2, %3\n" \
: dmod_(c_vec) \
: amod_(a_vec), bmod_(b_vec), cmod_(c_vec) \
:); \
}
#define DISPATCH_MFMA_CTRL_(mfma_, ctrl_) \
if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vvv) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vaa) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vav) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "a", "v", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_vva) \
{ \
DISPATCH_MFMA_(mfma_, "+v", "v", "a", "v") \
} \
else if constexpr(ctrl_ == WGAttrCtlEnum::Raw_avv) \
{ \
DISPATCH_MFMA_(mfma_, "+a", "v", "v", "a") \
}
// FP16 // FP16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
{ {
using ADataType = fp16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = fp16_t; using ADataType = fp16_t;
using CDataType = float; using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>; using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>; using BVecType = ext_vector_t<fp16_t, 4>;
...@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -33,16 +89,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8f16", Ctrl)
else
{
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8 ...@@ -52,18 +115,20 @@ struct WarpGemmAttributeMfmaImplF16F16F32M32N32K8
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_32x32x8f16(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
{ {
using ADataType = fp16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = fp16_t; using ADataType = fp16_t;
using CDataType = float; using BDataType = fp16_t;
using CDataType = float;
using AVecType = ext_vector_t<fp16_t, 4>; using AVecType = ext_vector_t<fp16_t, 4>;
using BVecType = ext_vector_t<fp16_t, 4>; using BVecType = ext_vector_t<fp16_t, 4>;
...@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -84,16 +149,23 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16f16", Ctrl)
else
{
#if defined(__gfx9__) #if defined(__gfx9__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, c_vec, 0, 0, 0);
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16 ...@@ -103,19 +175,21 @@ struct WarpGemmAttributeMfmaImplF16F16F32M16N16K16
return bit_cast<CVecType>( return bit_cast<CVecType>(
__builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0)); __builtin_amdgcn_mfma_f32_16x16x16f16(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
// Bf16 // Bf16
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
{ {
using ADataType = bf16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = bf16_t; using ADataType = bf16_t;
using CDataType = float; using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>; using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>; using BVecType = ext_vector_t<bf16_t, 4>;
...@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -136,28 +210,35 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_32x32x8bf16_1k", Ctrl)
else
{
#if defined(__gfx90a__) || defined(__gfx94__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__) #elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) { static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16( c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec) reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec) reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec, c_vec,
0, 0,
0, 0,
0); 0);
}); });
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8 ...@@ -181,18 +262,20 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
}); });
return c_vec; return c_vec;
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
{ {
using ADataType = bf16_t; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = bf16_t; using ADataType = bf16_t;
using CDataType = float; using BDataType = bf16_t;
using CDataType = float;
using AVecType = ext_vector_t<bf16_t, 4>; using AVecType = ext_vector_t<bf16_t, 4>;
using BVecType = ext_vector_t<bf16_t, 4>; using BVecType = ext_vector_t<bf16_t, 4>;
...@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -213,28 +296,34 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
{
#if defined(__gfx90a__) || defined(__gfx94__) #if defined(__gfx90a__) || defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
#elif defined(__gfx908__) #elif defined(__gfx908__)
static_for<0, 2, 1>{}([&](auto k) { static_for<0, 2, 1>{}([&](auto k) {
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16( c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec) reinterpret_cast<const thread_buffer<ADataType, 4>&>(a_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec) reinterpret_cast<const thread_buffer<BDataType, 4>&>(b_vec)
.template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}], .template get_as<ext_vector_t<bf16_t, 2>>()[number<k>{}],
c_vec, c_vec,
0, 0,
0, 0,
0); 0);
}); });
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16 ...@@ -258,20 +347,21 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
}); });
return c_vec; return c_vec;
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
// FP8 // FP8
template <typename AType_, typename BType_> template <typename AType_, typename BType_, WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
{ {
using ADataType = AType_; static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using BDataType = BType_; using ADataType = AType_;
using CDataType = float; using BDataType = BType_;
using CDataType = float;
using AVecType = ext_vector_t<ADataType, 8>; using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>; using BVecType = ext_vector_t<BDataType, 8>;
...@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -292,38 +382,120 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
static constexpr index_t kCM1PerLane = 4; static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
CK_TILE_DEVICE void template <bool post_nop_ = false>
operator()(CVecType& c_vec, const AVecType& a_vec, const BVecType& b_vec) const CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{ {
if constexpr(Ctrl == WGAttrCtlEnum::Raw_vvv)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vaa)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "a", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vav)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "a", "v", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "a", "v", "v")
}
}
else if constexpr(Ctrl == WGAttrCtlEnum::Raw_vva)
{
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_fp8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_fp8_bf8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_fp8", "+v", "v", "a", "v")
}
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
{
DISPATCH_MFMA_("mfma_f32_32x32x16_bf8_bf8", "+v", "v", "a", "v")
}
}
else
{
#if defined(__gfx94__) #if defined(__gfx94__)
if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>) if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>) else if constexpr(std::is_same_v<ADataType, fp8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_fp8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>) else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, fp8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_fp8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>) else if constexpr(std::is_same_v<ADataType, bf8_t> && std::is_same_v<BDataType, bf8_t>)
c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8( c_vec = __builtin_amdgcn_mfma_f32_32x32x16_bf8_bf8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0); bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__) #elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) { static_for<0, 8, 1>{}([&](auto k) {
float a_f32 = float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec) type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]); .template get_as<ADataType>()[number<k>{}]);
float b_f32 = float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec) type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]); .template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0); c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
}); });
#else #else
ignore = c_vec; ck_tile::ignore = c_vec;
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
#endif #endif
}
} }
// c_vec = a_vec * b_vec // c_vec = a_vec * b_vec
...@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base ...@@ -356,20 +528,97 @@ struct WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base
}); });
return c_vec; return c_vec;
#else #else
ignore = a_vec; ck_tile::ignore = a_vec;
ignore = b_vec; ck_tile::ignore = b_vec;
return CVecType{0.f}; return CVecType{0.f};
#endif #endif
} }
}; };
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_fp8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<fp8_t, bf8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_fp8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, fp8_t, Ctrl_>;
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 = using WarpGemmAttributeMfmaImpl_f32_32x32x16_bf8_bf8 =
WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t>; WarpGemmAttributeMfmaImpl_f32_32x32x16_f8_base<bf8_t, bf8_t, Ctrl_>;
// int8
template <WGAttrCtlEnum Ctrl_ = WGAttrCtlEnum::Default_>
struct WarpGemmAttributeMfmaImpl_i32_32x32x16_i8
{
static constexpr WGAttrCtlEnum Ctrl = Ctrl_;
using ADataType = int8_t;
using BDataType = int8_t;
using CDataType = int32_t;
using AVecType = ext_vector_t<ADataType, 8>;
using BVecType = ext_vector_t<BDataType, 8>;
using CVecType = ext_vector_t<CDataType, 16>;
static constexpr index_t kM = 32;
static constexpr index_t kN = 32;
static constexpr index_t kK = 16;
static constexpr index_t kAMLane = 32;
static constexpr index_t kBNLane = 32;
static constexpr index_t kABKLane = 2;
static constexpr index_t kABKPerLane = 8;
static constexpr index_t kCMLane = 2;
static constexpr index_t kCNLane = 32;
static constexpr index_t kCM0PerLane = 4;
static constexpr index_t kCM1PerLane = 4;
// c_vec += a_vec * b_vec
template <bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CVecType& c_vec,
const AVecType& a_vec,
const BVecType& b_vec,
bool_constant<post_nop_> = {}) const
{
DISPATCH_MFMA_CTRL_("v_mfma_i32_32x32x16_i8", Ctrl)
else
{
#if defined(__gfx94__)
c_vec = __builtin_amdgcn_mfma_i32_32x32x8i8(
bit_cast<long>(a_vec), bit_cast<long>(b_vec), c_vec, 0, 0, 0);
#elif defined(__gfx908__) || defined(__gfx90a__)
static_for<0, 8, 1>{}([&](auto k) {
float a_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<ADataType, 8>&>(a_vec)
.template get_as<ADataType>()[number<k>{}]);
float b_f32 =
type_convert<float>(reinterpret_cast<const thread_buffer<BDataType, 8>&>(b_vec)
.template get_as<BDataType>()[number<k>{}]);
c_vec = __builtin_amdgcn_mfma_f32_32x32x2f32(a_f32, b_f32, c_vec, 0, 0, 0);
});
#else
ck_tile::ignore = c_vec;
ck_tile::ignore = a_vec;
ck_tile::ignore = b_vec;
#endif
}
}
// c_vec = a_vec * b_vec
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
{
CVecType c_vec{0};
operator()(c_vec, a_vec, b_vec);
return c_vec;
}
};
#undef DISPATCH_MFMA_
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher; ...@@ -21,40 +21,40 @@ struct WarpGemmMfmaDispatcher;
// clang-format off // clang-format off
// fp16 // fp16
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaF16F16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaF16F16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaF16F16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaF16F16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaF16F16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaF16F16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<half_t, half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::half_t, ck_tile::half_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaF16F16F32M32N32K16SwizzleA; };
// bf16 // bf16
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 16, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K16TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, false> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 16, 16, 32, true> { using Type = WarpGemmMfmaBf16Bf16F32M16N16K32TransposedCDistribution; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 8, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K8SwizzleA; };
template<> struct WarpGemmMfmaDispatcher<bf16_t, bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf16_t, ck_tile::bf16_t, float, 32, 32, 16, false, true> { using Type = WarpGemmMfmaBf16Bf16F32M32N32K16SwizzleA; };
// fp8 // fp8
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8; };
template<> struct WarpGemmMfmaDispatcher<fp8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::fp8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_fp8_bf8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::fp8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_fp8_CTransposed; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, false> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8; };
template<> struct WarpGemmMfmaDispatcher<bf8_t, bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; }; template<> struct WarpGemmMfmaDispatcher<ck_tile::bf8_t, ck_tile::bf8_t, float, 32, 32, 16, true> { using Type = WarpGemmMfma_f32_32x32x16_bf8_bf8_CTransposed; };
// clang-format on // clang-format on
} // namespace impl } // namespace impl
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2023, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -14,6 +14,11 @@ struct WarpGemmImpl ...@@ -14,6 +14,11 @@ struct WarpGemmImpl
static constexpr index_t kM = WarpGemmAttribute::kM; static constexpr index_t kM = WarpGemmAttribute::kM;
static constexpr index_t kN = WarpGemmAttribute::kN; static constexpr index_t kN = WarpGemmAttribute::kN;
static constexpr index_t kK = WarpGemmAttribute::kK; static constexpr index_t kK = WarpGemmAttribute::kK;
/// @brief The number of elements in K dimension processed by single thread in wavefront.
///
/// @note Note that WarpGemm may run MFMA instruction multiple times (on different K).
/// In such situation this value reflects this fact.
static constexpr index_t kKPerThread = WarpGemmAttribute::kKPerThread;
using ADataType = typename WarpGemmAttribute::ADataType; using ADataType = typename WarpGemmAttribute::ADataType;
using BDataType = typename WarpGemmAttribute::BDataType; using BDataType = typename WarpGemmAttribute::BDataType;
...@@ -31,11 +36,21 @@ struct WarpGemmImpl ...@@ -31,11 +36,21 @@ struct WarpGemmImpl
using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>; using BWarpTensor = static_distributed_tensor<BDataType, BWarpDstr>;
using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>; using CWarpTensor = static_distributed_tensor<CDataType, CWarpDstr>;
CK_TILE_DEVICE void operator()(CWarpTensor& c, const AWarpTensor& a, const BWarpTensor& b) const CK_TILE_HOST_DEVICE static constexpr auto get_num_of_access()
{ {
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>; return WarpGemmAttribute_::get_num_of_access();
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>; }
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>;
template <typename CTensor, typename ATensor, typename BTensor, bool post_nop_ = false>
CK_TILE_DEVICE void
operator()(CTensor& c, const ATensor& a, const BTensor& b, bool_constant<post_nop_> = {}) const
{
static_assert(detail::is_similiar_distributed_tensor_v<CTensor, CWarpTensor> &&
detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
...@@ -44,18 +59,49 @@ struct WarpGemmImpl ...@@ -44,18 +59,49 @@ struct WarpGemmImpl
auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0]; auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec // c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec); WarpGemmAttribute{}(c_vec, a_vec, b_vec, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec); c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
} }
CK_TILE_DEVICE auto operator()(const AWarpTensor& a, const BWarpTensor& b) const template <typename CTensor,
typename ATensor,
typename BTensor,
index_t i_subk,
bool post_nop_ = false>
CK_TILE_DEVICE void operator()(CTensor& c,
const ATensor& a,
const BTensor& b,
number<i_subk>,
bool_constant<post_nop_> = {}) const
{ {
CWarpTensor c; using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{};
using AVec = ext_vector_t<ADataType, AWarpTensor::get_thread_buffer_size()>; const auto a_vec = a.get_thread_buffer().template get_as<AVec>()[I0];
using BVec = ext_vector_t<BDataType, BWarpTensor::get_thread_buffer_size()>; const auto b_vec = b.get_thread_buffer().template get_as<BVec>()[I0];
using CVec = ext_vector_t<CDataType, CWarpTensor::get_thread_buffer_size()>; auto c_vec = c.get_thread_buffer().template get_as<CVec>()[I0];
// c_vec += a_vec * b_vec
WarpGemmAttribute{}(c_vec, a_vec, b_vec, number<i_subk>{}, bool_constant<post_nop_>{});
c.get_thread_buffer().template set_as<CVec>(I0, c_vec);
}
template <typename ATensor, typename BTensor>
CK_TILE_DEVICE auto operator()(const ATensor& a, const BTensor& b) const
{
using CTensor = CWarpTensor;
static_assert(detail::is_similiar_distributed_tensor_v<ATensor, AWarpTensor> &&
detail::is_similiar_distributed_tensor_v<BTensor, BWarpTensor>);
CTensor c;
using AVec = ext_vector_t<ADataType, ATensor::get_thread_buffer_size()>;
using BVec = ext_vector_t<BDataType, BTensor::get_thread_buffer_size()>;
using CVec = ext_vector_t<CDataType, CTensor::get_thread_buffer_size()>;
constexpr auto I0 = number<0>{}; constexpr auto I0 = number<0>{};
......
...@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass ...@@ -121,7 +121,7 @@ struct Layernorm2dFwdPipelineOnePass
auto [mean, var] = block_welford(acc, cur_count, max_count); auto [mean, var] = block_welford(acc, cur_count, max_count);
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count); block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std // compute inv-std
auto inv_std = tile_elementwise_in( auto inv_std = tile_elementwise_in(
......
...@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -35,6 +35,7 @@ struct Layernorm2dFwdPipelineTwoPass
static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync; static constexpr bool kNeedCrossWarpSync = Problem::kNeedCrossWarpSync;
static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM static constexpr bool kPadM = false; // TODO - BlockLayernorm2dFwdProblem::kPadM
static constexpr bool kPadN = Problem::Traits::kPadN; static constexpr bool kPadN = Problem::Traits::kPadN;
static constexpr bool kFastFDiv = Problem::Traits::kFastFDiv;
static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd; static constexpr auto kFusedAdd = Problem::Traits::kFusedAdd;
static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant; static constexpr auto kFusedQuant = Problem::Traits::kFusedQuant;
...@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass ...@@ -137,15 +138,22 @@ struct Layernorm2dFwdPipelineTwoPass
block_welford_sync(mean, var, cur_count); block_welford_sync(mean, var, cur_count);
block_welford_cross_warp_sync(mean, var, cur_count, smem); block_welford_cross_warp_sync(mean, var, cur_count, smem);
block_tile_welford_post_scale_var(var, cur_count); block_tile_welford_post_scale_var(var, cur_count, constant<kFastFDiv>{});
// compute inv-std // compute inv-std
auto inv_std = tile_elementwise_in( auto inv_std = tile_elementwise_in(
[&](const auto& v_) { [&](const auto& v_) {
return type_convert<ComputeDataType>(1.0f) / (sqrt(v_ + epsilon)); if(kFastFDiv && std::is_same_v<ComputeDataType, float>)
{
return type_convert<ComputeDataType>(1.0f) *
__builtin_amdgcn_rcpf(sqrt(v_ + epsilon));
}
else
{
return type_convert<ComputeDataType>(1.0f) / sqrt(v_ + epsilon);
}
}, },
var); var);
if constexpr(kSaveMean) if constexpr(kSaveMean)
store_tile(mean_window, cast_tile<MeanDataType>(mean)); store_tile(mean_window, cast_tile<MeanDataType>(mean));
if constexpr(kSaveInvStd) if constexpr(kSaveInvStd)
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "ck_tile/ops/smoothquant/kernel/moe_smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp" #include "ck_tile/ops/smoothquant/kernel/smoothquant_kernel.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_default_policy.hpp"
#include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp" #include "ck_tile/ops/smoothquant/pipeline/smoothquant_pipeline_one_pass.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment