Unverified Commit 2b6458dd authored by Thomas Ning's avatar Thomas Ning Committed by GitHub
Browse files

[CK Tile] Improve the Layout, Padding, and Alignment features of CK Tile GEMM (#1651)

* Finished the feature

* Modified the test file

* Test case update

* addresss comment

* Addressed the review comment

* Fixed the CI error
parent 5fb150db
...@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat ...@@ -8,7 +8,10 @@ This folder contains example for GEMM using ck_tile tile-programming implementat
mkdir build && cd build mkdir build && cd build
# you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank # you can replace <arch> with the appropriate architecture (for example gfx90a or gfx942) or leave it blank
sh ../script/cmake-ck-dev.sh ../ <arch> sh ../script/cmake-ck-dev.sh ../ <arch>
# The basic pipeline method on the gemm calculation
make tile_example_gemm_basic -j make tile_example_gemm_basic -j
# The memory bound pipeline on the gemm calculation
make tile_example_gemm_mem_pipeline -j
``` ```
This will result in an executable `build/bin/tile_example_gemm_basic` This will result in an executable `build/bin/tile_example_gemm_basic`
......
...@@ -17,10 +17,11 @@ ...@@ -17,10 +17,11 @@
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
{ {
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadM, kPadN, kPadK & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadM = false;
constexpr bool kPadB = true; constexpr bool kPadN = false;
constexpr bool kPadC = true; constexpr bool kPadK = false;
constexpr bool kTilePermute = false; constexpr bool kTilePermute = false;
// The rank and permutation will also be generate out by the CodeGen part. // The rank and permutation will also be generate out by the CodeGen part.
constexpr ck_tile::index_t kOutputRank = 2; constexpr ck_tile::index_t kOutputRank = 2;
...@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -56,8 +57,8 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
CShuffleEpilogue, CShuffleEpilogue,
ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType, ck_tile::CShuffleEpilogue<ck_tile::CShuffleEpilogueProblem<AccDataType,
CDataType, CDataType,
kPadA, kPadM,
kPadB, kPadN,
kTilePermute, kTilePermute,
kOutputRank, kOutputRank,
1, 1,
...@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -65,13 +66,13 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
TilePartitioner::kM, TilePartitioner::kM,
TilePartitioner::kN>>, TilePartitioner::kN>>,
ck_tile::Default2DEpilogue< ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadA, kPadB>>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>>;
using CodegenGemmTraits = using CodegenGemmTraits =
ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using CodegenPipelineProblem = ck_tile:: using CodegenPipelineProblem = ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>; GemmPipelineProblem<ADataType, BDataType, AccDataType, CodegenGemmShape, CodegenGemmTraits>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy<ALayout, BLayout, CLayout>; using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>; ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem, CodegenGemmPolicy>;
// ToDo: Will add the codegen part to test different pipeline policies in GEMM. // ToDo: Will add the codegen part to test different pipeline policies in GEMM.
......
...@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -31,9 +31,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
// The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part. // The kPadA, kPadB, kPadC & kBlockPerCu should also come from the Codegen part.
constexpr bool kPadA = true; constexpr bool kPadM = true;
constexpr bool kPadB = true; constexpr bool kPadN = true;
constexpr bool kPadC = true; constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s) ...@@ -46,9 +46,9 @@ float gemm_calc(const gemm_basic_args& args, const ck_tile::stream_config& s)
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
......
...@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in) ...@@ -170,7 +170,7 @@ CK_TILE_DEVICE void shuffle_tile(OutTensor& out, const InTensor& in)
} }
else else
{ {
// NOT implemented static_assert(false, "The shuffle should always happen!");
} }
} }
......
...@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -863,6 +863,8 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t N1 = kBlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = kNPerBlock / (N2 * N1); constexpr index_t N0 = kNPerBlock / (N2 * N1);
static_assert(N0 != 0); static_assert(N0 != 0);
......
...@@ -115,12 +115,22 @@ struct GemmKernel ...@@ -115,12 +115,22 @@ struct GemmKernel
} }
}(); }();
auto a_pad_view = pad_tensor_view( auto a_pad_view = [&]() {
a_tensor_view, if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), {
// somehow clang-format is splitting below line into multiple. return pad_tensor_view(
// clang-format off a_tensor_view,
sequence<false, GemmPipeline::kPadA>{}); make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
// clang-format on // clang-format on
auto a_block_window = make_tile_window( auto a_block_window = make_tile_window(
...@@ -128,12 +138,22 @@ struct GemmKernel ...@@ -128,12 +138,22 @@ struct GemmKernel
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
{i_m, 0}); {i_m, 0});
auto b_pad_view = pad_tensor_view( auto b_pad_view = [&]() {
b_tensor_view, if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), {
// clang-format off return pad_tensor_view(
sequence<false, GemmPipeline::kPadB>{}); b_tensor_view,
// clang-format on make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<false, GemmPipeline::kPadK>{});
}
else
{
return pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence<GemmPipeline::kPadN, false>{});
}
}();
auto b_block_window = make_tile_window( auto b_block_window = make_tile_window(
b_pad_view, b_pad_view,
...@@ -171,18 +191,28 @@ struct GemmKernel ...@@ -171,18 +191,28 @@ struct GemmKernel
} }
}(); }();
auto c_pad_view = pad_tensor_view( auto c_pad_view = [&]() {
c_tensor_view, if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), {
// clang-format off return pad_tensor_view(
sequence<false, GemmPipeline::kPadC>{}); c_tensor_view,
// clang-format on make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
auto c_block_window = make_tile_window( sequence<false, GemmPipeline::kPadN>{});
}
else
{
return pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence<GemmPipeline::kPadM, false>{});
}
}();
auto CBlockWindow_pad = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
{i_m, i_n}); {i_m, i_n});
EpiloguePipeline{}(c_block_window, c_block_tile); EpiloguePipeline{}(CBlockWindow_pad, c_block_tile);
} }
}; };
......
...@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -113,9 +113,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadK = Problem::kPadK;
// Where is the right place for HasHotLoop and TailNum ??? // Where is the right place for HasHotLoop and TailNum ???
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
......
...@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -33,9 +33,9 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Problem::VectorSizeB;
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Problem::VectorSizeC;
static constexpr bool kPadA = Problem::kPadA; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadB = Problem::kPadB; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadC = Problem::kPadC; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize() CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{ {
...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -101,11 +101,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeADramTileDistribution<Problem>()); Policy::template MakeADramTileDistribution<Problem>());
// A LDS tile window for store // A LDS tile window for store
auto a_copy_lds_window = auto a_copy_lds_window = make_tile_window(
make_tile_window(a_lds_block, a_lds_block, make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kMPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
a_copy_dram_window.get_tile_distribution());
// B DRAM tile window for load // B DRAM tile window for load
auto b_copy_dram_window = auto b_copy_dram_window =
...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -115,11 +112,8 @@ struct GemmPipelineAGmemBGmemCRegV1
Policy::template MakeBDramTileDistribution<Problem>()); Policy::template MakeBDramTileDistribution<Problem>());
// B LDS tile window for store // B LDS tile window for store
auto b_copy_lds_window = auto b_copy_lds_window = make_tile_window(
make_tile_window(b_lds_block, b_lds_block, make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}), {0, 0});
make_tuple(number<kNPerBlock>{}, number<kKPerBlock>{}),
{0, 0},
b_copy_dram_window.get_tile_distribution());
// A LDS tile for block GEMM // A LDS tile for block GEMM
auto a_lds_gemm_window = make_tile_window( auto a_lds_gemm_window = make_tile_window(
...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -149,12 +143,32 @@ struct GemmPipelineAGmemBGmemCRegV1
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_block_tile); if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
store_tile(a_copy_lds_window, a_block_tile_tmp); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
}
else
{
store_tile(a_copy_lds_window, tile_elementwise_in(a_element_func, a_block_tile));
}
// LDS write 0 // LDS write 0
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
else
{
store_tile(b_copy_lds_window, tile_elementwise_in(b_element_func, b_block_tile));
}
} }
index_t iCounter = num_loop - 1; index_t iCounter = num_loop - 1;
...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -180,8 +194,19 @@ struct GemmPipelineAGmemBGmemCRegV1
store_tile(a_copy_lds_window, a_block_tile_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp);
// LDS write i + 1 // LDS write i + 1
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile); if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
store_tile(b_copy_lds_window, b_block_tile_tmp); {
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
}
else
{
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_block_tile);
store_tile(b_copy_lds_window, b_block_tile_tmp);
}
iCounter--; iCounter--;
} }
......
...@@ -11,6 +11,7 @@ namespace ck_tile { ...@@ -11,6 +11,7 @@ namespace ck_tile {
// Default policy class should not be templated, put template on member functions instead // Default policy class should not be templated, put template on member functions instead
struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
{ {
#if 0 #if 0
// 2d // 2d
template <typename Problem> template <typename Problem>
...@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -116,6 +117,20 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
return smem_size; return smem_size;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
#elif 1 #elif 1
// fake XOR // fake XOR
template <typename Problem> template <typename Problem>
...@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -192,80 +207,269 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = kKPerBlock / K1; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t M2 = get_warp_size() / K0; {
#if 1 // coalesce reading for each blocks constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M1 = kBlockSize / get_warp_size(); constexpr index_t M0 = MPerBlock / M1;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
constexpr index_t M0 = kMPerBlock / (M2 * M1); constexpr index_t K3 = total_pixels / M1;
constexpr index_t KPack = GetSmemPackA<Problem>();
return make_static_tile_distribution( static_assert(KPack % K3 == 0);
tile_distribution_encoding<sequence<1>, constexpr index_t K2 = KPack / K3;
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, if constexpr(get_warp_size() % (K2 * M0))
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
sequence<1, 2>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<0, 1>>{}); static_assert(KPerBlock == K0 * K1 * K2 * K3);
#else // coalesce reading for each warps return make_static_tile_distribution(
constexpr index_t M0 = kBlockSize / get_warp_size(); tile_distribution_encoding<sequence<1>,
constexpr index_t M1 = kMPerBlock / (M2 * M0); tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
return make_static_tile_distribution( tuple<sequence<0>, sequence<1, 0, 2>>,
tile_distribution_encoding<sequence<1>, sequence<2, 1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, sequence<3, 1>>{});
tuple<sequence<1>, sequence<1, 2>>, }
tuple<sequence<0>, sequence<2, 0>>, else
sequence<1, 2>, {
sequence<1, 1>>{}); constexpr index_t K1 = (K2 * M0) / get_warp_size();
#endif constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = 16 / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = 16 / sizeof(BDataType); constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t N0 = kNPerBlock / N1;
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
#if 1 // coalesce reading for each blocks static_assert(total_pixels % N1 == 0);
constexpr index_t N1 = kBlockSize / get_warp_size(); constexpr index_t K3 = total_pixels / N1;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); static_assert(kKPack % K3 == 0);
constexpr index_t N0 = kNPerBlock / (N2 * N1); constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
return make_static_tile_distribution( if constexpr(warp_size % (K2 * N0) == 0)
tile_distribution_encoding<sequence<1>, {
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, constexpr index_t K1 = warp_size / (K2 * N0);
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K0 = kBlockSize / warp_size;
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>, return make_static_tile_distribution(
sequence<0, 1>>{}); tile_distribution_encoding<sequence<1>,
#else // coalesce reading for each warps tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
constexpr index_t N0 = kBlockSize / get_warp_size(); tuple<sequence<2>, sequence<2, 1, 2>>,
constexpr index_t N1 = kNPerBlock / (N2 * N0); tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
return make_static_tile_distribution( sequence<1, 3>>{});
tile_distribution_encoding<sequence<1>, }
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, else
tuple<sequence<1>, sequence<1, 2>>, {
tuple<sequence<0>, sequence<2, 0>>, constexpr index_t K1 = (K2 * N0) / get_warp_size();
sequence<1, 2>, constexpr index_t K2_m = K2 / K1;
sequence<1, 1>>{}); constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
#endif static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = kMPerBlock / M1;
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackA<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
...@@ -3,40 +3,133 @@ ...@@ -3,40 +3,133 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
static constexpr int _VectorSize = 16;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename TileGemmTraits_>
struct GemmPipelineProblem struct GemmPipelineProblemBase
{ {
using ADataType = remove_cvref_t<ADataType_>; using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>; using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>; using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>; using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size(); static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
static constexpr bool kPadA = GemmTraits::kPadA; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC; static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kM * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(ADataType)
? pixels_per_thread
: VectorLoadSize / sizeof(ADataType);
}
else
{
return VectorLoadSize / sizeof(ADataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentB()
{
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
constexpr index_t pixels_per_thread =
BlockGemmShape::kN * BlockGemmShape::kK / kBlockSize;
return pixels_per_thread < VectorLoadSize / sizeof(BDataType)
? pixels_per_thread
: VectorLoadSize / sizeof(BDataType);
}
else
{
return VectorLoadSize / sizeof(BDataType);
}
}
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentC()
{
if constexpr(std::is_same_v<CLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{
constexpr index_t N1 = kBlockSize / get_warp_size();
constexpr index_t N2 = std::min(BlockGemmShape::kN / N1, get_warp_size());
constexpr index_t M0 = get_warp_size() / N2;
constexpr index_t M1 = BlockGemmShape::kM / M0;
static constexpr index_t VectorSizeA = kPadA ? 1 : _VectorSize / sizeof(ADataType); return std::min(M1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
static constexpr index_t VectorSizeB = kPadB ? 1 : _VectorSize / sizeof(BDataType); }
static constexpr index_t VectorSizeC = kPadC ? 1 : _VectorSize / sizeof(CDataType); else
{
constexpr index_t M1 = kBlockSize / get_warp_size();
constexpr index_t M2 = std::min(BlockGemmShape::kM / M1, get_warp_size());
constexpr index_t N0 = get_warp_size() / M2;
constexpr index_t N1 = BlockGemmShape::kN / N0;
return std::min(N1, static_cast<index_t>(VectorLoadSize / sizeof(CDataType)));
}
}
static constexpr index_t VectorSizeA = []() {
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::RowMajor>)
{
return kPadK ? 1 : GetAlignmentA();
}
else
{
return kPadM ? 1 : GetAlignmentA();
}
}();
static constexpr index_t VectorSizeB = []() {
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::ColumnMajor>)
{
return kPadN ? 1 : GetAlignmentB();
}
else
{
return kPadK ? 1 : GetAlignmentB();
}
}();
static constexpr index_t VectorSizeC = []() {
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
return kPadN ? 1 : GetAlignmentC();
}
else
{
return kPadM ? 1 : GetAlignmentC();
}
}();
}; };
// Alias for GemmPipelineProblem
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
...@@ -45,30 +138,15 @@ template <typename ADataType_, ...@@ -45,30 +138,15 @@ template <typename ADataType_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
{ {
using ADataType = remove_cvref_t<ADataType_>; static constexpr auto Scheduler = Scheduler_;
using BDataType = remove_cvref_t<BDataType_>; static constexpr auto HasHotLoop = HasHotLoop_;
using CDataType = remove_cvref_t<CDataType_>; static constexpr auto TailNum = TailNum_;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadA = GemmTraits::kPadA;
static constexpr bool kPadB = GemmTraits::kPadB;
static constexpr bool kPadC = GemmTraits::kPadC;
static constexpr index_t VectorSizeA = kPadA ? _VectorSize / sizeof(ADataType) : 1;
static constexpr index_t VectorSizeB = kPadB ? _VectorSize / sizeof(BDataType) : 1;
static constexpr index_t VectorSizeC = kPadC ? _VectorSize / sizeof(CDataType) : 1;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -9,12 +9,8 @@ ...@@ -9,12 +9,8 @@
namespace ck_tile { namespace ck_tile {
// UniversalGemm Policy // UniversalGemm Policy
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
struct UniversalGemmPipelineAgBgCrPolicy struct UniversalGemmPipelineAgBgCrPolicy
{ {
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr auto I0 = number<0>{}; static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
...@@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -34,13 +30,14 @@ struct UniversalGemmPipelineAgBgCrPolicy
TransposeC>; TransposeC>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1; constexpr index_t K0 = KPerBlock / K1;
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value) if constexpr(std::is_same<tensor_layout::gemm::RowMajor, ALayout>::value)
{ {
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1 constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1 ? 1
...@@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -176,13 +173,15 @@ struct UniversalGemmPipelineAgBgCrPolicy
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1; constexpr index_t K0 = KPerBlock / K1;
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value) if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, BLayout>::value)
{ {
// NLdsLayer * K0 as logical Bank // NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1 constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
...@@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -331,72 +330,285 @@ struct UniversalGemmPipelineAgBgCrPolicy
return smem_size; return smem_size;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using ADataType = remove_cvref_t<typename Problem::ADataType>;
typename Problem::BDataType, using ALayout = remove_cvref_t<typename Problem::ALayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
constexpr index_t K0 = KPerBlock / K1; {
constexpr index_t M2 = get_warp_size() / K0; constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t M1 = BlockSize / get_warp_size(); constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(total_pixels % M1 == 0);
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t K3 = total_pixels / M1;
constexpr index_t M0 = MPerBlock / (M2 * M1); constexpr index_t KPack = GetSmemPackA<Problem>();
static_assert(KPack % K3 == 0);
return make_static_tile_distribution( constexpr index_t K2 = KPack / K3;
tile_distribution_encoding<sequence<1>, if constexpr(get_warp_size() % (K2 * M0) == 0)
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K1 = get_warp_size() / (K2 * M0);
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<1, 2>, static_assert(KPerBlock == K0 * K1 * K2 * K3);
sequence<0, 1>>{}); return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
if constexpr(get_warp_size() % (M2 * K0) == 0)
{
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using BDataType = remove_cvref_t<typename Problem::BDataType>;
typename Problem::BDataType, using BLayout = remove_cvref_t<typename Problem::BLayout>;
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK; if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
constexpr index_t K0 = KPerBlock / K1; {
constexpr index_t N2 = get_warp_size() / K0; constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t N1 = BlockSize / get_warp_size(); constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error."); static_assert(total_pixels % N1 == 0);
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error."); constexpr index_t K3 = total_pixels / N1;
constexpr index_t N0 = NPerBlock / (N2 * N1); constexpr index_t KPack = GetSmemPackB<Problem>();
static_assert(KPack % K3 == 0);
return make_static_tile_distribution( constexpr index_t K2 = KPack / K3;
tile_distribution_encoding<sequence<1>, if constexpr(get_warp_size() % (K2 * N0) == 0)
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>, {
tuple<sequence<1>, sequence<1, 2>>, constexpr index_t K1 = get_warp_size() / (K2 * N0);
tuple<sequence<1>, sequence<2, 0>>, constexpr index_t K0 = BlockSize / get_warp_size();
sequence<1, 2>, static_assert(KPerBlock == K0 * K1 * K2 * K3);
sequence<0, 1>>{}); return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
}
else
{
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
// coalesce reading for each blocks
if constexpr(get_warp_size() % (N2 * K0) == 0)
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1;
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % M1 == 0);
constexpr index_t K3 = total_pixels / M1;
constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1;
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize;
static_assert(total_pixels % N1 == 0);
constexpr index_t K3 = total_pixels / N1;
constexpr index_t kKPack = GetSmemPackB<Problem>();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * N0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * N0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
......
...@@ -3,19 +3,23 @@ ...@@ -3,19 +3,23 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
namespace ck_tile { namespace ck_tile {
template <bool kPadA_, template <bool kPadM_,
bool kPadB_, bool kPadN_,
bool kPadC_, bool kPadK_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_> typename CLayout_>
struct TileGemmTraits struct TileGemmTraits
{ {
static constexpr bool kPadA = kPadA_; static constexpr bool kPadM = kPadM_;
static constexpr bool kPadB = kPadB_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadC = kPadC_; static constexpr bool kPadK = kPadK_;
static constexpr int _VectorSize = 16;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
......
...@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -53,9 +53,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
constexpr ck_tile::index_t N_Warp_Tile = 32; constexpr ck_tile::index_t N_Warp_Tile = 32;
constexpr ck_tile::index_t K_Warp_Tile = 8; constexpr ck_tile::index_t K_Warp_Tile = 8;
constexpr bool kPadA = true; constexpr bool kPadM = true;
constexpr bool kPadB = true; constexpr bool kPadN = true;
constexpr bool kPadC = true; constexpr bool kPadK = true;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
...@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -68,9 +68,9 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>; using TilePartitioner = ck_tile::GemmTilePartitioner<GemmShape>;
using GemmEpilogue = ck_tile::Default2DEpilogue< using GemmEpilogue = ck_tile::Default2DEpilogue<
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, false, kPadC>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadA, kPadB, kPadC, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem< using BaseGemmPipeline = ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>; ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>;
...@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -108,7 +108,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
if(s.log_level_ > 0) if(s.log_level_ > 0)
{ {
std::cout << "Lunching kernel with args:" std::cout << "Launching kernel with args:"
<< " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}" << " grid: {" << grids.x << ", " << grids.y << ", " << grids.z << "}"
<< ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z << ", blocks: {" << blocks.x << ", " << blocks.y << ", " << blocks.z
<< "}" << std::endl; << "}" << std::endl;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment