Unverified Commit b305a29e authored by rocking's avatar rocking Committed by GitHub
Browse files

Remove index tensor in avgpool (#1093)



* Remove index tensor

* fix syntax

---------
Co-authored-by: default avatarIllia Silin <98187287+illsilin@users.noreply.github.com>
Co-authored-by: default avatarillsilin <Illia.Silin@amd.com>
parent a167e3c7
...@@ -94,7 +94,6 @@ int main(int argc, char* argv[]) ...@@ -94,7 +94,6 @@ int main(int argc, char* argv[])
SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size); SimpleDeviceMem in_device_buf(sizeof(InDataType) * in_tensor_size);
SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size); SimpleDeviceMem out_device_buf(sizeof(OutDataType) * out_tensor_size);
SimpleDeviceMem out_indices_device_buf(sizeof(IndexDataType) * out_tensor_size);
using DeviceOp = ck::tensor_operation::device::DevicePoolFwd<InOutRank, using DeviceOp = ck::tensor_operation::device::DevicePoolFwd<InOutRank,
WindowRank, WindowRank,
...@@ -124,10 +123,10 @@ int main(int argc, char* argv[]) ...@@ -124,10 +123,10 @@ int main(int argc, char* argv[])
for(int i = 0; i < op_ptrs.size(); ++i) for(int i = 0; i < op_ptrs.size(); ++i)
{ {
auto& op_ptr = op_ptrs[i]; auto& op_ptr = op_ptrs[i];
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr =
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), nullptr,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
...@@ -184,10 +183,10 @@ int main(int argc, char* argv[]) ...@@ -184,10 +183,10 @@ int main(int argc, char* argv[])
std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString() std::cout << "Run the best instance without timing: " << op_ptr->GetTypeString()
<< std::endl; << std::endl;
auto argument_ptr = op_ptr->MakeArgumentPointer( auto argument_ptr =
static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()), op_ptr->MakeArgumentPointer(static_cast<InDataType*>(in_device_buf.GetDeviceBuffer()),
static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()), static_cast<OutDataType*>(out_device_buf.GetDeviceBuffer()),
static_cast<IndexDataType*>(out_indices_device_buf.GetDeviceBuffer()), nullptr,
in_length, in_length,
window_spatial_lengths, window_spatial_lengths,
out_length, out_length,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment