Unverified Commit 9a8ee8a3 authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Reduction for int8 and bfloat16 (#125)



* Use thread cluster descriptor and explicit M_K 2d descriptor to simply Blockwise Reduction

* Change by replacing ReduceDims by NumReduceDims as Device Reduce interface template parameter

* Rename the folder name for the pool2d and reduce examples

* Update to reduction test scripts

* Add Readme for pool2d_fwd and reduce_blockwise examples

* Add support for int8_t reduction (ADD/AVG, MIN/MAX/AMAX)

* Tiny fix in reduce profiler and tiny update in reduce testing scripts

* Tiny fix in testing script profile_reduce_no_index.sh

* Tiny fix in testing script profile_reduce_no_index.sh

* Add support for bfp16 reduction (using bhalf_t = ushort)

* Tiny fix in amd_buffer_addressing.hpp

* Tiny change in script/profile_reduce_with_index.sh

* Use AccDataType for Beta value and use element_wise::PassThrough

* Use type_convert for type converting in host layer reduction

* Renaming and refining in Reduction profiler/device layer/examples

* Renaming and refining in Reduction profiler/device layer/examples

* Renaming all NumReduceDims to NumReduceDim

* Fix the leaked type_convert in ThreadwiseTensorSliceTransfer_v2

* Update to testing scripts to add bf16 support

* added more static_assert

* Remove buggy tunable configurations defined in device_reduce_instance_xxx.hpp

* Add static_assert to give compile-time warning for incorrect thread slice-size/vector-size configurations

* minor change

* Refine and fix (in GetWorkspaceSizeInBytes of MultiBlockPartialReduce) to make int8 completely pass

* Tiny renaming in gridwise_2d_reduction_multiblock_partial_reduce.hpp

* Tiny fix in script/profile_reduce_no_index.sh

* Refine in DeviceReduce layer with regard to using NumInvariantDim/NumReduceDim or InvariantDims/ReduceDims

* Generic renaming in host reduction and DeviceReduce layer

* Add support for 4-d all dimension reduction in the profiler and add_device_reduce_xxx instances

* Use multi-thread and simplification for host Reduction implementation

* Add ctest for reduction

* Update to clarify the using of data init method in produce_reduce/example_reduce/test_reduce/

* Update to the reduce CTest executables to enable default testing behavior when no command argument

* Renaming
Co-authored-by: default avatarJianfeng yan <jfyan008@gmail.com>
parent cb87b049
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include "device_reduce_instance_multiblock_atomic_add.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(bhalf_t, float, float, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -8,9 +8,11 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(half_t, float, float, 5, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,9 +8,11 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,9 +8,11 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
// clang-format on
......
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -8,21 +8,27 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
......
......@@ -8,12 +8,15 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,25 +8,32 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,6 +8,7 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,33 +8,42 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
// Will be moved to use MultiBlockAtomicAdd
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
// clang-format on
......
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include "device_reduce_instance_multiblock_partial_reduce.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 0, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 3); // for MIN
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 2, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 3); // for MAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 3, 0, 1, 2, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 3); // for AMAX
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 4);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 4, 1);
ADD_MULTIBLOCK_PARTIAL_REDUCE_INST_BY_ID(int8_t, int8_t, int8_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(bhalf_t, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
......@@ -8,21 +8,27 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
......
......@@ -8,12 +8,15 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(half_t, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,30 +8,39 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
......
......@@ -8,12 +8,15 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(float, double, float, 7, 0, 0, 2, 1);
// clang-format on
......
......@@ -8,30 +8,39 @@ namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 4);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_THREADWISE_INST_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
......
#include "device_reduce_instance_threadwise.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 0, 0, 0, 2, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 4);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_THREADWISE_INST_BY_ID(int8_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment