Commit b2c7d774 authored by ThomasNing's avatar ThomasNing
Browse files

Add the changes from include/ck_tile

parent d1e71770
...@@ -66,7 +66,6 @@ else() ...@@ -66,7 +66,6 @@ else()
-Wunreachable-code -Wunreachable-code
-Wunused -Wunused
-Wno-reserved-identifier -Wno-reserved-identifier
-v --save-temps -Wno-gnu-line-marke
-Werror -Werror
-Wno-option-ignored -Wno-option-ignored
-Wsign-compare -Wsign-compare
......
...@@ -20,6 +20,7 @@ ...@@ -20,6 +20,7 @@
#include "ck_tile/host/reference/reference_batched_masking.hpp" #include "ck_tile/host/reference/reference_batched_masking.hpp"
#include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp" #include "ck_tile/host/reference/reference_batched_rotary_position_embedding.hpp"
#include "ck_tile/host/reference/reference_batched_softmax.hpp" #include "ck_tile/host/reference/reference_batched_softmax.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
#include "ck_tile/host/reference/reference_elementwise.hpp" #include "ck_tile/host/reference/reference_elementwise.hpp"
#include "ck_tile/host/reference/reference_fused_moe.hpp" #include "ck_tile/host/reference/reference_fused_moe.hpp"
#include "ck_tile/host/reference/reference_gemm.hpp" #include "ck_tile/host/reference/reference_gemm.hpp"
...@@ -34,4 +35,3 @@ ...@@ -34,4 +35,3 @@
#include "ck_tile/host/reference/reference_topk.hpp" #include "ck_tile/host/reference/reference_topk.hpp"
#include "ck_tile/host/stream_config.hpp" #include "ck_tile/host/stream_config.hpp"
#include "ck_tile/host/timer.hpp" #include "ck_tile/host/timer.hpp"
#include "ck_tile/host/reference/reference_batched_transpose.hpp"
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
......
...@@ -32,11 +32,11 @@ ...@@ -32,11 +32,11 @@
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_comp_v4.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_mem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_compute_v4_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
......
...@@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1 ...@@ -26,12 +26,14 @@ struct BlockGemmARegBRegCRegV1
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>(); static constexpr auto config = Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WG = remove_cvref_t<decltype(config.template at<0>())>; using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
static constexpr index_t MWarp = config.template at<1>(); static constexpr index_t MWarp = config.template at<1>();
static constexpr index_t NWarp = config.template at<2>(); static constexpr index_t NWarp = config.template at<2>();
static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WG::kM); static constexpr index_t MIterPerWarp = MPerBlock / (MWarp * WarpGemm::kM);
static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WG::kN); static constexpr index_t NIterPerWarp = NPerBlock / (NWarp * WarpGemm::kN);
static constexpr index_t KIterPerWarp = KPerBlock / WG::kK; static constexpr index_t KIterPerWarp = KPerBlock / WarpGemm::kK;
static constexpr index_t KPack = WarpGemm::kKPerThread;
CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode() CK_TILE_DEVICE static constexpr auto MakeABlockDistributionEncode()
{ {
...@@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -43,7 +45,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto a_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
a_block_outer_dstr_encoding, typename WG::AWarpDstrEncoding{}); a_block_outer_dstr_encoding, typename WarpGemm::AWarpDstrEncoding{});
return a_block_dstr_encode; return a_block_dstr_encode;
} }
...@@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -58,7 +60,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto b_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
b_block_outer_dstr_encoding, typename WG::BWarpDstrEncoding{}); b_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
return b_block_dstr_encode; return b_block_dstr_encode;
} }
...@@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -73,7 +75,7 @@ struct BlockGemmARegBRegCRegV1
sequence<1, 2>, sequence<1, 2>,
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
return c_block_dstr_encode; return c_block_dstr_encode;
} }
...@@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1 ...@@ -112,13 +114,13 @@ struct BlockGemmARegBRegCRegV1
.get_static_tile_distribution_encoding())>>, .get_static_tile_distribution_encoding())>>,
"C distribution is wrong!"); "C distribution is wrong!");
using AWarpDstr = typename WG::AWarpDstr; using AWarpDstr = typename WarpGemm::AWarpDstr;
using BWarpDstr = typename WG::BWarpDstr; using BWarpDstr = typename WarpGemm::BWarpDstr;
using CWarpDstr = typename WG::CWarpDstr; using CWarpDstr = typename WarpGemm::CWarpDstr;
using AWarpTensor = typename WG::AWarpTensor; using AWarpTensor = typename WarpGemm::AWarpTensor;
using BWarpTensor = typename WG::BWarpTensor; using BWarpTensor = typename WarpGemm::BWarpTensor;
using CWarpTensor = typename WG::CWarpTensor; using CWarpTensor = typename WarpGemm::CWarpTensor;
constexpr auto a_warp_y_lengths = constexpr auto a_warp_y_lengths =
to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths()); to_sequence(AWarpDstr{}.get_ys_to_d_descriptor().get_lengths());
...@@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -157,7 +159,7 @@ struct BlockGemmARegBRegCRegV1
merge_sequences(sequence<1, 1>{}, c_warp_y_lengths)); merge_sequences(sequence<1, 1>{}, c_warp_y_lengths));
// warp GEMM // warp GEMM
WG{}(c_warp_tensor, a_warp_tensor, b_warp_tensor); WarpGemm{}(c_warp_tensor, a_warp_tensor, b_warp_tensor);
// write C warp tensor into C block tensor // write C warp tensor into C block tensor
c_block_tensor.set_y_sliced_thread_data( c_block_tensor.set_y_sliced_thread_data(
...@@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1 ...@@ -180,7 +182,7 @@ struct BlockGemmARegBRegCRegV1
sequence<0, 0>>{}; sequence<0, 0>>{};
constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding( constexpr auto c_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
c_block_outer_dstr_encoding, typename WG::CWarpDstrEncoding{}); c_block_outer_dstr_encoding, typename WarpGemm::CWarpDstrEncoding{});
constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode); constexpr auto c_block_dstr = make_static_tile_distribution(c_block_dstr_encode);
auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr); auto c_block_tensor = make_static_distributed_tensor<CDataType>(c_block_dstr);
return c_block_tensor; return c_block_tensor;
......
...@@ -54,8 +54,8 @@ struct GemmPipelineAgBgCrImplBase ...@@ -54,8 +54,8 @@ struct GemmPipelineAgBgCrImplBase
auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc); auto a_lds_block = make_tensor_view<address_space_enum::lds>(p_a_lds, a_lds_block_desc);
// TODO: LDS alignment should come from Policy! // TODO: LDS alignment should come from Policy!
constexpr index_t a_lds_block_space_size_aligned = constexpr index_t a_lds_block_space_size_aligned = integer_least_multiple(
integer_least_multiple(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16); sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), 16);
// B tile in LDS // B tile in LDS
BDataType* __restrict__ p_b_lds = static_cast<BDataType*>( BDataType* __restrict__ p_b_lds = static_cast<BDataType*>(
......
...@@ -72,8 +72,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -72,8 +72,6 @@ struct GemmPipelineAgBgCrCompV3 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer; static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer;
static constexpr bool isDoubleSmemBuffer = Problem::isDoubleSmemBuffer;
static constexpr bool HasHotLoop = Problem::HasHotLoop; static constexpr bool HasHotLoop = Problem::HasHotLoop;
static constexpr auto TailNum = Problem::TailNum; static constexpr auto TailNum = Problem::TailNum;
static constexpr auto Scheduler = Problem::Scheduler; static constexpr auto Scheduler = Problem::Scheduler;
......
...@@ -9,8 +9,30 @@ ...@@ -9,8 +9,30 @@
namespace ck_tile { namespace ck_tile {
// A Tile Window: global memory
// B Tile Window: global memory
// C Distributed tensor: register
template <typename Problem>
struct BaseGemmPipelineAgBgCrCompV4
{
static constexpr index_t PrefetchStages = 3;
static constexpr index_t PrefillStages = 1;
static constexpr index_t GlobalBufferNum = 1;
CK_TILE_HOST static constexpr bool BlockHasHotloop(index_t num_loop)
{
return num_loop > PrefetchStages;
}
CK_TILE_HOST static constexpr TailNumber GetBlockLoopTailNum(index_t num_loop)
{
ignore = num_loop;
return TailNumber::Two;
}
};
template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy> template <typename Problem, typename Policy = GemmPipelineAGmemBGmemCregComputeV4DefaultPolicy>
struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV4<Problem>
{ {
using Base = BaseGemmPipelineAgBgCrCompV3<Problem>; using Base = BaseGemmPipelineAgBgCrCompV3<Problem>;
using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>; using PipelineImplBase = GemmPipelineAgBgCrImplBase<Problem, Policy>;
...@@ -35,9 +57,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -35,9 +57,9 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t VectorSizeA = Problem::VectorSizeA; static constexpr index_t VectorSizeA = Policy::template GetVectorSizeA<Problem>();
static constexpr index_t VectorSizeB = Problem::VectorSizeB; static constexpr index_t VectorSizeB = Policy::template GetVectorSizeB<Problem>();
static constexpr index_t VectorSizeC = Problem::VectorSizeC; static constexpr index_t VectorSizeC = Policy::template GetVectorSizeC<Problem>();
static constexpr bool kPadM = Problem::kPadM; static constexpr bool kPadM = Problem::kPadM;
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
...@@ -54,7 +76,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -54,7 +76,10 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return Policy::IsTransposeC(); } CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Policy::template IsTransposeC<Problem>();
}
template <GemmPipelineScheduler Scheduler> template <GemmPipelineScheduler Scheduler>
struct PipelineImpl : public PipelineImplBase struct PipelineImpl : public PipelineImplBase
...@@ -147,11 +172,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -147,11 +172,22 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
remove_cvref_t<typename BDramBlockWindowTmp::DataType>>, remove_cvref_t<typename BDramBlockWindowTmp::DataType>>,
"wrong!"); "wrong!");
static_assert(MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && constexpr bool is_a_col_major =
NPerBlock == std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>;
BDramBlockWindowTmp{}.get_window_lengths()[number<0>{}] && constexpr bool is_b_row_major = std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>;
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[number<1>{}],
"wrong!"); static_assert(is_a_col_major
? (KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (MPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == ADramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"A block window has incorrect lengths for defined ALayout!");
static_assert(is_b_row_major
? (KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}])
: (NPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I0{}] &&
KPerBlock == BDramBlockWindowTmp{}.get_window_lengths()[I1{}]),
"B block window has incorrect lengths for defined BLayout!");
////////////// global window & register ///////////////// ////////////// global window & register /////////////////
// A DRAM tile window for load // A DRAM tile window for load
...@@ -176,37 +212,33 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -176,37 +212,33 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
ABlockTile a_global_load_tile; ABlockTile a_global_load_tile;
BBlockTile b_global_load_tile; BBlockTile b_global_load_tile;
using ADramTileWindowStep = typename ADramBlockWindowTmp::BottomTensorIndex;
using BDramTileWindowStep = typename BDramBlockWindowTmp::BottomTensorIndex;
constexpr ADramTileWindowStep a_dram_tile_window_step =
is_a_col_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
constexpr BDramTileWindowStep b_dram_tile_window_step =
is_b_row_major ? make_array(KPerBlock, 0) : make_array(0, KPerBlock);
// global prefetch 0 // global prefetch 0
// global read 0 // global read 0
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
////////////// LDS desc, window & register ///////////////// ////////////// LDS desc, window & register /////////////////
auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0); auto&& [a_lds_block0, b_lds_block0] = Base::GetABLdsTensorViews(p_smem_0);
auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1); auto&& [a_lds_block1, b_lds_block1] = Base::GetABLdsTensorViews(p_smem_1);
auto a_copy_lds_window0 = auto a_copy_lds_window0 = make_tile_window(
make_tile_window(a_lds_block0, a_lds_block0, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ABlockTileDistr);
auto a_copy_lds_window1 = auto a_copy_lds_window1 = make_tile_window(
make_tile_window(a_lds_block1, a_lds_block1, make_tuple(number<MPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<MPerBlock>{}, number<KPerBlock>{}),
{0, 0},
ABlockTileDistr);
auto b_copy_lds_window0 = auto b_copy_lds_window0 = make_tile_window(
make_tile_window(b_lds_block0, b_lds_block0, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BBlockTileDistr);
auto b_copy_lds_window1 = auto b_copy_lds_window1 = make_tile_window(
make_tile_window(b_lds_block1, b_lds_block1, make_tuple(number<NPerBlock>{}, number<KPerBlock>{}), {0, 0});
make_tuple(number<NPerBlock>{}, number<KPerBlock>{}),
{0, 0},
BBlockTileDistr);
// Block GEMM // Block GEMM
auto block_gemm = BlockGemm(); auto block_gemm = BlockGemm();
...@@ -216,11 +248,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -216,11 +248,32 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile); tile_elementwise_inout([](auto& c) { c = 0; }, c_block_tile);
// LDS write 0 // LDS write 0
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
// global read 1 // global read 1
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
block_sync_lds(); block_sync_lds();
...@@ -262,11 +315,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -262,11 +315,31 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
if(HasHotLoop) if(HasHotLoop)
{ {
...@@ -280,11 +353,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -280,11 +353,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window0, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm // gemm
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
HotLoopScheduler(); HotLoopScheduler();
...@@ -296,11 +393,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -296,11 +393,35 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0); Base::LocalPrefetch(a_block_tile0, a_lds_ld_window0);
Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0); Base::LocalPrefetch(b_block_tile0, b_lds_ld_window0);
Base::LocalPrefill(a_copy_lds_window1, a_global_load_tile, a_element_func); if constexpr(is_a_col_major)
Base::LocalPrefill(b_copy_lds_window1, b_global_load_tile, b_element_func); {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window1, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(
a_copy_lds_window1, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window1, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(
b_copy_lds_window1, b_global_load_tile, b_element_func);
}
Base::GlobalPrefetch(a_global_load_tile, a_copy_dram_window); Base::GlobalPrefetch(
Base::GlobalPrefetch(b_global_load_tile, b_copy_dram_window); a_global_load_tile, a_copy_dram_window, a_dram_tile_window_step);
Base::GlobalPrefetch(
b_global_load_tile, b_copy_dram_window, b_dram_tile_window_step);
// gemm // gemm
block_gemm(c_block_tile, a_block_tile1, b_block_tile1); block_gemm(c_block_tile, a_block_tile1, b_block_tile1);
HotLoopScheduler(); HotLoopScheduler();
...@@ -318,8 +439,28 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem> ...@@ -318,8 +439,28 @@ struct GemmPipelineAgBgCrCompV4 : public BaseGemmPipelineAgBgCrCompV3<Problem>
block_sync_lds(); block_sync_lds();
Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1); Base::LocalPrefetch(a_block_tile1, a_lds_ld_window1);
Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1); Base::LocalPrefetch(b_block_tile1, b_lds_ld_window1);
if constexpr(is_a_col_major)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegTileDistribution<Problem>());
transpose_tile2d(a_shuffle_tmp, a_global_load_tile);
Base::LocalPrefill(a_copy_lds_window0, a_shuffle_tmp, a_element_func);
}
else
{
Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func); Base::LocalPrefill(a_copy_lds_window0, a_global_load_tile, a_element_func);
}
if constexpr(is_b_row_major)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegTileDistribution<Problem>());
transpose_tile2d(b_shuffle_tmp, b_global_load_tile);
Base::LocalPrefill(b_copy_lds_window0, b_shuffle_tmp, b_element_func);
}
else
{
Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func); Base::LocalPrefill(b_copy_lds_window0, b_global_load_tile, b_element_func);
}
block_gemm(c_block_tile, a_block_tile0, b_block_tile0); block_gemm(c_block_tile, a_block_tile0, b_block_tile0);
} }
// 2 // 2
......
...@@ -33,7 +33,7 @@ struct GemmPipelineProblemBase ...@@ -33,7 +33,7 @@ struct GemmPipelineProblemBase
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool isDoubleSmemBuffer = GemmTraits::isDoubleSmemBuffer; static constexpr bool isDoubleSmemBuffer = Traits::isDoubleSmemBuffer;
static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr auto Scheduler = GemmPipelineScheduler::Default;
...@@ -163,6 +163,8 @@ struct UniversalGemmPipelineProblem ...@@ -163,6 +163,8 @@ struct UniversalGemmPipelineProblem
static constexpr bool kPadN = Traits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr bool isDoubleSmemBuffer = Traits::isDoubleSmemBuffer;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_; static constexpr auto TailNum = TailNum_;
......
...@@ -22,8 +22,6 @@ struct TileGemmTraits ...@@ -22,8 +22,6 @@ struct TileGemmTraits
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_; static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
// TODO this can't be hardcoded here! Should be in policy! // TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16; static constexpr int _VectorSize = 16;
...@@ -37,6 +35,7 @@ struct TileGemmTraits ...@@ -37,6 +35,7 @@ struct TileGemmTraits
template <bool kPadM_, template <bool kPadM_,
bool kPadN_, bool kPadN_,
bool kPadK_, bool kPadK_,
bool isDoubleSmemBuffer_,
typename ALayout_, typename ALayout_,
typename BLayout_, typename BLayout_,
typename CLayout_, typename CLayout_,
...@@ -47,6 +46,8 @@ struct TileGemmUniversalTraits ...@@ -47,6 +46,8 @@ struct TileGemmUniversalTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
static constexpr bool isDoubleSmemBuffer = isDoubleSmemBuffer_;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
using CLayout = CLayout_; using CLayout = CLayout_;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment