Commit dc1c2bf8 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into letaoqin/update_layernorm

parents 5cfd751b a285d6f9
...@@ -5,9 +5,8 @@ ...@@ -5,9 +5,8 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_areg_bsmem_creg_v1_custom_policy.hpp"
...@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -27,20 +26,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::QDataType, typename Problem::QDataType,
...@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -66,20 +60,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetPTOGradTBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::OGradDataType, typename Problem::OGradDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, Problem::kBlockSize,
Problem::BlockFmhaShape::kVHeaddim, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kVHeaddim,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1WarpTile>, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
Problem::kPadHeadDimV,
Problem::kPadHeadDimV,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -104,20 +93,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetOGradVBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::OGradDataType, BlockGemmProblem<typename Problem::OGradDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK2>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm2BlockWarps, Problem::BlockFmhaShape::kK2>,
typename Problem::BlockFmhaShape::Gemm2WarpTile>, typename Problem::BlockFmhaShape::Gemm2BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm2WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = WarpGemmMfmaDispatcher< using WarpGemm = WarpGemmMfmaDispatcher<
typename Problem::OGradDataType, typename Problem::OGradDataType,
...@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -143,20 +127,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradTQTBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::QDataType, typename Problem::QDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kN0, Problem::kBlockSize,
Problem::BlockFmhaShape::kQKHeaddim, TileGemmShape<sequence<Problem::BlockFmhaShape::kN0,
Problem::BlockFmhaShape::kK3>, Problem::BlockFmhaShape::kQKHeaddim,
typename Problem::BlockFmhaShape::Gemm3BlockWarps, Problem::BlockFmhaShape::kK3>,
typename Problem::BlockFmhaShape::Gemm3WarpTile>, typename Problem::BlockFmhaShape::Gemm3BlockWarps,
TileGemmTraits<Problem::kPadSeqLenK, typename Problem::BlockFmhaShape::Gemm3WarpTile>>;
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -181,20 +160,15 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetSGradKTBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::GemmDataType, BlockGemmProblem<typename Problem::GemmDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::AccDataType, typename Problem::AccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kQKHeaddim, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK4>, Problem::BlockFmhaShape::kQKHeaddim,
typename Problem::BlockFmhaShape::Gemm4BlockWarps, Problem::BlockFmhaShape::kK4>,
typename Problem::BlockFmhaShape::Gemm4WarpTile>, typename Problem::BlockFmhaShape::Gemm4BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm4WarpTile>>;
Problem::kPadHeadDimQ,
Problem::kPadSeqLenK,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
using WarpGemm = using WarpGemm =
WarpGemmMfmaDispatcher<typename Problem::GemmDataType, WarpGemmMfmaDispatcher<typename Problem::GemmDataType,
...@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -222,7 +196,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using QDataType = remove_cvref_t<typename Problem::QDataType>; using QDataType = remove_cvref_t<typename Problem::QDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(QDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(QDataType); constexpr index_t kMinVecLoad = 4 / sizeof(QDataType);
...@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -241,7 +215,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using KDataType = remove_cvref_t<typename Problem::KDataType>; using KDataType = remove_cvref_t<typename Problem::KDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(KDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(KDataType); constexpr index_t kMinVecLoad = 4 / sizeof(KDataType);
...@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -260,7 +234,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using VDataType = remove_cvref_t<typename Problem::VDataType>; using VDataType = remove_cvref_t<typename Problem::VDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(VDataType);
constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kMNPerBlock * kKPerBlock / kBlockSize;
...@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -280,7 +254,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>; using OGradDataType = remove_cvref_t<typename Problem::OGradDataType>;
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType); constexpr index_t kMaxVecLoad = 16 / sizeof(OGradDataType);
constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType); constexpr index_t kMinVecLoad = 4 / sizeof(OGradDataType);
...@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -341,7 +315,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
...@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -353,7 +327,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
return total_pixels / GetAlignmentK<Problem>(); return total_pixels / GetAlignmentK<Problem>();
...@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -364,7 +338,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize; constexpr index_t total_pixels = kNPerBlock * kKPerBlock / kBlockSize;
...@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -402,7 +376,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -425,7 +399,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentV<Problem>(); constexpr index_t K1 = GetAlignmentV<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -448,7 +422,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>(); constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -471,7 +445,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -842,44 +816,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKLdsWriteBlockDescriptor()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackK<Problem>(); constexpr index_t kKPack = GetSmemKPackK<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>(); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kKPack>();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetQKBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto k_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto k_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
k_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto k_block_dstr = make_static_tile_distribution(k_block_dstr_encode);
return k_block_dstr;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeKRegBlockDescriptor()
{ {
...@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -891,7 +833,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm0BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
...@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -916,45 +858,13 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVLdsWriteBlockDescriptor()
{ {
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kVPack = GetSmemKPackV<Problem>(); constexpr index_t kVPack = GetSmemKPackV<Problem>();
return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>(); return MakeXLdsBlockDescriptor<kNPerBlock, kKPerBlock, kVPack>();
} }
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegSliceBlockDescriptor()
{
using BlockGemm = remove_cvref_t<decltype(GetOGradVBlockGemm<Problem>())>;
constexpr auto config = BlockGemm::Policy::template GetWarpGemmMWarpNWarp<Problem>();
using WarpGemm = remove_cvref_t<decltype(config.template at<0>())>;
constexpr index_t MWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<0>{});
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
constexpr auto v_block_outer_dstr_encoding =
tile_distribution_encoding<sequence<MWarp>,
tuple<sequence<NIterPerWarp, NWarp>, sequence<KIterPerWarp>>,
tuple<sequence<0, 1>>,
tuple<sequence<0, 1>>,
sequence<1, 2>,
sequence<0, 0>>{};
constexpr auto v_block_dstr_encode = detail::make_embed_tile_distribution_encoding(
v_block_outer_dstr_encoding, typename WarpGemm::BWarpDstrEncoding{});
constexpr auto v_block_dstr = make_static_tile_distribution(v_block_dstr_encode);
return v_block_dstr;
}
template <typename Problem> template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeVRegBlockDescriptor()
{ {
...@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -966,7 +876,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{}); constexpr index_t NWarp = Problem::BlockFmhaShape::Gemm2BlockWarps::at(number<1>{});
constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0; constexpr index_t kNPerBlock = Problem::BlockFmhaShape::kN0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2;
constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN); constexpr index_t NIterPerWarp = kNPerBlock / (NWarp * WarpGemm::kN);
constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK; constexpr index_t KIterPerWarp = kKPerBlock / WarpGemm::kK;
...@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -992,7 +902,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentK<Problem>(); constexpr index_t K1 = GetAlignmentK<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1074,7 +984,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor() CK_TILE_HOST_DEVICE static constexpr auto MakeQLdsBlockDescriptor()
{ {
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t kKPack = GetSmemKPackQ<Problem>(); constexpr index_t kKPack = GetSmemKPackQ<Problem>();
...@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1118,7 +1028,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK0; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kQKHeaddim;
constexpr index_t K1 = GetAlignmentQ<Problem>(); constexpr index_t K1 = GetAlignmentQ<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1281,7 +1191,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
// Hold full block data // Hold full block data
constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0; constexpr index_t kMPerBlock = Problem::BlockFmhaShape::kM0;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t kKPack = GetSmemKPackOGrad<Problem>(); constexpr index_t kKPack = GetSmemKPackOGrad<Problem>();
...@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1325,7 +1235,7 @@ struct BlockFmhaBwdPipelineDefaultPolicy
{ {
constexpr index_t kBlockSize = Problem::kBlockSize; constexpr index_t kBlockSize = Problem::kBlockSize;
constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kK2; constexpr index_t kKPerBlock = Problem::BlockFmhaShape::kVHeaddim;
constexpr index_t K1 = GetAlignmentOGrad<Problem>(); constexpr index_t K1 = GetAlignmentOGrad<Problem>();
constexpr index_t K0 = kKPerBlock / K1; constexpr index_t K0 = kKPerBlock / K1;
...@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1885,6 +1795,8 @@ struct BlockFmhaBwdPipelineDefaultPolicy
static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0; static constexpr index_t kN0 = Problem::BlockFmhaShape::kN0;
static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim; static constexpr index_t kQKHeaddim = Problem::BlockFmhaShape::kQKHeaddim;
static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim; static constexpr index_t kVHeaddim = Problem::BlockFmhaShape::kVHeaddim;
static constexpr index_t kK0 = Problem::BlockFmhaShape::kK0;
static constexpr index_t kK2 = Problem::BlockFmhaShape::kK2;
static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4; static constexpr index_t kK4 = Problem::BlockFmhaShape::kK4;
static constexpr index_t WarpGemmM = static constexpr index_t WarpGemmM =
...@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1899,14 +1811,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
// Compute // Compute
static constexpr index_t Gemm0MFMA = static constexpr index_t Gemm0MFMA =
kM0 * kN0 * kQKHeaddim / kM0 * kN0 * kK0 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm1MFMA = static constexpr index_t Gemm1MFMA =
kM0 * kN0 * kVHeaddim /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kN0 * kVHeaddim * kM0 / kN0 * kVHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm2MFMA =
kM0 * kN0 * kK2 / (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
static constexpr index_t Gemm3MFMA = static constexpr index_t Gemm3MFMA =
kN0 * kQKHeaddim * kM0 / kN0 * kQKHeaddim * kM0 /
(kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK); (kBlockSize / get_warp_size() * WarpGemmM * WarpGemmN * WarpGemmK);
...@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy ...@@ -1929,13 +1839,12 @@ struct BlockFmhaBwdPipelineDefaultPolicy
kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>(); kM0 * kQKHeaddim / get_warp_size() / GetTransposedAlignmentQ<Problem>();
static constexpr index_t SGradT_LDS_READ_P1 = static constexpr index_t SGradT_LDS_READ_P1 =
kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); kM0 * kK4 / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t Q_LDS_READ = static constexpr index_t Q_LDS_READ = kM0 * kK0 / kBlockSize / GetAlignmentQ<Problem>();
kM0 * kQKHeaddim / kBlockSize / GetAlignmentQ<Problem>();
static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t LSE_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
static constexpr index_t SGradT_LDS_READ_P2 = static constexpr index_t SGradT_LDS_READ_P2 =
kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>(); kM0 * (kN0 - kK4) / (get_warp_size() * Gemm4MWarp) / GetSmemKPackSGrad<Problem>();
static constexpr index_t OGrad_LDS_READ = static constexpr index_t OGrad_LDS_READ =
kM0 * kVHeaddim / kBlockSize / GetAlignmentOGrad<Problem>(); kM0 * kK2 / kBlockSize / GetAlignmentOGrad<Problem>();
static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4); static constexpr index_t D_LDS_READ = WarpGemmM == 16 ? kM0 / (4 * 4) : kM0 / (2 * 4);
// LDS Write // LDS Write
......
...@@ -5,9 +5,9 @@ ...@@ -5,9 +5,9 @@
#include "ck_tile/core.hpp" #include "ck_tile/core.hpp"
#include "ck_tile/ops/common/tensor_layout.hpp" #include "ck_tile/ops/common/tensor_layout.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_problem.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp" #include "ck_tile/ops/gemm/pipeline/tile_gemm_shape.hpp"
#include "ck_tile/ops/gemm/pipeline/tile_gemm_traits.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp" #include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
#include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp" #include "ck_tile/ops/gemm/block/block_gemm_asmem_bsmem_creg_v1_custom_policy.hpp"
...@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true> ...@@ -77,20 +77,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ true>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false> ...@@ -207,20 +202,15 @@ struct BlockFmhaPipelineQXCustomPolicy</* QLoadOnce = */ false>
CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetQKBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::QDataType, BlockGemmProblem<typename Problem::QDataType,
typename Problem::KDataType, typename Problem::KDataType,
typename Problem::SaccDataType, typename Problem::SaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN0, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK0>, Problem::BlockFmhaShape::kN0,
typename Problem::BlockFmhaShape::Gemm0BlockWarps, Problem::BlockFmhaShape::kK0>,
typename Problem::BlockFmhaShape::Gemm0WarpTile>, typename Problem::BlockFmhaShape::Gemm0BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm0WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
constexpr auto warp_gemm = []() { constexpr auto warp_gemm = []() {
if constexpr(std::is_same_v<typename Problem::QDataType, half_t> && if constexpr(std::is_same_v<typename Problem::QDataType, half_t> &&
...@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo ...@@ -968,20 +958,15 @@ struct BlockFmhaPipelineQXKSVSCustomPolicy : BlockFmhaPipelineQXCustomPolicy<QLo
CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm() CK_TILE_HOST_DEVICE static constexpr auto GetKVBlockGemm()
{ {
using GemmProblem = using GemmProblem =
GemmPipelineProblem<typename Problem::PDataType, BlockGemmProblem<typename Problem::PDataType,
typename Problem::VDataType, typename Problem::VDataType,
typename Problem::OaccDataType, typename Problem::OaccDataType,
TileGemmShape<sequence<Problem::BlockFmhaShape::kM0, Problem::kBlockSize,
Problem::BlockFmhaShape::kN1, TileGemmShape<sequence<Problem::BlockFmhaShape::kM0,
Problem::BlockFmhaShape::kK1>, Problem::BlockFmhaShape::kN1,
typename Problem::BlockFmhaShape::Gemm1BlockWarps, Problem::BlockFmhaShape::kK1>,
typename Problem::BlockFmhaShape::Gemm1WarpTile>, typename Problem::BlockFmhaShape::Gemm1BlockWarps,
TileGemmTraits<Problem::kPadSeqLenQ, typename Problem::BlockFmhaShape::Gemm1WarpTile>>;
Problem::kPadSeqLenK,
Problem::kPadHeadDimQ,
typename tensor_layout::gemm::RowMajor,
typename tensor_layout::gemm::ColumnMajor,
typename tensor_layout::gemm::RowMajor>>;
auto warp_gemm = [&]() { auto warp_gemm = [&]() {
if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> && if constexpr(std::is_same_v<typename Problem::KDataType, fp8_t> &&
......
...@@ -23,6 +23,7 @@ ...@@ -23,6 +23,7 @@
#include "ck_tile/ops/gemm/block/block_gemm_problem.hpp" #include "ck_tile/ops/gemm/block/block_gemm_problem.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp" #include "ck_tile/ops/gemm/kernel/gemm_kernel.hpp"
#include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp" #include "ck_tile/ops/gemm/kernel/gemm_tile_partitioner.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_universal_pipeline_ag_bg_cr_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v1_default_policy.hpp"
#include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp" #include "ck_tile/ops/gemm/pipeline/gemm_pipeline_agmem_bgmem_creg_v2.hpp"
......
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#pragma once
#include "ck_tile/core.hpp"
#include "ck_tile/ops/gemm/warp/warp_gemm_dispatcher.hpp"
namespace ck_tile {
// UniversalGemm Policy
template <typename LayoutA_, typename LayoutB_, typename LayoutC_>
struct UniversalGemmPipelineAgBgCrPolicy
{
using LayoutA = remove_cvref_t<LayoutA_>;
using LayoutB = remove_cvref_t<LayoutB_>;
using LayoutC = remove_cvref_t<LayoutC_>;
static constexpr auto I0 = number<0>{};
static constexpr auto I1 = number<1>{};
static constexpr auto I2 = number<2>{};
static constexpr bool TransposeC = true;
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeALdsBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using ADataType = remove_cvref_t<typename Problem::ADataType>;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
if constexpr(std::is_same<tensor_layout::gemm::RowMajor, LayoutA>::value)
{
constexpr auto MLdsLayer = 32 * 4 / KPerBlock / sizeof(ADataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(ADataType);
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(K0 * number<MLdsLayer>{}, number<MPerBlock / MLdsLayer>{}, K1),
make_tuple(K1, number<KPerBlock * MLdsLayer>{}, I1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<MPerBlock / MLdsLayer>{},
number<K0 * MLdsLayer>{})),
make_pass_through_transform(K1)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto a_lds_block_desc_ak0_kMLdsLayer_m_ak1 = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<MLdsLayer>{})),
make_pass_through_transform(number<MPerBlock / MLdsLayer>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_ak0_kMLdsLayer_m_ak1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<MPerBlock / MLdsLayer>{}, number<MLdsLayer>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
else // ColumnMajor A
{
// kfold and mpair dimension is not always required.
// more dimension in merge_transform increase the difficulty of generating immarg offset
// for compiler.
constexpr auto M0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I0);
constexpr auto M1 = MPerBlock / M0;
constexpr auto KThreadWrite = Problem::kBlockSize / M0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kM;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * M0 * sizeof(ADataType) > 128) ? 1 : 128 / (K1 * M0 * sizeof(ADataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=mpair<=kN0
constexpr auto mpair = (K1 * WarpGemm::kM * sizeof(ADataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kM * sizeof(ADataType))) > M0
? M0
: 128 / (K1 * WarpGemm::kM * sizeof(ADataType)));
constexpr auto a_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * M1>{},
number<kfold * M0 / mpair>{},
number<mpair>{},
K1));
constexpr auto a_lds_block_desc_permuted = transform_tensor_descriptor(
a_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * M1>{}, number<kfold * M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto a_lds_block_desc_unmerged = transform_tensor_descriptor(
a_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<M1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<M0 / mpair>{})),
make_pass_through_transform(number<mpair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto a_lds_block_desc_m_k = transform_tensor_descriptor(
a_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<M0 / mpair>{}, number<mpair>{}, number<M1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return a_lds_block_desc_m_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBLdsBlockDescriptor()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
using BDataType = remove_cvref_t<typename Problem::BDataType>;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
if constexpr(std::is_same<tensor_layout::gemm::ColumnMajor, LayoutB>::value)
{
// NLdsLayer * K0 as logical Bank
constexpr auto NLdsLayer = 32 * 4 / KPerBlock / sizeof(BDataType) < 1
? 1
: 32 * 4 / KPerBlock / sizeof(BDataType);
;
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor(
make_tuple(K0 * number<NLdsLayer>{}, number<NPerBlock / NLdsLayer>{}, K1),
make_tuple(K1, number<KPerBlock * NLdsLayer>{}, I1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(make_xor_transform(make_tuple(number<NPerBlock / NLdsLayer>{},
number<K0 * NLdsLayer>{})),
make_pass_through_transform(K1)),
make_tuple(sequence<1, 0>{}, sequence<2>{}),
make_tuple(sequence<1, 0>{}, sequence<2>{}));
constexpr auto b_lds_block_desc_bk0_kNLdsLayer_n_bk1 = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(make_unmerge_transform(make_tuple(K0, number<NLdsLayer>{})),
make_pass_through_transform(number<NPerBlock / NLdsLayer>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{}, sequence<1>{}, sequence<2>{}),
make_tuple(sequence<0, 2>{}, sequence<1>{}, sequence<3>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_bk0_kNLdsLayer_n_bk1,
make_tuple(make_merge_transform_v3_division_mod(make_tuple(K0, K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<NPerBlock / NLdsLayer>{}, number<NLdsLayer>{}))),
make_tuple(sequence<0, 3>{}, sequence<1, 2>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
else // RowMajor B
{
constexpr auto N0 = get_warp_size() * Problem::BlockGemmShape::BlockWarps::at(I1);
constexpr auto N1 = NPerBlock / N0;
constexpr auto KThreadWrite = Problem::kBlockSize / N0;
constexpr auto K0PerThreadWrite = K0 / KThreadWrite;
constexpr auto KThreadRead = 64 / WarpGemm::kN;
constexpr auto K0PerThreadRead = K0 / KThreadRead;
constexpr auto kfold =
(K1 * N0 * sizeof(BDataType) > 128) ? 1 : 128 / (K1 * N0 * sizeof(BDataType));
constexpr auto KThreadReadPerm =
(kfold * K0PerThreadWrite / K0PerThreadRead) > 1
? KThreadRead / (kfold * K0PerThreadWrite / K0PerThreadRead)
: KThreadRead;
// 1<=npair<=kN0
constexpr auto npair = (K1 * WarpGemm::kN * sizeof(BDataType) > 128)
? 1
: ((128 / (K1 * WarpGemm::kN * sizeof(BDataType))) > N0
? N0
: 128 / (K1 * WarpGemm::kN * sizeof(BDataType)));
constexpr auto b_lds_block_desc = make_naive_tensor_descriptor_packed(
make_tuple(number<KThreadWrite / kfold / KThreadReadPerm>{},
number<K0PerThreadWrite>{},
number<KThreadReadPerm * N1>{},
number<kfold * N0 / npair>{},
number<npair>{},
K1));
constexpr auto b_lds_block_desc_permuted = transform_tensor_descriptor(
b_lds_block_desc,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_xor_transform(
make_tuple(number<KThreadReadPerm * N1>{}, number<kfold * N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}),
make_tuple(
sequence<0>{}, sequence<1>{}, sequence<2, 3>{}, sequence<4>{}, sequence<5>{}));
constexpr auto b_lds_block_desc_unmerged = transform_tensor_descriptor(
b_lds_block_desc_permuted,
make_tuple(
make_pass_through_transform(number<KThreadWrite / kfold / KThreadReadPerm>{}),
make_pass_through_transform(number<K0PerThreadWrite>{}),
make_unmerge_transform(make_tuple(number<KThreadReadPerm>{}, number<N1>{})),
make_unmerge_transform(make_tuple(number<kfold>{}, number<N0 / npair>{})),
make_pass_through_transform(number<npair>{}),
make_pass_through_transform(K1)),
make_tuple(sequence<0>{},
sequence<1>{},
sequence<2>{},
sequence<3>{},
sequence<4>{},
sequence<5>{}),
make_tuple(sequence<1>{},
sequence<2>{},
sequence<0, 3>{},
sequence<4, 5>{},
sequence<6>{},
sequence<7>{}));
constexpr auto b_lds_block_desc_n_k = transform_tensor_descriptor(
b_lds_block_desc_unmerged,
make_tuple(make_merge_transform_v3_division_mod(
make_tuple(number<KThreadReadPerm>{},
number<KThreadWrite / kfold / KThreadReadPerm>{},
number<kfold>{},
number<K0PerThreadWrite>{},
K1)),
make_merge_transform_v3_division_mod(
make_tuple(number<N0 / npair>{}, number<npair>{}, number<N1>{}))),
make_tuple(sequence<0, 1, 4, 2, 7>{}, sequence<5, 6, 3>{}),
make_tuple(sequence<1>{}, sequence<0>{}));
return b_lds_block_desc_n_k;
}
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeA()
{
constexpr index_t smem_size_a = sizeof(typename Problem::ADataType) *
MakeALdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_a;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSizeB()
{
constexpr index_t smem_size_b = sizeof(typename Problem::BDataType) *
MakeBLdsBlockDescriptor<Problem>().get_element_space_size();
return smem_size_b;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr index_t GetSmemSize()
{
constexpr index_t smem_size_a = GetSmemSizeA<Problem>();
constexpr index_t smem_size_b = GetSmemSizeB<Problem>();
index_t smem_size = 0;
smem_size += smem_size_a + smem_size_b;
return smem_size;
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeADramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t MPerBlock = Problem::BlockGemmShape::kM;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t M2 = get_warp_size() / K0;
constexpr index_t M1 = BlockSize / get_warp_size();
static_assert(M2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(M1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t M0 = MPerBlock / (M2 * M1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<M0, M1, M2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto MakeBDramTileDistribution()
{
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
Problem::BlockGemmShape::WarpTile::at(I0),
Problem::BlockGemmShape::WarpTile::at(I1),
Problem::BlockGemmShape::WarpTile::at(I2),
TransposeC>;
constexpr index_t BlockSize = Problem::kBlockSize;
constexpr index_t NPerBlock = Problem::BlockGemmShape::kN;
constexpr index_t KPerBlock = Problem::BlockGemmShape::kK;
constexpr index_t K1 = WarpGemm::kK;
constexpr index_t K0 = KPerBlock / K1;
constexpr index_t N2 = get_warp_size() / K0;
constexpr index_t N1 = BlockSize / get_warp_size();
static_assert(N2 != 0, "M2 is zero, which will lead to a division by zero error.");
static_assert(N1 != 0, "M1 is zero, which will lead to a division by zero error.");
constexpr index_t N0 = NPerBlock / (N2 * N1);
return make_static_tile_distribution(
tile_distribution_encoding<sequence<1>,
tuple<sequence<N0, N1, N2>, sequence<K0, K1>>,
tuple<sequence<1>, sequence<1, 2>>,
tuple<sequence<1>, sequence<2, 0>>,
sequence<1, 2>,
sequence<0, 1>>{});
}
template <typename Problem>
CK_TILE_HOST_DEVICE static constexpr auto GetBlockGemm()
{
using AccDataType = float;
using BlockWarps = typename Problem::BlockGemmShape::BlockWarps;
using WarpTile = typename Problem::BlockGemmShape::WarpTile;
using WarpGemm = WarpGemmMfmaDispatcher<typename Problem::ADataType,
typename Problem::BDataType,
AccDataType,
WarpTile::at(I0),
WarpTile::at(I1),
WarpTile::at(I2),
TransposeC>;
using BlockGemmPolicy = BlockGemmASmemBSmemCRegV1CustomPolicy<typename Problem::ADataType,
typename Problem::BDataType,
typename Problem::CDataType,
BlockWarps,
WarpGemm>;
return BlockGemmASmemBSmemCRegV1<Problem, BlockGemmPolicy>{};
}
};
} // namespace ck_tile
...@@ -46,7 +46,7 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple< ...@@ -46,7 +46,7 @@ using device_gemm_xdl_universal_f8_f16_f16_mk_kn_mn_comp_instances = std::tuple<
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 64, 16, 8, 32, 32, 3, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 192, 256, 64, 16, 8, 32, 32, 3, 4, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 128, 16, 8, 32, 32, 2, 2, S<8, 32, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>, DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v3>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 8, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<8, 32, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 4, 8, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
// We prefer following instance, however, existing compiler bug cause it failed to generate sanity code. // We prefer following instance, however, existing compiler bug cause it failed to generate sanity code.
// DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>, // DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Intrawave, BlockGemmPipelineVersion::v5>,
DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1> DeviceGemm_Xdl_CShuffleV3< Row, Row, Row, F8, F16, F16, F32, F16, PassThrough, PassThrough, PassThrough, GemmSpec, 256, 128, 128, 64, 16, 4, 32, 32, 2, 2, S<4, 64, 1>, S<1, 0, 2>, S<1, 0, 2>, 2, 16, 16, 0, S<16, 16, 1>, S<0, 2, 1>, S<0, 2, 1>, 1, 8, 4, 0, 1, 1, S<1, 32, 1, 8>, 8, BlockGemmPipelineScheduler::Interwave, BlockGemmPipelineVersion::v1>
......
...@@ -18,4 +18,9 @@ if(result EQUAL 0) ...@@ -18,4 +18,9 @@ if(result EQUAL 0)
target_link_libraries(test_bf8 PRIVATE utility) target_link_libraries(test_bf8 PRIVATE utility)
endif() endif()
add_gtest_executable(test_custom_type test_custom_type.cpp)
if(result EQUAL 0)
target_link_libraries(test_custom_type PRIVATE utility)
endif()
add_gtest_executable(test_type_convert_const type_convert_const.cpp) add_gtest_executable(test_type_convert_const type_convert_const.cpp)
// SPDX-License-Identifier: MIT
// Copyright (c) 2024, Advanced Micro Devices, Inc. All rights reserved.
#include "gtest/gtest.h"
#include "ck/utility/data_type.hpp"
#include "ck/utility/type_convert.hpp"
using ck::bf8_t;
using ck::bhalf_t;
using ck::f8_t;
using ck::half_t;
using ck::Number;
using ck::type_convert;
using ck::vector_type;
TEST(Custom_bool, TestSize)
{
struct custom_bool_t
{
bool data;
};
ASSERT_EQ(sizeof(custom_bool_t), sizeof(bool));
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 2>), sizeof(vector_type<bool, 2>));
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 4>), sizeof(vector_type<bool, 4>));
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 8>), sizeof(vector_type<bool, 8>));
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 16>), sizeof(vector_type<bool, 16>));
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 32>), sizeof(vector_type<bool, 32>));
ASSERT_EQ(sizeof(vector_type<custom_bool_t, 64>), sizeof(vector_type<bool, 64>));
}
TEST(Custom_bool, TestAsType)
{
struct custom_bool_t
{
using type = bool;
type data;
custom_bool_t() : data{type{}} {}
custom_bool_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<bool> test_vec = {false, true, false, true};
// reference vector
vector_type<custom_bool_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_bool_t>()(Number<i>{}).data, false);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bool_t>()(Number<i>{}) = custom_bool_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_bool_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bool_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_bool, TestAsTypeReshape)
{
struct custom_bool_t
{
using type = bool;
type data;
custom_bool_t() : data{type{}} {}
custom_bool_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<bool> test_vec = {false, true, false, true};
// reference vector
vector_type<custom_bool_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_bool_t>()(Number<i>{}).data, false);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bool_t>()(Number<i>{}) = custom_bool_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_bool_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_bool_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bool_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_int8, TestSize)
{
struct custom_int8_t
{
int8_t data;
};
ASSERT_EQ(sizeof(custom_int8_t), sizeof(int8_t));
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 2>), sizeof(vector_type<int8_t, 2>));
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 4>), sizeof(vector_type<int8_t, 4>));
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 8>), sizeof(vector_type<int8_t, 8>));
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 16>), sizeof(vector_type<int8_t, 16>));
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 32>), sizeof(vector_type<int8_t, 32>));
ASSERT_EQ(sizeof(vector_type<custom_int8_t, 64>), sizeof(vector_type<int8_t, 64>));
}
TEST(Custom_int8, TestAsType)
{
struct custom_int8_t
{
using type = int8_t;
type data;
custom_int8_t() : data{type{}} {}
custom_int8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<int8_t> test_vec = {3, -6, 8, -2};
// reference vector
vector_type<custom_int8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_int8_t>()(Number<i>{}).data, 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_int8_t>()(Number<i>{}) = custom_int8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_int8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_int8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_int8, TestAsTypeReshape)
{
struct custom_int8_t
{
using type = int8_t;
type data;
custom_int8_t() : data{type{}} {}
custom_int8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<int8_t> test_vec = {3, -6, 8, -2};
// reference vector
vector_type<custom_int8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_int8_t>()(Number<i>{}).data, 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_int8_t>()(Number<i>{}) = custom_int8_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_int8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_int8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_int8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_uint8, TestSize)
{
struct custom_uint8_t
{
uint8_t data;
};
ASSERT_EQ(sizeof(custom_uint8_t), sizeof(uint8_t));
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 2>), sizeof(vector_type<uint8_t, 2>));
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 4>), sizeof(vector_type<uint8_t, 4>));
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 8>), sizeof(vector_type<uint8_t, 8>));
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 16>), sizeof(vector_type<uint8_t, 16>));
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 32>), sizeof(vector_type<uint8_t, 32>));
ASSERT_EQ(sizeof(vector_type<custom_uint8_t, 64>), sizeof(vector_type<uint8_t, 64>));
}
TEST(Custom_uint8, TestAsType)
{
struct custom_uint8_t
{
using type = uint8_t;
type data;
custom_uint8_t() : data{type{}} {}
custom_uint8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<uint8_t> test_vec = {3, 6, 8, 2};
// reference vector
vector_type<custom_uint8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_uint8_t>()(Number<i>{}) = custom_uint8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_uint8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_uint8, TestAsTypeReshape)
{
struct custom_uint8_t
{
using type = uint8_t;
type data;
custom_uint8_t() : data{type{}} {}
custom_uint8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<uint8_t> test_vec = {3, 6, 8, 2};
// reference vector
vector_type<custom_uint8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, 0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_uint8_t>()(Number<i>{}) = custom_uint8_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_uint8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_uint8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_uint8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_f8, TestSize)
{
struct custom_f8_t
{
_BitInt(8) data;
};
ASSERT_EQ(sizeof(custom_f8_t), sizeof(_BitInt(8)));
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 2>), sizeof(vector_type<_BitInt(8), 2>));
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 4>), sizeof(vector_type<_BitInt(8), 4>));
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 8>), sizeof(vector_type<_BitInt(8), 8>));
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 16>), sizeof(vector_type<_BitInt(8), 16>));
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 32>), sizeof(vector_type<_BitInt(8), 32>));
ASSERT_EQ(sizeof(vector_type<custom_f8_t, 64>), sizeof(vector_type<_BitInt(8), 64>));
}
TEST(Custom_f8, TestAsType)
{
struct custom_f8_t
{
using type = _BitInt(8);
type data;
custom_f8_t() : data{type{}} {}
custom_f8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f),
type_convert<_BitInt(8)>(-0.6f),
type_convert<_BitInt(8)>(0.8f),
type_convert<_BitInt(8)>(-0.2f)};
// reference vector
vector_type<custom_f8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_f8_t>()(Number<i>{}).data, 0); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_f8_t>()(Number<i>{}) = custom_f8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_f8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_f8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_f8, TestAsTypeReshape)
{
struct custom_f8_t
{
using type = _BitInt(8);
type data;
custom_f8_t() : data{type{}} {}
custom_f8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<_BitInt(8)> test_vec = {type_convert<_BitInt(8)>(0.3f),
type_convert<_BitInt(8)>(-0.6f),
type_convert<_BitInt(8)>(0.8f),
type_convert<_BitInt(8)>(-0.2f)};
// reference vector
vector_type<custom_f8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_f8_t>()(Number<i>{}).data, 0); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_f8_t>()(Number<i>{}) = custom_f8_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_f8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_f8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_f8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_bf8, TestSize)
{
struct custom_bf8_t
{
unsigned _BitInt(8) data;
};
ASSERT_EQ(sizeof(custom_bf8_t), sizeof(unsigned _BitInt(8)));
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 2>), sizeof(vector_type<unsigned _BitInt(8), 2>));
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 4>), sizeof(vector_type<unsigned _BitInt(8), 4>));
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 8>), sizeof(vector_type<unsigned _BitInt(8), 8>));
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 16>), sizeof(vector_type<unsigned _BitInt(8), 16>));
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 32>), sizeof(vector_type<unsigned _BitInt(8), 32>));
ASSERT_EQ(sizeof(vector_type<custom_bf8_t, 64>), sizeof(vector_type<unsigned _BitInt(8), 64>));
}
TEST(Custom_bf8, TestAsType)
{
struct custom_bf8_t
{
using type = unsigned _BitInt(8);
type data;
custom_bf8_t() : data{type{}} {}
custom_bf8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<unsigned _BitInt(8)> test_vec = {type_convert<unsigned _BitInt(8)>(0.3f),
type_convert<unsigned _BitInt(8)>(-0.6f),
type_convert<unsigned _BitInt(8)>(0.8f),
type_convert<unsigned _BitInt(8)>(-0.2f)};
// reference vector
vector_type<custom_bf8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, 0); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bf8_t>()(Number<i>{}) = custom_bf8_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_bf8_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_bf8, TestAsTypeReshape)
{
struct custom_bf8_t
{
using type = unsigned _BitInt(8);
type data;
custom_bf8_t() : data{type{}} {}
custom_bf8_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<unsigned _BitInt(8)> test_vec = {type_convert<unsigned _BitInt(8)>(0.3f),
type_convert<unsigned _BitInt(8)>(-0.6f),
type_convert<unsigned _BitInt(8)>(0.8f),
type_convert<unsigned _BitInt(8)>(-0.2f)};
// reference vector
vector_type<custom_bf8_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}(
[&](auto i) { ASSERT_EQ(right_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, 0); });
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bf8_t>()(Number<i>{}) = custom_bf8_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_bf8_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_bf8_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bf8_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_half, TestSize)
{
struct custom_half_t
{
half_t data;
};
ASSERT_EQ(sizeof(custom_half_t), sizeof(half_t));
ASSERT_EQ(sizeof(vector_type<custom_half_t, 2>), sizeof(vector_type<half_t, 2>));
ASSERT_EQ(sizeof(vector_type<custom_half_t, 4>), sizeof(vector_type<half_t, 4>));
ASSERT_EQ(sizeof(vector_type<custom_half_t, 8>), sizeof(vector_type<half_t, 8>));
ASSERT_EQ(sizeof(vector_type<custom_half_t, 16>), sizeof(vector_type<half_t, 16>));
ASSERT_EQ(sizeof(vector_type<custom_half_t, 32>), sizeof(vector_type<half_t, 32>));
ASSERT_EQ(sizeof(vector_type<custom_half_t, 64>), sizeof(vector_type<half_t, 64>));
}
TEST(Custom_half, TestAsType)
{
struct custom_half_t
{
using type = half_t;
type data;
custom_half_t() : data{type{}} {}
custom_half_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<half_t> test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}};
// reference vector
vector_type<custom_half_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_half_t>()(Number<i>{}).data,
type_convert<half_t>(0.0f));
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_half_t>()(Number<i>{}) = custom_half_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_half_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_half_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_half, TestAsTypeReshape)
{
struct custom_half_t
{
using type = half_t;
type data;
custom_half_t() : data{type{}} {}
custom_half_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<half_t> test_vec = {half_t{0.3f}, half_t{-0.6f}, half_t{0.8f}, half_t{-0.2f}};
// reference vector
vector_type<custom_half_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_half_t>()(Number<i>{}).data,
type_convert<half_t>(0.0f));
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_half_t>()(Number<i>{}) = custom_half_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_half_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_half_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_half_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_bhalf, TestSize)
{
struct custom_bhalf_t
{
bhalf_t data;
};
ASSERT_EQ(sizeof(custom_bhalf_t), sizeof(bhalf_t));
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 2>), sizeof(vector_type<bhalf_t, 2>));
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 4>), sizeof(vector_type<bhalf_t, 4>));
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 8>), sizeof(vector_type<bhalf_t, 8>));
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 16>), sizeof(vector_type<bhalf_t, 16>));
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 32>), sizeof(vector_type<bhalf_t, 32>));
ASSERT_EQ(sizeof(vector_type<custom_bhalf_t, 64>), sizeof(vector_type<bhalf_t, 64>));
}
TEST(Custom_bhalf, TestAsType)
{
struct custom_bhalf_t
{
using type = bhalf_t;
type data;
custom_bhalf_t() : data{type{}} {}
custom_bhalf_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<bhalf_t> test_vec = {type_convert<bhalf_t>(0.3f),
type_convert<bhalf_t>(-0.6f),
type_convert<bhalf_t>(0.8f),
type_convert<bhalf_t>(-0.2f)};
// reference vector
vector_type<custom_bhalf_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data,
type_convert<bhalf_t>(0.0f));
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bhalf_t>()(Number<i>{}) = custom_bhalf_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_bhalf_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_bhalf, TestAsTypeReshape)
{
struct custom_bhalf_t
{
using type = bhalf_t;
type data;
custom_bhalf_t() : data{type{}} {}
custom_bhalf_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<bhalf_t> test_vec = {type_convert<bhalf_t>(0.3f),
type_convert<bhalf_t>(-0.6f),
type_convert<bhalf_t>(0.8f),
type_convert<bhalf_t>(-0.2f)};
// reference vector
vector_type<custom_bhalf_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data,
type_convert<bhalf_t>(0.0f));
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_bhalf_t>()(Number<i>{}) = custom_bhalf_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_bhalf_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_bhalf_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_bhalf_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_float, TestSize)
{
struct custom_float_t
{
float data;
};
ASSERT_EQ(sizeof(custom_float_t), sizeof(float));
ASSERT_EQ(sizeof(vector_type<custom_float_t, 2>), sizeof(vector_type<float, 2>));
ASSERT_EQ(sizeof(vector_type<custom_float_t, 4>), sizeof(vector_type<float, 4>));
ASSERT_EQ(sizeof(vector_type<custom_float_t, 8>), sizeof(vector_type<float, 8>));
ASSERT_EQ(sizeof(vector_type<custom_float_t, 16>), sizeof(vector_type<float, 16>));
ASSERT_EQ(sizeof(vector_type<custom_float_t, 32>), sizeof(vector_type<float, 32>));
ASSERT_EQ(sizeof(vector_type<custom_float_t, 64>), sizeof(vector_type<float, 64>));
}
TEST(Custom_float, TestAsType)
{
struct custom_float_t
{
using type = float;
type data;
custom_float_t() : data{type{}} {}
custom_float_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<float> test_vec = {0.3f, -0.6f, 0.8f, -0.2f};
// reference vector
vector_type<custom_float_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_float_t>()(Number<i>{}).data, 0.0f);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_float_t>()(Number<i>{}) = custom_float_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_float_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_float_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_float, TestAsTypeReshape)
{
struct custom_float_t
{
using type = float;
type data;
custom_float_t() : data{type{}} {}
custom_float_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<float> test_vec = {0.3f, -0.6f, 0.8f, -0.2f};
// reference vector
vector_type<custom_float_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_float_t>()(Number<i>{}).data, 0.0f);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_float_t>()(Number<i>{}) = custom_float_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_float_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_float_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_float_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_double, TestSize)
{
struct custom_double_t
{
double data;
};
ASSERT_EQ(sizeof(custom_double_t), sizeof(double));
ASSERT_EQ(sizeof(vector_type<custom_double_t, 2>), sizeof(vector_type<double, 2>));
ASSERT_EQ(sizeof(vector_type<custom_double_t, 4>), sizeof(vector_type<double, 4>));
ASSERT_EQ(sizeof(vector_type<custom_double_t, 8>), sizeof(vector_type<double, 8>));
ASSERT_EQ(sizeof(vector_type<custom_double_t, 16>), sizeof(vector_type<double, 16>));
ASSERT_EQ(sizeof(vector_type<custom_double_t, 32>), sizeof(vector_type<double, 32>));
ASSERT_EQ(sizeof(vector_type<custom_double_t, 64>), sizeof(vector_type<double, 64>));
}
TEST(Custom_double, TestAsType)
{
struct custom_double_t
{
using type = double;
type data;
custom_double_t() : data{type{}} {}
custom_double_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<double> test_vec = {0.3, 0.6, 0.8, 0.2};
// reference vector
vector_type<custom_double_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_double_t>()(Number<i>{}).data, 0.0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_double_t>()(Number<i>{}) = custom_double_t{test_vec.at(i)};
});
// copy the vector
vector_type<custom_double_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_double_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Custom_double, TestAsTypeReshape)
{
struct custom_double_t
{
using type = double;
type data;
custom_double_t() : data{type{}} {}
custom_double_t(type init) : data{init} {}
};
// test size
const int size = 4;
std::vector<double> test_vec = {0.3, 0.6, 0.8, 0.2};
// reference vector
vector_type<custom_double_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<custom_double_t>()(Number<i>{}).data, 0.0);
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<custom_double_t>()(Number<i>{}) = custom_double_t{test_vec.at(i)};
});
// copy the first half of a vector
vector_type<custom_double_t, size / 2> left_vec{
right_vec.template AsType<vector_type<custom_double_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<custom_double_t>()(Number<i>{}).data, test_vec.at(i));
});
}
TEST(Complex_half, TestSize)
{
struct complex_half_t
{
half_t real;
half_t img;
};
ASSERT_EQ(sizeof(complex_half_t), sizeof(half_t) + sizeof(half_t));
ASSERT_EQ(sizeof(vector_type<complex_half_t, 2>),
sizeof(vector_type<half_t, 2>) + sizeof(vector_type<half_t, 2>));
ASSERT_EQ(sizeof(vector_type<complex_half_t, 4>),
sizeof(vector_type<half_t, 4>) + sizeof(vector_type<half_t, 4>));
ASSERT_EQ(sizeof(vector_type<complex_half_t, 8>),
sizeof(vector_type<half_t, 8>) + sizeof(vector_type<half_t, 8>));
ASSERT_EQ(sizeof(vector_type<complex_half_t, 16>),
sizeof(vector_type<half_t, 16>) + sizeof(vector_type<half_t, 16>));
ASSERT_EQ(sizeof(vector_type<complex_half_t, 32>),
sizeof(vector_type<half_t, 32>) + sizeof(vector_type<half_t, 32>));
ASSERT_EQ(sizeof(vector_type<complex_half_t, 64>),
sizeof(vector_type<half_t, 64>) + sizeof(vector_type<half_t, 64>));
}
TEST(Complex_half, TestAlignment)
{
struct complex_half_t
{
half_t real;
half_t img;
};
ASSERT_EQ(alignof(vector_type<complex_half_t, 2>),
alignof(vector_type<half_t, 2>) + alignof(vector_type<half_t, 2>));
ASSERT_EQ(alignof(vector_type<complex_half_t, 4>),
alignof(vector_type<half_t, 4>) + alignof(vector_type<half_t, 4>));
ASSERT_EQ(alignof(vector_type<complex_half_t, 8>),
alignof(vector_type<half_t, 8>) + alignof(vector_type<half_t, 8>));
ASSERT_EQ(alignof(vector_type<complex_half_t, 16>),
alignof(vector_type<half_t, 16>) + alignof(vector_type<half_t, 16>));
ASSERT_EQ(alignof(vector_type<complex_half_t, 32>),
alignof(vector_type<half_t, 32>) + alignof(vector_type<half_t, 32>));
ASSERT_EQ(alignof(vector_type<complex_half_t, 64>),
alignof(vector_type<half_t, 64>) + alignof(vector_type<half_t, 64>));
}
TEST(Complex_half, TestAsType)
{
struct complex_half_t
{
using type = half_t;
type real;
type img;
complex_half_t() : real{type{}}, img{type{}} {}
complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {}
};
// test size
const int size = 4;
// custom type number of elements
const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type);
std::vector<half_t> test_vec = {half_t{0.3f},
half_t{-0.6f},
half_t{0.8f},
half_t{-0.2f},
half_t{0.5f},
half_t{-0.7f},
half_t{0.9f},
half_t{-0.3f}};
// reference vector
vector_type<complex_half_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).real,
type_convert<half_t>(0.0f));
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).img,
type_convert<half_t>(0.0f));
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<complex_half_t>()(Number<i>{}) =
complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)};
});
// copy the vector
vector_type<complex_half_t, size> left_vec{right_vec};
// check if values were copied correctly
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).real,
test_vec.at(num_elem * i));
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).img,
test_vec.at(num_elem * i + 1));
});
}
TEST(Complex_half, TestAsTypeReshape)
{
struct complex_half_t
{
using type = half_t;
type real;
type img;
complex_half_t() : real{type{}}, img{type{}} {}
complex_half_t(type real_init, type img_init) : real{real_init}, img{img_init} {}
};
// test size
const int size = 4;
// custom type number of elements
const int num_elem = sizeof(complex_half_t) / sizeof(complex_half_t::type);
std::vector<half_t> test_vec = {half_t{0.3f},
half_t{-0.6f},
half_t{0.8f},
half_t{-0.2f},
half_t{0.5f},
half_t{-0.7f},
half_t{0.9f},
half_t{-0.3f}};
// reference vector
vector_type<complex_half_t, size> right_vec;
// check default CTOR
ck::static_for<0, size, 1>{}([&](auto i) {
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).real,
type_convert<half_t>(0.0f));
ASSERT_EQ(right_vec.template AsType<complex_half_t>()(Number<i>{}).img,
type_convert<half_t>(0.0f));
});
// assign test values to the vector
ck::static_for<0, size, 1>{}([&](auto i) {
right_vec.template AsType<complex_half_t>()(Number<i>{}) =
complex_half_t{test_vec.at(num_elem * i), test_vec.at(num_elem * i + 1)};
});
// copy the first half of a vector
vector_type<complex_half_t, size / 2> left_vec{
right_vec.template AsType<vector_type<complex_half_t, size / 2>::type>()(Number<0>{})};
// check if values were copied correctly
ck::static_for<0, size / 2, 1>{}([&](auto i) {
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).real,
test_vec.at(num_elem * i));
ASSERT_EQ(left_vec.template AsType<complex_half_t>()(Number<i>{}).img,
test_vec.at(num_elem * i + 1));
});
}
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment