/*! * Copyright (c) 2019 by Contributors * \file kernel/cpu/functor.h * \brief Functors for template on CPU */ #ifndef DGL_KERNEL_CPU_FUNCTOR_H_ #define DGL_KERNEL_CPU_FUNCTOR_H_ #include #include #include "../binary_reduce_common.h" namespace dgl { namespace kernel { // Reducer functor specialization template struct ReduceSum { static void Call(DType* addr, DType val) { #pragma omp atomic *addr += val; } static DType BackwardCall(DType val, DType accum) { return 1; } }; template struct ReduceMax { static void Call(DType* addr, DType val) { #pragma omp critical *addr = std::max(*addr, val); } static DType BackwardCall(DType val, DType accum) { return static_cast(val == accum); } }; template struct ReduceMin { static void Call(DType* addr, DType val) { #pragma omp critical *addr = std::min(*addr, val); } static DType BackwardCall(DType val, DType accum) { return static_cast(val == accum); } }; template struct ReduceProd { static void Call(DType* addr, DType val) { #pragma omp atomic *addr *= val; } static DType BackwardCall(DType val, DType accum) { return accum / val; } }; template struct ReduceNone { static void Call(DType* addr, DType val) { *addr = val; } static DType BackwardCall(DType val, DType accum) { return 1; } }; } // namespace kernel } // namespace dgl #endif // DGL_KERNEL_CPU_FUNCTOR_H_