Commit 1b2bf88d authored by Adam Osewski's avatar Adam Osewski
Browse files

Refactoring and review comment.s

parent ba676917
# set(CMAKE_BUILD_TYPE Debug)
add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp) add_executable(tile_example_gemm_basic EXCLUDE_FROM_ALL gemm_basic.cpp)
add_executable(tile_example_gemm_basic_mem_pipeline EXCLUDE_FROM_ALL gemm_basic_mem_pipeline.cpp) add_executable(tile_example_gemm_basic_mem_pipeline EXCLUDE_FROM_ALL gemm_basic_mem_pipeline.cpp)
...@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -24,19 +24,19 @@ struct BlockGemmASmemBSmemCRegV1
static constexpr index_t kBlockSize = Problem::kBlockSize; static constexpr index_t kBlockSize = Problem::kBlockSize;
// C += A * B // C += A * B
template <typename CBlockTensor, typename ABlockWindowTmp, typename BBlockWindowTmp> template <typename CBlockTensor, typename ABlockWindow, typename BBlockWindow>
CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor, CK_TILE_DEVICE void operator()(CBlockTensor& c_block_tensor,
const ABlockWindowTmp& a_block_window_tmp, const ABlockWindow& a_block_window,
const BBlockWindowTmp& b_block_window_tmp) const const BBlockWindow& b_block_window) const
{ {
static_assert(std::is_same_v<ADataType, typename ABlockWindowTmp::DataType> && static_assert(std::is_same_v<ADataType, typename ABlockWindow::DataType> &&
std::is_same_v<BDataType, typename BBlockWindowTmp::DataType> && std::is_same_v<BDataType, typename BBlockWindow::DataType> &&
std::is_same_v<AccDataType, typename CBlockTensor::DataType>, std::is_same_v<AccDataType, typename CBlockTensor::DataType>,
"wrong!"); "wrong!");
constexpr index_t MPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t MPerBlock = ABlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t NPerBlock = BBlockWindowTmp{}.get_window_lengths()[number<0>{}]; constexpr index_t NPerBlock = BBlockWindow{}.get_window_lengths()[number<0>{}];
constexpr index_t KPerBlock = ABlockWindowTmp{}.get_window_lengths()[number<1>{}]; constexpr index_t KPerBlock = ABlockWindow{}.get_window_lengths()[number<1>{}];
static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN && static_assert(MPerBlock == BlockGemmShape::kM && NPerBlock == BlockGemmShape::kN &&
KPerBlock == BlockGemmShape::kK, KPerBlock == BlockGemmShape::kK,
...@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -62,9 +62,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct A-warp-window // construct A-warp-window
auto a_warp_window_tmp = make_tile_window( auto a_warp_window_tmp = make_tile_window(
a_block_window_tmp.get_bottom_tensor_view(), a_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kM>{}, number<WG::kK>{}), make_tuple(number<WG::kM>{}, number<WG::kK>{}),
a_block_window_tmp.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0}, a_block_window.get_window_origin() + multi_index<2>{iMWarp * WG::kM, 0},
make_static_tile_distribution(typename WG::AWarpDstrEncoding{})); make_static_tile_distribution(typename WG::AWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill #if 0 // FIXME: using array will cause register spill
...@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -97,9 +97,9 @@ struct BlockGemmASmemBSmemCRegV1
// construct B-warp-window // construct B-warp-window
auto b_warp_window_tmp = make_tile_window( auto b_warp_window_tmp = make_tile_window(
b_block_window_tmp.get_bottom_tensor_view(), b_block_window.get_bottom_tensor_view(),
make_tuple(number<WG::kN>{}, number<WG::kK>{}), make_tuple(number<WG::kN>{}, number<WG::kK>{}),
b_block_window_tmp.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0}, b_block_window.get_window_origin() + multi_index<2>{iNWarp * WG::kN, 0},
make_static_tile_distribution(typename WG::BWarpDstrEncoding{})); make_static_tile_distribution(typename WG::BWarpDstrEncoding{}));
#if 0 // FIXME: using array will cause register spill #if 0 // FIXME: using array will cause register spill
...@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1 ...@@ -200,12 +200,12 @@ struct BlockGemmASmemBSmemCRegV1
} }
// C = A * B // C = A * B
template <typename ABlockTensorTmp, typename BBlockWindowTmp> template <typename ABlockTensorTmp, typename BBlockWindow>
CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp, CK_TILE_DEVICE auto operator()(const ABlockTensorTmp& a_block_tensor_tmp,
const BBlockWindowTmp& b_block_window_tmp) const const BBlockWindow& b_block_window) const
{ {
auto c_block_tensor = MakeCBlockTile(); auto c_block_tensor = MakeCBlockTile();
operator()(c_block_tensor, a_block_tensor_tmp, b_block_window_tmp); operator()(c_block_tensor, a_block_tensor_tmp, b_block_window);
return c_block_tensor; return c_block_tensor;
} }
}; };
......
...@@ -12,7 +12,6 @@ namespace ck_tile { ...@@ -12,7 +12,6 @@ namespace ck_tile {
// A Tile Window: global memory // A Tile Window: global memory
// B Tile Window: global memory // B Tile Window: global memory
// C Distributed tensor: register // C Distributed tensor: register
template <typename Problem> template <typename Problem>
struct BaseGemmPipelineAgBgCrMem struct BaseGemmPipelineAgBgCrMem
{ {
...@@ -25,11 +24,13 @@ struct BaseGemmPipelineAgBgCrMem ...@@ -25,11 +24,13 @@ struct BaseGemmPipelineAgBgCrMem
static constexpr index_t NPerBlock = BlockGemmShape::kN; static constexpr index_t NPerBlock = BlockGemmShape::kN;
static constexpr index_t KPerBlock = BlockGemmShape::kK; static constexpr index_t KPerBlock = BlockGemmShape::kK;
static constexpr index_t MinMemInFlyBytes = 32768;
static constexpr index_t WgpPerCU = static constexpr index_t WgpPerCU =
(4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1; (4 * get_warp_size() / BlockSize) >= 1 ? 4 * get_warp_size() / BlockSize : 1;
// TODO: Is this 32K value gfx9 arch specific? // TODO: Is this 32K value gfx9 arch specific?
static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil( static constexpr index_t FullMemBandPrefetchStages = integer_divide_ceil(
32768 / WgpPerCU, MinMemInFlyBytes / WgpPerCU,
(MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock); (MPerBlock * sizeof(ADataType) + NPerBlock * sizeof(BDataType)) * KPerBlock);
static constexpr index_t PrefetchStages = static constexpr index_t PrefetchStages =
FullMemBandPrefetchStages >= 2 FullMemBandPrefetchStages >= 2
......
add_subdirectory(image_to_column) add_subdirectory(image_to_column)
add_subdirectory(gemm)
...@@ -17,11 +17,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -17,11 +17,11 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Col, Row, F16, F16, F32, F32> std::tuple< Row, Col, Row, F16, F16, F32, F16>
// TODO: fixme! // TODO: fixme!
// std::tuple< Col, Row, Row, F16, F16, F32, F32>, // std::tuple< Col, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F32>, // std::tuple< Row, Row, Row, F16, F16, F32, F16>,
// std::tuple< Col, Col, Row, F16, F16, F32, F32> // std::tuple< Col, Col, Row, F16, F16, F32, F16>
>; >;
// clang-format on // clang-format on
......
...@@ -14,10 +14,9 @@ template <typename Tuple> ...@@ -14,10 +14,9 @@ template <typename Tuple>
class TestCkTileGemmMemPipeline : public ::testing::Test class TestCkTileGemmMemPipeline : public ::testing::Test
{ {
protected: protected:
using ALayout = std::tuple_element_t<0, Tuple>; using ALayout = std::tuple_element_t<0, Tuple>;
using BLayout = std::tuple_element_t<1, Tuple>; using BLayout = std::tuple_element_t<1, Tuple>;
using CLayout = std::tuple_element_t<2, Tuple>; using CLayout = std::tuple_element_t<2, Tuple>;
;
using ADataType = std::tuple_element_t<3, Tuple>; using ADataType = std::tuple_element_t<3, Tuple>;
using BDataType = std::tuple_element_t<4, Tuple>; using BDataType = std::tuple_element_t<4, Tuple>;
using AccDataType = std::tuple_element_t<5, Tuple>; using AccDataType = std::tuple_element_t<5, Tuple>;
...@@ -90,6 +89,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test ...@@ -90,6 +89,7 @@ class TestCkTileGemmMemPipeline : public ::testing::Test
using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem< using GemmPipeline = ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType, BDataType,
AccDataType,
CDataType, CDataType,
GemmShape, GemmShape,
ALayout, ALayout,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment