/*! * Copyright (c) 2020 by Contributors * \file array/cuda/functor.cuh * \brief Functors for template on CUDA */ #ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_ #define DGL_ARRAY_CUDA_FUNCTOR_CUH_ #include #include #include "./atomic.cuh" #include "./fp16.cuh" #include "bf16.cuh" namespace dgl { namespace aten { namespace cuda { /////////////////////////////// CUDA binary operators /////////////////////////////// namespace binary { template struct Add { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] + rhs[0]; } }; template constexpr bool Add::use_lhs; template constexpr bool Add::use_rhs; template constexpr bool Add::reduce_last_dim; template struct Sub { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] - rhs[0]; } }; template constexpr bool Sub::use_lhs; template constexpr bool Sub::use_rhs; template constexpr bool Sub::reduce_last_dim; template struct Mul { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] * rhs[0]; } }; template constexpr bool Mul::use_lhs; template constexpr bool Mul::use_rhs; template constexpr bool Mul::reduce_last_dim; template struct Div { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] / rhs[0]; } }; template constexpr bool Div::use_lhs; template constexpr bool Div::use_rhs; template constexpr bool Div::reduce_last_dim; template struct CopyLhs { static constexpr bool use_lhs = true; static constexpr bool use_rhs = false; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0]; } }; template constexpr bool CopyLhs::use_lhs; template constexpr bool CopyLhs::use_rhs; template constexpr bool CopyLhs::reduce_last_dim; template struct CopyRhs { static constexpr bool use_lhs = false; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { return rhs[0]; } }; template constexpr bool CopyRhs::use_lhs; template constexpr bool CopyRhs::use_rhs; template constexpr bool CopyRhs::reduce_last_dim; template struct Dot { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = true; static __device__ __forceinline__ DType Call( const DType *lhs, const DType *rhs, int64_t len = 1) { DType rst = static_cast(0.0f); for (int64_t i = 0; i < len; ++i) { rst += lhs[i] * rhs[i]; } return rst; } }; template constexpr bool Dot::use_lhs; template constexpr bool Dot::use_rhs; template constexpr bool Dot::reduce_last_dim; } // end of namespace binary /////////////////////////////// CUDA reduce operators /////////////////////////////// namespace reduce { template struct _Sum { static constexpr __host__ __device__ __forceinline__ DType zero() { return 0.; } static constexpr bool require_arg = false; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { *out_buf += val; } else { cuda::AtomicAdd(out_buf, val); } } static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_buf, DType val, Idx id) { if (!atomic) { *out_buf += val; } else { cuda::AtomicAdd(out_buf, val); } } static __device__ __forceinline__ void CallArg(Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) {} }; template struct Sum: _Sum { }; template struct Sum: _Sum { static constexpr __host__ __device__ __forceinline__ half zero() { return __float2half_rn(0.); } }; #if BF16_ENABLED template struct Sum: _Sum { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { return __float2bfloat16_rn(0.); } }; #endif // BF16_ENABLED template struct _Max { static constexpr __host__ __device__ __forceinline__ DType zero() { return -std::numeric_limits::infinity(); } static constexpr bool require_arg = true; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { if (*out_buf < val) { *out_buf = val; *arg_u_buf = uid; *arg_e_buf = eid; } } else { cuda::AtomicMax(out_buf, val); } } static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_buf, DType val, Idx id) { if (!atomic) { if (*out_buf < val) { *out_buf = val; *arg_buf = id; } } else { cuda::AtomicMax(out_buf, val); } } static __device__ __forceinline__ void CallArg(Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) { if (atomic) { if (val == val_ref) { if (arg_u_buf) arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid; } } } }; template struct Max : _Max { }; template struct Max : _Max { static constexpr __host__ __device__ __forceinline__ half zero() { return __float2half_rn(-6.550400e+04f); } }; #if BF16_ENABLED template struct Max : _Max { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { return __float2bfloat16_rn(-std::numeric_limits::infinity()); } }; #endif // BF16_ENABLED template struct _Min { static constexpr __host__ __device__ __forceinline__ DType zero() { return std::numeric_limits::infinity(); } static constexpr bool require_arg = true; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { if (*out_buf > val) { *out_buf = val; *arg_u_buf = uid; *arg_e_buf = eid; } } else { cuda::AtomicMin(out_buf, val); } } static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_buf, DType val, Idx id) { if (!atomic) { if (*out_buf > val) { *out_buf = val; *arg_buf = id; } } else { cuda::AtomicMin(out_buf, val); } } static __device__ __forceinline__ void CallArg(Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) { if (atomic) { if (val == val_ref) { if (arg_u_buf) arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid; } } } }; template struct Min : _Min { }; template struct Min : _Min { static constexpr __host__ __device__ __forceinline__ half zero() { return __float2half_rn(6.550400e+04f); } }; #if BF16_ENABLED template struct Min : _Min { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { return __float2bfloat16_rn(std::numeric_limits::infinity()); } }; #endif // BF16_ENABLED } // namespace reduce } // namespace cuda } // namespace aten } // namespace dgl #endif // DGL_ARRAY_CUDA_FUNCTOR_CUH_