Commit 91e37486 authored by rocking's avatar rocking
Browse files

Adjust the order of device_pool_fwd

parent 3fa7f1eb
...@@ -115,12 +115,12 @@ int main(int argc, char* argv[]) ...@@ -115,12 +115,12 @@ int main(int argc, char* argv[])
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
...@@ -174,12 +174,12 @@ int main(int argc, char* argv[]) ...@@ -174,12 +174,12 @@ int main(int argc, char* argv[])
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
...@@ -109,12 +109,12 @@ int main(int argc, char* argv[]) ...@@ -109,12 +109,12 @@ int main(int argc, char* argv[])
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
...@@ -168,12 +168,12 @@ int main(int argc, char* argv[]) ...@@ -168,12 +168,12 @@ int main(int argc, char* argv[])
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
in_tensor_stride,
out_tensor_stride,
out_tensor_stride,
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
...@@ -116,12 +116,12 @@ bool pool_test(bool do_verification, ...@@ -116,12 +116,12 @@ bool pool_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{N, C, Hi, Wi}, {N, C, Hi, Wi},
{Y, X}, {Y, X},
{N, C, Ho, Wo}, {N, C, Ho, Wo},
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
...@@ -123,12 +123,12 @@ bool pool3d_test(bool do_verification, ...@@ -123,12 +123,12 @@ bool pool3d_test(bool do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{N, C, Di, Hi, Wi}, {N, C, Di, Hi, Wi},
{Z, Y, X}, {Z, Y, X},
{N, C, Do, Ho, Wo}, {N, C, Do, Ho, Wo},
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
...@@ -25,12 +25,12 @@ struct DevicePoolFwd : public BaseOperator ...@@ -25,12 +25,12 @@ struct DevicePoolFwd : public BaseOperator
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t> input_stride,
std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t> input_stride,
std::vector<ck::index_t> output_stride,
std::vector<ck::index_t> indices_stride,
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
......
...@@ -144,9 +144,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -144,9 +144,7 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
} }
using ABGridDescs = decltype( using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>; using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
...@@ -285,12 +283,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -285,12 +283,12 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t>, // Suppose tensor layout = NHWC
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
......
...@@ -151,9 +151,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -151,9 +151,7 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem); return make_tuple(in_grid_desc_reducem_reducek, out_grid_desc_reducem);
} }
using ABGridDescs = decltype( using ABGridDescs = decltype(MakeABGridDescriptor_A_M_K_B_M(1, 1, {}, {}, {}, {}, {}, {}));
MakeABGridDescriptor_A_M_K_B_M(1, 1, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}, {1, 1}));
using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>; using AGridDesc_M_K = remove_cvref_t<decltype(ABGridDescs{}[I0])>;
using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>; using BGridDesc_M = remove_cvref_t<decltype(ABGridDescs{}[I1])>;
...@@ -290,12 +288,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -290,12 +288,12 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
MakeArgumentPointer(const void* p_in_dev, MakeArgumentPointer(const void* p_in_dev,
void* p_out_dev, void* p_out_dev,
void* p_out_indices_dev, void* p_out_indices_dev,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t> input_lengths, std::vector<ck::index_t> input_lengths,
std::vector<ck::index_t> window_lengths, std::vector<ck::index_t> window_lengths,
std::vector<ck::index_t> output_lengths, std::vector<ck::index_t> output_lengths,
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t>, // Suppose tensor layout = NDHWC
std::vector<ck::index_t> window_strides, std::vector<ck::index_t> window_strides,
std::vector<ck::index_t> input_left_pads, std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads, std::vector<ck::index_t> input_right_pads,
......
...@@ -145,12 +145,12 @@ bool profile_pool2d_fwd_impl(int do_verification, ...@@ -145,12 +145,12 @@ bool profile_pool2d_fwd_impl(int do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
{C * Hi * Wi, 1, Wi * C, C},
{C * Ho * Wo, 1, Wo * C, C},
{C * Ho * Wo, 1, Wo * C, C},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
...@@ -150,12 +150,12 @@ bool profile_pool3d_fwd_impl(int do_verification, ...@@ -150,12 +150,12 @@ bool profile_pool3d_fwd_impl(int do_verification,
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()),
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
{Di * C * Hi * Wi, 1, C * Hi * Wi, Wi * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
{Do * C * Ho * Wo, 1, C * Ho * Wo, Wo * C, C},
window_strides, window_strides,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment