Commit 41fc6a24 authored by Adam Osewski's avatar Adam Osewski
Browse files

Few small changes & formatting.

parent 7ffb0921
...@@ -26,7 +26,7 @@ struct GemmKernel ...@@ -26,7 +26,7 @@ struct GemmKernel
using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>; using ADataType = remove_cvref_t<typename GemmPipeline::ADataType>;
using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>; using BDataType = remove_cvref_t<typename GemmPipeline::BDataType>;
// using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>; // using CAccDataType = remove_cvref_t<typename GemmPipeline::CDataType>;
using CDataType = remove_cvref_t<typename EpiloguePipeline::CDataType>; using CDataType = remove_cvref_t<typename EpiloguePipeline::ODataType>;
__host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch) __host__ static constexpr auto GridSize(index_t M, index_t N, index_t KBatch)
{ {
...@@ -118,8 +118,10 @@ struct GemmKernel ...@@ -118,8 +118,10 @@ struct GemmKernel
auto a_pad_view = pad_tensor_view( auto a_pad_view = pad_tensor_view(
a_tensor_view, a_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kK>{}),
sequence < false, // somehow clang-format is splitting below line into multiple.
GemmPipeline::kPadA ? true : false > {}); // clang-format off
sequence<false, GemmPipeline::kPadA ? true : false>{});
// clang-format on
auto a_block_window = make_tile_window( auto a_block_window = make_tile_window(
a_pad_view, a_pad_view,
...@@ -129,8 +131,9 @@ struct GemmKernel ...@@ -129,8 +131,9 @@ struct GemmKernel
auto b_pad_view = pad_tensor_view( auto b_pad_view = pad_tensor_view(
b_tensor_view, b_tensor_view,
make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}), make_tuple(number<TilePartitioner::kN>{}, number<TilePartitioner::kK>{}),
sequence < false, // clang-format off
GemmPipeline::kPadB ? true : false > {}); sequence<false, GemmPipeline::kPadB ? true : false>{});
// clang-format on
auto b_block_window = make_tile_window( auto b_block_window = make_tile_window(
b_pad_view, b_pad_view,
...@@ -171,8 +174,9 @@ struct GemmKernel ...@@ -171,8 +174,9 @@ struct GemmKernel
auto c_pad_view = pad_tensor_view( auto c_pad_view = pad_tensor_view(
c_tensor_view, c_tensor_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
sequence < false, // clang-format off
GemmPipeline::kPadC ? true : false > {}); sequence<false, GemmPipeline::kPadC ? true : false>{});
// clang-format on
auto CBlockWindow = make_tile_window( auto CBlockWindow = make_tile_window(
c_pad_view, c_pad_view,
make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}), make_tuple(number<TilePartitioner::kM>{}, number<TilePartitioner::kN>{}),
......
...@@ -51,7 +51,7 @@ template <typename ADataType_, ...@@ -51,7 +51,7 @@ template <typename ADataType_,
bool kPadB_ = false, bool kPadB_ = false,
bool kPadC_ = false, bool kPadC_ = false,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = false, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem struct UniversalGemmPipelineProblem
{ {
......
...@@ -27,6 +27,7 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -27,6 +27,7 @@ struct BaseGemmPipelineAgBgCrMem
static constexpr index_t WgpPerCU = static constexpr index_t WgpPerCU =
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
// TODO: Is this 32K value gfx9 arch specific?
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil( static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
32768 / WgpPerCU, 32768 / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
...@@ -206,6 +207,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -206,6 +207,7 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(), integer_divide_ceil(sizeof(ADataType) * a_lds_block_desc.get_element_space_size(),
16) * 16) *
16; 16;
// B tile in LDS // B tile in LDS
BDataType* p_b_lds = static_cast<BDataType*>( BDataType* p_b_lds = static_cast<BDataType*>(
static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned)); static_cast<void*>(static_cast<char*>(p_smem) + a_lds_block_space_size_aligned));
...@@ -251,20 +253,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -251,20 +253,6 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
// Block GEMM // Block GEMM
constexpr auto block_gemm = BlockGemm(); constexpr auto block_gemm = BlockGemm();
auto c_block_tile = block_gemm.MakeCBlockTile(); auto c_block_tile = block_gemm.MakeCBlockTile();
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
if(blockIdx.x == 0 && blockIdx.y == 0 && threadIdx.x == 0)
{
printf("Pipeline >>> bid: %d, tid: %d: HasHotLoop: %s, TailNumber: %d,"
" PrefetchStages: %d, num_loop: %d\n",
blockIdx.x,
threadIdx.x,
(HasHotLoop ? "True" : "False"),
static_cast<index_t>(TailNum),
PrefetchStages,
num_loop);
}
using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution()); using ABlockTileDistr = decltype(a_copy_dram_window.get_tile_distribution());
using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution()); using BBlockTileDistr = decltype(b_copy_dram_window.get_tile_distribution());
...@@ -277,6 +265,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem> ...@@ -277,6 +265,9 @@ struct GemmPipelineAgBgCrMem : public BaseGemmPipelineAgBgCrMem<Problem>
tuple_array<ABlockTile, PrefetchStages> a_block_tiles; tuple_array<ABlockTile, PrefetchStages> a_block_tiles;
tuple_array<BBlockTile, PrefetchStages> b_block_tiles; tuple_array<BBlockTile, PrefetchStages> b_block_tiles;
// -----------------------------------------------------------------------------------------
// Gemm pipeline start
// prefetch // prefetch
// global read 0 // global read 0
GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window); GlobalPrefetch(a_block_tiles.get(I0{}), a_copy_dram_window);
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment