Commit d22124dc authored by rocking's avatar rocking
Browse files

Check the pooling dim

parent fe8ed3d2
...@@ -294,13 +294,16 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -294,13 +294,16 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t>) override std::vector<ck::index_t> pooling_dims) override
{ {
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3})
throw std::runtime_error("pooling_dims only support {2, 3} in pool2d so far");
index_t N = input_lengths[0]; index_t N = input_lengths[0];
index_t C = input_lengths[1]; index_t C = input_lengths[1];
index_t Hi = input_lengths[2]; index_t Hi = input_lengths[2];
......
...@@ -299,13 +299,16 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -299,13 +299,16 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t>) override std::vector<ck::index_t> pooling_dims) override
{ {
if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank || if(input_lengths.size() != InOutRank || window_lengths.size() != WindowRank ||
input_lengths.size() != InOutRank || window_strides.size() != WindowRank || input_lengths.size() != InOutRank || window_strides.size() != WindowRank ||
input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank) input_left_pads.size() != WindowRank || input_right_pads.size() != WindowRank)
throw std::runtime_error("dimension is incorrect"); throw std::runtime_error("dimension is incorrect");
if(pooling_dims != std::vector<ck::index_t>{2, 3, 4})
throw std::runtime_error("pooling_dims only support {2, 3, 4} in pool3d so far");
index_t N = input_lengths[0]; index_t N = input_lengths[0];
index_t C = input_lengths[1]; index_t C = input_lengths[1];
index_t Di = input_lengths[2]; index_t Di = input_lengths[2];
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment