"...composable_kernel_rocm.git" did not exist on "f6cb5b846d1eff1d1e35ab58273becfd40bd0831"
Commit 158ec146 authored by rocking's avatar rocking
Browse files

Add type to base class

parent 2280e5bb
......@@ -13,7 +13,13 @@ namespace ck {
namespace tensor_operation {
namespace device {
template <index_t InOutRank, index_t WindowRank, ReduceTensorOp ReduceOpId, bool OuputIndex>
template <index_t InOutRank,
index_t WindowRank,
typename InDataType,
typename OutDataType,
typename IndexDataType,
ReduceTensorOp ReduceOpId,
bool OuputIndex>
struct DevicePoolFwd : public BaseOperator
{
virtual std::unique_ptr<BaseArgument>
......
......@@ -31,7 +31,7 @@ template <typename InDataType,
ck::index_t ReduceKThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
: public DevicePoolFwd<4, 2, ReduceOpId, OuputIndex>
: public DevicePoolFwd<4, 2, InDataType, OutDataType, IndexDataType, ReduceOpId, OuputIndex>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
......@@ -31,7 +31,7 @@ template <typename InDataType,
ck::index_t KThreadSliceSize,
ck::index_t InSrcOutDstVectorSize>
struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
: public DevicePoolFwd<5, 3, ReduceOpId, OuputIndex>
: public DevicePoolFwd<5, 3, InDataType, OutDataType, IndexDataType, ReduceOpId, OuputIndex>
{
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
......@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
......
......@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
......
......@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
......
......@@ -11,7 +11,7 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::AVG;
void add_device_pooling3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
......
......@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
}
void add_device_pooling2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F16, F16, I32, ReduceOpId, true>>>& instances)
{
add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
......
......@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
}
void add_device_pooling2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, F32, F32, I32, ReduceOpId, true>>>& instances)
{
add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
......
......@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
}
void add_device_pooling3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F16, F16, I32, ReduceOpId, true>>>& instances)
{
add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
......
......@@ -11,14 +11,14 @@ namespace instance {
static constexpr auto ReduceOpId = ck::ReduceTensorOp::MAX;
void add_device_pooling3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, false>>>& instances)
{
add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
}
void add_device_pooling3d_fwd_ndhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances)
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, F32, F32, I32, ReduceOpId, true>>>& instances)
{
add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment