/*! * Copyright (c) 2020 by Contributors * @file array/cpu/spmm_binary_ops.h * @brief SPMM CPU Binary ops. */ #ifndef DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_ #define DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_ #include #include #include namespace dgl { namespace aten { namespace cpu { namespace op { //////////////////////////////// binary operators on CPU /////////////////////////////////// template struct Add { typedef DType type; static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; inline static DType Call(const DType* lhs_off, const DType* rhs_off) { return *lhs_off + *rhs_off; } }; template constexpr bool Add::use_lhs; template constexpr bool Add::use_rhs; template struct Sub { typedef DType type; static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; inline static DType Call(const DType* lhs_off, const DType* rhs_off) { return *lhs_off - *rhs_off; } }; template constexpr bool Sub::use_lhs; template constexpr bool Sub::use_rhs; template struct Mul { typedef DType type; static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; inline static DType Call(const DType* lhs_off, const DType* rhs_off) { return *lhs_off * *rhs_off; } }; template constexpr bool Mul::use_lhs; template constexpr bool Mul::use_rhs; template struct Div { typedef DType type; static constexpr bool use_lhs = true; static constexpr bool use_rhs = true; inline static DType Call(const DType* lhs_off, const DType* rhs_off) { return *lhs_off / *rhs_off; } }; template constexpr bool Div::use_lhs; template constexpr bool Div::use_rhs; template struct CopyLhs { typedef DType type; static constexpr bool use_lhs = true; static constexpr bool use_rhs = false; inline static DType Call(const DType* lhs_off, const DType*) { return *lhs_off; } }; template constexpr bool CopyLhs::use_lhs; template constexpr bool CopyLhs::use_rhs; template struct CopyRhs { typedef DType type; static constexpr bool use_lhs = false; static constexpr bool use_rhs = true; inline static DType Call(const DType*, const DType* rhs_off) { return *rhs_off; } }; template constexpr bool CopyRhs::use_lhs; template constexpr bool CopyRhs::use_rhs; //////////////////////////////// Reduce operators on CPU /////////////////////////////////// template struct Max { typedef DType type; static constexpr DType zero = -std::numeric_limits::infinity(); // return true if accum should be replaced inline static DType Call(DType accum, DType val) { return accum < val; } }; template constexpr DType Max::zero; template struct Min { typedef DType type; static constexpr DType zero = std::numeric_limits::infinity(); // return true if accum should be replaced inline static DType Call(DType accum, DType val) { return accum > val; } }; template constexpr DType Min::zero; #define SWITCH_OP(op, Op, ...) \ do { \ if ((op) == "add") { \ typedef dgl::aten::cpu::op::Add Op; \ { __VA_ARGS__ } \ } else if ((op) == "sub") { \ typedef dgl::aten::cpu::op::Sub Op; \ { __VA_ARGS__ } \ } else if ((op) == "mul") { \ typedef dgl::aten::cpu::op::Mul Op; \ { __VA_ARGS__ } \ } else if ((op) == "div") { \ typedef dgl::aten::cpu::op::Div Op; \ { __VA_ARGS__ } \ } else if ((op) == "copy_lhs") { \ typedef dgl::aten::cpu::op::CopyLhs Op; \ { __VA_ARGS__ } \ } else if ((op) == "copy_rhs") { \ typedef dgl::aten::cpu::op::CopyRhs Op; \ { __VA_ARGS__ } \ } else { \ LOG(FATAL) << "Unsupported SpMM binary operator: " << op; \ } \ } while (0) } // namespace op } // namespace cpu } // namespace aten } // namespace dgl #endif // DGL_ARRAY_CPU_SPMM_BINARY_OPS_H_