Commit bccc6d8b authored by wangshaojie6's avatar wangshaojie6
Browse files

merge develop and resolve conflict

parents c6b52884 91d8b7d6
...@@ -34,6 +34,7 @@ ...@@ -34,6 +34,7 @@
#include "reduction_enums.hpp" #include "reduction_enums.hpp"
#include "reduction_common.hpp" #include "reduction_common.hpp"
#include "host_reduce_util.hpp" #include "host_reduce_util.hpp"
#include "host_common_util.hpp"
#include "host_tensor.hpp" #include "host_tensor.hpp"
#include "data_type.hpp" #include "data_type.hpp"
...@@ -200,7 +201,7 @@ struct ReductionHost ...@@ -200,7 +201,7 @@ struct ReductionHost
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
using ck::type_convert; using ck::type_convert;
using ck::host_reduce::binop_with_nan_check2; using ck::host_reduce::binop_with_index_and_nan_check;
using ck::host_reduce::ReduceOpFn2; using ck::host_reduce::ReduceOpFn2;
using ck::host_reduce::ReduceOpZeroVal; using ck::host_reduce::ReduceOpZeroVal;
...@@ -211,8 +212,7 @@ struct ReductionHost ...@@ -211,8 +212,7 @@ struct ReductionHost
AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>(); AccDataType accuVal = ReduceOpZeroVal<AccDataType, ReduceOpId>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(IndexDataType i = 0; i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size()); for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -221,9 +221,9 @@ struct ReductionHost ...@@ -221,9 +221,9 @@ struct ReductionHost
preUnaryOp(currVal); preUnaryOp(currVal);
auto currIndex = i; auto currIndex = static_cast<IndexDataType>(i);
binop_with_nan_check2<AccDataType, PropagateNan>( binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex); opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
...@@ -247,9 +247,7 @@ struct ReductionHost ...@@ -247,9 +247,7 @@ struct ReductionHost
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
for(IndexDataType i = 0; for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
i < ck::type_convert<IndexDataType>(reduce_dim_indexes.size());
i++)
{ {
auto offset_reduce = auto offset_reduce =
get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]); get_offset_from_index<NumReduceDim>(reduceStrides, reduce_dim_indexes[i]);
...@@ -259,9 +257,9 @@ struct ReductionHost ...@@ -259,9 +257,9 @@ struct ReductionHost
preUnaryOp(currVal); preUnaryOp(currVal);
auto currIndex = i; auto currIndex = static_cast<IndexDataType>(i);
binop_with_nan_check2<AccDataType, PropagateNan>( binop_with_index_and_nan_check<AccDataType, IndexDataType, PropagateNan>(
opReduce2, accuVal, currVal, accuIndex, currIndex); opReduce2, accuVal, currVal, accuIndex, currIndex);
}; };
......
...@@ -11,6 +11,7 @@ namespace host { ...@@ -11,6 +11,7 @@ namespace host {
template <typename ADataType, template <typename ADataType,
typename BDataType, typename BDataType,
typename CDataType, typename CDataType,
typename AccDataType,
typename AElementwiseOperation, typename AElementwiseOperation,
typename BElementwiseOperation, typename BElementwiseOperation,
typename CElementwiseOperation> typename CElementwiseOperation>
...@@ -53,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator ...@@ -53,20 +54,20 @@ struct ReferenceGemm : public device::BaseOperator
auto f_mk_kn_mn = [&](auto m, auto n) { auto f_mk_kn_mn = [&](auto m, auto n) {
const int K = arg.a_m_k_.mDesc.GetLengths()[1]; const int K = arg.a_m_k_.mDesc.GetLengths()[1];
float v_acc = 0; AccDataType v_acc = 0;
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
float v_a; AccDataType v_a;
float v_b; AccDataType v_b;
arg.a_element_op_(v_a, static_cast<const float>(arg.a_m_k_(m, k))); arg.a_element_op_(v_a, static_cast<const AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(v_b, static_cast<const float>(arg.b_k_n_(k, n))); arg.b_element_op_(v_b, static_cast<const AccDataType>(arg.b_k_n_(k, n)));
v_acc += v_a * v_b; v_acc += v_a * v_b;
} }
float v_c; AccDataType v_c;
arg.c_element_op_(v_c, v_acc); arg.c_element_op_(v_c, v_acc);
......
...@@ -9,26 +9,11 @@ ...@@ -9,26 +9,11 @@
#include "device_reduce_instance_blockwise_i8_i8_i8.hpp" #include "device_reduce_instance_blockwise_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_i8_i32_i8.hpp" #include "device_reduce_instance_blockwise_i8_i32_i8.hpp"
#include "device_reduce_instance_blockwise_b16_f32_b16.hpp" #include "device_reduce_instance_blockwise_b16_f32_b16.hpp"
#include "device_reduce_instance_blockwise_second_call_f16_f16_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f16.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f32.hpp"
#include "device_reduce_instance_blockwise_second_call_f64_f64_f64.hpp"
#include "device_reduce_instance_blockwise_second_call_i8_i8_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_i32_i32_i8.hpp"
#include "device_reduce_instance_blockwise_second_call_f32_f32_b16.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_atomic_add_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp" #include "device_reduce_instance_multiblock_atomic_add_b16_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f16_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f16_f32_f16.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f32_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f32_f64_f32.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_f64_f64_f64.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i8_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_i8_i32_i8.hpp"
#include "device_reduce_instance_multiblock_partial_reduce_b16_f32_b16.hpp"
#include "device_reduce_instance_threadwise_f16_f16_f16.hpp" #include "device_reduce_instance_threadwise_f16_f16_f16.hpp"
#include "device_reduce_instance_threadwise_f16_f32_f16.hpp" #include "device_reduce_instance_threadwise_f16_f32_f16.hpp"
#include "device_reduce_instance_threadwise_f32_f32_f32.hpp" #include "device_reduce_instance_threadwise_f32_f32_f32.hpp"
......
...@@ -3,13 +3,27 @@ ...@@ -3,13 +3,27 @@
#include "reduction_operator_mapping.hpp" #include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp" #include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise.hpp" #include "device_reduce_multiblock.hpp"
namespace ck { namespace ck {
namespace tensor_operation { namespace tensor_operation {
namespace device { namespace device {
namespace device_reduce_instance { namespace device_reduce_instance {
using reduce_configuration_1_instances_blockwise = std::tuple<
// clang-format off
// BlockSize | MThreadClusterSize | KThreadClusterSize
ReductionConfiguration_1<256, 128, 2>,
ReductionConfiguration_1<256, 64, 4>,
ReductionConfiguration_1<256, 32, 8>,
ReductionConfiguration_1<256, 16, 16>,
ReductionConfiguration_1<256, 8, 32>,
ReductionConfiguration_1<256, 4, 64>,
ReductionConfiguration_1<256, 2, 128>,
ReductionConfiguration_1<256, 1, 256>
// clang-format on
>;
#ifdef QUICK_REDUCE_TEST #ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise = std::tuple< using reduce_configuration_2_instances_blockwise = std::tuple<
// clang-format off // clang-format off
...@@ -58,8 +72,8 @@ template <typename InDataType, ...@@ -58,8 +72,8 @@ template <typename InDataType,
int Rank, int Rank,
int NumReduceDim, int NumReduceDim,
ReduceTensorOp ReduceOpId, ReduceTensorOp ReduceOpId,
NanPropagation NanOpt, bool PropagateNan,
ReduceTensorIndices IndicesOpt> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances)
{ {
...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise( ...@@ -73,92 +87,94 @@ void add_device_reduce_instance_blockwise(
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX); ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES); constexpr bool OutputIndex = Indexable && UseIndex;
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true; static_for<0, std::tuple_size<reduce_configuration_1_instances_blockwise>::value, 1>{}(
[&](auto i) {
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) { using cfg1 = remove_cvref_t<decltype(
using cfg1 = std::get<i.value>(reduce_configuration_1_instances_blockwise{}))>;
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}(
static_for<0, std::tuple_size<reduce_configuration_2_instances_blockwise>::value, 1>{}( [&](auto j) {
[&](auto j) { using cfg2 = remove_cvref_t<decltype(
using cfg2 = remove_cvref_t<decltype( std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
std::get<j.value>(reduce_configuration_2_instances_blockwise{}))>;
using ReduceOpInstance =
using ReduceOpInstance = DeviceReduceBlockWise<InDataType, DeviceReduceMultiBlock<InDataType,
AccDataType, AccDataType,
OutDataType, OutDataType,
Rank, Rank,
NumReduceDim, NumReduceDim,
ReduceOperation, ReduceOperation,
InElementwiseOperation, InElementwiseOperation,
AccElementwiseOperation, AccElementwiseOperation,
PropagateNan, InMemoryDataOperationEnum::Set,
NeedIndices, PropagateNan,
cfg1::BlockSize_, OutputIndex,
cfg1::MThreadClusterSize_, false, // HaveIndexInputIfOutputIndex
cfg1::KThreadClusterSize_, cfg1::BlockSize_,
cfg2::MThreadSliceSize_, cfg1::MThreadClusterSize_,
cfg2::KThreadSliceSize_, cfg1::KThreadClusterSize_,
cfg2::InSrcVectorDim_, cfg2::MThreadSliceSize_,
cfg2::InSrcVectorSize_, cfg2::KThreadSliceSize_,
cfg2::OutDstVectorSize_>; cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
device_op_instances.push_back( cfg2::OutDstVectorSize_>;
std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
}); device_op_instances.push_back(
}); std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
}; };
#define ADD_BLOCKWISE_INST_BY_TYPE( \ #define ADD_BLOCKWISE_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise<inT, \ template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
NanOpt, \ PropagateNan, \
IndicesOpt>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<DeviceReducePtr< \
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \ typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \ typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \ AccElementwiseOperation>> & \
device_op_instances) device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \ ADD_BLOCKWISE_INST_REF_BY_TYPE(inT, \
compT, \ compT, \
outT, \ outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \ static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \ static_cast<bool>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \ static_cast<bool>(IndicesOpt), \
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
} // namespace device_reduce_instance } // namespace device_reduce_instance
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_B16_F32_B16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F16_F16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F16_F32_F16_HPP
#include "reduction_enums.hpp" #include "data_type.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F32_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I32_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP #ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP #define DEVICE_REDUCE_INSTANCE_BLOCKWISE_I8_I8_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise.hpp" #include "device_reduce_instance_blockwise.hpp"
namespace ck { namespace ck {
......
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_HPP
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_impl_common.hpp"
#include "device_reduce_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
#ifdef QUICK_REDUCE_TEST
using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 1, 1, 1, 3>
// clang-format on
>;
#else
using reduce_configuration_2_instances_blockwise_second_call = std::tuple<
// clang-format off
// InSrcVectorDim | InSrcVectorSize | OutDstVectorSize | MThreadSliceSize | KThreadSliceSize
ReductionConfiguration_2<1, 4, 1, 1, 8>,
ReductionConfiguration_2<1, 4, 1, 1, 4>,
ReductionConfiguration_2<1, 2, 1, 1, 2>,
ReductionConfiguration_2<1, 1, 1, 1, 3>,
ReductionConfiguration_2<1, 1, 1, 1, 5>,
ReductionConfiguration_2<1, 1, 1, 1, 7>,
ReductionConfiguration_2<1, 1, 1, 1, 11>
// clang-format on
>;
#endif
template <typename AccDataType, ReduceTensorOp ReduceOpId>
using deviceReduceBlockWiseSecondCallPtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::AccElementwiseOperation>;
template <typename InDataType,
typename AccDataType,
typename OutDataType,
int Rank,
int NumReduceDim,
ReduceTensorOp ReduceOpId,
NanPropagation NanOpt,
ReduceTensorIndices IndicesOpt>
void add_device_reduce_instance_blockwise_second_call(
std::vector<deviceReduceBlockWiseSecondCallPtrType<AccDataType, ReduceOpId>>&
device_op_instances)
{
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType;
using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::
InElementwiseOperation;
using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, false, true>::
AccElementwiseOperation;
constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
ReduceOpId == ReduceTensorOp::AMAX);
constexpr bool NeedIndices = Indexable && (IndicesOpt != ReduceTensorIndices::NO_INDICES);
constexpr bool PropagateNan = (NanOpt == NanPropagation::NOT_PROPAGATE_NAN) ? false : true;
static_assert(std::is_same<InDataType, AccDataType>::value,
"InDataType and AccDataType should be the same to use "
"add_device_reduce_instance_blockwise_second_call!");
static_for<0, std::tuple_size<reduce_configuration_1_instances>::value, 1>{}([&](auto i) {
using cfg1 =
remove_cvref_t<decltype(std::get<i.value>(reduce_configuration_1_instances{}))>;
static_for<0,
std::tuple_size<reduce_configuration_2_instances_blockwise_second_call>::value,
1>{}([&](auto j) {
using cfg2 = remove_cvref_t<decltype(
std::get<j.value>(reduce_configuration_2_instances_blockwise_second_call{}))>;
using ReduceOpInstance = DeviceReduceBlockWiseSecondCall<InDataType,
AccDataType,
OutDataType,
Rank,
NumReduceDim,
ReduceOperation,
InElementwiseOperation,
AccElementwiseOperation,
PropagateNan,
NeedIndices,
cfg1::BlockSize_,
cfg1::MThreadClusterSize_,
cfg1::KThreadClusterSize_,
cfg2::MThreadSliceSize_,
cfg2::KThreadSliceSize_,
cfg2::InSrcVectorDim_,
cfg2::InSrcVectorSize_,
cfg2::OutDstVectorSize_>;
device_op_instances.push_back(std::make_unique<ReduceOpInstance>(ReduceOpInstance{}));
});
});
};
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector<deviceReduceBlockWiseSecondCallPtrType<compT, ReduceOpId>> & \
device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_SECOND_CALL_INST_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \
Rank, \
NumReduceDim)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise_second_call<inT, \
compT, \
outT, \
Rank, \
NumReduceDim, \
ReduceOpId, \
NanOpt, \
IndicesOpt>( \
std::vector< \
DeviceReducePtr<typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, false, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_TYPE(inT, \
compT, \
outT, \
static_cast<ReduceTensorOp>(ReduceOpId), \
static_cast<NanPropagation>(NanOpt), \
static_cast<ReduceTensorIndices>(IndicesOpt), \
Rank, \
NumReduceDim)
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F16_F16_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(half_t, half_t, half_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_B16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_B16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, bhalf_t, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F16_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, half_t, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F32_F32_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(float, float, float, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F32_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, float, 7, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_F64_F64_F64_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 5, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 3); // for NORM2
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 7, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 3); // for MIN
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 2, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 3); // for MAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 3, 0, 1, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 3); // for AMAX
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(double, double, double, 4, 0, 1, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
#ifndef DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I32_I32_I8_HPP
#define DEVICE_REDUCE_INSTANCE_BLOCKWISE_SECOND_CALL_I32_I32_I8_HPP
#include "reduction_enums.hpp"
#include "reduction_operator_mapping.hpp"
#include "device_reduce_instance_blockwise_second_call.hpp"
namespace ck {
namespace tensor_operation {
namespace device {
namespace device_reduce_instance {
// clang-format off
// InDataType | AccDataType | OutDataType | ReduceOpId | NanPropaOpt | IndicesOpt | Rank | NumReduceDim
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 3); // for ADD
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 0, 0, 0, 2, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 3); // for AVG
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 4);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 4, 1);
ADD_BLOCKWISE_SECOND_CALL_INST_REF_BY_ID(int32_t, int32_t, int8_t, 5, 0, 0, 2, 1);
// clang-format on
} // namespace device_reduce_instance
} // namespace device
} // namespace tensor_operation
} // namespace ck
#endif
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment