"docs/API_Reference_Guide.rst" did not exist on "cb3fac4d2a4e8bcac0e15118d7afd4af93301132"
Commit cc50b687 authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent ab8e0f28
......@@ -932,8 +932,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce::SquaredAdd,
false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d1_zeroVal = ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d0_zeroVal =
ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d1_zeroVal =
ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}(
......@@ -984,8 +986,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt{}(
divisor_sqrt, divisor);
tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
});
......
......@@ -84,7 +84,10 @@ struct Add
struct SquaredAdd
{
template <class T>
__host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); };
__host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment