Commit 8e3aef3b authored by Chao Liu's avatar Chao Liu
Browse files

format

parent 4a76bc07
......@@ -12,7 +12,7 @@ struct BatchedGemmUtil
{
template <index_t MPerBlock, index_t NPerBlock>
static constexpr auto
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01=1, index_t N01=1)
MakeBlock2CTileMap(index_t batch_count, index_t M, index_t N, index_t M01 = 1, index_t N01 = 1)
{
constexpr auto M1 = MPerBlock;
constexpr auto N1 = NPerBlock;
......
......@@ -495,7 +495,8 @@ struct DeviceConv2dBwdDataXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01, N01));
GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01, N01));
}
}
}
......
......@@ -415,8 +415,8 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
}
}
......@@ -427,8 +427,7 @@ struct DeviceConv2dFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
AGridDesc_K0_M_K1 a_grid_desc_k0_m_k1_;
BGridDesc_K0_N_K1 b_grid_desc_k0_n_k1_;
CGridDesc_M_N c_grid_desc_m_n_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
CGridDesc_M0_N0_M1_N1_M2_M3_M4_N2 c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_;
typename GridwiseGemm::DefaultBlock2CTileMap block_2_ctile_map_;
index_t M01_;
index_t N01_;
......
......@@ -1078,8 +1078,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_.push_back(
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
block_2_ctile_map_container_.push_back(GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
}
}
}
......@@ -1135,7 +1135,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
}
}
}
......@@ -1201,7 +1202,8 @@ struct DeviceConvndBwdDataXdl_Input_N_Di_Hi_Wi_C_Weight_K_Z_Y_X_C_Output_N_Do_Ho
descs[I2]));
block_2_ctile_map_container_.push_back(
GridwiseGemm::MakeDefaultBlock2CTileMap(descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
GridwiseGemm::MakeDefaultBlock2CTileMap(
descs[I2].GetLength(I0), descs[I2].GetLength(I1), M01_, N01_));
}
}
}
......
......@@ -712,8 +712,8 @@ struct DeviceConvNDFwdXdl_Input_N_Hi_Wi_C_Weight_K_Y_X_C_Output_N_Ho_Wo_K
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
}
}
......
......@@ -226,7 +226,6 @@ struct DeviceGemmXdl
decltype(GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(CGridDesc_M_N{}));
using Block2CTileMap = decltype(GridwiseGemm::MakeDefaultBlock2CTileMap(1, 1, 1, 1));
// Argument
struct Argument : public BaseArgument
{
......@@ -268,8 +267,8 @@ struct DeviceGemmXdl
c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_ =
GridwiseGemm::MakeCGridDescriptor_M0_N0_M1_N1_M2_M3_M4_N2(c_grid_desc_m_n_);
block_2_ctile_map_ =
GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n_.GetLength(I0), c_grid_desc_m_n_.GetLength(I1), M01, N01);
}
}
......
......@@ -236,8 +236,9 @@ struct DeviceGroupedGemmXdl
index_t N01,
ck::index_t BlockStart)
{
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01, N01);
BlockStart_ = BlockStart;
block_2_ctile_map_ = GridwiseGemm::MakeDefaultBlock2CTileMap(
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1), M01, N01);
BlockStart_ = BlockStart;
}
template <typename TopIdx>
......
......@@ -351,8 +351,7 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
}
// return block_id to C matrix tile idx (m0, n0) mapping
__host__ __device__ static constexpr auto
MakeDefaultBlock2CTileMap(index_t M, index_t N)
__host__ __device__ static constexpr auto MakeDefaultBlock2CTileMap(index_t M, index_t N)
{
constexpr auto M1 = Number<MPerBlock>{};
constexpr auto N1 = Number<NPerBlock>{};
......@@ -393,12 +392,11 @@ struct GridwiseGemm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
{
static_assert(CGridDesc_M_N::GetNumOfVisibleDimension() == 2);
return MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
return MakeDefaultBlock2CTileMap(c_grid_desc_m_n.GetLength(I0),
c_grid_desc_m_n.GetLength(I1));
}
using DefaultBlock2CTileMap =
remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(1, 1))>;
using DefaultBlock2CTileMap = remove_cvref_t<decltype(MakeDefaultBlock2CTileMap(1, 1))>;
template <bool HasMainK0BlockLoop,
typename AGridDesc_AK0_M_AK1,
......
......@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances{});
}
} // namespace device_gemm_instance
......
......@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances{});
}
} // namespace device_gemm_instance
......
......@@ -44,7 +44,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_mk_kn_mn_instances{});
}
} // namespace device_gemm_instance
......
......@@ -49,7 +49,8 @@ using device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances =
void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(
std::vector<DeviceGemmPtr<PassThrough, PassThrough, PassThrough>>& instances)
{
add_device_operation_instances(instances, device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{});
add_device_operation_instances(instances,
device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances{});
}
} // namespace device_gemm_instance
......
......@@ -69,10 +69,14 @@ void add_device_gemm_xdl_splitk_f32_f32_f32_mk_nk_mn_instances(std::vector<Devic
void add_device_gemm_xdl_splitk_f32_f32_f32_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_f32_f32_f32_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_mk_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_kn_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
void add_device_gemm_xdl_splitk_c_shuffle_f16_f16_f16_km_nk_mn_instances(
std::vector<DeviceGemmNoOpPtr>&);
} // namespace device_gemm_instance
} // namespace device
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment