Unverified Commit 1f543bfa authored by Qianfeng's avatar Qianfeng Committed by GitHub
Browse files

Regulate reduction accumulator operations and Element-wise operations (#274)

* Remove template from Reducton operation classes and add template to their operator() and GetIdentityValue() interfaces

* Change to unary elementwise operators and the reduce_unary_operator (class for mapping) and dependent variations in all host layers

* Remove the data type template parameter from reduce_binary_operator (class for mapping) and dependent variations in host layers

* Add InMemoryDataOperatonSupportedOnDataType to check the matching between data type and InMemoryDataOperation

* Use struct-scope operator template instantiation for binary and unary element-wise operations

* Change a few more elementwise operations to use template for operator()

* Tiny correction in Normalize operator

* Add static_assert to check the data type appliability for some reduction accumulator and element-wise operatons

* Correction in some examples with regard to using ReduceAccDataType

* Use static_assert for UnaryDivide

* Update to merged codes to use Element-wise operations and Reduction Accumulator operations correctly

* Tiny fix with regard to SetWorkSpacePointer()
parent 63cdd923
...@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -171,15 +171,15 @@ struct GridwiseReduction_mk_to_m_multiblock
AccDataType beta, AccDataType beta,
OutDataType* const __restrict__ p_out_value_global) OutDataType* const __restrict__ p_out_value_global)
{ {
const auto identityVal = ReduceOperation::GetIdentityValue(); const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
// LDS // LDS
__shared__ AccDataType p_reduce_work_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_buffer[BlockSize];
const auto in_global_val_buf = const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal)); ReduceOperation::template GetIdentityValue<InDataType>());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock ...@@ -358,12 +358,12 @@ struct GridwiseReduction_mk_to_m_multiblock
__shared__ AccDataType p_reduce_work_val_buffer[BlockSize]; __shared__ AccDataType p_reduce_work_val_buffer[BlockSize];
__shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize]; __shared__ IndexDataType p_reduce_work_idx_buffer[BlockSize];
const auto identityVal = ReduceOperation::GetIdentityValue(); const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf = const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal)); ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto out_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
......
...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -135,12 +135,12 @@ struct GridwiseReduction_mk_to_m_threadwise
ReduceOperation, ReduceOperation,
PropagateNan>; PropagateNan>;
const auto identityVal = ReduceOperation::GetIdentityValue(); const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf = const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal)); ReduceOperation::template GetIdentityValue<InDataType>());
auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( auto dst_global_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_out_value_global, out_grid_desc_m.GetElementSpaceSize()); p_out_value_global, out_grid_desc_m.GetElementSpaceSize());
...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise ...@@ -276,12 +276,12 @@ struct GridwiseReduction_mk_to_m_threadwise
(void)acc_elementwise_op; (void)acc_elementwise_op;
const auto identityVal = ReduceOperation::GetIdentityValue(); const auto identityVal = ReduceOperation::template GetIdentityValue<AccDataType>();
const auto in_global_val_buf = const auto in_global_val_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
make_dynamic_buffer<AddressSpaceEnum::Global>(p_in_value_global, p_in_value_global,
in_grid_desc_m_k.GetElementSpaceSize(), in_grid_desc_m_k.GetElementSpaceSize(),
type_convert<InDataType>(identityVal)); ReduceOperation::template GetIdentityValue<InDataType>());
const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>( const auto in_global_idx_buf = make_dynamic_buffer<AddressSpaceEnum::Global>(
p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize()); p_in_index_global, in_grid_desc_m_k.GetElementSpaceSize());
......
...@@ -927,7 +927,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -927,7 +927,8 @@ struct GridwiseGemmBiasAddReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_zeroVal = DReduceOperation::GetIdentityValue(); const auto d_zeroVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_zeroVal; }); [&](auto I) { d_thread_buf(I) = d_zeroVal; });
......
...@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -816,7 +816,8 @@ struct GridwiseGemmReduce_k0mk1_k0nk1_mn_xdl_cshuffle_v1
false>; false>;
// Global write Gemm shuffle + reduction // Global write Gemm shuffle + reduction
const auto d_identityVal = DReduceOperation::GetIdentityValue(); const auto d_identityVal =
DReduceOperation::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto I) { d_thread_buf(I) = d_identityVal; }); [&](auto I) { d_thread_buf(I) = d_identityVal; });
......
...@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe ...@@ -37,7 +37,7 @@ __global__ void kernel_buffer_set_value(const Grid1dBufferDescType grid_1d_buffe
{ {
using PassThroughOp = tensor_operation::element_wise::UnaryIdentic<DataType, DataType>; using PassThroughOp = tensor_operation::element_wise::PassThrough;
constexpr auto I0 = Number<0>{}; constexpr auto I0 = Number<0>{};
......
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "config.hpp" #include "config.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "type.hpp"
namespace ck { namespace ck {
...@@ -54,64 +55,92 @@ namespace reduce { ...@@ -54,64 +55,92 @@ namespace reduce {
// accumulated index also need be // accumulated index also need be
// changed. // changed.
template <class T>
struct Add struct Add
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::AtomicAdd || return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set; operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
}; };
template <class T>
struct Mul struct Mul
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); }; {
return type_convert<T>(1.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
}; };
template <class T>
struct Max struct Max
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -120,28 +149,41 @@ struct Max ...@@ -120,28 +149,41 @@ struct Max
} }
}; };
template <class T>
struct Min struct Min
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); }; {
return NumericLimits<T>::Max();
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_min to be added // ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
{ {
a = b; a = b;
...@@ -150,28 +192,41 @@ struct Min ...@@ -150,28 +192,41 @@ struct Min
} }
}; };
template <class T>
struct AMax struct AMax
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -181,7 +236,7 @@ struct AMax ...@@ -181,7 +236,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation) constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
...@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation ...@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return (result); return (result);
}; };
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
};
}; // end of namespace reduce }; // end of namespace reduce
} // end of namespace ck } // end of namespace ck
......
...@@ -174,15 +174,18 @@ struct ReductionHost ...@@ -174,15 +174,18 @@ struct ReductionHost
const InDataType* in_data, const InDataType* in_data,
float beta, float beta,
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
if constexpr(OutputIndex) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(
alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op);
} }
else else
{ {
RunImpl_no_index(alpha, in_data, beta, out_data); RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op);
}; };
}; };
...@@ -190,7 +193,9 @@ struct ReductionHost ...@@ -190,7 +193,9 @@ struct ReductionHost
const InDataType* in_data, const InDataType* in_data,
float beta, float beta,
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
...@@ -200,12 +205,10 @@ struct ReductionHost ...@@ -200,12 +205,10 @@ struct ReductionHost
ReduceOperation, ReduceOperation,
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
...@@ -236,7 +239,7 @@ struct ReductionHost ...@@ -236,7 +239,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
...@@ -297,7 +300,12 @@ struct ReductionHost ...@@ -297,7 +300,12 @@ struct ReductionHost
}; };
}; };
void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data) void RunImpl_no_index(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
...@@ -306,12 +314,9 @@ struct ReductionHost ...@@ -306,12 +314,9 @@ struct ReductionHost
using Accumulation = using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -338,7 +343,7 @@ struct ReductionHost ...@@ -338,7 +343,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
......
...@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float v_in; arg.in_element_op_(v_acc, v_acc);
arg.in_element_op_(v_in, v_acc); arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
......
...@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
arg.a_element_op_(a, arg.a_m_k_(m, k)); arg.a_element_op_(a, static_cast<AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(b, arg.b_k_n_(k, n)); arg.b_element_op_(b, static_cast<AccDataType>(arg.b_k_n_(k, n)));
acc += a * b; acc += a * b;
} }
......
...@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple< ...@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOpId> template <ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr< using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -75,14 +75,13 @@ template <typename InDataType, ...@@ -75,14 +75,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise( ...@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise( ...@@ -150,21 +149,17 @@ void add_device_reduce_instance_blockwise(
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_BLOCKWISE_INST_REF_BY_TYPE( \ #define ADD_BLOCKWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_blockwise<inT, \ extern template void add_device_reduce_instance_blockwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< ...@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOperation> template <ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType = using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr<
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>:: typename reduce_unary_operator<ReduceOperation, true, true>::InElementwiseOperation,
InElementwiseOperation, typename reduce_unary_operator<ReduceOperation, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -77,15 +75,13 @@ template <typename InDataType, ...@@ -77,15 +75,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_multiblock_atomic_add( void add_device_reduce_instance_multiblock_atomic_add(
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>& std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>>& device_op_instances)
device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -172,21 +167,17 @@ void add_device_reduce_instance_multiblock_atomic_add(
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \ extern template void add_device_reduce_instance_multiblock_atomic_add<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple< ...@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOpId> template <ReduceTensorOp ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr< using deviceReduceThreadWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -61,14 +61,13 @@ template <typename InDataType, ...@@ -61,14 +61,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_threadwise( void add_device_reduce_instance_threadwise(
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceThreadWisePtrType<ReduceOpId>>& device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise( ...@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID( \ #define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise( ...@@ -127,21 +126,17 @@ void add_device_reduce_instance_threadwise(
Rank, \ Rank, \
NumReduceDim) NumReduceDim)
#define ADD_THREADWISE_INST_REF_BY_TYPE( \ #define ADD_THREADWISE_INST_REF_BY_TYPE( \
inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, PropagateNan, UseIndex, Rank, NumReduceDim) \
extern template void add_device_reduce_instance_threadwise<inT, \ extern template void add_device_reduce_instance_threadwise<inT, \
compT, \ compT, \
outT, \ outT, \
Rank, \ Rank, \
NumReduceDim, \ NumReduceDim, \
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID( \ #define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
......
...@@ -21,12 +21,12 @@ template <ck::index_t... Is> ...@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using DOutElementOps = ck::Tuple<Div, Div>;
......
...@@ -21,12 +21,12 @@ template <ck::index_t... Is> ...@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using DOutElementOps = ck::Tuple<Div, Div>;
......
...@@ -21,12 +21,12 @@ template <ck::index_t... Is> ...@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using DOutElementOps = ck::Tuple<Div, Div>;
......
...@@ -21,12 +21,12 @@ template <ck::index_t... Is> ...@@ -21,12 +21,12 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Div = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, true>; using Div = ck::tensor_operation::element_wise::UnaryDivide;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Div, Div>; using DOutElementOps = ck::Tuple<Div, Div>;
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment