Commit 2faeaece authored by j4yan's avatar j4yan
Browse files

navi_gemm_km_kn_mn_fp32 compiles and passes one test.

parent fd7eee0d
...@@ -3,7 +3,7 @@ ...@@ -3,7 +3,7 @@
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_adaptor.hpp" #include "tensor_adaptor.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer_v4r1.hpp"
#include "threadwise_contraction_dlops.hpp" #include "threadwise_contraction_dlops.hpp"
namespace ck { namespace ck {
......
...@@ -33,7 +33,8 @@ template < ...@@ -33,7 +33,8 @@ template <
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
index_t NPerBlock, index_t NPerBlock,
index_t KPerBlock, index_t K0PerBlock,
index_t K1,
index_t M1PerThread, index_t M1PerThread,
index_t N1PerThread, index_t N1PerThread,
index_t KPerThread, index_t KPerThread,
...@@ -56,17 +57,13 @@ template < ...@@ -56,17 +57,13 @@ template <
typename CThreadTransferSrcDstAccessOrder, typename CThreadTransferSrcDstAccessOrder,
index_t CThreadTransferSrcDstVectorDim, index_t CThreadTransferSrcDstVectorDim,
index_t CThreadTransferDstScalarPerVector, index_t CThreadTransferDstScalarPerVector,
typename AGridStepHacks,
typename BGridStepHacks,
typename CGridStepHacks,
typename AGridMoveSliceWindowStepHacks,
typename BGridMoveSliceWindowStepHacks,
enable_if_t< enable_if_t<
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> && is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough> &&
is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>, is_same_v<AElementwiseOperation, ck::tensor_operation::element_wise::PassThrough>,
bool> = false> bool> = false>
struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation> struct DeviceGemmDlops
: public DeviceGemm<AElementwiseOperation, BElementwiseOperation, CElementwiseOperation>
{ {
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -201,12 +198,12 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -201,12 +198,12 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AK0MK1GridDesc, AGridDesc_K0_M_K1,
BK0NK1GridDesc, BGridDesc_K0_N_K1,
CMNGridDesc, CGridDesc_M_N,
MPerBlock, MPerBlock,
NPerBlock, NPerBlock,
KPerBlock, K0PerBlock,
M1PerThread, M1PerThread,
N1PerThread, N1PerThread,
KPerThread, KPerThread,
...@@ -228,18 +225,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -228,18 +225,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1, BBlockTransferDstVectorTensorLengths_K0_N0_N1_K1,
CThreadTransferSrcDstAccessOrder, CThreadTransferSrcDstAccessOrder,
CThreadTransferSrcDstVectorDim, CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector>;
AGridStepHacks,
BGridStepHacks, using AGridDesc_K0_M0_M1_K1 =
CGridStepHacks,
AGridMoveSliceWindowStepHacks,
BGridMoveSliceWindowStepHacks>;
using AK0M0M1K1GridDesc =
decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{})); decltype(GridwiseGemm::MakeAK0M0M1K1GridDescriptor(AGridDesc_K0_M_K1{}));
using BK0N0N1K1GridDesc = decltype(GridwiseGemm::MakeBKN0N1GridDescriptor(BGridDesc_K0_N_K1{})); using BGridDesc_K0_N0_N1_K1 =
using CM0M10M11N0N10N11GridDesc = decltype(GridwiseGemm::MakeBK0N0N1K1GridDescriptor(BGridDesc_K0_N_K1{}));
using CGridDesc_M0_M10_M11_N0_N10_N11 =
decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(CGridDesc_M_N{}));
using DefaultBlock2CTileMap =
decltype(GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -261,10 +256,9 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -261,10 +256,9 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
: p_a_grid_{p_a_grid}, : p_a_grid_{p_a_grid},
p_b_grid_{p_b_grid}, p_b_grid_{p_b_grid},
p_c_grid_{p_c_grid}, p_c_grid_{p_c_grid},
a_grid_desc_k0_m_k1_{}, a_grid_desc_k0_m0_m1_k1_{},
b_grid_desc_k0_n_k1_{}, b_grid_desc_k0_n0_n1_k1_{},
c_grid_desc_m_n_{}, c_grid_desc_m0_m10_m11_n0_n10_n11_{},
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_{},
block_2_ctile_map_{}, block_2_ctile_map_{},
M01_{M01}, M01_{M01},
N01_{N01} N01_{N01}
...@@ -272,15 +266,19 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -272,15 +266,19 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
// b_element_op_{b_element_op}, // b_element_op_{b_element_op},
// c_element_op_{c_element_op} // c_element_op_{c_element_op}
{ {
a_grid_desc_k0_m_k1_ = DeviceGemmXdl::MakeAGridDescriptor_K0_M_K1(M, K, StrideA); a_grid_desc_k0_m_k1_ = DeviceGemmDlops::MakeAGridDescriptor_K0_M_K1(M, K, StrideA);
b_grid_desc_k0_n_k1_ = DeviceGemmXdl::MakeBGridDescriptor_K0_N_K1(K, N, StrideB); b_grid_desc_k0_n_k1_ = DeviceGemmDlops::MakeBGridDescriptor_K0_N_K1(K, N, StrideB);
c_grid_desc_m_n_ = DeviceGemmXdl::MakeCGridDescriptor_M_N(M, N, StrideC); c_grid_desc_m_n_ = DeviceGemmDlops::MakeCGridDescriptor_M_N(M, N, StrideC);
if(GridwiseGemm::CheckValidity( if(GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_, M01_, N01_)) a_grid_desc_k0_m_k1_, b_grid_desc_k0_n_k1_, c_grid_desc_m_n_))
{ {
c_m0_m10_m11_n0_n10_n11_grid_desc = a_grid_desc_k0_m0_m1_k1_ =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_m_n_grid_desc); GridwiseGemm::MakeAK0M0M1K1GridDescriptor(a_grid_desc_k0_m_k1_);
b_grid_desc_k0_n0_n1_k1_ =
GridwiseGemm::MakeBK0N0N1K1GridDescriptor(b_grid_desc_k0_n_k1_);
c_grid_desc_m0_m10_m11_n0_n10_n11_ =
GridwiseGemm::MakeCM0M10M11N0N10N11GridDescriptor(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_); GridwiseGemm::MakeCBlockIdToM0N0BlockClusterAdaptor(c_grid_desc_m_n_);
...@@ -292,11 +290,15 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -292,11 +290,15 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
AK0M0M1K1GridDesc a_k0_m0_m1_k1_grid_desc; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BK0N0N1K1GridDesc b_k0_n0_n1_k1_grid_desc; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CM0M10M11N0N10N11GridDesc c_m0_m10_m11_n0_n10_n11_grid_desc; CGridDesc_M_N c_grid_desc_m_n_;
AGridDesc_K0_M0_M1_K1 a_grid_desc_k0_m0_m1_k1_;
BGridDesc_K0_N0_N1_K1 b_grid_desc_k0_n0_n1_k1_;
CGridDesc_M0_M10_M11_N0_N10_N11 c_grid_desc_m0_m10_m11_n0_n10_n11_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -309,36 +311,35 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -309,36 +311,35 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
// Invoker // Invoker
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
{ {
using Argument = DeviceGemmXdl::Argument; using Argument = DeviceGemmDlops::Argument;
float Run(const Argument& arg, int nrepeat = 1) float Run(const Argument& arg, int nrepeat = 1)
{ {
{ {
std::cout << "arg.a_grid_desc_k0_m_k1_{" << arg.a_grid_desc_k0_m_k1_.GetLength(I0) std::cout << "arg.a_grid_desc_k0_m0_m1_k1_{"
<< ", " << arg.a_grid_desc_k0_m_k1_.GetLength(I1) << ", " << arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_.GetLength(I2) << "}" << std::endl; << arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_{" << arg.b_grid_desc_k0_n_k1_.GetLength(I0) std::cout << "arg.b_grid_desc_k0_n0_n1_k1_{"
<< ", " << arg.b_grid_desc_k0_n_k1_.GetLength(I1) << ", " << arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_.GetLength(I2) << "}" << std::endl; << arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n0_n1_k1_.GetLength(I2) << "}" << std::endl;
std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", " std::cout << "arg.c_grid_desc_m_n_{ " << arg.c_grid_desc_m_n_.GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl; << arg.c_grid_desc_m_n_.GetLength(I1) << "}" << std::endl;
} }
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, if(!GridwiseGemm::CheckValidity(
arg.b_grid_desc_k0_n_k1_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_))
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting"); "wrong! GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 has invalid setting");
} }
const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_); const index_t grid_size = GridwiseGemm::CalculateGridSize(arg.c_grid_desc_m_n_.GetLength(I0), arg.c_grid_desc_m_n_.GetLength(I1));
const auto K0 = arg.a_grid_desc_k0_m_k1_.GetLength(I0); const auto K0 = arg.a_grid_desc_k0_m0_m1_k1_.GetLength(I0);
const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0); const bool has_main_k_block_loop = GridwiseGemm::CalculateHasMainKBlockLoop(K0);
const bool has_double_tail_k_block_loop = const bool has_double_tail_k_block_loop =
GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0); GridwiseGemm::CalculateHasDoubleTailKBlockLoop(K0);
...@@ -351,10 +352,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -351,10 +352,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType, ADataType,
CDataType, CDataType,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
true>; true>;
...@@ -369,7 +370,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -369,7 +370,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_, arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_, arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_); arg.block_2_ctile_map_);
} }
else if(has_main_k_block_loop && !has_double_tail_k_block_loop) else if(has_main_k_block_loop && !has_double_tail_k_block_loop)
{ {
...@@ -377,10 +378,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -377,10 +378,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType, ADataType,
CDataType, CDataType,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<DefaultBlock2CTileMap>,
true, true,
false>; false>;
...@@ -395,7 +396,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -395,7 +396,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_, arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_, arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_); arg.block_2_ctile_map_);
} }
else if(!has_main_k_block_loop && has_double_tail_k_block_loop) else if(!has_main_k_block_loop && has_double_tail_k_block_loop)
{ {
...@@ -403,10 +404,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -403,10 +404,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType, ADataType,
CDataType, CDataType,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
true>; true>;
...@@ -421,7 +422,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -421,7 +422,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_, arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_, arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_); arg.block_2_ctile_map_);
} }
else else
{ {
...@@ -429,10 +430,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -429,10 +430,10 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
kernel_gemm_dlops_v1r3<GridwiseGemm, kernel_gemm_dlops_v1r3<GridwiseGemm,
ADataType, ADataType,
CDataType, CDataType,
remove_reference_t<AK0M0M1K1GridDesc>, remove_reference_t<AGridDesc_K0_M0_M1_K1>,
remove_reference_t<BK0N0N1K1GridDesc>, remove_reference_t<BGridDesc_K0_N0_N1_K1>,
remove_reference_t<CM0M10M11N0N10N11GridDesc>, remove_reference_t<CGridDesc_M0_M10_M11_N0_N10_N11>,
remove_reference_t<CBlockIdToM0N0BlockClusterAdaptor>, remove_reference_t<DefaultBlock2CTileMap>,
false, false,
false>; false>;
...@@ -447,7 +448,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -447,7 +448,7 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
arg.a_grid_desc_k0_m0_m1_k1_, arg.a_grid_desc_k0_m0_m1_k1_,
arg.b_grid_desc_k0_n0_n1_k1_, arg.b_grid_desc_k0_n0_n1_k1_,
arg.c_grid_desc_m0_m10_m11_n0_n10_n11_, arg.c_grid_desc_m0_m10_m11_n0_n10_n11_,
arg.cblockid_to_m0_n0_block_cluster_adaptor_); arg.block_2_ctile_map_);
} }
return ave_time; return ave_time;
...@@ -468,11 +469,8 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -468,11 +469,8 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
return GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_, return GridwiseGemm::CheckValidity(
arg.b_grid_desc_k0_n_k1_, arg.a_grid_desc_k0_m_k1_, arg.b_grid_desc_k0_n_k1_, arg.c_grid_desc_m_n_);
arg.c_grid_desc_m_n_,
arg.M01_,
arg.N01_);
} }
// polymorphic // polymorphic
...@@ -555,17 +553,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp ...@@ -555,17 +553,16 @@ struct DeviceGemmDlops : public DeviceGemm<AElementwiseOperation, BElementwiseOp
auto str = std::stringstream(); auto str = std::stringstream();
// clang-format off // clang-format off
str << "DeviceGemmXdl" str << "DeviceGemmDlops"
<< "<" << "<"
<< BlockSize << ", " << BlockSize << ", "
<< MPerBlock << ", " << MPerBlock << ", "
<< NPerBlock << ", " << NPerBlock << ", "
<< K0PerBlock << ", " << K0PerBlock << ", "
<< K1 << ", " << K1 << ", "
<< MPerXDL << ", " << M1PerThread << ", "
<< NPerXDL << ", " << N1PerThread << ", "
<< MXdlPerWave << ", " << KPerThread
<< NXdlPerWave
<< ">"; << ">";
// clang-format on // clang-format on
......
...@@ -7,8 +7,9 @@ ...@@ -7,8 +7,9 @@
#include "tensor_descriptor_helper.hpp" #include "tensor_descriptor_helper.hpp"
#include "blockwise_gemm_dlops_v2r3.hpp" #include "blockwise_gemm_dlops_v2r3.hpp"
#include "blockwise_tensor_slice_transfer_v5r1.hpp" #include "blockwise_tensor_slice_transfer_v5r1.hpp"
#include "threadwise_tensor_slice_transfer_v2.hpp" #include "threadwise_tensor_slice_transfer.hpp"
#include "threadwise_tensor_slice_set.hpp" #include "threadwise_tensor_slice_set.hpp"
#include "element_wise_operation.hpp"
namespace ck { namespace ck {
...@@ -327,7 +328,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -327,7 +328,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
ABlockTransferThreadClusterArrangeOrder, ABlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(a_k0_m0_m1_k1_grid_desc), remove_reference_t<decltype(a_k0_m0_m1_k1_grid_desc)>,
decltype(a_k0_m0_m1_k1_block_desc), decltype(a_k0_m0_m1_k1_block_desc),
ABlockTransferSrcAccessOrder, ABlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
...@@ -351,7 +352,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -351,7 +352,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
BBlockTransferThreadClusterArrangeOrder, BBlockTransferThreadClusterArrangeOrder,
FloatAB, FloatAB,
FloatAB, FloatAB,
decltype(b_k0_n0_n1_k1_grid_desc), remove_reference_t<decltype(b_k0_n0_n1_k1_grid_desc)>,
decltype(b_k0_n0_n1_k1_block_desc), decltype(b_k0_n0_n1_k1_block_desc),
BBlockTransferSrcAccessOrder, BBlockTransferSrcAccessOrder,
Sequence<0, 1, 2, 3>, Sequence<0, 1, 2, 3>,
...@@ -498,10 +499,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -498,10 +499,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
// LDS double buffer: tail // LDS double buffer: tail
if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left if constexpr(HasDoubleTailKBlockLoop) // if has 2 iteration left
{ {
a_blockwise_copy.MoveSrcSliceWindow( a_blockwise_copy.MoveSrcSliceWindow(a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step);
a_k0_m0_m1_k1_grid_desc, a_block_slice_copy_step); b_blockwise_copy.MoveSrcSliceWindow(b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
b_blockwise_copy.MoveSrcSliceWindow(
b_k0_n0_n1_k1_grid_desc, b_block_slice_copy_step);
__syncthreads(); __syncthreads();
...@@ -552,6 +551,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -552,6 +551,7 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
FloatC, FloatC,
decltype(c_m0_m10_m11_n0_n10_n11_thread_desc), decltype(c_m0_m10_m11_n0_n10_n11_thread_desc),
decltype(c_m0_m10_m11_n0_n10_n11_grid_desc), decltype(c_m0_m10_m11_n0_n10_n11_grid_desc),
ck::tensor_operation::element_wise::PassThrough,
Sequence<1, Sequence<1,
c_m10_m11_n10_n11_thread_tensor_lengths[I0], c_m10_m11_n10_n11_thread_tensor_lengths[I0],
c_m10_m11_n10_n11_thread_tensor_lengths[I1], c_m10_m11_n10_n11_thread_tensor_lengths[I1],
...@@ -569,7 +569,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3 ...@@ -569,7 +569,8 @@ struct GridwiseGemmDlops_km_kn_mn_v1r3
c_m10_m11_n10_n11_thread_origin_idx_on_block[I1], c_m10_m11_n10_n11_thread_origin_idx_on_block[I1],
in0, in0,
c_m10_m11_n10_n11_thread_origin_idx_on_block[I2], c_m10_m11_n10_n11_thread_origin_idx_on_block[I2],
c_m10_m11_n10_n11_thread_origin_idx_on_block[I3])} c_m10_m11_n10_n11_thread_origin_idx_on_block[I3]),
ck::tensor_operation::element_wise::PassThrough{}}
.Run(c_m0_m10_m11_n0_n10_n11_thread_desc, .Run(c_m0_m10_m11_n0_n10_n11_thread_desc,
make_tuple(I0, I0, I0, I0, I0, I0), make_tuple(I0, I0, I0, I0, I0, I0),
c_thread_buf, c_thread_buf,
......
#ifndef CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP #pragma once
#define CK_THREADWISE_TENSOR_SLICE_TRANSFER_V5R1_HPP
#include "common_header.hpp" #include "common_header.hpp"
#include "tensor_descriptor.hpp" #include "tensor_descriptor.hpp"
...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1 ...@@ -609,4 +608,3 @@ struct ThreadwiseTensorSliceTransfer_v5r1
}; };
} // namespace ck } // namespace ck
#endif
...@@ -45,3 +45,12 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE ...@@ -45,3 +45,12 @@ set_target_properties(device_gemm_instance PROPERTIES POSITION_INDEPENDENT_CODE
install(TARGETS device_gemm_instance LIBRARY DESTINATION lib) install(TARGETS device_gemm_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_instance) clang_tidy_check(device_gemm_instance)
add_library(device_gemm_dlops_instance SHARED device_gemm_dlops_f32_f32_f32_km_kn_mn_instance.cpp)
target_compile_features(device_gemm_dlops_instance PUBLIC)
set_target_properties(device_gemm_dlops_instance PROPERTIES POSITION_INDEPENDENT_CODE ON)
install(TARGETS device_gemm_dlops_instance LIBRARY DESTINATION lib)
clang_tidy_check(device_gemm_dlops_instance)
...@@ -23,16 +23,52 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough; ...@@ -23,16 +23,52 @@ using PassThrough = ck::tensor_operation::element_wise::PassThrough;
static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default; static constexpr auto GemmDefault = ck::tensor_operation::device::GemmSpecialization::Default;
// Compilation parameters for a[k, m] * b[k, n] = c[m, n] // Compilation parameters for a[k, m] * b[k, n] = c[m, n]
using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = using device_gemm_dlops_f32_f32_f32_km_kn_mn_instances = std::tuple<
std::tuple<
// clang-format off // clang-format off
// ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer| // ##########| AData| BData| CData| AccData| ALayout| BLayout| CLayout| A| B| C| GEMM| Block| MPer| NPer| K0Per| K1| M1Per| N1Per| KPer| M11N11Thread| M11N11Thread| ABlockTransfer| ABlockTransfer| ABlockTransfer| ABlockTransfer|
// ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess| // ##########| Type| Type| Type| Type| | | | Elementwise| Elementwise| Elementwise|Spacialization| Size| Block| Block| Block| | ThreadM111| ThreadN111| Thread| ClusterM110Xs| ClusterN110Xs| ThreadSliceLengths| ThreadClusterLengths| ThreadCluster| SrcAccess|
// ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order| // ##########| | | | | | | | Operation| Operation| Operation| | | | | | | | | | | | K0_M0_M1_K1| K0_M0_M1_K1| ArrangeOrder| Order|
// ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | // ##########| | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | | |
DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4> // DeviceGemmDlops< F32, F32, F32, F32, Col, Row, Row, PassThrough, PassThrough, PassThrough, GemmDefault, 256, 128, 128, 8, 2, 4, 4, 1, S<8, 2>, S<8, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0 ,3>, S<4, 1, 1, 2>, S<1, 2, 0, 3>, S<1, 1, 1, 2>, S<4, 1, 1, 2>, S<2, 1, 128, 1>, S<1, 2, 0, 3>, S<1, 2, 0, 3>, S<4, 1, 1, 2>, S<1, 2, 0 3>, S<1, 1, 1, 2>, S<0, 1, 2, 3, 4, 5>, 5, 4>
// clang-format on // clang-format on
>; DeviceGemmDlops<F32,
F32,
F32,
F32,
Col,
Row,
Row,
PassThrough,
PassThrough,
PassThrough,
GemmDefault,
256,
128,
128,
8,
2,
4,
4,
1,
S<8, 2>,
S<8, 2>,
S<4, 1, 1, 2>,
S<2, 1, 128, 1>,
S<1, 2, 0, 3>,
S<1, 2, 0, 3>,
S<4, 1, 1, 2>,
S<1, 2, 0, 3>,
S<1, 1, 1, 2>,
S<4, 1, 1, 2>,
S<2, 1, 128, 1>,
S<1, 2, 0, 3>,
S<1, 2, 0, 3>,
S<4, 1, 1, 2>,
S<1, 2, 0, 3>,
S<1, 1, 1, 2>,
S<0, 1, 2, 3, 4, 5>,
5,
4>>;
void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances( void add_device_gemm_dlops_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
......
...@@ -37,6 +37,7 @@ add_subdirectory(space_filling_curve) ...@@ -37,6 +37,7 @@ add_subdirectory(space_filling_curve)
add_subdirectory(conv_util) add_subdirectory(conv_util)
add_subdirectory(reference_conv_fwd) add_subdirectory(reference_conv_fwd)
add_subdirectory(gemm) add_subdirectory(gemm)
add_subdirectory(gemm_dlops)
add_subdirectory(gemm_split_k) add_subdirectory(gemm_split_k)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
......
add_test_executable(test_gemm_dlops_fp32 gemm_fp32.cpp) add_test_executable(test_gemm_dlops_fp32 gemm_dlops_fp32.cpp)
target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor) target_link_libraries(test_gemm_dlops_fp32 PRIVATE host_tensor)
target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance) target_link_libraries(test_gemm_dlops_fp32 PRIVATE device_gemm_dlops_instance)
......
...@@ -6,7 +6,7 @@ ...@@ -6,7 +6,7 @@
#include <tuple> #include <tuple>
#include <vector> #include <vector>
#include "gemm_util.hpp" #include "../gemm/gemm_util.hpp"
#include "config.hpp" #include "config.hpp"
#include "print.hpp" #include "print.hpp"
#include "device.hpp" #include "device.hpp"
...@@ -15,7 +15,6 @@ ...@@ -15,7 +15,6 @@
#include "host_gemm.hpp" #include "host_gemm.hpp"
#include "device_tensor.hpp" #include "device_tensor.hpp"
#include "device_gemm_xdl.hpp" #include "device_gemm_xdl.hpp"
#include "device_gemm_dlops_c_shuffle.hpp"
#include "element_wise_operation.hpp" #include "element_wise_operation.hpp"
#include "reference_gemm.hpp" #include "reference_gemm.hpp"
#include "gemm_specialization.hpp" #include "gemm_specialization.hpp"
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment