/*! * Copyright (c) 2020 by Contributors * \file array/cuda/functor.cuh * \brief Functors for template on CUDA */ #ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_ #define DGL_ARRAY_CUDA_FUNCTOR_CUH_ namespace dgl { namespace aten { namespace cuda { /////////////////////////////// CUDA binary operators /////////////////////////////// namespace binary { template struct Add { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] + rhs[0]; } }; template constexpr bool Add::use_lhs; template constexpr bool Add::use_rhs; template constexpr bool Add::reduce_last_dim; template struct Sub { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] - rhs[0]; } }; template constexpr bool Sub::use_lhs; template constexpr bool Sub::use_rhs; template constexpr bool Sub::reduce_last_dim; template struct Mul { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] * rhs[0]; } }; template constexpr bool Mul::use_lhs; template constexpr bool Mul::use_rhs; template constexpr bool Mul::reduce_last_dim; template struct Div { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] / rhs[0]; } }; template constexpr bool Div::use_lhs; template constexpr bool Div::use_rhs; template constexpr bool Div::reduce_last_dim; template struct CopyU { static constexpr bool use_lhs = true; static constexpr bool use_rhs = false; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0]; } }; template constexpr bool CopyU::use_lhs; template constexpr bool CopyU::use_rhs; template constexpr bool CopyU::reduce_last_dim; template struct CopyE { static constexpr bool use_lhs = false; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return rhs[0]; } }; template constexpr bool CopyE::use_lhs; template constexpr bool CopyE::use_rhs; template constexpr bool CopyE::reduce_last_dim; template struct Dot { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = true; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { DType rst = static_cast(0); for (int64_t i = 0; i < len; ++i) { rst += lhs[i] * rhs[i]; } return rst; } }; template constexpr bool Dot::use_lhs; template constexpr bool Dot::use_rhs; template constexpr bool Dot::reduce_last_dim; } // end of namespace binary /////////////////////////////// CUDA reduce operators /////////////////////////////// namespace reduce { template struct Sum { static constexpr DType zero = 0; static constexpr bool require_arg = false; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { *out_buf += val; } else { cuda::AtomicAdd(out_buf, val); } } static __device__ __forceinline__ void CallArg(Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) {} }; template constexpr DType Sum::zero; template constexpr bool Sum::require_arg; template struct Max { static constexpr DType zero = std::numeric_limits::lowest(); static constexpr bool require_arg = true; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { if (*out_buf < val) { *out_buf = val; *arg_u_buf = uid; *arg_e_buf = eid; } } else { cuda::AtomicMax(out_buf, val); } } static __device__ __forceinline__ void CallArg(Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) { if (atomic) { if (val == val_ref) { if (arg_u_buf) arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid; } } } }; template constexpr DType Max::zero; template constexpr bool Max::require_arg; template struct Min { static constexpr DType zero = std::numeric_limits::max(); static constexpr bool require_arg = true; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { if (*out_buf > val) { *out_buf = val; *arg_u_buf = uid; *arg_e_buf = eid; } } else { cuda::AtomicMin(out_buf, val); } } static __device__ __forceinline__ void CallArg(Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) { if (atomic) { if (val == val_ref) { if (arg_u_buf) arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid; } } } }; template constexpr DType Min::zero; template constexpr bool Min::require_arg; } // namespace reduce } // namespace cuda } // namespace aten } // namespace dgl #endif // DGL_ARRAY_CUDA_FUNCTOR_CUH_