#ifndef DEVICE_GEMM_SPLITK_XDL_HPP #define DEVICE_GEMM_SPLITK_XDL_HPP #include #include "device.hpp" #include "device_base.hpp" #include "device_gemm.hpp" #include "common_header.hpp" #include "tensor_layout.hpp" #include "tensor_descriptor.hpp" #include "tensor_descriptor_helper.hpp" #include "gridwise_gemm_xdlops_v2r4.hpp" namespace ck { namespace tensor_operation { namespace device { template struct DeviceGemmSplitKXdl : public DeviceGemm { static constexpr auto I0 = Number<0>{}; static constexpr auto I1 = Number<1>{}; static constexpr auto I2 = Number<2>{}; static constexpr auto I3 = Number<3>{}; static constexpr auto K1Number = Number{}; static auto MakeAGridDescriptor_KBatch_K0_M_K1(index_t M, index_t K, index_t StrideA, int KBatch, int KPad) { assert(KPad % (K1 * KBatch) == 0); const index_t K0 = KPad / (K1 * KBatch); const auto a_grid_desc_m_k = [&]() { if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(StrideA, I1)); } else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, K), make_tuple(I1, StrideA)); } }(); const auto a_grid_desc_m_kpad = transform_tensor_descriptor( a_grid_desc_m_k, make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(M)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto a_grid_desc_kbatch_k0_m_k1 = transform_tensor_descriptor( a_grid_desc_m_kpad, make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), make_pass_through_transform(M)), make_tuple(Sequence<1>{}, Sequence<0>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); return a_grid_desc_kbatch_k0_m_k1; } static auto MakeBGridDescriptor_KBatch_K0_N_K1(index_t K, index_t N, index_t StrideB, int KBatch, int KPad) { assert(KPad % (K1 * KBatch) == 0); const index_t K0 = KPad / (K1 * KBatch); const auto b_grid_desc_k_n = [&]() { if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(StrideB, I1)); } else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(K, N), make_tuple(I1, StrideB)); } }(); const auto b_grid_desc_kpad_n = transform_tensor_descriptor( b_grid_desc_k_n, make_tuple(make_right_pad_transform(K, KPad - K), make_pass_through_transform(N)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0>{}, Sequence<1>{})); const auto b_grid_desc_kbatch_k0_n_k1 = transform_tensor_descriptor( b_grid_desc_kpad_n, make_tuple(make_unmerge_transform(make_tuple(KBatch, K0, K1Number)), make_pass_through_transform(N)), make_tuple(Sequence<0>{}, Sequence<1>{}), make_tuple(Sequence<0, 1, 3>{}, Sequence<2>{})); return b_grid_desc_kbatch_k0_n_k1; } static auto MakeCGridDescriptor_M_N(index_t M, index_t N, index_t StrideC) { if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(StrideC, I1)); } else if constexpr(is_same::value) { return make_naive_tensor_descriptor(make_tuple(M, N), make_tuple(I1, StrideC)); } } static auto GetKBatchAndKPad(index_t M, index_t N, index_t K) { const auto GridMN = M * N / (MPerBlock * NPerBlock); const index_t KBatch = std::max(DesiredGridSize / GridMN, 1); const index_t K0 = math::integer_divide_ceil(K, K1 * K0PerBlock * KBatch) * K0PerBlock; const index_t KPad = KBatch * K0 * K1; return std::make_tuple(KBatch, KPad); } using AGridDesc_K0_M_K1 = decltype(MakeAGridDescriptor_KBatch_K0_M_K1(1, 1, 1, 1, 1)); using BGridDesc_K0_N_K1 = decltype(MakeBGridDescriptor_KBatch_K0_N_K1(1, 1, 1, 1, 1)); using CGridDesc_M_N = decltype(MakeCGridDescriptor_M_N(1, 1, 1)); // TODO remove these hacks static constexpr auto a_kbatch_k0_m_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: Kbatch Sequence<0, 0, 0, 0, 0>{}, // 1+: K0 Sequence<0, 0, 0, 0, 0>{}, // 2+: M Sequence<0, 0, 0, 0, 0>{}), // 3+: K1 make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: Kbatch Sequence<0, 0, 0, 0, 0>{}, // 1-: K0 Sequence<0, 0, 0, 0, 0>{}, // 2-: M Sequence<0, 0, 0, 0, 0>{})); // 3-: K1 static constexpr auto b_kbatch_k0_n_k1_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0+: Kbatch Sequence<0, 0, 0, 0, 0>{}, // 0+: K0 Sequence<0, 0, 0, 0, 0>{}, // 1+: N Sequence<0, 0, 0, 0, 0>{}), // 2+: K1 make_tuple(Sequence<0, 0, 0, 0, 0>{}, // 0-: Kbatch Sequence<0, 0, 0, 0, 0>{}, // 1-: K0 Sequence<0, 0, 0, 0, 0>{}, // 2-: N Sequence<0, 0, 0, 0, 0>{})); // 3-: K1 static constexpr auto c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks = make_tuple(make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0+: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1+: N0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2+: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3+: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4+: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5+: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6+: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}), // 7+: N2 make_tuple(Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 0-: M0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 1-: N0 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 2-: M1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 3-: N1 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 4-: M2 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 5-: M3 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{}, // 6-: M4 Sequence<0, 0, 0, 0, 0, 0, 0, 0, 0>{})); // 7-: N2 static constexpr auto a_kbatch_k0_m_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; static constexpr auto b_kbatch_k0_n_k1_grid_move_slice_window_step_hacks = Sequence<0, 0, 0, 0, 0>{}; // GridwiseGemm using GridwiseGemm = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum_t::Set, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, decltype(a_kbatch_k0_m_k1_grid_step_hacks), // AGridStepHacks, decltype(b_kbatch_k0_n_k1_grid_step_hacks), // BGridStepHacks, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, decltype( a_kbatch_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, decltype( b_kbatch_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, false, // CAccessOrderMRepeatNRepeat, ABlockLdsAddExtraM, BBlockLdsAddExtraN>; // GridwiseGemm using GridwiseGemmAtomicAdd = GridwiseGemm_bk0mk1_bk0nk1_mn_xdlops_v2r4< BlockSize, ADataType, // TODO: distinguish A/B datatype AccDataType, CDataType, InMemoryDataOperationEnum_t::AtomicAdd, AGridDesc_K0_M_K1, BGridDesc_K0_N_K1, CGridDesc_M_N, MPerBlock, NPerBlock, K0PerBlock, MPerXDL, NPerXDL, K1, MXdlPerWave, NXdlPerWave, ABlockTransferThreadSliceLengths_K0_M_K1, ABlockTransferThreadClusterLengths_K0_M_K1, ABlockTransferThreadClusterArrangeOrder, ABlockTransferSrcAccessOrder, ABlockTransferSrcVectorDim, ABlockTransferSrcScalarPerVector, ABlockTransferDstScalarPerVector_K1, false, // AThreadTransferSrcResetCoordinateAfterRun, BBlockTransferThreadSliceLengths_K0_N_K1, BBlockTransferThreadClusterLengths_K0_N_K1, BBlockTransferThreadClusterArrangeOrder, BBlockTransferSrcAccessOrder, BBlockTransferSrcVectorDim, BBlockTransferSrcScalarPerVector, BBlockTransferDstScalarPerVector_K1, false, // BThreadTransferSrcResetCoordinateAfterRun, Sequence<0, 2, 4, 5, 6, 1, 3, 7>, // CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstVectorDim, CThreadTransferDstScalarPerVector, decltype(a_kbatch_k0_m_k1_grid_step_hacks), // AGridStepHacks, decltype(b_kbatch_k0_n_k1_grid_step_hacks), // BGridStepHacks, decltype(c_m0_n0_m1_n1_m2_m3_m4_n2_grid_step_hacks), // CGridStepHacks, decltype( a_kbatch_k0_m_k1_grid_move_slice_window_step_hacks), // AGridMoveSliceWindowStepHacks, decltype( b_kbatch_k0_n_k1_grid_move_slice_window_step_hacks), // BGridMoveSliceWindowStepHacks, false, // CAccessOrderMRepeatNRepeat, ABlockLdsAddExtraM, BBlockLdsAddExtraN>; using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 = decltype(GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(CGridDesc_M_N{})); using Block2CTileMap = decltype(GridwiseGemm::MakeCBlockClusterAdaptor(CGridDesc_M_N{}, 1, 1, 1)); // Argument struct Argument : public BaseArgument { Argument(const ADataType* p_a_grid, const BDataType* p_b_grid, CDataType* p_c_grid, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC, index_t M01, index_t N01) : p_a_grid_{p_a_grid}, p_b_grid_{p_b_grid}, p_c_grid_{p_c_grid}, a_grid_desc_kbatch_k0_m_k1_{}, b_grid_desc_kbatch_k0_n_k1_{}, c_grid_desc_m_n_{}, c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{}, block_2_ctile_map_{}, M01_{M01}, N01_{N01} { int KBatch = 1, KPad = K; std::tie(KBatch, KPad) = DeviceGemmSplitKXdl::GetKBatchAndKPad(M, N, K); a_grid_desc_kbatch_k0_m_k1_ = DeviceGemmSplitKXdl::MakeAGridDescriptor_KBatch_K0_M_K1( M, K, StrideA, KBatch, KPad); b_grid_desc_kbatch_k0_n_k1_ = DeviceGemmSplitKXdl::MakeBGridDescriptor_KBatch_K0_N_K1( K, N, StrideB, KBatch, KPad); c_grid_desc_m_n_ = DeviceGemmSplitKXdl::MakeCGridDescriptor_M_N(M, N, StrideC); if(GridwiseGemm::CheckValidity(a_grid_desc_kbatch_k0_m_k1_, b_grid_desc_kbatch_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) { c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = GridwiseGemm::MakeCM0N0M1N1M2M3M4N2GridDescriptor(c_grid_desc_m_n_); block_2_ctile_map_ = GridwiseGemm::MakeCBlockClusterAdaptor(c_grid_desc_m_n_, M01, N01, KBatch); } } // private: const ADataType* p_a_grid_; const BDataType* p_b_grid_; CDataType* p_c_grid_; AGridDesc_K0_M_K1 a_grid_desc_kbatch_k0_m_k1_; BGridDesc_K0_N_K1 b_grid_desc_kbatch_k0_n_k1_; CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; Block2CTileMap block_2_ctile_map_; index_t M01_; index_t N01_; }; // Invoker struct Invoker : public BaseInvoker { using Argument = DeviceGemmSplitKXdl::Argument; float Run(const Argument& arg, int nrepeat = 1) { const auto kbatch = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0); { std::cout << "arg.a_grid_desc_kbatch_k0_m_k1_{" << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I0) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1) << ", " << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I2) << ", " << arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.b_grid_desc_kbatch_k0_n_k1_{" << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I0) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I1) << ", " << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I2) << ", " << arg.b_grid_desc_kbatch_k0_n_k1_.GetLength(I3) << "}" << std::endl; std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; } if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, arg.M01_, arg.N01_)) { throw std::runtime_error( "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v2r3 has invalid setting"); } const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_, kbatch); const auto K0 = arg.a_grid_desc_kbatch_k0_m_k1_.GetLength(I1); const bool has_main_k0_block_loop = GridwiseGemm::CalculateHasMainK0BlockLoop(K0); float ave_time = 0; const auto Run = [&](const auto& kernel) { ave_time = launch_and_time_kernel(kernel, nrepeat, dim3(grid_size), dim3(BlockSize), 0, arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.block_2_ctile_map_); hipGetErrorString( hipMemset(arg.p_c_grid_, 0, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_.GetElementSpaceSize() * sizeof(CDataType))); launch_kernel(kernel, dim3(grid_size), dim3(BlockSize), 0, arg.p_a_grid_, arg.p_b_grid_, arg.p_c_grid_, arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_, arg.block_2_ctile_map_); }; if(has_main_k0_block_loop) { if(kbatch == 1) { const auto kernel = kernel_gemm_xdlops_v2r4< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t, remove_reference_t, true>; Run(kernel); } else { const auto kernel = kernel_gemm_xdlops_v2r4< GridwiseGemmAtomicAdd, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t, remove_reference_t, true>; Run(kernel); } } else { if(kbatch == 1) { const auto kernel = kernel_gemm_xdlops_v2r4< GridwiseGemm, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t, remove_reference_t, false>; Run(kernel); } else { const auto kernel = kernel_gemm_xdlops_v2r4< GridwiseGemmAtomicAdd, ADataType, // TODO: distiguish A/B datatype CDataType, remove_reference_t, remove_reference_t, remove_reference_t, remove_reference_t, false>; Run(kernel); } } return ave_time; } // polymorphic float Run(const BaseArgument* p_arg, int nrepeat = 1) override { return Run(*dynamic_cast(p_arg), nrepeat); } }; static constexpr bool IsValidCompilationParameter() { // TODO: properly implement this check return true; } static bool IsSupportedArgument(const Argument& arg) { return GridwiseGemm::CheckValidity(arg.a_grid_desc_kbatch_k0_m_k1_, arg.b_grid_desc_kbatch_k0_n_k1_, arg.c_grid_desc_m_n_, arg.M01_, arg.N01_); } // polymorphic bool IsSupportedArgument(const BaseArgument* p_arg) override { return IsSupportedArgument(*dynamic_cast(p_arg)); } static auto MakeArgument(const ADataType* p_a, const BDataType* p_b, CDataType* p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC) { return Argument{p_a, p_b, p_c, M, N, K, StrideA, StrideB, StrideC, 1, 1}; } static auto MakeInvoker() { return Invoker{}; } // polymorphic std::unique_ptr MakeArgumentPointer(const void* p_a, const void* p_b, void* p_c, index_t M, index_t N, index_t K, index_t StrideA, index_t StrideB, index_t StrideC) override { return std::make_unique(static_cast(p_a), static_cast(p_b), static_cast(p_c), M, N, K, StrideA, StrideB, StrideC, 1, 1); } // polymorphic std::unique_ptr MakeInvokerPointer() override { return std::make_unique(Invoker{}); } }; } // namespace device } // namespace tensor_operation } // namespace ck #endif