Commit 7f3c6e28 authored by Anthony Chang's avatar Anthony Chang
Browse files

add squared add reduction op; allows sq sum

parent 8c144c7a
...@@ -51,6 +51,7 @@ struct ThreadwiseReduction ...@@ -51,6 +51,7 @@ struct ThreadwiseReduction
static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!"); static_assert(src_length_m == dst_length_m, "lengths of source and dst buffer must match!");
using Op = OpReduce;
using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>; using Accumulation = detail::AccumulateWithNanCheck<PropagateNan, OpReduce, AccDataType>;
template <typename SrcBufferType, typename DstBufferType> template <typename SrcBufferType, typename DstBufferType>
......
...@@ -71,6 +71,16 @@ struct Add ...@@ -71,6 +71,16 @@ struct Add
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; } __host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b; }
}; };
template <class T>
struct SquaredAdd
{
using dataType = T;
__host__ __device__ static constexpr T GetReductionZeroVal() { return static_cast<T>(0.0f); };
__host__ __device__ inline constexpr void operator()(T& a, T b) const { a = a + b * b; }
};
template <class T> template <class T>
struct Mul struct Mul
{ {
......
Markdown is supported
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment