"...composable_kernel_rocm.git" did not exist on "5311d1b32579e0427954d0ecba3a44280b75c165"
Unverified Commit b305a29e authored by rocking's avatar rocking Committed by GitHub
Browse files

Remove index tensor in avgpool (#1093)



* Remove index tensor

* fix syntax

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
parent a167e3c7
...@@ -94,7 +94,6 @@ int main(int argc, char* argv[]) ...@@ -94,7 +94,6 @@ int main(int argc, char* argv[])
SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size);
SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size);
SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size);
using DeviceOp = ck::tensor_operation::device::DevicePoolFwd<InOutRank, using DeviceOp = ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank, WindowRank,
...@@ -123,22 +122,22 @@ int main(int argc, char* argv[]) ...@@ -123,22 +122,22 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr =
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), nullptr,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
in_tensor_stride, in_tensor_stride,
out_tensor_stride, out_tensor_stride,
out_tensor_stride, out_tensor_stride,
window_strides, window_strides,
window_dilations, window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{2, 3, 4}); {2, 3, 4});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
...@@ -184,21 +183,21 @@ int main(int argc, char* argv[]) ...@@ -184,21 +183,21 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr =
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), nullptr,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
in_tensor_stride, in_tensor_stride,
out_tensor_stride, out_tensor_stride,
out_tensor_stride, out_tensor_stride,
window_strides, window_strides,
window_dilations, window_dilations,
input_left_pads, input_left_pads,
input_right_pads, input_right_pads,
{2, 3, 4}); {2, 3, 4});
auto invoker_ptr = op_ptr->MakeInvokerPointer(); auto invoker_ptr = op_ptr->MakeInvokerPointer();
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment