Commit 328ab6f3 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

removed A/B/CGridDesc from DeviceOps that use gridwise_gemm_v2r3 and gridwise_gemm_cshuffle

parent e739c577
...@@ -222,9 +222,6 @@ struct DeviceBatchedGemmXdl ...@@ -222,9 +222,6 @@ struct DeviceBatchedGemmXdl
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -373,7 +370,7 @@ struct DeviceBatchedGemmXdl ...@@ -373,7 +370,7 @@ struct DeviceBatchedGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceBatchedGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceBatchedGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceBatchedGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -407,7 +404,7 @@ struct DeviceBatchedGemmXdl ...@@ -407,7 +404,7 @@ struct DeviceBatchedGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceBatchedGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceBatchedGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceBatchedGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
......
...@@ -403,6 +403,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -403,6 +403,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -492,7 +495,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -492,7 +495,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01)); GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01, N01));
} }
} }
} }
...@@ -504,7 +507,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -504,7 +507,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_; std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_; std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_; std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2> std::vector<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_; std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_; index_t M01_;
...@@ -594,8 +597,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -594,8 +597,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation, OutElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
...@@ -627,8 +629,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -627,8 +629,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation, OutElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
......
...@@ -317,9 +317,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -317,9 +317,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -351,6 +348,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -351,6 +348,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -416,7 +416,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -416,7 +416,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
} }
} }
...@@ -427,7 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -427,7 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
...@@ -492,7 +492,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -492,7 +492,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -523,7 +523,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -523,7 +523,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
......
...@@ -309,9 +309,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_ ...@@ -309,9 +309,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
AccDataType, AccDataType,
OutDataType, OutDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
......
...@@ -960,9 +960,6 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -960,9 +960,6 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -994,6 +991,9 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -994,6 +991,9 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -1079,7 +1079,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1079,7 +1079,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_)); GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
} }
} }
} }
...@@ -1135,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1135,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_)); GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
} }
} }
} }
...@@ -1201,7 +1201,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1201,7 +1201,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
descs[I2])); descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_)); GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
} }
} }
} }
...@@ -1214,7 +1214,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1214,7 +1214,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_; std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_; std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_; std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2> std::vector<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_; c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_; std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_; index_t M01_;
...@@ -1308,8 +1308,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1308,8 +1308,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation, OutElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
...@@ -1341,8 +1340,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1341,8 +1340,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t< remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation, OutElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
InElementwiseOperation, InElementwiseOperation,
......
...@@ -614,9 +614,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -614,9 +614,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -648,6 +645,9 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -648,6 +645,9 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7, // CThreadTransferSrcDstVectorDim, 7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>; CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -713,7 +713,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -713,7 +713,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
} }
} }
...@@ -724,8 +724,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -724,8 +724,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
...@@ -789,7 +788,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -789,7 +788,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
...@@ -820,7 +819,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -820,7 +819,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType, CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>, remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>, remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation, InElementwiseOperation,
WeiElementwiseOperation, WeiElementwiseOperation,
OutElementwiseOperation, OutElementwiseOperation,
......
...@@ -187,9 +187,9 @@ struct DeviceGemmXdl ...@@ -187,9 +187,9 @@ struct DeviceGemmXdl
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1, // AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1, // BGridDesc_K0_N_K1,
CGridDesc_M_N, // CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -222,6 +222,11 @@ struct DeviceGemmXdl ...@@ -222,6 +222,11 @@ struct DeviceGemmXdl
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
NumPrefetch>; NumPrefetch>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -264,7 +269,7 @@ struct DeviceGemmXdl ...@@ -264,7 +269,7 @@ struct DeviceGemmXdl
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01); GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
} }
} }
...@@ -275,9 +280,8 @@ struct DeviceGemmXdl ...@@ -275,9 +280,8 @@ struct DeviceGemmXdl
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_; Block2CTileMap block_2_ctile_map_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -331,7 +335,7 @@ struct DeviceGemmXdl ...@@ -331,7 +335,7 @@ struct DeviceGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -362,7 +366,7 @@ struct DeviceGemmXdl ...@@ -362,7 +366,7 @@ struct DeviceGemmXdl
CDataType, CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>, remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>, remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>, CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
......
...@@ -342,9 +342,6 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -342,9 +342,6 @@ struct DeviceGemm_Xdl_CShuffle
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage, NumGemmKPrefetchStage,
BlockSize, BlockSize,
MPerBlock, MPerBlock,
...@@ -377,6 +374,9 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -377,6 +374,9 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -422,7 +422,7 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -422,7 +422,7 @@ struct DeviceGemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_; AElementwiseOperation a_element_op_;
...@@ -470,18 +470,18 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -470,18 +470,18 @@ struct DeviceGemm_Xdl_CShuffle
if(has_main_k0_block_loop) if(has_main_k0_block_loop)
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
true>; true>;
if(nrepeat == 0) if(nrepeat == 0)
{ {
...@@ -522,18 +522,18 @@ struct DeviceGemm_Xdl_CShuffle ...@@ -522,18 +522,18 @@ struct DeviceGemm_Xdl_CShuffle
} }
else else
{ {
const auto kernel = kernel_gemm_xdl_cshuffle_v1< const auto kernel =
GridwiseGemm, kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype ADataType, // TODO: distiguish A/B datatype
CDataType, CDataType,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap, typename GridwiseGemm::DefaultBlock2CTileMap,
false>; false>;
if(nrepeat == 0) if(nrepeat == 0)
{ {
......
...@@ -419,6 +419,8 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -419,6 +419,8 @@ struct DeviceGemmXdlSplitKCShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock, CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>; CShuffleBlockTransferScalarPerVector_NPerBlock>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1)); using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument struct Argument : public BaseArgument
...@@ -505,7 +507,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -505,7 +507,7 @@ struct DeviceGemmXdlSplitKCShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_; AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_; BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_; c_grid_desc_mblock_mperblock_nblock_nperblock_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_; ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_; Block2CTileMap block_2_ctile_map_;
...@@ -564,7 +566,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -564,7 +566,7 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2CTileMap, Block2CTileMap,
true>; true>;
...@@ -621,7 +623,7 @@ struct DeviceGemmXdlSplitKCShuffle ...@@ -621,7 +623,7 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation, CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1, DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1, DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock, CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch, ComputePtrOffsetOfStridedBatch,
Block2CTileMap, Block2CTileMap,
false>; false>;
......
...@@ -188,9 +188,6 @@ struct DeviceGroupedGemmXdl ...@@ -188,9 +188,6 @@ struct DeviceGroupedGemmXdl
AccDataType, AccDataType,
CDataType, CDataType,
InMemoryDataOperationEnum::Set, InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation, AElementwiseOperation,
BElementwiseOperation, BElementwiseOperation,
CElementwiseOperation, CElementwiseOperation,
...@@ -223,11 +220,14 @@ struct DeviceGroupedGemmXdl ...@@ -223,11 +220,14 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector, CThreadTransferDstScalarPerVector,
NumPrefetch>; NumPrefetch>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
struct GroupedGemmBlock2CTileMap struct GroupedGemmBlock2CTileMap
{ {
GroupedGemmBlock2CTileMap() GroupedGemmBlock2CTileMap()
{ {
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1);
BlockStart_ = -1; BlockStart_ = -1;
} }
...@@ -236,7 +236,7 @@ struct DeviceGroupedGemmXdl ...@@ -236,7 +236,7 @@ struct DeviceGroupedGemmXdl
index_t N01, index_t N01,
ck::index_t BlockStart) ck::index_t BlockStart)
{ {
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01, N01);
BlockStart_ = BlockStart; BlockStart_ = BlockStart;
} }
...@@ -258,8 +258,7 @@ struct DeviceGroupedGemmXdl ...@@ -258,8 +258,7 @@ struct DeviceGroupedGemmXdl
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_; GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_;
......
...@@ -150,9 +150,6 @@ template <typename FloatAB, ...@@ -150,9 +150,6 @@ template <typename FloatAB,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation, typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation, InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage, index_t NumGemmKPrefetchStage,
index_t BlockSize, index_t BlockSize,
index_t MPerBlock, index_t MPerBlock,
...@@ -261,6 +258,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -261,6 +258,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01} // block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename CGridDesc_M_N>
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1, CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1, const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
...@@ -307,9 +305,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -307,9 +305,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true; return true;
} }
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr index_t __host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n) CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -326,9 +327,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -326,9 +327,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return has_main_k0_block_loop; return has_main_k0_block_loop;
} }
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n) MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{ {
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
const auto M = c_grid_desc_m_n.GetLength(I0); const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1); const auto N = c_grid_desc_m_n.GetLength(I1);
...@@ -347,11 +351,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -347,11 +351,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n) MakeDefaultBlock2CTileMap(index_t M, index_t N)
{ {
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -385,13 +386,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -385,13 +386,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return cblockid_to_m0_n0_block_cluster_adaptor; return cblockid_to_m0_n0_block_cluster_adaptor;
} }
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype( template <typename CGridDesc_M_N>
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>; __host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
return MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
using DefaultBlock2CTileMap = using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>; remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(1, 1))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap> template <bool HasMainK0BlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid, __device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid, const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid, FloatC* __restrict__ p_c_grid,
......
...@@ -525,8 +525,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3 ...@@ -525,8 +525,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return cblockid_to_m0_n0_block_cluster_adaptor; return cblockid_to_m0_n0_block_cluster_adaptor;
} }
// using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
// decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(1, 1, 1, 1)); using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(1, 1, 1, 1));
template <bool HasMainK0BlockLoop, template <bool HasMainK0BlockLoop,
......
# device_gemm_instance # device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE set(DEVICE_GEMM_INSTANCE_SOURCE
# device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp; device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp; device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp; # device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp;
......
This diff is collapsed.
...@@ -44,4 +44,4 @@ add_subdirectory(batched_gemm_reduce) ...@@ -44,4 +44,4 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm) add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd) add_subdirectory(convnd_fwd)
add_subdirectory(reduce) add_subdirectory(reduce)
add_subdirectory(conv2d_bwd_weight) # add_subdirectory(conv2d_bwd_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment