"docs/vscode:/vscode.git/clone" did not exist on "365a938884dfcd33b2c89b814d69a08acb97de0f"
Commit 0ab4fa0f authored by rocking's avatar rocking
Browse files

Check if argument is valid

parent 400cb28e
......@@ -75,7 +75,8 @@ bool pool3d_bwd_test(bool do_verification,
std::vector<ck::index_t> dinput_right_pads)
{
using DevicePoolBwdInstance =
ck::tensor_operation::device::DeviceAvgPool3dBwdImpl<DOutDataType,
ck::tensor_operation::device::DeviceAvgPool3dBwdImpl<3,
DOutDataType,
DInDataType,
ComputeDataType, // ComputeDataType
64, // BlockSize
......
......@@ -12,7 +12,7 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <typename DOutDataType, typename DInDataType>
template <index_t NDimSpatial, typename DOutDataType, typename DInDataType>
struct DeviceAvgPoolBwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
......
......@@ -23,7 +23,8 @@ namespace device {
// Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W
template <typename DOutDataType,
template <index_t NDimSpatial,
typename DOutDataType,
typename DInDataType,
typename ComputeDataType,
ck::index_t BlockSize,
......@@ -33,10 +34,8 @@ template <typename DOutDataType,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize,
bool IsFastestDimReduced>
struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataType>
struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<NDimSpatial, DOutDataType, DInDataType>
{
static constexpr index_t NDimSpatial = 3;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......@@ -356,6 +355,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout},
p_din_grid_{p_din},
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
din_n_c_wos_length_{din_n_c_wos_length},
dout_n_c_wos_strides_{dout_n_c_wos_strides},
din_n_c_wos_strides_{din_n_c_wos_strides},
num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
{
......@@ -407,6 +410,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_;
std::vector<ck::index_t> dout_n_c_wos_lengths_;
std::vector<ck::index_t> din_n_c_wos_length_;
std::vector<ck::index_t> dout_n_c_wos_strides_;
std::vector<ck::index_t> din_n_c_wos_strides_;
int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
......@@ -468,7 +475,31 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
static bool IsSupportedArgument(const Argument& arg)
{
ignore = arg;
constexpr index_t Rank = NDimSpatial + 2;
int doutFastestDim = -1;
int dinFastestDim = -1;
for(int i = 0; i < Rank; ++i)
{
if(arg.dout_n_c_wos_strides_[i] == 1)
doutFastestDim = i;
if(arg.din_n_c_wos_strides_[i] == 1)
dinFastestDim = i;
}
if(doutFastestDim == -1 || dinFastestDim == -1)
{
if constexpr(InSrcOutDstVectorSize != 1)
return false;
}
else
{
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
return false;
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
return false;
}
return true;
}
......@@ -490,6 +521,17 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override
{
constexpr index_t Rank = NDimSpatial + 2;
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
throw std::runtime_error("dimension is incorrect");
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din),
dout_n_c_wos_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment