"composable_kernel/include/utility/Sequence.hpp" did not exist on "7a89684f92cc39afbb13ad970c1a3282e60b9180"
Commit 41fc6a24 authored by Adam Osewski's avatar Adam Osewski
Browse files

Few small changes & formatting.

parent 7ffb0921
......@@ -26,7 +26,7 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{
......@@ -118,8 +118,10 @@ struct GemmKernel
auto a_pad_view = pad_tensor_view(
a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < false,
GemmPipeline::kPadA ? true : false > {});
// somehow clang-format is splitting below line into multiple.
// clang-format off
sequence<false, GemmPipeline::kPadA ? true : false>{});
// clang-format on
auto a_block_window = make_tile_window(
a_pad_view,
......@@ -129,8 +131,9 @@ struct GemmKernel
auto b_pad_view = pad_tensor_view(
b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < false,
GemmPipeline::kPadB ? true : false > {});
// clang-format off
sequence<false, GemmPipeline::kPadB ? true : false>{});
// clang-format on
auto b_block_window = make_tile_window(
b_pad_view,
......@@ -171,8 +174,9 @@ struct GemmKernel
auto c_pad_view = pad_tensor_view(
c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < false,
GemmPipeline::kPadC ? true : false > {});
// clang-format off
sequence<false, GemmPipeline::kPadC ? true : false>{});
// clang-format on
auto CBlockWindow = make_tile_window(
c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
......
......@@ -51,7 +51,7 @@ template <typename ADataType_,
bool kPadB_ = false,
bool kPadC_ = false,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = false,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem
{
......
......@@ -27,6 +27,7 @@ struct BaseGemmPipelineAgBgCrMem
static constexpr index_t WgpPerCU =
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
// TODO: Is this 32K value gfx9 arch specific?
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
......@@ -206,6 +207,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) *
16;
// B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
......@@ -251,20 +253,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// Block GEMM
constexpr auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
{
printf("Pipeline >>> bid: %d, tid: %d: HasHotLoop: %s, TailNumber: %d,"
" PrefetchStages: %d, num_loop: %d\n",
blockIdx.x,
threadIdx.x,
(HasHotLoop ? "True" : "False"),
static_cast<index_t>(TailNum),
PrefetchStages,
num_loop);
}
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
......@@ -277,6 +265,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch
// global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment