Commit 7b833910 authored by rocking's avatar rocking
Browse files

Extract IndexDataType to template

parent 58e912d8
...@@ -150,9 +150,10 @@ bool pool_test(bool do_verification, ...@@ -150,9 +150,10 @@ bool pool_test(bool do_verification,
{ {
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C< ck::tensor_operation::device::DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
OutDataType, // OutDataType OutDataType, // OutDataType
AccDataType, // AccDataType IndexDataType, // IndexDataType
AccDataType, // AccDataType
ReduceOpId, ReduceOpId,
OutputIndex, OutputIndex,
64, // BlockSize 64, // BlockSize
......
...@@ -166,9 +166,10 @@ bool pool3d_test(bool do_verification, ...@@ -166,9 +166,10 @@ bool pool3d_test(bool do_verification,
{ {
using DevicePoolFwdInstance = using DevicePoolFwdInstance =
ck::tensor_operation::device::DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C< ck::tensor_operation::device::DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<
InDataType, // InDataType InDataType, // InDataType
OutDataType, // OutDataType OutDataType, // OutDataType
AccDataType, // AccDataType IndexDataType, // IndexDataType
AccDataType, // AccDataType
ReduceOpId, ReduceOpId,
OutputIndex, OutputIndex,
64, // BlockSize 64, // BlockSize
......
...@@ -20,6 +20,7 @@ namespace device { ...@@ -20,6 +20,7 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, // enable if OuputIndex == true
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool OuputIndex, bool OuputIndex,
...@@ -42,8 +43,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C ...@@ -42,8 +43,6 @@ struct DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C
static constexpr index_t InOutRank = 4; static constexpr index_t InOutRank = 4;
static constexpr index_t WindowRank = 2; static constexpr index_t WindowRank = 2;
using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
......
...@@ -20,6 +20,7 @@ namespace device { ...@@ -20,6 +20,7 @@ namespace device {
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType, // enable if OuputIndex == true
typename AccDataType, typename AccDataType,
ck::ReduceTensorOp ReduceOpId, ck::ReduceTensorOp ReduceOpId,
bool OuputIndex, bool OuputIndex,
...@@ -42,8 +43,6 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C ...@@ -42,8 +43,6 @@ struct DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C
static constexpr index_t InOutRank = 5; static constexpr index_t InOutRank = 5;
static constexpr index_t WindowRank = 3; static constexpr index_t WindowRank = 3;
using IndexDataType = int32_t;
using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
......
...@@ -14,7 +14,7 @@ void add_device_avg_pooling2d_fwd_nhwc_f16_instances( ...@@ -14,7 +14,7 @@ void add_device_avg_pooling2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, F32, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
} }
} // namespace instance } // namespace instance
......
...@@ -14,7 +14,7 @@ void add_device_avg_pooling2d_fwd_nhwc_f32_instances( ...@@ -14,7 +14,7 @@ void add_device_avg_pooling2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, F32, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
} }
} // namespace instance } // namespace instance
......
...@@ -14,7 +14,7 @@ void add_device_avg_pooling3d_fwd_ndhwc_f16_instances( ...@@ -14,7 +14,7 @@ void add_device_avg_pooling3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, F32, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F32, ReduceOpId, false>{});
} }
......
...@@ -14,7 +14,7 @@ void add_device_avg_pooling3d_fwd_ndhwc_f32_instances( ...@@ -14,7 +14,7 @@ void add_device_avg_pooling3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, F32, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
} }
......
...@@ -14,14 +14,14 @@ void add_device_max_pooling2d_fwd_nhwc_f16_instances( ...@@ -14,14 +14,14 @@ void add_device_max_pooling2d_fwd_nhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, F16, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
} }
void add_device_max_pooling2d_fwd_nhwc_index_f16_instances( void add_device_max_pooling2d_fwd_nhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F16, F16, F16, ReduceOpId, true>{}); instances, device_pooling2d_fwd_nhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
} }
} // namespace instance } // namespace instance
......
...@@ -14,14 +14,14 @@ void add_device_max_pooling2d_fwd_nhwc_f32_instances( ...@@ -14,14 +14,14 @@ void add_device_max_pooling2d_fwd_nhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, F32, ReduceOpId, false>{}); instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
} }
void add_device_max_pooling2d_fwd_nhwc_index_f32_instances( void add_device_max_pooling2d_fwd_nhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<4, 2, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling2d_fwd_nhwc_instances<F32, F32, F32, ReduceOpId, true>{}); instances, device_pooling2d_fwd_nhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
} }
} // namespace instance } // namespace instance
......
...@@ -14,14 +14,14 @@ void add_device_max_pooling3d_fwd_ndhwc_f16_instances( ...@@ -14,14 +14,14 @@ void add_device_max_pooling3d_fwd_ndhwc_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, F16, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, false>{});
} }
void add_device_max_pooling3d_fwd_ndhwc_index_f16_instances( void add_device_max_pooling3d_fwd_ndhwc_index_f16_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, F16, ReduceOpId, true>{}); instances, device_pooling3d_fwd_ndhwc_instances<F16, F16, I32, F16, ReduceOpId, true>{});
} }
} // namespace instance } // namespace instance
......
...@@ -14,14 +14,14 @@ void add_device_max_pooling3d_fwd_ndhwc_f32_instances( ...@@ -14,14 +14,14 @@ void add_device_max_pooling3d_fwd_ndhwc_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, false>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, F32, ReduceOpId, false>{}); instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, false>{});
} }
void add_device_max_pooling3d_fwd_ndhwc_index_f32_instances( void add_device_max_pooling3d_fwd_ndhwc_index_f32_instances(
std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances) std::vector<std::unique_ptr<DevicePoolFwd<5, 3, ReduceOpId, true>>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, F32, ReduceOpId, true>{}); instances, device_pooling3d_fwd_ndhwc_instances<F32, F32, I32, F32, ReduceOpId, true>{});
} }
} // namespace instance } // namespace instance
......
...@@ -15,20 +15,22 @@ namespace tensor_operation { ...@@ -15,20 +15,22 @@ namespace tensor_operation {
namespace device { namespace device {
namespace instance { namespace instance {
using I32 = int32_t;
using F16 = ck::half_t; using F16 = ck::half_t;
using F32 = float; using F32 = float;
template <typename InDataType, template <typename InDataType,
typename OutDataType, typename OutDataType,
typename IndexDataType,
typename AccDataType, typename AccDataType,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
bool OuputIndex> bool OuputIndex>
using device_pooling2d_fwd_nhwc_instances = using device_pooling2d_fwd_nhwc_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 1, 1, 1>, DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 2, 1, 2>, DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 4, 1, 4> DevicePool2dFwd_Input_N_Hi_Wi_C_Output_N_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on // clang-format on
>; >;
...@@ -40,9 +42,9 @@ template <typename InDataType, ...@@ -40,9 +42,9 @@ template <typename InDataType,
using device_pooling3d_fwd_ndhwc_instances = using device_pooling3d_fwd_ndhwc_instances =
// clang-format off // clang-format off
std::tuple < std::tuple <
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 1, 1, 1>, DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 1, 1, 1>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 2, 1, 2>, DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 2, 1, 2>,
DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 4, 1, 4> DevicePool3dFwd_Input_N_Di_Hi_Wi_C_Output_N_Do_Ho_Wo_C<InDataType, OutDataType, IndexDataType, AccDataType, ReduceOpId, OuputIndex, 256, 256, 1, 4, 1, 4>
// clang-format on // clang-format on
>; >;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment