Commit 0ab4fa0f authored by rocking's avatar rocking
Browse files

Check if argument is valid

parent 400cb28e
...@@ -75,7 +75,8 @@ bool pool3d_bwd_test(bool do_verification, ...@@ -75,7 +75,8 @@ bool pool3d_bwd_test(bool do_verification,
std::vector<ck::index_t> dinput_right_pads) std::vector<ck::index_t> dinput_right_pads)
{ {
using DevicePoolBwdInstance = using DevicePoolBwdInstance =
ck::tensor_operation::device::DeviceAvgPool3dBwdImpl<DOutDataType, ck::tensor_operation::device::DeviceAvgPool3dBwdImpl<3,
DOutDataType,
DInDataType, DInDataType,
ComputeDataType, // ComputeDataType ComputeDataType, // ComputeDataType
64, // BlockSize 64, // BlockSize
......
...@@ -12,7 +12,7 @@ namespace ck { ...@@ -12,7 +12,7 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <typename DOutDataType, typename DInDataType> template <index_t NDimSpatial, typename DOutDataType, typename DInDataType>
struct DeviceAvgPoolBwd : public BaseOperator struct DeviceAvgPoolBwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
......
...@@ -23,7 +23,8 @@ namespace device { ...@@ -23,7 +23,8 @@ namespace device {
// Out = AvgPoolFwd(In) // Out = AvgPoolFwd(In)
// Din = AvgPoolBwd(Dout) // Din = AvgPoolBwd(Dout)
// Pooling dimension = D, H, W // Pooling dimension = D, H, W
template <typename DOutDataType, template <index_t NDimSpatial,
typename DOutDataType,
typename DInDataType, typename DInDataType,
typename ComputeDataType, typename ComputeDataType,
ck::index_t BlockSize, ck::index_t BlockSize,
...@@ -33,10 +34,8 @@ template <typename DOutDataType, ...@@ -33,10 +34,8 @@ template <typename DOutDataType,
ck::index_t KThreadSliceSize, ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize, ck::index_t InSrcOutDstVectorSize,
bool IsFastestDimReduced> bool IsFastestDimReduced>
struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataType> struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<NDimSpatial, DOutDataType, DInDataType>
{ {
static constexpr index_t NDimSpatial = 3;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
...@@ -356,6 +355,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -356,6 +355,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::vector<ck::index_t> input_right_pads) std::vector<ck::index_t> input_right_pads)
: p_dout_grid_{p_dout}, : p_dout_grid_{p_dout},
p_din_grid_{p_din}, p_din_grid_{p_din},
dout_n_c_wos_lengths_{dout_n_c_wos_lengths},
din_n_c_wos_length_{din_n_c_wos_length},
dout_n_c_wos_strides_{dout_n_c_wos_strides},
din_n_c_wos_strides_{din_n_c_wos_strides},
num_reduce_{1}, num_reduce_{1},
div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]} div_element_op_{window_lengths[0] * window_lengths[1] * window_lengths[2]}
{ {
...@@ -407,6 +410,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -407,6 +410,10 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
const DOutDataType* p_dout_grid_; const DOutDataType* p_dout_grid_;
DInDataType* p_din_grid_; DInDataType* p_din_grid_;
std::vector<ck::index_t> dout_n_c_wos_lengths_;
std::vector<ck::index_t> din_n_c_wos_length_;
std::vector<ck::index_t> dout_n_c_wos_strides_;
std::vector<ck::index_t> din_n_c_wos_strides_;
int num_reduce_; int num_reduce_;
std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_; std::vector<DoutGridDesc_M_K> dout_grid_desc_m_k_container_;
...@@ -468,7 +475,31 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -468,7 +475,31 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
static bool IsSupportedArgument(const Argument& arg) static bool IsSupportedArgument(const Argument& arg)
{ {
ignore = arg; constexpr index_t Rank = NDimSpatial + 2;
int doutFastestDim = -1;
int dinFastestDim = -1;
for(int i = 0; i < Rank; ++i)
{
if(arg.dout_n_c_wos_strides_[i] == 1)
doutFastestDim = i;
if(arg.din_n_c_wos_strides_[i] == 1)
dinFastestDim = i;
}
if(doutFastestDim == -1 || dinFastestDim == -1)
{
if constexpr(InSrcOutDstVectorSize != 1)
return false;
}
else
{
if(arg.dout_n_c_wos_lengths_[doutFastestDim] % InSrcOutDstVectorSize != 0)
return false;
if(arg.din_n_c_wos_length_[dinFastestDim] % InSrcOutDstVectorSize != 0)
return false;
}
return true; return true;
} }
...@@ -490,6 +521,17 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp ...@@ -490,6 +521,17 @@ struct DeviceAvgPool3dBwdImpl : public DeviceAvgPoolBwd<DOutDataType, DInDataTyp
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads) override std::vector<ck::index_t> input_right_pads) override
{ {
constexpr index_t Rank = NDimSpatial + 2;
if(dout_n_c_wos_strides.size() != Rank || din_n_c_wos_strides.size() != Rank ||
dout_n_c_wos_lengths.size() != Rank || din_n_c_wos_length.size() != Rank)
throw std::runtime_error("dimension is incorrect");
if(window_lengths.size() != NDimSpatial || window_strides.size() != NDimSpatial ||
window_dilations.size() != NDimSpatial || input_left_pads.size() != NDimSpatial ||
input_right_pads.size() != NDimSpatial)
throw std::runtime_error("dimension is incorrect");
return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout), return std::make_unique<Argument>(static_cast<const DOutDataType*>(p_dout),
static_cast<DInDataType*>(p_din), static_cast<DInDataType*>(p_din),
dout_n_c_wos_lengths, dout_n_c_wos_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment