Commit 3c8d9843 authored by rocking's avatar rocking
Browse files

Add index stride and output stride

parent 102b4922
......@@ -116,8 +116,10 @@ bool pool_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{N, C, Hi, Wi},
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{N, C, Hi, Wi},
{Y, X},
{N, C, Ho, Wo},
window_strides,
......
......@@ -124,8 +124,10 @@ bool pool3d_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{N, C, Di, Hi, Wi},
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{N, C, Di, Hi, Wi},
{Z, Y, X},
{N, C, Do, Ho, Wo},
window_strides,
......
......@@ -20,8 +20,10 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev,
void* p_out_indices_dev,
std::array<ck::index_t, InOutRank> input_lengths,
std::array<ck::index_t, InOutRank> input_stride,
std::array<ck::index_t, InOutRank> output_stride,
std::array<ck::index_t, InOutRank> indices_stride,
std::array<ck::index_t, InOutRank> input_lengths,
std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, InOutRank> output_lengths,
std::array<ck::index_t, WindowRank> window_strides,
......
......@@ -286,6 +286,9 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev,
void* p_out_indices_dev,
std::array<ck::index_t, InOutRank>, // Suppose tensor layout = NHWC
std::array<ck::index_t, InOutRank>, // Suppose tensor layout = NHWC
std::array<ck::index_t, InOutRank>, // Suppose tensor layout = NHWC
std::array<ck::index_t, InOutRank> input_lengths,
std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, InOutRank> output_lengths,
......
......@@ -291,6 +291,9 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev,
void* p_out_indices_dev,
std::array<ck::index_t, InOutRank>, // Suppose tensor layout = NDHWC
std::array<ck::index_t, InOutRank>, // Suppose tensor layout = NDHWC
std::array<ck::index_t, InOutRank>, // Suppose tensor layout = NDHWC
std::array<ck::index_t, InOutRank> input_lengths,
std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, InOutRank> output_lengths,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment