Commit cd01db8b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge descriptors into one object

parent a434991e
...@@ -151,17 +151,17 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -151,17 +151,17 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
ck::index_t K, index_t K,
ck::index_t C, index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<index_t> input_right_pads,
std::vector<ck::index_t> tildes) std::vector<index_t> tildes)
{ {
using namespace ck; using namespace ck;
...@@ -348,21 +348,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -348,21 +348,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc); in_gemmm_gemmn_grid_desc);
} }
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
ck::index_t K, index_t K,
ck::index_t C, index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<index_t> input_right_pads,
std::vector<ck::index_t> tildes) std::vector<index_t> tildes)
{ {
using namespace ck; using namespace ck;
...@@ -621,22 +621,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -621,22 +621,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc); in_gemmm_gemmn_grid_desc);
} }
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N, MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
ck::index_t K, index_t K,
ck::index_t C, index_t C,
std::vector<ck::index_t> input_spatial_lengths, std::vector<index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths, std::vector<index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths, std::vector<index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides, std::vector<index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations, std::vector<index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads, std::vector<index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<index_t> input_right_pads,
std::vector<ck::index_t> tildes) std::vector<index_t> tildes)
{ {
using namespace ck; using namespace ck;
...@@ -978,8 +977,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -978,8 +977,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc, wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc); in_gemmm_gemmn_grid_desc);
} }
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetDummyABCGridDesc() static auto GetDummyABCGridDesc()
...@@ -1125,9 +1123,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1125,9 +1123,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_xtilde}); {i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); grid_desc_container_.push_back(descs);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
} }
} }
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false> template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
...@@ -1172,9 +1168,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1172,9 +1168,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_ytilde, i_xtilde}); {i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); grid_desc_container_.push_back(descs);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
} }
} }
} }
...@@ -1228,9 +1222,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1228,9 +1222,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_, input_left_pads_,
input_right_pads_, input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde}); {i_ztilde, i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]); grid_desc_container_.push_back(descs);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
} }
} }
} }
...@@ -1239,9 +1231,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1239,9 +1231,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const ADataType* p_a_grid_; const ADataType* p_a_grid_;
const BDataType* p_b_grid_; const BDataType* p_b_grid_;
CDataType* p_c_grid_; CDataType* p_c_grid_;
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_; std::vector<ABCGridDescs> grid_desc_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
index_t M01_; index_t M01_;
// for checking IsSupportedArgument() // for checking IsSupportedArgument()
index_t Conv_N_; index_t Conv_N_;
...@@ -1265,50 +1255,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1265,50 +1255,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{}) float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{ {
float ave_time = 0; float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(size_t i = 0; i < arg.grid_desc_container_.size(); i++)
{ {
#if DEBUG_LOG auto a_grid_desc_k0_m_k1 = arg.grid_desc_container_[i][I0];
{ auto b_grid_desc_k0_n_k1 = arg.grid_desc_container_[i][I1];
std::cout << "arg.a_grid_desc_k0_m_k1_container_{" auto c_grid_desc_m_n = arg.grid_desc_container_[i][I2];
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7)
<< " ) " << std::endl;
}
#endif
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(
arg.b_grid_desc_k0_n_k1_container_[i], a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n))
arg.c_grid_desc_m_n_container_[i]))
{ {
throw std::runtime_error( throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting"); "wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
...@@ -1316,11 +1270,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1316,11 +1270,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
index_t gdx, gdy, gdz; index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize( std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_container_[i].GetLength(I0), c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
arg.c_grid_desc_m_n_container_[i].GetLength(I1));
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) * const auto K =
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2); a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K)) if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{ {
...@@ -1341,9 +1294,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1341,9 +1294,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i], a_grid_desc_k0_m_k1,
arg.b_grid_desc_k0_n_k1_container_[i], b_grid_desc_k0_n_k1,
arg.c_grid_desc_m_n_container_[i], c_grid_desc_m_n,
GridwiseGemm::CalculateNumKBlockLoop(K)); GridwiseGemm::CalculateNumKBlockLoop(K));
} }
else else
...@@ -1366,9 +1319,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1366,9 +1319,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_, arg.p_a_grid_,
arg.p_b_grid_, arg.p_b_grid_,
arg.p_c_grid_, arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i], a_grid_desc_k0_m_k1,
arg.b_grid_desc_k0_n_k1_container_[i], b_grid_desc_k0_n_k1,
arg.c_grid_desc_m_n_container_[i], c_grid_desc_m_n,
GridwiseGemm::CalculateNumKBlockLoop(K)); GridwiseGemm::CalculateNumKBlockLoop(K));
} }
} }
...@@ -1419,11 +1372,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl ...@@ -1419,11 +1372,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
} }
// Gridwise GEMM size // Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++) for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++)
{ {
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i], if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0],
arg.b_grid_desc_k0_n_k1_container_[i], arg.grid_desc_container_[i][I1],
arg.c_grid_desc_m_n_container_[i])) arg.grid_desc_container_[i][I2]))
{ {
return false; return false;
} }
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment