Commit 328ab6f3 authored by Jianfeng yan's avatar Jianfeng yan
Browse files

removed A/B/CGridDesc from DeviceOps that use gridwise_gemm_v2r3 and gridwise_gemm_cshuffle

parent e739c577
......@@ -222,9 +222,6 @@ struct DeviceBatchedGemmXdl
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -373,7 +370,7 @@ struct DeviceBatchedGemmXdl
CDataType,
remove_reference_t<DeviceBatchedGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -407,7 +404,7 @@ struct DeviceBatchedGemmXdl
CDataType,
remove_reference_t<DeviceBatchedGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceBatchedGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......
......@@ -403,6 +403,9 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
......@@ -492,7 +495,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01, N01));
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01, N01));
}
}
}
......@@ -504,7 +507,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
std::vector<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_;
......@@ -594,8 +597,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
......@@ -627,8 +629,7 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
......
......@@ -317,9 +317,6 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -351,6 +348,9 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
......@@ -416,7 +416,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
}
}
......@@ -427,7 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
......@@ -492,7 +492,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -523,7 +523,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......
......@@ -309,9 +309,6 @@ struct DeviceConv3dFwdXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho_Wo_
AccDataType,
OutDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......
......@@ -960,9 +960,6 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -994,6 +991,9 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
......@@ -1079,7 +1079,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
}
}
}
......@@ -1135,7 +1135,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
}
}
}
......@@ -1201,7 +1201,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2], M01_, N01_));
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
}
}
}
......@@ -1214,7 +1214,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
std::vector<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_;
std::vector<typename GridwiseGemm::DefaultBlock2CTileMap> block_2_ctile_map_container_;
index_t M01_;
......@@ -1308,8 +1308,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
......@@ -1341,8 +1340,7 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
OutElementwiseOperation,
WeiElementwiseOperation,
InElementwiseOperation,
......
......@@ -614,9 +614,6 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -648,6 +645,9 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
7, // CThreadTransferSrcDstVectorDim,
CThreadTransferDstScalarPerVector>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
......@@ -713,7 +713,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
}
}
......@@ -724,8 +724,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......@@ -789,7 +788,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......@@ -820,7 +819,7 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
CDataType,
remove_reference_t<DeviceOp::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceOp::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
remove_reference_t<CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
InElementwiseOperation,
WeiElementwiseOperation,
OutElementwiseOperation,
......
......@@ -187,9 +187,9 @@ struct DeviceGemmXdl
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
// AGridDesc_K0_M_K1,
// BGridDesc_K0_N_K1,
// CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -222,6 +222,11 @@ struct DeviceGemmXdl
CThreadTransferDstScalarPerVector,
NumPrefetch>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1));
// Argument
struct Argument : public BaseArgument
{
......@@ -264,7 +269,7 @@ struct DeviceGemmXdl
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_, M01, N01);
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
}
}
......@@ -275,9 +280,8 @@ struct DeviceGemmXdl
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
Block2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
AElementwiseOperation a_element_op_;
......@@ -331,7 +335,7 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -362,7 +366,7 @@ struct DeviceGemmXdl
CDataType,
remove_reference_t<DeviceGemmXdl::AGridDesc_K0_M_K1>,
remove_reference_t<DeviceGemmXdl::BGridDesc_K0_N_K1>,
remove_reference_t<typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2>,
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......
......@@ -342,9 +342,6 @@ struct DeviceGemm_Xdl_CShuffle
BElementwiseOperation,
CElementwiseOperation,
InMemoryDataOperationEnum::Set,
AGridDesc_AK0_M_AK1,
BGridDesc_BK0_N_BK1,
CGridDesc_M_N,
NumGemmKPrefetchStage,
BlockSize,
MPerBlock,
......@@ -377,6 +374,9 @@ struct DeviceGemm_Xdl_CShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
// Argument
struct Argument : public BaseArgument
{
......@@ -422,7 +422,7 @@ struct DeviceGemm_Xdl_CShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
AElementwiseOperation a_element_op_;
......@@ -470,18 +470,18 @@ struct DeviceGemm_Xdl_CShuffle
if(has_main_k0_block_loop)
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
true>;
if(nrepeat == 0)
{
......@@ -522,18 +522,18 @@ struct DeviceGemm_Xdl_CShuffle
}
else
{
const auto kernel = kernel_gemm_xdl_cshuffle_v1<
GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
const auto kernel =
kernel_gemm_xdl_cshuffle_v1<GridwiseGemm,
ADataType, // TODO: distiguish A/B datatype
CDataType,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename GridwiseGemm::DefaultBlock2CTileMap,
false>;
if(nrepeat == 0)
{
......
......@@ -419,6 +419,8 @@ struct DeviceGemmXdlSplitKCShuffle
CShuffleBlockTransferClusterLengths_MBlock_MPerBlock_NBlock_NPerBlock,
CShuffleBlockTransferScalarPerVector_NPerBlock>;
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = decltype(
GridwiseGemm::MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}));
using Block2CTileMap = decltype(BatchedGemmUtil::MakeBlock2CTileMap<MPerBlock, NPerBlock>(1, 1, 1));
struct Argument : public BaseArgument
......@@ -505,7 +507,7 @@ struct DeviceGemmXdlSplitKCShuffle
AGridDesc_AK0_M_AK1 a_grid_desc_ak0_m_ak1_;
BGridDesc_BK0_N_BK1 b_grid_desc_bk0_n_bk1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock
c_grid_desc_mblock_mperblock_nblock_nperblock_;
ComputePtrOffsetOfStridedBatch compute_ptr_offset_of_batch_;
Block2CTileMap block_2_ctile_map_;
......@@ -564,7 +566,7 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
true>;
......@@ -621,7 +623,7 @@ struct DeviceGemmXdlSplitKCShuffle
CElementwiseOperation,
DeviceOp::AGridDesc_AK0_M_AK1,
DeviceOp::BGridDesc_BK0_N_BK1,
typename GridwiseGemm::CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
ComputePtrOffsetOfStridedBatch,
Block2CTileMap,
false>;
......
......@@ -188,9 +188,6 @@ struct DeviceGroupedGemmXdl
AccDataType,
CDataType,
InMemoryDataOperationEnum::Set,
AGridDesc_K0_M_K1,
BGridDesc_K0_N_K1,
CGridDesc_M_N,
AElementwiseOperation,
BElementwiseOperation,
CElementwiseOperation,
......@@ -223,11 +220,14 @@ struct DeviceGroupedGemmXdl
CThreadTransferDstScalarPerVector,
NumPrefetch>;
using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
struct GroupedGemmBlock2CTileMap
{
GroupedGemmBlock2CTileMap()
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(CGridDesc_M_N{}, 1, 1);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1);
BlockStart_ = -1;
}
......@@ -236,7 +236,7 @@ struct DeviceGroupedGemmXdl
index_t N01,
ck::index_t BlockStart)
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n, M01, N01);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01, N01);
BlockStart_ = BlockStart;
}
......@@ -258,8 +258,7 @@ struct DeviceGroupedGemmXdl
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
typename GridwiseGemm::CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
GroupedGemmBlock2CTileMap grouped_gemm_block_2_ctile_map_;
......
......@@ -150,9 +150,6 @@ template <typename FloatAB,
typename BElementwiseOperation,
typename CElementwiseOperation,
InMemoryDataOperationEnum CGlobalMemoryDataOperation,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDesc_M_N,
index_t NumGemmKPrefetchStage,
index_t BlockSize,
index_t MPerBlock,
......@@ -261,6 +258,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// block_id to matrix tile idx (m0, n0) mapping are controlled by {M01, N01}
template <typename AGridDesc_AK0_M_AK1, typename BGridDesc_BK0_N_BK1, typename CGridDesc_M_N>
__host__ __device__ static constexpr bool
CheckValidity(const AGridDesc_AK0_M_AK1& a_grid_desc_ak0_m_ak1,
const BGridDesc_BK0_N_BK1& b_grid_desc_bk0_n_bk1,
......@@ -307,9 +305,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return true;
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr index_t
CalculateGridSize(const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -326,9 +327,12 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return has_main_k0_block_loop;
}
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
......@@ -347,11 +351,8 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
MakeDefaultBlock2CTileMap(index_t M, index_t N)
{
const auto M = c_grid_desc_m_n.GetLength(I0);
const auto N = c_grid_desc_m_n.GetLength(I1);
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -385,13 +386,24 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
return cblockid_to_m0_n0_block_cluster_adaptor;
}
using CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock = remove_cvref_t<decltype(
MakeCGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock(CGridDesc_M_N{}))>;
template <typename CGridDesc_M_N>
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(const CGridDesc_M_N& c_grid_desc_m_n)
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
return MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
}
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(CGridDesc_M_N{}))>;
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(1, 1))>;
template <bool HasMainK0BlockLoop, typename Block2CTileMap>
template <bool HasMainK0BlockLoop,
typename AGridDesc_AK0_M_AK1,
typename BGridDesc_BK0_N_BK1,
typename CGridDescriptor_MBlock_MPerBlock_NBlock_NPerBlock,
typename Block2CTileMap>
__device__ static void Run(const FloatAB* __restrict__ p_a_grid,
const FloatAB* __restrict__ p_b_grid,
FloatC* __restrict__ p_c_grid,
......
......@@ -525,8 +525,6 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdlops_v2r3
return cblockid_to_m0_n0_block_cluster_adaptor;
}
// using CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 =
// decltype(MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using DefaultBlock2CTileMap = decltype(MakeDefaultBlock2CTileMap(1, 1, 1, 1));
template <bool HasMainK0BlockLoop,
......
# device_gemm_instance
set(DEVICE_GEMM_INSTANCE_SOURCE
# device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
# device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
# device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_mk_nk_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_kn_mn_instance.cpp;
device_gemm_xdl_f32_f32_f32_km_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_mk_nk_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_kn_mn_instance.cpp;
device_gemm_xdl_f16_f16_f16_km_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_kn_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_mk_nk_mn_instance.cpp;
# device_gemm_xdl_c_shuffle_int8_int8_int8_km_kn_mn_instance.cpp;
......
This diff is collapsed.
......@@ -44,4 +44,4 @@ add_subdirectory(batched_gemm_reduce)
add_subdirectory(grouped_gemm)
add_subdirectory(convnd_fwd)
add_subdirectory(reduce)
add_subdirectory(conv2d_bwd_weight)
# add_subdirectory(conv2d_bwd_weight)
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment