Commit 6805df0e authored by Chao Liu's avatar Chao Liu
Browse files

Merge remote-tracking branch 'origin/develop' into gelu

parents 1fdbe3fe e4584d91
...@@ -28,6 +28,7 @@ ...@@ -28,6 +28,7 @@
#include "config.hpp" #include "config.hpp"
#include "data_type.hpp" #include "data_type.hpp"
#include "type.hpp"
namespace ck { namespace ck {
...@@ -54,64 +55,92 @@ namespace reduce { ...@@ -54,64 +55,92 @@ namespace reduce {
// accumulated index also need be // accumulated index also need be
// changed. // changed.
template <class T>
struct Add struct Add
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::AtomicAdd || return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set; operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Add accumulator!");
a = a + b;
}
}; };
template <class T>
struct Mul struct Mul
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(1.0f); }; {
return type_convert<T>(1.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a * b; } template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, int32_t>::value,
"The data type is not supported by the Mul accumulator!");
a = a * b;
}
}; };
template <class T>
struct Max struct Max
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue() __host__ __device__ static constexpr T GetIdentityValue()
{ {
return NumericLimits<T>::Lowest(); return NumericLimits<T>::Lowest();
}; };
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Max accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -120,28 +149,41 @@ struct Max ...@@ -120,28 +149,41 @@ struct Max
} }
}; };
template <class T>
struct Min struct Min
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return NumericLimits<T>::Max(); }; {
return NumericLimits<T>::Max();
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_min to be added // ToChange: atomic_min to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the Min accumulator!");
if(a > b) if(a > b)
{ {
a = b; a = b;
...@@ -150,28 +192,41 @@ struct Min ...@@ -150,28 +192,41 @@ struct Min
} }
}; };
template <class T>
struct AMax struct AMax
{ {
using dataType = T; template <typename T>
__host__ __device__ static constexpr T GetIdentityValue()
__host__ __device__ static constexpr T GetIdentityValue() { return static_cast<T>(0.0f); }; {
return type_convert<T>(0.0f);
};
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
// ToChange: atomic_max to be added // ToChange: atomic_max to be added
return operation == InMemoryDataOperationEnum::Set; return operation == InMemoryDataOperationEnum::Set;
}; };
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
a = b; a = b;
} }
template <typename T>
__host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const __host__ __device__ inline constexpr void operator()(T& a, T b, bool& changed) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
is_same<T, half_t>::value || is_same<T, int32_t>::value ||
is_same<T, int8_t>::value,
"The data type is not supported by the AMax accumulator!");
if(a < b) if(a < b)
{ {
a = b; a = b;
...@@ -181,7 +236,7 @@ struct AMax ...@@ -181,7 +236,7 @@ struct AMax
}; };
template <typename T> template <typename T>
T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation) constexpr T GetIdentityValueForInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
T result = ck::type_convert<T>(0.0f); T result = ck::type_convert<T>(0.0f);
...@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation ...@@ -191,6 +246,44 @@ T GetIdentityValueueForInMemoryDataOperation(InMemoryDataOperationEnum operation
return (result); return (result);
}; };
template <InMemoryDataOperationEnum Operation, typename DataType>
struct InMemoryDataOperatonSupportedOnDataType
{
static constexpr bool value = false;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicAdd, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::AtomicMax, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Set, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, bhalf_t>::value ||
is_same<DataType, int8_t>::value || is_same<DataType, int32_t>::value;
};
template <typename DataType>
struct InMemoryDataOperatonSupportedOnDataType<InMemoryDataOperationEnum::Add, DataType>
{
static constexpr bool value =
is_same<DataType, float>::value || is_same<DataType, double>::value ||
is_same<DataType, half_t>::value || is_same<DataType, int8_t>::value ||
is_same<DataType, int32_t>::value;
};
}; // end of namespace reduce }; // end of namespace reduce
} // end of namespace ck } // end of namespace ck
......
...@@ -174,15 +174,18 @@ struct ReductionHost ...@@ -174,15 +174,18 @@ struct ReductionHost
const InDataType* in_data, const InDataType* in_data,
float beta, float beta,
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
if constexpr(OutputIndex) if constexpr(OutputIndex)
{ {
RunImpl_with_index(alpha, in_data, beta, out_data, out_indices); RunImpl_with_index(
alpha, in_data, beta, out_data, out_indices, in_elementwise_op, acc_elementwise_op);
} }
else else
{ {
RunImpl_no_index(alpha, in_data, beta, out_data); RunImpl_no_index(alpha, in_data, beta, out_data, in_elementwise_op, acc_elementwise_op);
}; };
}; };
...@@ -190,7 +193,9 @@ struct ReductionHost ...@@ -190,7 +193,9 @@ struct ReductionHost
const InDataType* in_data, const InDataType* in_data,
float beta, float beta,
OutDataType* out_data, OutDataType* out_data,
IndexDataType* out_indices) IndexDataType* out_indices,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
...@@ -200,12 +205,10 @@ struct ReductionHost ...@@ -200,12 +205,10 @@ struct ReductionHost
ReduceOperation, ReduceOperation,
AccDataType, AccDataType,
IndexDataType>; IndexDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++) for(std::size_t i = 0; i < reduce_dim_indexes.size(); i++)
...@@ -236,7 +239,7 @@ struct ReductionHost ...@@ -236,7 +239,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
IndexDataType accuIndex = 0; IndexDataType accuIndex = 0;
auto offset_invariant = auto offset_invariant =
...@@ -297,7 +300,12 @@ struct ReductionHost ...@@ -297,7 +300,12 @@ struct ReductionHost
}; };
}; };
void RunImpl_no_index(float alpha, const InDataType* in_data, float beta, OutDataType* out_data) void RunImpl_no_index(float alpha,
const InDataType* in_data,
float beta,
OutDataType* out_data,
InElementwiseOperation in_elementwise_op,
AccElementwiseOperation acc_elementwise_op)
{ {
using ck::float_equal_one; using ck::float_equal_one;
using ck::float_equal_zero; using ck::float_equal_zero;
...@@ -306,12 +314,9 @@ struct ReductionHost ...@@ -306,12 +314,9 @@ struct ReductionHost
using Accumulation = using Accumulation =
ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>; ck::detail::AccumulateWithNanCheck<PropagateNan, ReduceOperation, AccDataType>;
InElementwiseOperation in_elementwise_op(divider);
AccElementwiseOperation acc_elementwise_op(divider);
if constexpr(NumInvariantDim == 0) if constexpr(NumInvariantDim == 0)
{ {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
for(const auto& reduce_index : reduce_dim_indexes) for(const auto& reduce_index : reduce_dim_indexes)
{ {
...@@ -338,7 +343,7 @@ struct ReductionHost ...@@ -338,7 +343,7 @@ struct ReductionHost
else else
{ {
auto thread_reduce_func = [&](auto invariant_index) { auto thread_reduce_func = [&](auto invariant_index) {
AccDataType accuVal = ReduceOperation::GetIdentityValue(); AccDataType accuVal = ReduceOperation::template GetIdentityValue<AccDataType>();
auto offset_invariant = auto offset_invariant =
get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index); get_offset_from_index<NumInvariantDim>(invariantStrides, invariant_index);
......
...@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator ...@@ -106,9 +106,8 @@ struct ReferenceConvBwdData : public device::BaseOperator
} }
} }
float v_in; arg.in_element_op_(v_acc, v_acc);
arg.in_element_op_(v_in, v_acc); arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_acc);
arg.input_(n, c, wi) = ck::type_convert<InDataType>(v_in);
}; };
make_ParallelTensorFunctor(f_ncw, make_ParallelTensorFunctor(f_ncw,
......
...@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator ...@@ -66,8 +66,8 @@ struct ReferenceGemmBias2D : public device::BaseOperator
for(int k = 0; k < K; ++k) for(int k = 0; k < K; ++k)
{ {
arg.a_element_op_(a, arg.a_m_k_(m, k)); arg.a_element_op_(a, static_cast<AccDataType>(arg.a_m_k_(m, k)));
arg.b_element_op_(b, arg.b_k_n_(k, n)); arg.b_element_op_(b, static_cast<AccDataType>(arg.b_k_n_(k, n)));
acc += a * b; acc += a * b;
} }
......
...@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple< ...@@ -61,10 +61,10 @@ using reduce_configuration_2_instances_blockwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOpId> template <ReduceTensorOp ReduceOpId>
using deviceReduceBlockWisePtrType = DeviceReducePtr< using deviceReduceBlockWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -75,14 +75,13 @@ template <typename InDataType, ...@@ -75,14 +75,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_blockwise( void add_device_reduce_instance_blockwise(
std::vector<deviceReduceBlockWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceBlockWisePtrType<ReduceOpId>>& device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise( ...@@ -137,7 +136,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceBlockWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_BLOCKWISE_INST_BY_ID( \ #define ADD_BLOCKWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -160,11 +159,7 @@ void add_device_reduce_instance_blockwise( ...@@ -160,11 +159,7 @@ void add_device_reduce_instance_blockwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceBlockWisePtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_BLOCKWISE_INST_REF_BY_ID( \ #define ADD_BLOCKWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple< ...@@ -61,12 +61,10 @@ using reduce_configuration_2_instances_multiblock_atomic_add = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOperation> template <ReduceTensorOp ReduceOperation>
using deviceReduceMultiBlockAtomicAddPtrType = using deviceReduceMultiBlockAtomicAddPtrType = DeviceReducePtr<
DeviceReducePtr<typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>:: typename reduce_unary_operator<ReduceOperation, true, true>::InElementwiseOperation,
InElementwiseOperation, typename reduce_unary_operator<ReduceOperation, true, true>::AccElementwiseOperation>;
typename reduce_unary_operator<AccDataType, ReduceOperation, true, true>::
AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -77,15 +75,13 @@ template <typename InDataType, ...@@ -77,15 +75,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_multiblock_atomic_add( void add_device_reduce_instance_multiblock_atomic_add(
std::vector<deviceReduceMultiBlockAtomicAddPtrType<AccDataType, ReduceOpId>>& std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>>& device_op_instances)
device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -158,8 +154,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceMultiBlockAtomicAddPtrType<compT, ReduceOpId>> & \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -182,11 +177,7 @@ void add_device_reduce_instance_multiblock_atomic_add( ...@@ -182,11 +177,7 @@ void add_device_reduce_instance_multiblock_atomic_add(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceMultiBlockAtomicAddPtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \ #define ADD_MULTIBLOCK_ATOMIC_ADD_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple< ...@@ -47,10 +47,10 @@ using reduce_configuration_2_instances_threadwise = std::tuple<
>; >;
#endif #endif
template <typename AccDataType, ReduceTensorOp ReduceOpId> template <ReduceTensorOp ReduceOpId>
using deviceReduceThreadWisePtrType = DeviceReducePtr< using deviceReduceThreadWisePtrType = DeviceReducePtr<
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation, typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation,
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::AccElementwiseOperation>; typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation>;
template <typename InDataType, template <typename InDataType,
typename AccDataType, typename AccDataType,
...@@ -61,14 +61,13 @@ template <typename InDataType, ...@@ -61,14 +61,13 @@ template <typename InDataType,
bool PropagateNan, bool PropagateNan,
bool UseIndex> bool UseIndex>
void add_device_reduce_instance_threadwise( void add_device_reduce_instance_threadwise(
std::vector<deviceReduceThreadWisePtrType<AccDataType, ReduceOpId>>& device_op_instances) std::vector<deviceReduceThreadWisePtrType<ReduceOpId>>& device_op_instances)
{ {
using ReduceOperation = typename reduce_binary_operator<AccDataType, ReduceOpId>::opType; using ReduceOperation = typename reduce_binary_operator<ReduceOpId>::opType;
using InElementwiseOperation = using InElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>::InElementwiseOperation; typename reduce_unary_operator<ReduceOpId, true, true>::InElementwiseOperation;
using AccElementwiseOperation = using AccElementwiseOperation =
typename reduce_unary_operator<AccDataType, ReduceOpId, true, true>:: typename reduce_unary_operator<ReduceOpId, true, true>::AccElementwiseOperation;
AccElementwiseOperation;
constexpr bool Indexable = constexpr bool Indexable =
(ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX || (ReduceOpId == ReduceTensorOp::MIN || ReduceOpId == ReduceTensorOp::MAX ||
...@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise( ...@@ -114,7 +113,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<deviceReduceThreadWisePtrType<compT, ReduceOpId>> & device_op_instances) std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
#define ADD_THREADWISE_INST_BY_ID( \ #define ADD_THREADWISE_INST_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
...@@ -137,11 +136,7 @@ void add_device_reduce_instance_threadwise( ...@@ -137,11 +136,7 @@ void add_device_reduce_instance_threadwise(
ReduceOpId, \ ReduceOpId, \
PropagateNan, \ PropagateNan, \
UseIndex>( \ UseIndex>( \
std::vector<DeviceReducePtr< \ std::vector<deviceReduceThreadWisePtrType<ReduceOpId>> & device_op_instances)
typename reduce_unary_operator<compT, ReduceOpId, true, true>::InElementwiseOperation, \
typename reduce_unary_operator<compT, ReduceOpId, true, true>:: \
AccElementwiseOperation>> & \
device_op_instances)
#define ADD_THREADWISE_INST_REF_BY_ID( \ #define ADD_THREADWISE_INST_REF_BY_ID( \
inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \ inT, compT, outT, ReduceOpId, NanOpt, IndicesOpt, Rank, NumReduceDim) \
......
...@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d) ...@@ -30,6 +30,7 @@ add_subdirectory(gemm_bias2d)
add_subdirectory(gemm_bias_relu) add_subdirectory(gemm_bias_relu)
add_subdirectory(gemm_bias_relu_add) add_subdirectory(gemm_bias_relu_add)
add_subdirectory(gemm_reduce) add_subdirectory(gemm_reduce)
add_subdirectory(gemm_bias_add_reduce)
add_subdirectory(batched_gemm) add_subdirectory(batched_gemm)
add_subdirectory(conv1d_fwd) add_subdirectory(conv1d_fwd)
add_subdirectory(conv2d_fwd) add_subdirectory(conv2d_fwd)
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gkn_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
...@@ -21,11 +21,11 @@ template <ck::index_t... Is> ...@@ -21,11 +21,11 @@ template <ck::index_t... Is>
using S = ck::Sequence<Is...>; using S = ck::Sequence<Is...>;
using PassThrough = ck::tensor_operation::element_wise::PassThrough; using PassThrough = ck::tensor_operation::element_wise::PassThrough;
using ReduceSum = ck::reduce::Add<F32>; using ReduceSum = ck::reduce::Add;
using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>; using ReduceOps = ck::Tuple<ReduceSum, ReduceSum>;
using Identity = ck::tensor_operation::element_wise::UnaryIdentic<F32, F32, false>; using Identity = ck::tensor_operation::element_wise::PassThrough;
using Square = ck::tensor_operation::element_wise::UnarySquare<F32, F32, false>; using Square = ck::tensor_operation::element_wise::UnarySquare;
using DInElementOps = ck::Tuple<Identity, Square>; using DInElementOps = ck::Tuple<Identity, Square>;
using DOutElementOps = ck::Tuple<Identity, Identity>; using DOutElementOps = ck::Tuple<Identity, Identity>;
...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in ...@@ -62,12 +62,9 @@ using device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_in
>; >;
void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances( void add_device_batched_gemm_reduce_xdl_cshuffle_f16_f16_f16_f32_f32_gkm_gnk_gmn_instances(
std::vector<DeviceGemmReducePtr<DPtrsGlobal, std::vector<
PassThrough, DeviceGemmReducePtr<PassThrough, PassThrough, PassThrough, DInElementOps, DOutElementOps>>&
PassThrough, instances)
PassThrough,
DInElementOps,
DOutElementOps>>& instances)
{ {
add_device_operation_instances( add_device_operation_instances(
instances, instances,
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment