Commit 055acace authored by rocking's avatar rocking
Browse files

Imitate the argument from conv bwd

parent 55e420ec
...@@ -131,6 +131,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -131,6 +131,7 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde); const auto YDotSlice = math::integer_divide_ceil(Y - i_ytilde, YTilde);
const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde); const auto XDotSlice = math::integer_divide_ceil(X - i_xtilde, XTilde);
// Problem size of reduction kernel
const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C; const index_t MRaw = N * DTildeSlice * HTildeSlice * WTildeSlice * C;
const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw; const index_t MPad = math::integer_least_multiple(MRaw, M_BlockTileSize) - MRaw;
...@@ -293,6 +294,20 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -293,6 +294,20 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem); return make_tuple(out_grid_desc_reducem_reducek, in_grid_desc_reducem);
} }
using DoutDinGridDesc = decltype(Make3DGridDescriptor_Out_M_K_In_M({0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0},
{0, 0, 0}));
using DoutGridDesc_M_K = remove_cvref_t<tuple_element_t<0, DoutDinGridDesc>>;
using DinGridDesc_M = remove_cvref_t<tuple_element_t<1, DoutDinGridDesc>>;
struct Argument : public BaseArgument struct Argument : public BaseArgument
{ {
Argument(const DOutDataType* p_dout, Argument(const DOutDataType* p_dout,
...@@ -308,18 +323,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -308,18 +323,6 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout}, p_din_grid_{p_din}, num_reduce_{1} : p_dout_grid_{p_dout}, p_din_grid_{p_din}, num_reduce_{1}
{ {
ignore = p_dout;
ignore = p_din;
ignore = dout_n_c_wos_lengths;
ignore = dout_n_c_wos_strides;
ignore = din_n_c_wos_length;
ignore = din_n_c_wos_strides;
ignore = window_lengths;
ignore = window_strides;
ignore = window_dilations;
ignore = input_left_pads;
ignore = input_right_pads;
std::vector<ck::index_t> Tildes(NDimSpatial); std::vector<ck::index_t> Tildes(NDimSpatial);
for(int i = 0; i < NDimSpatial; ++i) for(int i = 0; i < NDimSpatial; ++i)
{ {
...@@ -346,8 +349,35 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -346,8 +349,35 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
{ {
continue; continue;
} }
const auto dout_din_grid_desc =
Make3DGridDescriptor_Out_M_K_In_M(dout_n_c_wos_lengths,
din_n_c_wos_length,
dout_n_c_wos_strides,
din_n_c_wos_strides,
window_lengths,
window_strides,
window_dilations,
input_left_pads,
input_right_pads,
{i_ztilde, i_ytilde, i_xtilde});
dout_grid_desc_m_k_container_.push_back(dout_din_grid_desc[I0]);
din_grid_desc_m_container_.push_back(dout_din_grid_desc[I1]);
}
}
} }
} }
void Print() const
{
for(index_t i = 0; i < num_reduce_; i++)
{
std::cout << "dout_grid_desc_m_k_container_" << dout_grid_desc_m_k_container_[i]
<< std::endl;
std::cout << "din_grid_desc_m_container_" << din_grid_desc_m_container_[i]
<< std::endl;
} }
} }
...@@ -356,6 +386,8 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -356,6 +386,8 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
DInDataType* p_din_grid_; DInDataType* p_din_grid_;
int num_reduce_; int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
std::vector<DinGridDesc_M> din_grid_desc_m_container_;
}; };
struct Invoker : public BaseInvoker struct Invoker : public BaseInvoker
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment