/** * Copyright (c) 2020 by Contributors * @file array/cuda/functor.cuh * @brief Functors for template on CUDA */ #ifndef DGL_ARRAY_CUDA_FUNCTOR_CUH_ #define DGL_ARRAY_CUDA_FUNCTOR_CUH_ #include #include #include "./atomic.cuh" #include "./fp16.cuh" #include "bf16.cuh" namespace dgl { namespace aten { namespace cuda { /////////////////////////// CUDA binary operators ////////////////////////////// namespace binary { template struct Add { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] + rhs[0]; } }; template constexpr bool Add::use_lhs; template constexpr bool Add::use_rhs; template constexpr bool Add::reduce_last_dim; template struct Sub { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] - rhs[0]; } }; template constexpr bool Sub::use_lhs; template constexpr bool Sub::use_rhs; template constexpr bool Sub::reduce_last_dim; template struct Mul { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] * rhs[0]; } }; template constexpr bool Mul::use_lhs; template constexpr bool Mul::use_rhs; template constexpr bool Mul::reduce_last_dim; template struct Div { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0] / rhs[0]; } }; template constexpr bool Div::use_lhs; template constexpr bool Div::use_rhs; template constexpr bool Div::reduce_last_dim; template struct CopyLhs { static constexpr bool use_lhs = true; static constexpr bool use_rhs = false; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { return lhs[0]; } }; template constexpr bool CopyLhs::use_lhs; template constexpr bool CopyLhs::use_rhs; template constexpr bool CopyLhs::reduce_last_dim; template struct CopyRhs { static constexpr bool use_lhs = false; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = false; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { return rhs[0]; } }; template constexpr bool CopyRhs::use_lhs; template constexpr bool CopyRhs::use_rhs; template constexpr bool CopyRhs::reduce_last_dim; template struct Dot { static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; static constexpr bool reduce_last_dim = true; static __device__ __forceinline__ DType Call(const DType *lhs, const DType *rhs, int64_t len = 1) { DType rst = static_cast(0.0f); for (int64_t i = 0; i < len; ++i) { rst += lhs[i] * rhs[i]; } return rst; } }; template constexpr bool Dot::use_lhs; template constexpr bool Dot::use_rhs; template constexpr bool Dot::reduce_last_dim; } // end of namespace binary /////////////////////////// CUDA reduce operators ////////////////////////////// namespace reduce { template struct _Sum { static constexpr __host__ __device__ __forceinline__ DType zero() { return 0.; } static constexpr bool require_arg = false; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { *out_buf += val; } else { cuda::AtomicAdd(out_buf, val); } } static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_buf, DType val, Idx id) { if (!atomic) { *out_buf += val; } else { cuda::AtomicAdd(out_buf, val); } } static __device__ __forceinline__ void CallArg( Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) {} }; template struct Sum : _Sum {}; template struct Sum : _Sum { static constexpr __host__ __device__ __forceinline__ __half zero() { return __float2half_rn(0.); } static __device__ __forceinline__ void Call( __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __half val, Idx uid, Idx eid) { _Sum::Call( out_buf, arg_u_buf, arg_e_buf, val, uid, eid); } static __device__ __forceinline__ void Call( __half *out_buf, Idx *arg_buf, __half val, Idx id) { _Sum::Call(out_buf, arg_buf, val, id); } // sometimes we have to use float in reduction for better precision static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __half val, Idx uid, Idx eid) { _Sum::Call(out_buf, arg_u_buf, arg_e_buf, static_cast(val), uid, eid); } static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_buf, __half val, Idx id) { _Sum::Call(out_buf, arg_buf, static_cast(val), id); } }; #if BF16_ENABLED template struct Sum : _Sum { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { return __float2bfloat16_rn(0.); } static __device__ __forceinline__ void Call( __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __nv_bfloat16 val, Idx uid, Idx eid) { _Sum::Call( out_buf, arg_u_buf, arg_e_buf, val, uid, eid); } static __device__ __forceinline__ void Call( __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) { _Sum::Call(out_buf, arg_buf, val, id); } // sometimes we have to use float in reduction for better precision static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __nv_bfloat16 val, Idx uid, Idx eid) { _Sum::Call(out_buf, arg_u_buf, arg_e_buf, static_cast(val), uid, eid); } static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) { _Sum::Call(out_buf, arg_buf, static_cast(val), id); } }; #endif // BF16_ENABLED template struct _Max { static constexpr __host__ __device__ __forceinline__ DType zero() { return -std::numeric_limits::infinity(); } static constexpr bool require_arg = true; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { if (*out_buf < val) { *out_buf = val; *arg_u_buf = uid; *arg_e_buf = eid; } } else { cuda::AtomicMax(out_buf, val); } } static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_buf, DType val, Idx id) { if (!atomic) { if (*out_buf < val) { *out_buf = val; *arg_buf = id; } } else { cuda::AtomicMax(out_buf, val); } } static __device__ __forceinline__ void CallArg( Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) { if (atomic) { if (val == val_ref) { if (arg_u_buf) arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid; } } } }; template struct Max : _Max {}; template struct Max : _Max { static constexpr __host__ __device__ __forceinline__ __half zero() { return __float2half_rn(-6.550400e+04f); } static __device__ __forceinline__ void Call( __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __half val, Idx uid, Idx eid) { _Max::Call( out_buf, arg_u_buf, arg_e_buf, val, uid, eid); } static __device__ __forceinline__ void Call( __half *out_buf, Idx *arg_buf, __half val, Idx id) { _Max::Call(out_buf, arg_buf, val, id); } // sometimes we have to use float in reduction for better precision static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __half val, Idx uid, Idx eid) { _Max::Call(out_buf, arg_u_buf, arg_e_buf, static_cast(val), uid, eid); } static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_buf, __half val, Idx id) { _Max::Call(out_buf, arg_buf, static_cast(val), id); } }; #if BF16_ENABLED template struct Max : _Max { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { return __float2bfloat16_rn(-std::numeric_limits::infinity()); } static __device__ __forceinline__ void Call( __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __nv_bfloat16 val, Idx uid, Idx eid) { _Max::Call( out_buf, arg_u_buf, arg_e_buf, val, uid, eid); } static __device__ __forceinline__ void Call( __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) { _Max::Call(out_buf, arg_buf, val, id); } // sometimes we have to use float in reduction for better precision static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __nv_bfloat16 val, Idx uid, Idx eid) { _Max::Call(out_buf, arg_u_buf, arg_e_buf, static_cast(val), uid, eid); } static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) { _Max::Call(out_buf, arg_buf, static_cast(val), id); } }; #endif // BF16_ENABLED template struct _Min { static constexpr __host__ __device__ __forceinline__ DType zero() { return std::numeric_limits::infinity(); } static constexpr bool require_arg = true; static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, DType val, Idx uid, Idx eid) { if (!atomic) { if (*out_buf > val) { *out_buf = val; *arg_u_buf = uid; *arg_e_buf = eid; } } else { cuda::AtomicMin(out_buf, val); } } static __device__ __forceinline__ void Call( DType *out_buf, Idx *arg_buf, DType val, Idx id) { if (!atomic) { if (*out_buf > val) { *out_buf = val; *arg_buf = id; } } else { cuda::AtomicMin(out_buf, val); } } static __device__ __forceinline__ void CallArg( Idx fid, Idx *arg_u_buf, Idx *arg_e_buf, DType val, DType val_ref, Idx uid, Idx eid) { if (atomic) { if (val == val_ref) { if (arg_u_buf) arg_u_buf[fid] = uid; if (arg_e_buf) arg_e_buf[fid] = eid; } } } }; template struct Min : _Min {}; template struct Min : _Min { static constexpr __host__ __device__ __forceinline__ __half zero() { return __float2half_rn(6.550400e+04f); } static __device__ __forceinline__ void Call( __half *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __half val, Idx uid, Idx eid) { _Min::Call( out_buf, arg_u_buf, arg_e_buf, val, uid, eid); } static __device__ __forceinline__ void Call( __half *out_buf, Idx *arg_buf, __half val, Idx id) { _Min::Call(out_buf, arg_buf, val, id); } // sometimes we have to use float in reduction for better precision static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __half val, Idx uid, Idx eid) { _Min::Call(out_buf, arg_u_buf, arg_e_buf, static_cast(val), uid, eid); } static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_buf, __half val, Idx id) { _Min::Call(out_buf, arg_buf, static_cast(val), id); } }; #if BF16_ENABLED template struct Min : _Min { static constexpr __host__ __device__ __forceinline__ __nv_bfloat16 zero() { return __float2bfloat16_rn(std::numeric_limits::infinity()); } static __device__ __forceinline__ void Call( __nv_bfloat16 *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __nv_bfloat16 val, Idx uid, Idx eid) { _Min::Call( out_buf, arg_u_buf, arg_e_buf, val, uid, eid); } static __device__ __forceinline__ void Call( __nv_bfloat16 *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) { _Min::Call(out_buf, arg_buf, val, id); } // sometimes we have to use float in reduction for better precision static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_u_buf, Idx *arg_e_buf, __nv_bfloat16 val, Idx uid, Idx eid) { _Min::Call(out_buf, arg_u_buf, arg_e_buf, static_cast(val), uid, eid); } static __device__ __forceinline__ void Call( float *out_buf, Idx *arg_buf, __nv_bfloat16 val, Idx id) { _Min::Call(out_buf, arg_buf, static_cast(val), id); } }; #endif // BF16_ENABLED } // namespace reduce } // namespace cuda } // namespace aten } // namespace dgl #endif // DGL_ARRAY_CUDA_FUNCTOR_CUH_