Unverified Commit 781005a5 authored by arai713's avatar arai713 Committed by GitHub
Browse files

Merge branch 'develop' into codegen_hiprtc

parents a11cf2c6 39dc25a9
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -39,17 +39,6 @@ struct GemmPipelineAGmemBGmemCRegV1
static constexpr bool kPadN = Problem::kPadN;
static constexpr bool kPadK = Problem::kPadK;
CK_TILE_HOST_DEVICE static constexpr index_t GetStaticLdsSize()
{
return integer_divide_ceil(
sizeof(ADataType) *
Policy::template MakeALdsBlockDescriptor<Problem>().get_element_space_size(),
16) *
16 +
sizeof(BDataType) *
Policy::template MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
}
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
return Policy::template GetSmemSize<Problem>();
......@@ -150,7 +139,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<ALayout, tensor_layout::gemm::ColumnMajor>)
{
auto a_shuffle_tmp = make_static_distributed_tensor<ADataType>(
Policy::template MakeShuffledARegBlockDescriptor<Problem>());
Policy::template MakeShuffledARegBlockDistribution<Problem>());
shuffle_tile(a_shuffle_tmp, a_block_tile);
const auto a_block_tile_tmp = tile_elementwise_in(a_element_func, a_shuffle_tmp);
store_tile(a_copy_lds_window, a_block_tile_tmp);
......@@ -164,7 +153,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
Policy::template MakeShuffledBRegBlockDistribution<Problem>());
shuffle_tile(b_shuffle_tmp, b_block_tile);
const auto b_block_tile_tmp = tile_elementwise_in(b_element_func, b_shuffle_tmp);
store_tile(b_copy_lds_window, b_block_tile_tmp);
......@@ -201,7 +190,7 @@ struct GemmPipelineAGmemBGmemCRegV1
if constexpr(std::is_same_v<BLayout, tensor_layout::gemm::RowMajor>)
{
auto b_shuffle_tmp_loop = make_static_distributed_tensor<BDataType>(
Policy::template MakeShuffledBRegBlockDescriptor<Problem>());
Policy::template MakeShuffledBRegBlockDistribution<Problem>());
shuffle_tile(b_shuffle_tmp_loop, b_block_tile);
store_tile(b_copy_lds_window,
tile_elementwise_in(b_element_func, b_shuffle_tmp_loop));
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2018-2024, Advanced Micro Devices, Inc. All rights reserved.
// Copyright (c) 2018-2025, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
......@@ -18,37 +18,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static constexpr bool TransposeC = true;
#if 0
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kMPerBlock, kKPerBlock), number<32>{});
return a_lds_block_desc;
}
// 2d
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc =
make_naive_tensor_descriptor_packed(make_tuple(kNPerBlock, kKPerBlock), number<32>{});
return b_lds_block_desc;
}
#elif 1
// 3d + padding
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
......@@ -58,7 +27,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
// TODO: this 8 is AK1! should be a policy parameter!
constexpr auto a_lds_block_desc_0 = make_naive_tensor_descriptor(
make_tuple(number<kKPerBlock / 8>{}, number<kMPerBlock>{}, number<8>{}),
make_tuple(number<(kMPerBlock + 1) * 8>{}, number<8>{}, number<1>{}),
......@@ -127,87 +95,14 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackA()
{
using ADataType = remove_cvref_t<typename Problem::ADataType>;
return Problem::VectorLoadSize / sizeof(ADataType);
return Problem::VectorLoadSize;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetSmemPackB()
{
using BDataType = remove_cvref_t<typename Problem::BDataType>;
return Problem::VectorLoadSize / sizeof(BDataType);
return Problem::VectorLoadSize;
}
#elif 1
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using namespace ck_tile;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t kMPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto a_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kMPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(ADataType);
constexpr auto a_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
a_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kMPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kMPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return a_lds_block_desc_m_k;
}
// fake XOR
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using namespace ck_tile;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t kNPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t kKPerBlock = Problem::BlockGemmShape::kK;
constexpr auto b_lds_block_desc_d1_d2_d3 = make_naive_tensor_descriptor_packed(
make_tuple(number<kNPerBlock / 2>{}, number<2>{}, number<kKPerBlock>{}),
number<kKPerBlock>{});
constexpr index_t kK1 = 16 / sizeof(BDataType);
constexpr auto b_lds_block_desc_d4_d5_d6 = transform_tensor_descriptor(
b_lds_block_desc_d1_d2_d3,
make_tuple(
make_xor_transform(make_tuple(number<kNPerBlock / 2>{}, number<kKPerBlock>{}), kK1),
make_pass_through_transform(2)),
make_tuple(sequence<0, 2>{}, sequence<1>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_d4_d5_d6,
make_tuple(make_merge_transform(make_tuple(number<kNPerBlock / 2>{}, number<2>{})),
make_pass_through_transform(kKPerBlock)),
make_tuple(sequence<0, 1>{}, sequence<2>{}),
make_tuple(sequence<0>{}, sequence<1>{}));
return b_lds_block_desc_n_k;
}
#endif
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
......@@ -273,7 +168,6 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
static_assert(M0 * M1 * M2 == MPerBlock,
"Incorrect M0, M2, M1 configuration! "
"M0, M1, M2 must cover whole MPerBlock!");
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
......@@ -394,7 +288,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDescriptor()
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledBRegBlockDistribution()
{
using BLayout = remove_cvref_t<typename Problem::BLayout>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
......@@ -442,7 +336,7 @@ struct GemmPipelineAGmemBGmemCRegV1DefaultPolicy
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDescriptor()
CK_TILE_HOST_DEVICE static constexpr auto MakeShuffledARegBlockDistribution()
{
using ALayout = remove_cvref_t<typename Problem::ALayout>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
......
......@@ -3,6 +3,7 @@
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_ag_bg_cr_scheduler.hpp"
namespace ck_tile {
......@@ -11,10 +12,10 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
typename Traits_>
struct GemmPipelineProblemBase
{
using GemmTraits = remove_cvref_t<TileGemmTraits_>;
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
......@@ -22,19 +23,19 @@ struct GemmPipelineProblemBase
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename GemmTraits::ALayout>;
using BLayout = remove_cvref_t<typename GemmTraits::BLayout>;
using CLayout = remove_cvref_t<typename GemmTraits::CLayout>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t VectorLoadSize = GemmTraits::_VectorSize;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = GemmTraits::kPadM;
static constexpr bool kPadN = GemmTraits::kPadN;
static constexpr bool kPadK = GemmTraits::kPadK;
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = GemmPipelineScheduler::Default;
static constexpr index_t VectorLoadSize = Traits::_VectorSize;
CK_TILE_HOST_DEVICE static constexpr auto GetAlignmentA()
{
if constexpr(std::is_same_v<ALayout, ck_tile::tensor_layout::gemm::ColumnMajor>)
......@@ -128,27 +129,43 @@ template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_>
typename Traits_>
using GemmPipelineProblem =
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, TileGemmTraits_>;
GemmPipelineProblemBase<ADataType_, BDataType_, CDataType_, BlockGemmShape_, Traits_>;
template <typename ADataType_,
typename BDataType_,
typename CDataType_,
typename BlockGemmShape_,
typename TileGemmTraits_,
typename Traits_,
GemmPipelineScheduler Scheduler_ = GemmPipelineScheduler::Intrawave,
bool HasHotLoop_ = true,
TailNumber TailNum_ = TailNumber::Full>
struct UniversalGemmPipelineProblem : public GemmPipelineProblemBase<ADataType_,
BDataType_,
CDataType_,
BlockGemmShape_,
TileGemmTraits_>
struct UniversalGemmPipelineProblem
{
using Traits = remove_cvref_t<Traits_>;
using ADataType = remove_cvref_t<ADataType_>;
using BDataType = remove_cvref_t<BDataType_>;
using CDataType = remove_cvref_t<CDataType_>;
using BlockGemmShape = remove_cvref_t<BlockGemmShape_>;
using ALayout = remove_cvref_t<typename Traits::ALayout>;
using BLayout = remove_cvref_t<typename Traits::BLayout>;
using CLayout = remove_cvref_t<typename Traits::CLayout>;
static constexpr index_t kBlockSize = BlockGemmShape::NumWarps * get_warp_size();
static constexpr bool kPadM = Traits::kPadM;
static constexpr bool kPadN = Traits::kPadN;
static constexpr bool kPadK = Traits::kPadK;
static constexpr auto Scheduler = Scheduler_;
static constexpr auto HasHotLoop = HasHotLoop_;
static constexpr auto TailNum = TailNum_;
static constexpr bool TransposeC = Traits::TransposeC;
};
} // namespace ck_tile
......@@ -19,11 +19,34 @@ struct TileGemmTraits
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
// TODO this can't be hardcoded here! Should be in policy!
static constexpr int _VectorSize = 16;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = false;
};
template <bool kPadM_,
bool kPadN_,
bool kPadK_,
typename ALayout_,
typename BLayout_,
typename CLayout_,
bool TransposeC_ = false>
struct TileGemmUniversalTraits
{
static constexpr bool kPadM = kPadM_;
static constexpr bool kPadN = kPadN_;
static constexpr bool kPadK = kPadK_;
using ALayout = ALayout_;
using BLayout = BLayout_;
using CLayout = CLayout_;
static constexpr bool TransposeC = TransposeC_;
};
} // namespace ck_tile
......@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
......@@ -14,26 +14,28 @@ using Row = ck_tile::tensor_layout::gemm::RowMajor;
using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
using Intrawave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Intrawave>;
using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
ck_tile::GemmPipelineScheduler::Interwave>;
using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
// using Interwave = ck_tile::integral_constant<ck_tile::GemmPipelineScheduler,
// ck_tile::GemmPipelineScheduler::Interwave>;
// using Mem = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Mem>;
using Comp = ck_tile::integral_constant<GemmPipelineType, GemmPipelineType::Comp>;
// TODO: Enable Memory pipeline, when it would be updated for vector loads on non-K major tensors.
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType, GemmPipelineScheduler, PipelineType
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
// std::tuple< Row, Col, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
// std::tuple< Col, Row, Row, F16, F16, F32, F16, Interwave, Mem>,
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Mem>,
std::tuple< Col, Col, Row, F16, F16, F32, F16, Intrawave, Comp>
// std::tuple< Col, Col, Row, F16, F16, F32, F16, Interwave, Mem>
>;
// clang-format on
......
......@@ -10,7 +10,13 @@ TYPED_TEST(TestCkTileGemmPipeline, SmallM)
constexpr int K = 320;
for(int M : Ms)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
else
this->Run(M, N, K);
}
}
TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
......@@ -18,14 +24,29 @@ TYPED_TEST(TestCkTileGemmPipeline, MidLargeM)
std::vector<int> Ms{127, 255, 312, 799, 1573};
constexpr int N = 1024;
constexpr int K = 320;
constexpr int VecLoadSize = 8;
for(int M : Ms)
{
if constexpr(std::is_same_v<typename TestFixture::ALayout,
ck_tile::tensor_layout::gemm::ColumnMajor>)
{
// TODO: Can we anyhow deduce used vector load size?
if(M % VecLoadSize == 0)
this->Run(M, N, K);
else
EXPECT_THROW((this->Run(M, N, K)), std::runtime_error);
}
else
{
this->Run(M, N, K);
}
}
}
TYPED_TEST(TestCkTileGemmPipeline, PaddK)
{
std::vector<int> Ms{127};
std::vector<int> Ms{128};
constexpr int N = 1024;
constexpr int K = 432;
......
......@@ -16,6 +16,7 @@ enum struct GemmPipelineType
Mem,
Comp
};
template <typename Tuple>
class TestCkTileGemmPipeline : public ::testing::Test
{
......@@ -51,6 +52,9 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool kPadN = PadN;
constexpr bool kPadK = PadK;
// TODO: For now - but this should also be a test parameter
constexpr bool TransposeC = false;
constexpr int kBlockPerCu = 1;
// ===============================================
......@@ -65,14 +69,16 @@ class TestCkTileGemmPipeline : public ::testing::Test
ck_tile::Default2DEpilogueProblem<AccDataType, CDataType, kPadM, kPadN>>;
using Traits = ck_tile::TileGemmTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout>;
using GemmUniversalTraits = ck_tile::
TileGemmUniversalTraits<kPadM, kPadN, kPadK, ALayout, BLayout, CLayout, TransposeC>;
using BaseGemmPipeline = std::conditional_t<
PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<
ck_tile::
GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>>>;
using GemmPipelineProblem =
ck_tile::GemmPipelineProblem<ADataType, BDataType, AccDataType, GemmShape, Traits>;
using BaseGemmPipeline =
std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::BaseGemmPipelineAgBgCrMem<GemmPipelineProblem>,
ck_tile::BaseGemmPipelineAgBgCrCompV3<GemmPipelineProblem>>;
const ck_tile::index_t k_grain = args.k_batch * K_Tile;
const ck_tile::index_t K_split = (args.K + k_grain - 1) / k_grain * K_Tile;
......@@ -84,26 +90,22 @@ class TestCkTileGemmPipeline : public ::testing::Test
constexpr bool has_hot_loop_v = has_hot_loop_.value;
constexpr auto tail_number_v = tail_number_.value;
using GemmPipeline =
std::conditional_t<PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrMem<
ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
Scheduler,
has_hot_loop_v,
tail_number_v>>,
ck_tile::GemmPipelineAgBgCrCompV3<
ck_tile::UniversalGemmPipelineProblem<ADataType,
using UniversalGemmProblem = ck_tile::UniversalGemmPipelineProblem<ADataType,
BDataType,
AccDataType,
GemmShape,
Traits,
GemmUniversalTraits,
Scheduler,
has_hot_loop_v,
tail_number_v>>>;
tail_number_v>;
using GemmPipeline = std::conditional_t<
PipelineType == GemmPipelineType::Mem,
ck_tile::GemmPipelineAgBgCrMem<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>,
ck_tile::GemmPipelineAgBgCrCompV3<UniversalGemmProblem,
ck_tile::UniversalGemmPipelineAgBgCrPolicy>>;
using Kernel = ck_tile::GemmKernel<TilePartitioner, GemmPipeline, GemmEpilogue>;
auto kargs = Kernel::MakeKernelArgs(args);
......@@ -128,17 +130,40 @@ class TestCkTileGemmPipeline : public ::testing::Test
};
if(has_hot_loop)
{
if constexpr(PipelineType == GemmPipelineType::Comp)
{
if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Full>{});
}
else
{
std::ostringstream err;
err << "For compute pipeline tail number should always be Full, but have \""
<< tail_num << "\" which is not supported! PrefetchStages: "
<< BaseGemmPipeline::PrefetchStages << "\n File: " << __FILE__ << ":"
<< __LINE__ << ", in function: " << __func__;
throw std::runtime_error(err.str());
}
}
if constexpr(PipelineType == GemmPipelineType::Mem)
{
// Tail pipeline One to Seven
if(tail_num == ck_tile::TailNumber::One)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::One>{});
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::One>{});
}
else if(tail_num == ck_tile::TailNumber::Full)
{
Run(ck_tile::bool_constant<true>{},
ck_tile::integral_constant<ck_tile::TailNumber, ck_tile::TailNumber::Full>{});
ck_tile::integral_constant<ck_tile::TailNumber,
ck_tile::TailNumber::Full>{});
}
if constexpr(BaseGemmPipeline::PrefetchStages > 2)
......@@ -196,6 +221,7 @@ class TestCkTileGemmPipeline : public ::testing::Test
}
}
}
}
else
{
// Tail number always Full - #PrefetchStages
......
......@@ -17,7 +17,7 @@ using Col = ck_tile::tensor_layout::gemm::ColumnMajor;
// clang-format off
using KernelTypes = ::testing::Types<
// ALayout, BLayout, CLayout, ADataType, BDataType, AccDataType, CDataType
std::tuple< Row, Row, Row, F16, F16, F32, F16>,
// std::tuple< Row, Row, Row, F16, F16, F32, F16>,
//std::tuple< Col, Row, Row, F16, F16, F32, F16>,
std::tuple< Row, Col, Row, F16, F16, F32, F16>//,
//std::tuple< Col, Col, Row, F16, F16, F32, F16>
......
......@@ -96,12 +96,9 @@ class TestCkTileGroupedGemm : public ::testing::Test
CodegenGemmShape,
CodegenGemmTraits<ALayout, BLayout, CLayout>>;
using CodegenGemmPolicy = ck_tile::UniversalGemmPipelineAgBgCrPolicy;
template <typename ALayout, typename BLayout, typename CLayout>
using CodegenGemmPipeline =
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>,
CodegenGemmPolicy>;
ck_tile::GemmPipelineAGmemBGmemCRegV1<CodegenPipelineProblem<ALayout, BLayout, CLayout>>;
template <typename ALayout, typename BLayout, typename CLayout>
using Kernel = ck_tile::GroupedGemmKernel<TilePartitioner,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment