/*! * Copyright (c) 2019 by Contributors * \file kernel/cuda/functor.cuh * \brief Functors for template on CUDA */ #ifndef DGL_KERNEL_CUDA_FUNCTOR_CUH_ #define DGL_KERNEL_CUDA_FUNCTOR_CUH_ #include "../binary_reduce_common.h" #include "./atomic.cuh" namespace dgl { namespace kernel { namespace cuda { // Cache load from global memory template struct LDGReader { static __device__ __forceinline__ DType Call(DType* addr) { #if __CUDA_ARCH__ >= 350 return __ldg(addr); #else return *addr; #endif } }; } // namespace cuda // Reducer functor specialization template struct ReduceSum { static __device__ __forceinline__ void Call(DType* addr, DType val) { cuda::AtomicAdd(addr, val); } static __device__ __forceinline__ DType BackwardCall(DType val, DType accum) { return 1; } }; template struct ReduceMax { static __device__ __forceinline__ void Call(DType* addr, DType val) { cuda::AtomicMax(addr, val); } static __device__ __forceinline__ DType BackwardCall(DType val, DType accum) { return static_cast(val == accum); } }; template struct ReduceMin { static __device__ __forceinline__ void Call(DType* addr, DType val) { cuda::AtomicMin(addr, val); } static __device__ __forceinline__ DType BackwardCall(DType val, DType accum) { return static_cast(val == accum); } }; template struct ReduceProd { static __device__ __forceinline__ void Call(DType* addr, DType val) { cuda::AtomicMul(addr, val); } static __device__ __forceinline__ DType BackwardCall(DType val, DType accum) { return accum / val; } }; template struct ReduceNone { static __device__ __forceinline__ void Call(DType* addr, DType val) { *addr = val; } static __device__ __forceinline__ DType BackwardCall(DType val, DType accum) { return 1; } }; } // namespace kernel } // namespace dgl #endif // DGL_KERNEL_CUDA_FUNCTOR_CUH_