Commit b2c7d774 authored by ThomasNing's avatar ThomasNing
Browse files

Add the changes from include/ck_tile

parent d1e71770
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marke
-Werror -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
...@@ -34,4 +35,3 @@ ...@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -32,11 +32,11 @@ ...@@ -32,11 +32,11 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
......
...@@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{ {
...@@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode; return a_block_dstr_encode;
} }
...@@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode; return b_block_dstr_encode;
} }
...@@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode; return c_block_dstr_encode;
} }
...@@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1 ...@@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1
.get_static_tile_distribution_encoding())>>, .get_static_tile_distribution_encoding())>>,
"C distribution is wrong!"); "C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr; using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor; using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor; using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths = constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor; return c_block_tensor;
......
...@@ -45,17 +45,17 @@ struct GemmPipelineAgBgCrImplBase ...@@ -45,17 +45,17 @@ struct GemmPipelineAgBgCrImplBase
{ {
load_tile(dst_block_tile, lds_tile_window); load_tile(dst_block_tile, lds_tile_window);
} }
CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const CK_TILE_DEVICE auto GetABLdsTensorViews(void* p_smem) const
{ {
// A tile in LDS // A tile in LDS
ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem); ADataType* __restrict__ p_a_lds = static_cast<ADataType*>(p_smem);
constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>(); constexpr auto a_lds_block_desc = Policy::template MakeALdsBlockDescriptor<Problem>();
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy! // TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
// B tile in LDS // B tile in LDS
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>( BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
......
...@@ -72,8 +72,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -72,8 +72,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer; static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer;
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
......
...@@ -9,8 +9,30 @@ ...@@ -9,8 +9,30 @@
namespace ck_tile { namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV4
{
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Two;
}
};
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>; using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>; using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
...@@ -35,9 +57,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -35,9 +57,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>();
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>();
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -54,7 +76,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -54,7 +76,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
...@@ -115,12 +140,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -115,12 +140,12 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ignore = i; ignore = i;
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x100, num_ds_read_inst / num_issue, 0); // DS read : 2 0x100, num_ds_read_inst / num_issue, 0); // DS read : 2
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA: 1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x200, num_ds_write_inst / num_issue, 0); // DS write : 1 0x200, num_ds_write_inst / num_issue, 0); // DS write : 1
__builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1 __builtin_amdgcn_sched_group_barrier(0x008, 1, 0); // MFMA : 1
__builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1 __builtin_amdgcn_sched_group_barrier(0x020, 1, 0); // VMEM read :1
__builtin_amdgcn_sched_group_barrier( __builtin_amdgcn_sched_group_barrier(
0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5 0x008, C_MFMA_Inst_Num / num_issue - 3, 0); // MFMA : 5
}); });
...@@ -147,11 +172,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -147,11 +172,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && constexpr bool is_a_col_major =
NPerBlock == std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register ///////////////// ////////////// global window & register /////////////////
// A DRAM tile window for load // A DRAM tile window for load
...@@ -176,37 +212,33 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -176,37 +212,33 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile a_global_load_tile; ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile; BBlockTile b_global_load_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// global prefetch 0 // global prefetch 0
// global read 0 // global read 0
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
////////////// LDS desc, window & register ///////////////// ////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
auto a_copy_lds_window0 = auto a_copy_lds_window0 = make_tile_window(
make_tile_window(a_lds_block0, a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0}, auto a_copy_lds_window1 = make_tile_window(
ABlockTileDistr); a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto a_copy_lds_window1 = auto b_copy_lds_window0 = make_tile_window(
make_tile_window(a_lds_block1, b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0}, auto b_copy_lds_window1 = make_tile_window(
ABlockTileDistr); b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
auto b_copy_lds_window0 =
make_tile_window(b_lds_block0,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BBlockTileDistr);
auto b_copy_lds_window1 =
make_tile_window(b_lds_block1,
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BBlockTileDistr);
// Block GEMM // Block GEMM
auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
...@@ -216,11 +248,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -216,11 +248,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
// global read 1 // global read 1
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds(); block_sync_lds();
...@@ -262,11 +315,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -262,11 +315,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
if(HasHotLoop) if(HasHotLoop)
{ {
...@@ -280,11 +353,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -280,11 +353,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Policy::template MakeShuffledARegTileDistribution<Problem>());
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window0, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm // gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler(); HotLoopScheduler();
...@@ -296,11 +393,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -296,11 +393,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Policy::template MakeShuffledARegTileDistribution<Problem>());
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(
a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm // gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler(); HotLoopScheduler();
...@@ -318,8 +439,28 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -318,8 +439,28 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_sync_lds(); block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
// 2 // 2
......
...@@ -5,6 +5,7 @@ ...@@ -5,6 +5,7 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
// Default policy for GemmPipelineAGmemBGmemCRegV1 // Default policy for GemmPipelineAGmemBGmemCRegV1
...@@ -15,9 +16,172 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -15,9 +16,172 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true; static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked;
/**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
// Assume DataType is even!
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0)
{
return (16 / sizeof(DataType));
}
else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0)
{
return (8 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
elements_per_thread % (4 / sizeof(DataType)) == 0)
{
return (4 / sizeof(DataType));
}
else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
elements_per_thread % (2 / sizeof(DataType)) == 0)
{
return (2 / sizeof(DataType));
}
else
{
return 1;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::KPack;
return KPack;
}
// 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
...@@ -52,7 +216,7 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -52,7 +216,7 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kNPerBlock>{}, number<8>{}),
make_tuple(number<(kNPerBlock) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kNPerBlock)*8>{}, number<8>{}, number<1>{}),
number<8>{}, number<8>{},
number<1>{}); number<1>{});
...@@ -65,20 +229,24 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -65,20 +229,24 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
return b_lds_block_desc; return b_lds_block_desc;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{ {
constexpr index_t smem_size_a = integer_least_multiple(sizeof(typename Problem::ADataType) * constexpr index_t smem_size_a =
MakeALdsBlockDescriptor<Problem>().get_element_space_size(), 16); integer_least_multiple(sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_a; return smem_size_a;
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{ {
constexpr index_t smem_size_b = integer_least_multiple(sizeof(typename Problem::BDataType) * constexpr index_t smem_size_b =
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(), 16); integer_least_multiple(sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size(),
16);
return smem_size_b; return smem_size_b;
} }
...@@ -91,301 +259,111 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -91,301 +259,111 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
return smem_size_a + smem_size_b; return smem_size_a + smem_size_b;
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) // Tile: MPerBlock X KPerBlock
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t M0 = MPerBlock / M1; MPerBlock,
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; KPerBlock,
static_assert(total_pixels % M1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / M1; ATileAccessPattern>;
constexpr index_t KPack = GetSmemPackA<Problem>(); return TileEncodingPattern::Make2DStaticTileDistribution();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * M0))
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
} }
// Tile: KPerBlock X MPerBlock
else else
{ {
constexpr index_t K1 = 16 / sizeof(ADataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t K0 = KPerBlock / K1; KPerBlock,
constexpr index_t M2 = get_warp_size() / K0; MPerBlock,
// coalesce reading for each blocks VecLoadSize,
if constexpr(get_warp_size() % (M2 * K0) == 0) ATileAccessPattern>;
{ return TileEncodingPattern::Make2DStaticTileDistribution();
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M1, M2 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
} }
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t N0 = NPerBlock / N1; KPerBlock,
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; NPerBlock,
static_assert(total_pixels % N1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / N1; BTileAccessPattern>;
constexpr index_t KPack = GetSmemPackB<Problem>(); return TileEncodingPattern::Make2DStaticTileDistribution();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
} }
// Tile: NPerBlock X KPerBlock
else else
{ {
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); NPerBlock,
constexpr index_t K0 = KPerBlock / K1; KPerBlock,
constexpr index_t N2 = get_warp_size() / K0; VecLoadSize,
// coalesce reading for each blocks BTileAccessPattern>;
if constexpr(get_warp_size() % (N2 * K0) == 0) return TileEncodingPattern::Make2DStaticTileDistribution();
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
static_assert(N0 * N1 * N2 == NPerBlock,
"Incorrect N0, N1, N2 configuration! "
"N0, N1, N2 must cover whole NPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
} }
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{ {
using BLayout = remove_cvref_t<typename Problem::BLayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>); constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t N0 = kNPerBlock / N1; KPerBlock,
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; MPerBlock,
static_assert(total_pixels % N1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / N1; ATileAccessPattern>;
constexpr index_t kKPack = GetSmemPackB<Problem>(); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * N0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * N0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>); constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t M0 = kMPerBlock / M1; KPerBlock,
constexpr index_t total_pixels = kMPerBlock * kKPerBlock / kBlockSize; NPerBlock,
static_assert(total_pixels % M1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / M1; BTileAccessPattern>;
constexpr index_t kKPack = GetSmemPackA<Problem>(); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
static_assert(kKPack % K3 == 0);
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = kBlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = kBlockSize / get_warp_size() / K1;
static_assert(kKPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
...@@ -399,7 +377,7 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy ...@@ -399,7 +377,7 @@ struct GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
WarpTile::at(I2), WarpTile::at(I2),
TransposeC>; Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType, using BlockGemmPolicy = BlockGemmARegBRegCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
typename Problem::CDataType, typename Problem::CDataType,
......
...@@ -33,7 +33,7 @@ struct GemmPipelineProblemBase ...@@ -33,7 +33,7 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool isDoubleSmemBuffer = GemmTraits::isDoubleSmemBuffer; static constexpr bool isDoubleSmemBuffer = Traits::isDoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr auto Scheduler = GemmPipelineScheduler::Default;
...@@ -163,6 +163,8 @@ struct UniversalGemmPipelineProblem ...@@ -163,6 +163,8 @@ struct UniversalGemmPipelineProblem
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool isDoubleSmemBuffer = Traits::isDoubleSmemBuffer;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_; static constexpr auto TailNum = TailNum_;
......
...@@ -22,8 +22,6 @@ struct TileGemmTraits ...@@ -22,8 +22,6 @@ struct TileGemmTraits
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_; static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
// TODO this can't be hardcoded here! Should be in policy! // TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16; static constexpr int _VectorSize = 16;
...@@ -37,6 +35,7 @@ struct TileGemmTraits ...@@ -37,6 +35,7 @@ struct TileGemmTraits
template <bool kPadM_, template <bool kPadM_,
bool kPadN_, bool kPadN_,
bool kPadK_, bool kPadK_,
bool isDoubleSmemBuffer_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_, typename CLayout_,
...@@ -47,6 +46,8 @@ struct TileGemmUniversalTraits ...@@ -47,6 +46,8 @@ struct TileGemmUniversalTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
using CLayout = CLayout_; using CLayout = CLayout_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment