Commit 4eb307e9 authored by rocking's avatar rocking
Browse files

Add tensor stride

parent 9621381f
......@@ -117,6 +117,7 @@ bool pool_test(bool do_verification,
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{N, C, Hi, Wi},
{C * Hi * Wi, 1, Wi * C, C},
{Y, X},
{N, C, Ho, Wo},
window_strides,
......
......@@ -125,6 +125,7 @@ bool pool3d_test(bool do_verification,
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{N, C, Di, Hi, Wi},
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
{Z, Y, X},
{N, C, Do, Ho, Wo},
window_strides,
......
......@@ -21,6 +21,7 @@ struct DevicePoolFwd : public BaseOperator
void* p_out_dev,
void* p_out_indices_dev,
std::array<ck::index_t, InOutRank> input_lengths,
std::array<ck::index_t, InOutRank> input_stride,
std::array<ck::index_t, WindowRank> window_lengths,
std::array<ck::index_t, InOutRank> output_lengths,
std::array<ck::index_t, WindowRank> window_strides,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment