Commit ab8e0f28 authored by Anthony Chang's avatar Anthony Chang
Browse files

keep up with recent changes in reduction API

parent 2c1ed8b2
...@@ -923,17 +923,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -923,17 +923,17 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock), decltype(d_reduce_thread_desc_mperblock),
reduce::Add<FloatReduceAcc>, reduce::Add,
false>; false>;
using ThreadwiseReduceD1 = using ThreadwiseReduceD1 =
ThreadwiseReduction<FloatReduceAcc, ThreadwiseReduction<FloatReduceAcc,
decltype(c_reduce_thread_desc_mperblock_nperblock), decltype(c_reduce_thread_desc_mperblock_nperblock),
decltype(d_reduce_thread_desc_mperblock), decltype(d_reduce_thread_desc_mperblock),
reduce::SquaredAdd<FloatReduceAcc>, reduce::SquaredAdd,
false>; false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::GetIdentityValue(); const auto d0_zeroVal = ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::GetIdentityValue(); const auto d1_zeroVal = ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); [&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
...@@ -951,7 +951,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -951,7 +951,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
BlockSize, BlockSize,
CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K CReduceThreadClusterLengths_MPerBlock_NPerBlock, // ThreadClusterLengths_M_K
Sequence<1, 0>, // ThreadClusterArrangeOrder Sequence<1, 0>, // ThreadClusterArrangeOrder
reduce::Add<FloatReduceAcc>, reduce::Add,
false>; false>;
static_for<0, mreduce_per_thread, 1>{}([&](auto i) { static_for<0, mreduce_per_thread, 1>{}([&](auto i) {
...@@ -984,8 +984,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -984,8 +984,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt; FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt<FloatReduceAcc, tensor_operation::element_wise::UnarySqrt{}(
FloatReduceAcc>{}(
divisor_sqrt, divisor); divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
......
...@@ -81,20 +81,19 @@ struct Add ...@@ -81,20 +81,19 @@ struct Add
} }
}; };
template <class T>
struct SquaredAdd struct SquaredAdd
{ {
using dataType = T; template <class T>
__host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); };
__device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
{ {
return operation == InMemoryDataOperationEnum::AtomicAdd || return operation == InMemoryDataOperationEnum::AtomicAdd ||
operation == InMemoryDataOperationEnum::Set; operation == InMemoryDataOperationEnum::Set;
}; };
template <class T>
__host__ __device__ inline constexpr void operator()(T& a, T b) const __host__ __device__ inline constexpr void operator()(T& a, T b) const
{ {
static_assert(is_same<T, float>::value || is_same<T, double>::value || static_assert(is_same<T, float>::value || is_same<T, double>::value ||
...@@ -106,7 +105,6 @@ struct SquaredAdd ...@@ -106,7 +105,6 @@ struct SquaredAdd
} }
}; };
template <class T>
struct Mul struct Mul
{ {
template <typename T> template <typename T>
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment