Commit 8e3aef3b authored by Chao Liu's avatar Chao Liu
Browse files

format

parent 4a76bc07
...@@ -12,7 +12,7 @@ struct BatchedGemmUtil ...@@ -12,7 +12,7 @@ struct BatchedGemmUtil
{ {
template <index_t MPerBlock, index_t NPerBlock> template <index_t MPerBlock, index_t NPerBlock>
static constexpr auto static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01=1, index_t N01=1) MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01 = 1, index_t N01 = 1)
{ {
constexpr auto M1 = MPerBlock; constexpr auto M1 = MPerBlock;
constexpr auto N1 = NPerBlock; constexpr auto N1 = NPerBlock;
......
...@@ -495,7 +495,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -495,7 +495,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01, N01)); GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01, N01));
} }
} }
} }
......
...@@ -415,8 +415,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -415,8 +415,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01); c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
} }
} }
...@@ -427,8 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -427,8 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_; AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_; BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_; CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_; typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_; index_t M01_;
index_t N01_; index_t N01_;
......
...@@ -1078,8 +1078,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1078,8 +1078,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back( c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(GridwiseGemm::MakeDefaultBlock2CTileMap(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_)); descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
} }
} }
} }
...@@ -1135,7 +1135,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1135,7 +1135,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2])); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_)); GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
} }
} }
} }
...@@ -1201,7 +1202,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho ...@@ -1201,7 +1202,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
descs[I2])); descs[I2]));
block_2_ctile_map_container_.push_back( block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_)); GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
} }
} }
} }
......
...@@ -712,8 +712,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K ...@@ -712,8 +712,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01); c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
} }
} }
......
...@@ -226,7 +226,6 @@ struct DeviceGemmXdl ...@@ -226,7 +226,6 @@ struct DeviceGemmXdl
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{})); decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1)); using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1));
// Argument // Argument
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
...@@ -268,8 +267,8 @@ struct DeviceGemmXdl ...@@ -268,8 +267,8 @@ struct DeviceGemmXdl
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ = c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_); GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ = block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01); c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
} }
} }
......
...@@ -236,8 +236,9 @@ struct DeviceGroupedGemmXdl ...@@ -236,8 +236,9 @@ struct DeviceGroupedGemmXdl
index_t N01, index_t N01,
ck::index_t BlockStart) ck::index_t BlockStart)
{ {
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01, N01); block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
BlockStart_ = BlockStart; c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01, N01);
BlockStart_ = BlockStart;
} }
template <typename TopIdx> template <typename TopIdx>
......
...@@ -351,8 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -351,8 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
} }
// return block_id to C matrix tile idx (m0, n0) mapping // return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto __host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(index_t M, index_t N)
MakeDefaultBlock2CTileMap(index_t M, index_t N)
{ {
constexpr auto M1 = Number<MPerBlock>{}; constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{}; constexpr auto N1 = Number<NPerBlock>{};
...@@ -393,12 +392,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -393,12 +392,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{ {
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2); static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
return MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1)); return MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
} }
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(1, 1))>;
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(1, 1))>;
template <bool HasMainK0BlockLoop, template <bool HasMainK0BlockLoop,
typename AGridDesc_AK0_M_AK1, typename AGridDesc_AK0_M_AK1,
......
...@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances = ...@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances( void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{}); add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances = ...@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances( void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{}); add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances = ...@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances( void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{}); add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -49,7 +49,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances = ...@@ -49,7 +49,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances( void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances) std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{ {
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{}); add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{});
} }
} // namespace device_gemm_instance } // namespace device_gemm_instance
......
...@@ -69,10 +69,14 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic ...@@ -69,10 +69,14 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&); void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&); std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance } // namespace device_gemm_instance
} // namespace device } // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment