Commit 2f463a94 authored by carlushuang's avatar carlushuang
Browse files

Merge remote-tracking branch 'origin/develop' into stream-k-initial-impl

parents ca8b5c79 ac9e01e2
...@@ -3,8 +3,7 @@ ...@@ -3,8 +3,7 @@
#pragma once #pragma once
#include <iostream> #include <vector>
#include <array>
#include "ck/tensor_operation/gpu/device/device_base.hpp" #include "ck/tensor_operation/gpu/device/device_base.hpp"
#include "ck/utility/reduction_enums.hpp" #include "ck/utility/reduction_enums.hpp"
...@@ -13,28 +12,33 @@ namespace ck { ...@@ -13,28 +12,33 @@ namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
template <ck::ReduceTensorOp ReduceOpId> template <index_t InOutRank,
struct DevicePool2dFwd : public BaseOperator index_t WindowRank,
typename InDataType,
typename OutDataType,
typename IndexDataType,
ReduceTensorOp ReduceOpId,
bool OutputIndex>
struct DevicePoolFwd : public BaseOperator
{ {
virtual std::unique_ptr<BaseArgument> virtual std::unique_ptr<BaseArgument>
MakeArgumentPointer(const void* in_dev, MakeArgumentPointer(const void* p_in_dev,
void* out_dev, void* p_out_dev,
void* out_indices_dev, void* p_out_indices_dev,
ck::index_t N, std::vector<ck::index_t> input_lengths,
ck::index_t C, std::vector<ck::index_t> window_lengths,
std::array<ck::index_t, 2> input_spatial_lengths, std::vector<ck::index_t> output_lengths,
std::array<ck::index_t, 2> window_spatial_lengths, std::vector<ck::index_t> input_stride,
std::array<ck::index_t, 2> output_spatial_lengths, std::vector<ck::index_t> output_stride,
std::array<ck::index_t, 2> window_strides, std::vector<ck::index_t> indices_stride,
std::array<ck::index_t, 2> input_left_pads, std::vector<ck::index_t> window_strides,
std::array<ck::index_t, 2> input_right_pads) = 0; std::vector<ck::index_t> input_left_pads,
std::vector<ck::index_t> input_right_pads,
std::vector<ck::index_t> pooling_dims) = 0;
virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0; virtual std::unique_ptr<BaseInvoker> MakeInvokerPointer() = 0;
}; };
template <ck::ReduceTensorOp ReduceOpId>
using DevicePool2dFwdPtr = std::unique_ptr<DevicePool2dFwd<ReduceOpId>>;
} // namespace device } // namespace device
} // namespace tensor_operation } // namespace tensor_operation
} // namespace ck } // namespace ck
This diff is collapsed.
This diff is collapsed.
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment