"examples/dreambooth/train_dreambooth_lora.py" did not exist on "e895952816dc4a9a04d1d82fb5f8bdee59341c06"
Commit cc50b687 authored by Anthony Chang's avatar Anthony Chang
Browse files

format

parent ab8e0f28
...@@ -932,8 +932,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -932,8 +932,10 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
reduce::SquaredAdd, reduce::SquaredAdd,
false>; false>;
const auto d0_zeroVal = ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>(); const auto d0_zeroVal =
const auto d1_zeroVal = ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>(); ThreadwiseReduceD0::Op::template GetIdentityValue<FloatReduceAcc>();
const auto d1_zeroVal =
ThreadwiseReduceD1::Op::template GetIdentityValue<FloatReduceAcc>();
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
[&](auto i) { d0_thread_buf(i) = d0_zeroVal; }); [&](auto i) { d0_thread_buf(i) = d0_zeroVal; });
static_for<0, mreduce_per_thread, 1>{}( static_for<0, mreduce_per_thread, 1>{}(
...@@ -984,8 +986,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1 ...@@ -984,8 +986,7 @@ struct GridwiseGemmLayernorm_k0mk1_k0nk1_mn_xdl_cshuffle_v1
FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum; FloatReduceAcc numerator = c_reduce_thread_buf(dst_offset) - avg_sum;
FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum; FloatReduceAcc divisor = epsilon + avg_squared_sum - avg_sum * avg_sum;
FloatReduceAcc divisor_sqrt; FloatReduceAcc divisor_sqrt;
tensor_operation::element_wise::UnarySqrt{}( tensor_operation::element_wise::UnarySqrt{}(divisor_sqrt, divisor);
divisor_sqrt, divisor);
c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt; c_reduce_thread_buf(dst_offset) = numerator / divisor_sqrt;
}); });
......
...@@ -84,7 +84,10 @@ struct Add ...@@ -84,7 +84,10 @@ struct Add
struct SquaredAdd struct SquaredAdd
{ {
template <class T> template <class T>
__host__ __device__ static constexpr T GetIdentityValue() { return type_convert<T>(0.0f); }; __host__ __device__ static constexpr T GetIdentityValue()
{
return type_convert<T>(0.0f);
};
__host__ __device__ static constexpr bool __host__ __device__ static constexpr bool
IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation) IsCompatibleInMemoryDataOperation(InMemoryDataOperationEnum operation)
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment