Commit c466ccd7 authored by rocking's avatar rocking
Browse files

Propagate NaN for layernorm

parent 798670d8
......@@ -97,22 +97,18 @@ struct GridwiseLayernorm_mk_to_mk
using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseSumReduce =
PartitionedBlockwiseReduction<AccDataType,
using BlockwiseSumReduce = PartitionedBlockwiseReduction<AccDataType,
BlockSize,
ThreadClusterLengths_M_K,
ThreadClusterArrangeOrder,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
true>;
using ThreadwiseSumReduce =
ThreadwiseReduction<AccDataType,
using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
true>;
static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment