Commit c466ccd7 authored by rocking's avatar rocking
Browse files

Propagate NaN for layernorm

parent 798670d8
...@@ -97,22 +97,18 @@ struct GridwiseLayernorm_mk_to_mk ...@@ -97,22 +97,18 @@ struct GridwiseLayernorm_mk_to_mk
using ThreadReduceDstDesc_M = using ThreadReduceDstDesc_M =
decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{}))); decltype(make_naive_tensor_descriptor_packed(make_tuple(Number<MThreadSliceSize>{})));
using BlockwiseSumReduce = using BlockwiseSumReduce = PartitionedBlockwiseReduction<AccDataType,
PartitionedBlockwiseReduction<AccDataType, BlockSize,
BlockSize, ThreadClusterLengths_M_K,
ThreadClusterLengths_M_K, ThreadClusterArrangeOrder,
ThreadClusterArrangeOrder, reduce::Add,
reduce::Add, true>;
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>; using ThreadwiseSumReduce = ThreadwiseReduction<AccDataType,
ThreadReduceSrcDesc_M_K,
using ThreadwiseSumReduce = ThreadReduceDstDesc_M,
ThreadwiseReduction<AccDataType, reduce::Add,
ThreadReduceSrcDesc_M_K, true>;
ThreadReduceDstDesc_M,
reduce::Add,
false, // ignored
detail::AccumulateWithNanIgnore<reduce::Add, AccDataType>>;
static constexpr auto I0 = Number<0>{}; static constexpr auto I0 = Number<0>{};
static constexpr auto I1 = Number<1>{}; static constexpr auto I1 = Number<1>{};
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment