Commit cd01db8b authored by Po-Yen, Chen's avatar Po-Yen, Chen
Browse files

Merge descriptors into one object

parent a434991e
......@@ -151,17 +151,17 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
std::vector<index_t> tildes)
{
using namespace ck;
......@@ -348,21 +348,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
}
} // function end
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
std::vector<index_t> tildes)
{
using namespace ck;
......@@ -621,22 +621,21 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 3, bool>::type = false>
static auto
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(ck::index_t N,
ck::index_t K,
ck::index_t C,
std::vector<ck::index_t> input_spatial_lengths,
std::vector<ck::index_t> filter_spatial_lengths,
std::vector<ck::index_t> output_spatial_lengths,
std::vector<ck::index_t> conv_filter_strides,
std::vector<ck::index_t> conv_filter_dilations,
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> tildes)
MakeABCGridDescriptor_A_K0_M_K1_B_K0_N_K1_C_M_N(index_t N,
index_t K,
index_t C,
std::vector<index_t> input_spatial_lengths,
std::vector<index_t> filter_spatial_lengths,
std::vector<index_t> output_spatial_lengths,
std::vector<index_t> conv_filter_strides,
std::vector<index_t> conv_filter_dilations,
std::vector<index_t> input_left_pads,
std::vector<index_t> input_right_pads,
std::vector<index_t> tildes)
{
using namespace ck;
......@@ -978,8 +977,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
wei_gemmk0_gemmn_gemmk1_grid_desc,
in_gemmm_gemmn_grid_desc);
}
} // function end
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 1, bool>::type = false>
static auto GetDummyABCGridDesc()
......@@ -1125,9 +1123,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_,
input_right_pads_,
{i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
grid_desc_container_.push_back(descs);
}
}
template <ck::index_t NDim, typename ck::enable_if<NDim == 2, bool>::type = false>
......@@ -1172,9 +1168,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_,
input_right_pads_,
{i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
grid_desc_container_.push_back(descs);
}
}
}
......@@ -1228,9 +1222,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
input_left_pads_,
input_right_pads_,
{i_ztilde, i_ytilde, i_xtilde});
a_grid_desc_k0_m_k1_container_.push_back(descs[I0]);
b_grid_desc_k0_n_k1_container_.push_back(descs[I1]);
c_grid_desc_m_n_container_.push_back(descs[I2]);
grid_desc_container_.push_back(descs);
}
}
}
......@@ -1239,9 +1231,7 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
const ADataType* p_a_grid_;
const BDataType* p_b_grid_;
CDataType* p_c_grid_;
std::vector<AGridDesc_K0_M_K1> a_grid_desc_k0_m_k1_container_;
std::vector<BGridDesc_K0_N_K1> b_grid_desc_k0_n_k1_container_;
std::vector<CGridDesc_M_N> c_grid_desc_m_n_container_;
std::vector<ABCGridDescs> grid_desc_container_;
index_t M01_;
// for checking IsSupportedArgument()
index_t Conv_N_;
......@@ -1265,50 +1255,14 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
float Run(const Argument& arg, const StreamConfig& stream_config = StreamConfig{})
{
float ave_time = 0;
for(size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
{
#if DEBUG_LOG
for(size_t i = 0; i < arg.grid_desc_container_.size(); i++)
{
std::cout << "arg.a_grid_desc_k0_m_k1_container_{"
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I1) << ", "
<< arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.b_grid_desc_k0_n_k1_container_{"
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I0) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I1) << ", "
<< arg.b_grid_desc_k0_n_k1_container_[i].GetLength(I2) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m_n_container_{ "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I0) << ", "
<< arg.c_grid_desc_m_n_container_[i].GetLength(I1) << "}"
<< std::endl;
std::cout << "arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_( "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I0)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I1)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I2)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I3)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I4)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I5)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I6)
<< ", "
<< arg.c_grid_desc_m0_n0_m1_n1_m2_m3_m4_n2_container_[i].GetLength(I7)
<< " ) " << std::endl;
}
#endif
auto a_grid_desc_k0_m_k1 = arg.grid_desc_container_[i][I0];
auto b_grid_desc_k0_n_k1 = arg.grid_desc_container_[i][I1];
auto c_grid_desc_m_n = arg.grid_desc_container_[i][I2];
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
if(!GridwiseGemm::CheckValidity(
a_grid_desc_k0_m_k1, b_grid_desc_k0_n_k1, c_grid_desc_m_n))
{
throw std::runtime_error(
"wrong! GridwiseGemm_km_kn_m0m1n0n1_xdlops_v3r1 has invalid setting");
......@@ -1316,11 +1270,10 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
index_t gdx, gdy, gdz;
std::tie(gdx, gdy, gdz) = GridwiseGemm::CalculateGridSize(
arg.c_grid_desc_m_n_container_[i].GetLength(I0),
arg.c_grid_desc_m_n_container_[i].GetLength(I1));
c_grid_desc_m_n.GetLength(I0), c_grid_desc_m_n.GetLength(I1));
const auto K = arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I0) *
arg.a_grid_desc_k0_m_k1_container_[i].GetLength(I2);
const auto K =
a_grid_desc_k0_m_k1.GetLength(I0) * a_grid_desc_k0_m_k1.GetLength(I2);
if(GridwiseGemm::CalculateHasMainKBlockLoop(K))
{
......@@ -1341,9 +1294,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
GridwiseGemm::CalculateNumKBlockLoop(K));
}
else
......@@ -1366,9 +1319,9 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
arg.p_a_grid_,
arg.p_b_grid_,
arg.p_c_grid_,
arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i],
a_grid_desc_k0_m_k1,
b_grid_desc_k0_n_k1,
c_grid_desc_m_n,
GridwiseGemm::CalculateNumKBlockLoop(K));
}
}
......@@ -1419,11 +1372,11 @@ struct DeviceConvNdBwdDataNwcKxcNwk_Xdl
}
// Gridwise GEMM size
for(std::size_t i = 0; i < arg.a_grid_desc_k0_m_k1_container_.size(); i++)
for(std::size_t i = 0; i < arg.grid_desc_container_.size(); i++)
{
if(!GridwiseGemm::CheckValidity(arg.a_grid_desc_k0_m_k1_container_[i],
arg.b_grid_desc_k0_n_k1_container_[i],
arg.c_grid_desc_m_n_container_[i]))
if(!GridwiseGemm::CheckValidity(arg.grid_desc_container_[i][I0],
arg.grid_desc_container_[i][I1],
arg.grid_desc_container_[i][I2]))
{
return false;
}
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment