Unverified Commit 39dc25a9 authored by Adam Osewski's avatar Adam Osewski Committed by GitHub
Browse files

[CK-Tile] Enable vectorized reads on all layouts & improve perf. (#1835)



* Refactor universal gemm policy.

* Adapt example to refactor changes.

* Introduce static encoding pattern

* Adding shuffled encoding patterns.

* Fix err in reverse tuple.

* Add transpose_tile2d

* Small refactoring + doc

* Enable reading on contiguous dimension in all layouts.

* Transpose A/B register tile if needed for comp v3 pipeline.

* Take contiguous dim size when calculating dram vector load size.

* A/B smem pack size taken from WarpGemm attributes

* Update B LDS layout and setup tile distribution pattern at class level.

* Fix static assert.

* Fix errors in examples.

* Formatting & fix IsTranspose

* Fix VectorSize & refactor.

* Add error loging messages.

* Fix VecLoadSize and TranspseC for mem pipeline.

* Update unit-tests & disable mem pipeline.

* Clang format

* Update include/ck_tile/core/tensor/tile_window.hpp
Co-authored-by: default avatarjakpiase <jakub.piasecki@amd.com>

* Fix compilation and reviewers comments.

* Refactor unit-test. Fallback to non-universal gemm.

Need to use GemmPipelineAGmemBGmemCRegV1 for now,
since GemmKernel is now supporting also non-K major vector reads.

---------
Co-authored-by: default avatarjakpiase <jakub.piasecki@amd.com>
parent 64d5c4d6
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -39,17 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -39,17 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadN = Problem::kPadN; static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK; static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize() CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{ {
return Policy::template GetSmemSize<Problem>(); return Policy::template GetSmemSize<Problem>();
...@@ -150,7 +139,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -150,7 +139,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{ {
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>( auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>()); Policy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile); shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp); const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp); store_tile(a_copy_lds_window, a_block_tile_tmp);
...@@ -164,7 +153,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -164,7 +153,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>( auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>()); Policy::template MakeShuffledBRegBlockDistribution<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile); shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp); const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp); store_tile(b_copy_lds_window, b_block_tile_tmp);
...@@ -201,7 +190,7 @@ struct GemmPipelineAGmemBGmemCRegV1 ...@@ -201,7 +190,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{ {
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>( auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>()); Policy::template MakeShuffledBRegBlockDistribution<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile); shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window, store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop)); tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
......
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
...@@ -18,37 +18,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -18,37 +18,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static constexpr bool TransposeC = true; static constexpr bool TransposeC = true;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif 1
// 3d + padding // 3d + padding
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
...@@ -58,7 +27,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -58,7 +27,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM; constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK; constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}), make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}), make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
...@@ -127,87 +95,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -127,87 +95,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA() CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; return Problem::VectorLoadSize;
return Problem::VectorLoadSize / sizeof(ADataType);
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB() CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; return Problem::VectorLoadSize;
return Problem::VectorLoadSize / sizeof(BDataType);
} }
#elif 1
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc_m_k;
}
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
...@@ -273,7 +168,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -273,7 +168,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(M0 * M1 * M2 == MPerBlock, static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! " "Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!"); "M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution( return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>, tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>, tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
...@@ -394,7 +288,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -394,7 +288,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDistribution()
{ {
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
...@@ -442,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy ...@@ -442,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ADataType = remove_cvref_t<typename Problem::ADataType>;
......
...@@ -3,6 +3,7 @@ ...@@ -3,6 +3,7 @@
#pragma once #pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -11,10 +12,10 @@ template <typename ADataType_, ...@@ -11,10 +12,10 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename Traits_>
struct GemmPipelineProblemBase struct GemmPipelineProblemBase
{ {
using GemmTraits = remove_cvref_t<TileGemmTraits_>; using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>; using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>; using BDataType = remove_cvref_t<BDataType_>;
...@@ -22,19 +23,19 @@ struct GemmPipelineProblemBase ...@@ -22,19 +23,19 @@ struct GemmPipelineProblemBase
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>; using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>; using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>; using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>; using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize; static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = GemmTraits::kPadM; static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN; static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK; static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default; static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA() CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{ {
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>) if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
...@@ -128,27 +129,43 @@ template <typename ADataType_, ...@@ -128,27 +129,43 @@ template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_> typename Traits_>
using GemmPipelineProblem = using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>; GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
template <typename ADataType_, template <typename ADataType_,
typename BDataType_, typename BDataType_,
typename CDataType_, typename CDataType_,
typename BlockGemmShape_, typename BlockGemmShape_,
typename TileGemmTraits_, typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave, GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true, bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full> TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_, struct UniversalGemmPipelineProblem
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
{ {
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = Scheduler_; static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_; static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_; static constexpr auto TailNum = TailNum_;
static constexpr bool TransposeC = Traits::TransposeC;
}; };
} // namespace ck_tile } // namespace ck_tile
// SPDX-License-Identifier: MIT // SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved. // Copyright (c) 2024-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once #pragma once
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp"
namespace ck_tile { namespace ck_tile {
...@@ -15,30 +16,43 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -15,30 +16,43 @@ struct UniversalGemmPipelineAgBgCrPolicy
static constexpr auto I1 = number<1>{}; static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{}; static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true; static constexpr auto ATileAccessPattern = tile_distribution_pattern::thread_raked;
static constexpr auto BTileAccessPattern = tile_distribution_pattern::thread_raked;
template <typename Problem, typename DataType, index_t MNPerBlock>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorLoadSize() /**
* @brief Get the maximum global memory vector load size.
*
* @tparam Problem The UniversalGemmPipelineProblem object.
* @tparam DataType The tensor data type we're considering.
* @tparam MNPerBlock The MPerBlock or NPerBlock value depending on tensor (A/B).
* @tparam XPerTile The contiguous Tile dimension size.
* @return Maximum DRAM vector load size.
*/
template <typename Problem, typename DataType, index_t MNPerBlock, index_t XPerTile>
CK_TILE_HOST_DEVICE static constexpr auto GetGlobalVectorLoadSize()
{ {
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize; constexpr index_t elements_per_thread = MNPerBlock * KPerBlock / BlockSize;
if constexpr(elements_per_thread % (16 / sizeof(DataType)) == 0) // Assume DataType is even!
if constexpr(XPerTile % (16 / sizeof(DataType)) == 0 &&
elements_per_thread % (16 / sizeof(DataType)) == 0)
{ {
return (16 / sizeof(DataType)); return (16 / sizeof(DataType));
} }
else if constexpr(elements_per_thread % (8 / sizeof(DataType)) == 0) else if constexpr(XPerTile % (8 / sizeof(DataType)) == 0 &&
elements_per_thread % (8 / sizeof(DataType)) == 0)
{ {
return (8 / sizeof(DataType)); return (8 / sizeof(DataType));
} }
else if constexpr(elements_per_thread % (4 / sizeof(DataType)) == 0 && else if constexpr(sizeof(DataType) >= 4 && XPerTile % (4 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 4) elements_per_thread % (4 / sizeof(DataType)) == 0)
{ {
return (4 / sizeof(DataType)); return (4 / sizeof(DataType));
} }
else if constexpr(elements_per_thread % (2 / sizeof(DataType)) == 0 && else if constexpr(sizeof(DataType) >= 2 && XPerTile % (2 / sizeof(DataType)) == 0 &&
sizeof(DataType) >= 2) elements_per_thread % (2 / sizeof(DataType)) == 0)
{ {
return (2 / sizeof(DataType)); return (2 / sizeof(DataType));
} }
...@@ -48,6 +62,126 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -48,6 +62,126 @@ struct UniversalGemmPipelineAgBgCrPolicy
} }
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeA()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, KPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, ADataType, MPerBlock, MPerBlock>();
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeB()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, NPerBlock>();
}
else
{
return GetGlobalVectorLoadSize<Problem, BDataType, NPerBlock, KPerBlock>();
}
}
/**
* @brief Get the vector store size for C tensor.
*
* @tparam Problem - Gemm pipeline problem class.
*
* @note The vector store size for output C tensor would depend on multiple factors
* like its data layout and warp gemm C transposition. In general it would
* be the number of consecutive elements in contiguous C dimension hold by
* single thread.
*
* @return The vector store size for C tensor.
*/
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetVectorSizeC()
{
using BlockGemm = remove_cvref_t<decltype(GetBlockGemm<Problem>())>;
using WG = typename BlockGemm::WarpGemm;
constexpr bool TransposeC = Problem::TransposeC;
using CLayout = typename Problem::CLayout;
using CWarpDstr = typename WG::CWarpDstr;
// N is contiguous dimension
if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::RowMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has multiple consecutive elements in
// N dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
else
{
// In this case each thread has just a single item in Ndim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
}
// M is contiguous dimension
else if constexpr(std::is_same_v<CLayout, tensor_layout::gemm::ColumnMajor>)
{
if constexpr(TransposeC)
{
// In this case each thread has just a single item in Mdim
return WG::WarpGemmAttribute::Impl::kCNLane / WG::kN;
}
else
{
// In this case each thread has multiple consecutive elements in
// M dimension, however consecutive threads' elements have stride.
constexpr index_t NDimY = CWarpDstr::NDimY;
constexpr auto c_warp_y_lengths =
CWarpDstr{}.get_ys_to_d_descriptor().get_lengths();
static_assert(WG::WarpGemmAttribute::Impl::kCM1PerLane ==
c_warp_y_lengths.get(number<NDimY - 1>{}));
return c_warp_y_lengths.get(number<NDimY - 1>{});
}
}
else
{
static_assert(false, "Unsupported CLayout!");
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BlockGemm = decltype(GetBlockGemm<Problem>());
constexpr index_t KPack = BlockGemm::Traits::KPack;
return KPack;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{ {
...@@ -56,7 +190,7 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -56,7 +190,7 @@ struct UniversalGemmPipelineAgBgCrPolicy
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>(); constexpr index_t KPack = GetSmemPackA<Problem>();
constexpr auto DataTypeSize = sizeof(ADataType); constexpr auto DataTypeSize = sizeof(ADataType);
constexpr auto MLdsLayer = constexpr auto MLdsLayer =
...@@ -99,54 +233,193 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -99,54 +233,193 @@ struct UniversalGemmPipelineAgBgCrPolicy
return a_lds_block_desc; return a_lds_block_desc;
} }
/**
* @brief Create LDS block descriptor for B tensor.
*
* @tparam Problem Gemm pipeline problem.
* @return B tensor LDS block descriptor.
*/
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{ {
// using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>();
constexpr auto DataTypeSize = sizeof(BDataType); #if 1
constexpr auto NLdsLayer = // if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
(32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize); {
constexpr index_t KPack = GetSmemPackB<Problem>();
constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor( constexpr auto BK0 = number<KPerBlock / KPack>{};
make_tuple(number<KPerBlock / KPack * NLdsLayer>{}, constexpr auto DataTypeSize = sizeof(BDataType);
number<NPerBlock / NLdsLayer>{}, constexpr auto NLdsLayer =
number<KPack>{}), (32 * 4 / KPerBlock / DataTypeSize) < 1 ? 1 : (32 * 4 / KPerBlock / DataTypeSize);
make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
number<KPack>{}, constexpr auto b_lds_block_desc_0 = make_naive_tensor_descriptor(
number<1>{}); make_tuple(
BK0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, number<KPack>{}),
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor( make_tuple(number<KPack>{}, number<KPerBlock * NLdsLayer>{}, number<1>{}),
b_lds_block_desc_0, number<KPack>{},
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{}, number<1>{});
number<KPerBlock / KPack * NLdsLayer>{})),
make_pass_through_transform(number<KPack>{})), constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
make_tuple(sequence<1, 0>{}, sequence<2>{}), b_lds_block_desc_0,
make_tuple(sequence<1, 0>{}, sequence<2>{})); make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
BK0 * number<NLdsLayer>{})),
constexpr auto b_lds_block_desc_xk0_mnldslayer_mn_xk1 = transform_tensor_descriptor( make_pass_through_transform(number<KPack>{})),
b_lds_block_desc_permuted, make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(make_unmerge_transform( make_tuple(sequence<1, 0>{}, sequence<2>{}));
make_tuple(number<KPerBlock / KPack>{}, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}), constexpr auto b_lds_block_desc_bk0_nldslayer_n_bk1 = transform_tensor_descriptor(
make_pass_through_transform(number<KPack>{})), b_lds_block_desc_permuted,
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}), make_tuple(make_unmerge_transform(make_tuple(BK0, number<NLdsLayer>{})),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{})); make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(number<KPack>{})),
constexpr auto b_lds_block_desc = transform_tensor_descriptor( make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
b_lds_block_desc_xk0_mnldslayer_mn_xk1, make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})), constexpr auto b_lds_block_desc = transform_tensor_descriptor(
make_merge_transform_v3_division_mod( b_lds_block_desc_bk0_nldslayer_n_bk1,
make_tuple(number<KPerBlock / KPack>{}, number<KPack>{}))), make_tuple(make_merge_transform_v3_division_mod(
make_tuple(sequence<1, 2>{}, sequence<0, 3>{}), make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{})),
make_tuple(sequence<0>{}, sequence<1>{})); make_merge_transform_v3_division_mod(make_tuple(BK0, number<KPack>{}))),
return b_lds_block_desc; make_tuple(sequence<1, 2>{}, sequence<0, 3>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc;
}
#else
else // B is Row Major
{
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
KPerBlock,
NPerBlock,
VecLoadSize,
BTileAccessPattern>;
constexpr auto BK0 = number<TileEncodingPattern::X1>{};
constexpr auto BK1 = number<TileEncodingPattern::Y0>{};
// constexpr auto N0 = BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I1);
constexpr auto N0 = TileEncodingPattern::X0;
constexpr auto N1 = NPerBlock / N0;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
constexpr auto NPerXdl = number<WarpTile::at(I1)>{};
// constexpr auto KThreadWrite =
// BBlockTransferThreadClusterLengths_BK0_N_BK1{}.At(I0);
constexpr auto KThreadWrite = TileEncodingPattern::Y2;
constexpr auto K0PerThreadWrite = BK0 / KThreadWrite;
constexpr auto KThreadRead = 64 / NPerXdl;
constexpr auto K0PerThreadRead = BK0 / KThreadRead;
constexpr auto kfold =
(BK1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (BK1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=n0
constexpr auto npair = (BK1 * NPerXdl * sizeof(BDataType) > 128)
? 1
: ((128 / (BK1 * NPerXdl * sizeof(BDataType))) > N0
? N0
: 128 / (BK1 * NPerXdl * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
BK1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(BK1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
// constexpr auto b_lds_block_desc_bk0_n_bk1 = transform_tensor_descriptor(
// b_lds_block_desc_unmerged,
// make_tuple(make_merge_transform_v3_division_mod(
// make_tuple(number<KThreadReadPerm>{},
// number<KThreadWrite / kfold / KThreadReadPerm>{},
// number<kfold>{},
// number<K0PerThreadWrite>{})),
// make_merge_transform_v3_division_mod(
// make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{})),
// make_pass_through_transform(BK1)),
// make_tuple(sequence<0, 1, 4, 2>{}, sequence<5, 6, 3>{}, sequence<7>{}),
// make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_kn = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
BK1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
// return b_lds_block_desc_bk0_n_bk1;
return b_lds_block_desc_kn;
// constexpr auto b_lds_block_desc_bk0_n_bk1 = make_naive_tensor_descriptor(
// make_tuple(BK0, number<NPerBlock>{}, number<KPack>{}),
// make_tuple(number<KPack>{}, number<KPerBlock>{}, number<1>{}),
// number<KPack>{},
// number<1>{});
// constexpr auto b_lds_block_desc = transform_tensor_descriptor(
// b_lds_block_desc_bk0_n_bk1,
// make_tuple(make_pass_through_transform(number<NPerBlock>{}),
// make_merge_transform_v3_division_mod(make_tuple(BK0,
// number<KPack>{}))),
// make_tuple(sequence<1>{}, sequence<0, 2>{}),
// make_tuple(sequence<0>{}, sequence<1>{}));
// return b_lds_block_desc;
}
#endif
} }
template <typename Problem> template <typename Problem>
...@@ -179,291 +452,127 @@ struct UniversalGemmPipelineAgBgCrPolicy ...@@ -179,291 +452,127 @@ struct UniversalGemmPipelineAgBgCrPolicy
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{ {
using ADataType = remove_cvref_t<typename Problem::ADataType>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ALayout = remove_cvref_t<typename Problem::ALayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM; // Tile: MPerBlock X KPerBlock
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::RowMajor>)
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
{ {
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t M0 = MPerBlock / M1; MPerBlock,
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; KPerBlock,
static_assert(total_pixels % M1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / M1; ATileAccessPattern>;
constexpr index_t KPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>(); return TileEncodingPattern::Make2DStaticTileDistribution();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * M0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * M0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
} }
// Tile: KPerBlock X MPerBlock
else else
{ {
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(ADataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t K0 = KPerBlock / K1; KPerBlock,
constexpr index_t M2 = get_warp_size() / K0; MPerBlock,
if constexpr(get_warp_size() % (M2 * K0) == 0) VecLoadSize,
{ ATileAccessPattern>;
constexpr index_t M1 = BlockSize / get_warp_size(); return TileEncodingPattern::Make2DStaticTileDistribution();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
else
{
constexpr index_t M0 = BlockSize / get_warp_size();
constexpr index_t M1 = MPerBlock / (M2 * M0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
} }
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution() CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{ {
using BDataType = remove_cvref_t<typename Problem::BDataType>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BLayout = remove_cvref_t<typename Problem::BLayout>;
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
// Tile: KPerBlock X NPerBlock
if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>) if constexpr(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>)
{ {
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType); using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t N0 = NPerBlock / N1; KPerBlock,
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; NPerBlock,
static_assert(total_pixels % N1 == 0); VecLoadSize,
constexpr index_t K3 = total_pixels / N1; BTileAccessPattern>;
constexpr index_t KPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>(); return TileEncodingPattern::Make2DStaticTileDistribution();
static_assert(KPack % K3 == 0);
constexpr index_t K2 = KPack / K3;
if constexpr(get_warp_size() % (K2 * N0) == 0)
{
constexpr index_t K1 = get_warp_size() / (K2 * N0);
constexpr index_t K0 = BlockSize / get_warp_size();
static_assert(KPerBlock == K0 * K1 * K2 * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<2, 1>,
sequence<3, 1>>{});
}
} }
// Tile: NPerBlock X KPerBlock
else else
{ {
using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t K1 = Problem::VectorLoadSize / sizeof(BDataType); NPerBlock,
constexpr index_t K0 = KPerBlock / K1; KPerBlock,
constexpr index_t N2 = get_warp_size() / K0; VecLoadSize,
// coalesce reading for each blocks BTileAccessPattern>;
if constexpr(get_warp_size() % (N2 * K0) == 0) return TileEncodingPattern::Make2DStaticTileDistribution();
{
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "N2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "N1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
// coalesce reading for each warps
else
{
constexpr index_t N0 = BlockSize / get_warp_size();
constexpr index_t N1 = NPerBlock / (N2 * N0);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<0>, sequence<2, 0>>,
sequence<1, 2>,
sequence<1, 1>>{});
}
} }
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegTileDistribution()
{ {
using ALayout = remove_cvref_t<typename Problem::ALayout>; using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>); static_assert(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kN; constexpr index_t MPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeA<Problem>();
constexpr index_t M1 = Problem::VectorLoadSize / sizeof(ADataType);
constexpr index_t M0 = MPerBlock / M1; using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t total_pixels = MPerBlock * KPerBlock / BlockSize; KPerBlock,
static_assert(total_pixels % M1 == 0); MPerBlock,
constexpr index_t K3 = total_pixels / M1; VecLoadSize,
constexpr index_t kKPack = GetVectorLoadSize<Problem, ADataType, MPerBlock>(); ATileAccessPattern>;
static_assert(kKPack % K3 == 0); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * M0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * M0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * M0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegTileDistribution()
{ {
using BLayout = remove_cvref_t<typename Problem::BLayout>; using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>); static_assert(std::is_same_v<BLayout, ck_tile::tensor_layout::gemm::RowMajor>);
constexpr index_t BlockSize = Problem::kBlockSize; constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN; constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK; constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t VecLoadSize = GetVectorSizeB<Problem>();
constexpr index_t N1 = Problem::VectorLoadSize / sizeof(BDataType);
constexpr index_t N0 = NPerBlock / N1; using TileEncodingPattern = TileDistributionEncodingPattern2D<BlockSize,
constexpr index_t total_pixels = NPerBlock * KPerBlock / BlockSize; KPerBlock,
static_assert(total_pixels % N1 == 0); NPerBlock,
constexpr index_t K3 = total_pixels / N1; VecLoadSize,
constexpr index_t kKPack = GetVectorLoadSize<Problem, BDataType, NPerBlock>(); BTileAccessPattern>;
static_assert(kKPack % K3 == 0); return TileEncodingPattern::MakeShuffled2DStaticTileDistribution();
constexpr index_t K2 = kKPack / K3; // TODO: this dimention could be outside single wave
constexpr index_t warp_size = get_warp_size();
if constexpr(warp_size % (K2 * N0) == 0)
{
constexpr index_t K1 = warp_size / (K2 * N0);
constexpr index_t K0 = BlockSize / warp_size;
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2, K3>>,
tuple<sequence<2>, sequence<2, 1, 2>>,
tuple<sequence<0>, sequence<1, 0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
else
{
constexpr index_t K1 = (K2 * N0) / get_warp_size();
constexpr index_t K2_m = K2 / K1;
constexpr index_t K0 = BlockSize / get_warp_size() / K1;
static_assert(KPerBlock == K0 * K1 * K2_m * K3);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1>, sequence<K0, K1, K2_m, K3>>,
tuple<sequence<2, 2>, sequence<1, 2>>,
tuple<sequence<0, 1>, sequence<0, 2>>,
sequence<1, 2>,
sequence<1, 3>>{});
}
} }
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC() { return TransposeC; } template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto IsTransposeC()
{
return Problem::TransposeC;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{ {
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps; using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile; using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType, using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
AccDataType, typename Problem::CDataType,
WarpTile::at(I0), WarpTile::at(I0),
WarpTile::at(I1), WarpTile::at(I1),
WarpTile::at(I2), WarpTile::at(I2),
TransposeC>; Problem::TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType, using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType, typename Problem::BDataType,
typename Problem::CDataType, typename Problem::CDataType,
BlockWarps, BlockWarps,
WarpGemm>; WarpGemm>;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{}; return BlockUniversalGemmAsBsCr<Problem, BlockGemmPolicy>{};
} }
}; };
......
...@@ -19,11 +19,34 @@ struct TileGemmTraits ...@@ -19,11 +19,34 @@ struct TileGemmTraits
static constexpr bool kPadN = kPadN_; static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_; static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16; static constexpr int _VectorSize = 16;
using ALayout = ALayout_; using ALayout = ALayout_;
using BLayout = BLayout_; using BLayout = BLayout_;
using CLayout = CLayout_; using CLayout = CLayout_;
static constexpr bool TransposeC = false;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool TransposeC_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
}; };
} // namespace ck_tile } // namespace ck_tile
...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>, // std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>, //std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//, std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16> //std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
...@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor; ...@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor; using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>; ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler, // using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>; // ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>; // using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>; using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>, // std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>, // std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>, std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem> // std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>; >;
// clang-format on // clang-format on
......
...@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM) ...@@ -10,22 +10,43 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr int K = 320; constexpr int K = 320;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); {
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
else
this->Run(M, N, K);
}
} }
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM) TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
{ {
std::vector<int> Ms{127, 255, 312, 799, 1573}; std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024; constexpr int N = 1024;
constexpr int K = 320; constexpr int K = 320;
constexpr int VecLoadSize = 8;
for(int M : Ms) for(int M : Ms)
this->Run(M, N, K); {
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// TODO: Can we anyhow deduce used vector load size?
if(M % VecLoadSize == 0)
this->Run(M, N, K);
else
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
} }
TYPED_TEST(TestCkTileGemmPipeline, PaddK) TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{ {
std::vector<int> Ms{127}; std::vector<int> Ms{128};
constexpr int N = 1024; constexpr int N = 1024;
constexpr int K = 432; constexpr int K = 432;
......
...@@ -16,6 +16,7 @@ enum struct GemmPipelineType ...@@ -16,6 +16,7 @@ enum struct GemmPipelineType
Mem, Mem,
Comp Comp
}; };
template <typename Tuple> template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test class TestCkTileGemmPipeline : public ::testing::Test
{ {
...@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN; constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK; constexpr bool kPadK = PadK;
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1; constexpr int kBlockPerCu = 1;
// =============================================== // ===============================================
...@@ -65,14 +69,16 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -65,14 +69,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>; ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>; using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline = std::conditional_t< using BaseGemmPipeline =
PipelineType == GemmPipelineType::Mem, std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem< ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>, ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile; const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile; const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
...@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value; constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value; constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline = using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
std::conditional_t<PipelineType == GemmPipelineType::Mem, BDataType,
ck_tile::GemmPipelineAgBgCrMem< AccDataType,
ck_tile::UniversalGemmPipelineProblem<ADataType, GemmShape,
BDataType, GemmUniversalTraits,
AccDataType, Scheduler,
GemmShape, has_hot_loop_v,
Traits, tail_number_v>;
Scheduler,
has_hot_loop_v, using GemmPipeline = std::conditional_t<
tail_number_v>>, PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrCompV3< ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineProblem<ADataType, ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
BDataType, ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
AccDataType, ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>; using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args); auto kargs = Kernel::MakeKernelArgs(args);
...@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test ...@@ -129,70 +131,94 @@ class TestCkTileGemmPipeline : public ::testing::Test
if(has_hot_loop) if(has_hot_loop)
{ {
// Tail pipeline One to Seven if constexpr(PipelineType == GemmPipelineType::Comp)
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Full)
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Two)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{}); ck_tile::TailNumber::Full>{});
} }
} else
if constexpr(BaseGemmPipeline::PrefetchStages > 3)
{
if(tail_num == ck_tile::TailNumber::Three)
{ {
Run(ck_tile::bool_constant<true>{}, std::ostringstream err;
ck_tile::integral_constant<ck_tile::TailNumber, err << "For compute pipeline tail number should always be Full, but have \""
ck_tile::TailNumber::Three>{}); << tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
} }
} }
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
if constexpr(PipelineType == GemmPipelineType::Mem)
{ {
if(tail_num == ck_tile::TailNumber::Four) // Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{}); ck_tile::TailNumber::One>{});
} }
} else if(tail_num == ck_tile::TailNumber::Full)
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{ {
Run(ck_tile::bool_constant<true>{}, Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{}); ck_tile::TailNumber::Full>{});
} }
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6) if constexpr(BaseGemmPipeline::PrefetchStages > 2)
{
if(tail_num == ck_tile::TailNumber::Six)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Two)
ck_tile::integral_constant<ck_tile::TailNumber, {
ck_tile::TailNumber::Six>{}); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Two>{});
}
} }
} if constexpr(BaseGemmPipeline::PrefetchStages > 3)
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{ {
Run(ck_tile::bool_constant<true>{}, if(tail_num == ck_tile::TailNumber::Three)
ck_tile::integral_constant<ck_tile::TailNumber, {
ck_tile::TailNumber::Seven>{}); Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Three>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 4)
{
if(tail_num == ck_tile::TailNumber::Four)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Four>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 5)
{
if(tail_num == ck_tile::TailNumber::Five)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Five>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 6)
{
if(tail_num == ck_tile::TailNumber::Six)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Six>{});
}
}
if constexpr(BaseGemmPipeline::PrefetchStages > 7)
{
if(tail_num == ck_tile::TailNumber::Seven)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Seven>{});
}
} }
} }
} }
......
...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor; ...@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off // clang-format off
using KernelTypes = ::testing::Types< using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType // ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>, // std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>, //std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//, std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16> //std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
...@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test ...@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
CodegenGemmShape, CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>; CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline = using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>, ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
CodegenGemmPolicy>;
template <typename ALayout, typename BLayout, typename CLayout> template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner, using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment