Commit ab8e0f28 authored by Anthony Chang's avatar Anthony Chang
Browse files

keep up with recent changes in reduction API

parent 2c1ed8b2
......@@ -923,17 +923,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::Add<FloatReduceAcc>,
reduce::Add,
false>;
using ThreadwiseReduceD1 =
ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock),
reduce::SquaredAdd<FloatReduceAcc>,
reduce::SquaredAdd,
false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetIdentityValue();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetIdentityValue();
const auto d0_zeroVal = ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
......@@ -951,7 +951,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder
reduce::Add<FloatReduceAcc>,
reduce::Add,
false>;
static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
......@@ -984,8 +984,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc,
FloatReduceAcc>{}(
tensor_operation::element_wise::UnarySqrt{}(
divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
......
......@@ -81,20 +81,19 @@ struct Add
}
};
template <class T>
struct SquaredAdd
{
using dataType = T;
template <class T>
__host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); };
__device__ static constexpr bool
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{
return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set;
};
template <class T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const
{
static_assert(is_same<T, float>::value || is_same<T, double>::value ||
......@@ -106,7 +105,6 @@ struct SquaredAdd
}
};
template <class T>
struct Mul
{
template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment