/*! * Copyright (c) 2019 by Contributors * \file array/array_aritch.cc * \brief DGL array arithmetic operations */ #include #include #include #include "../c_api_common.h" #include "./array_op.h" #include "./arith.h" using namespace dgl::runtime; namespace dgl { namespace aten { // Generate operators with both operations being NDArrays. #define BINARY_ELEMENT_OP(name, op) \ IdArray name(IdArray lhs, IdArray rhs) { \ IdArray ret; \ CHECK_SAME_DTYPE(lhs, rhs); \ CHECK_SAME_CONTEXT(lhs, rhs); \ ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, { \ ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { \ ret = impl::BinaryElewise(lhs, rhs); \ }); \ }); \ return ret; \ } // Generate operators with only lhs being NDArray. #define BINARY_ELEMENT_OP_L(name, op) \ IdArray name(IdArray lhs, int64_t rhs) { \ IdArray ret; \ ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, { \ ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { \ ret = impl::BinaryElewise(lhs, rhs); \ }); \ }); \ return ret; \ } // Generate operators with only lhs being NDArray. #define BINARY_ELEMENT_OP_R(name, op) \ IdArray name(int64_t lhs, IdArray rhs) { \ IdArray ret; \ ATEN_XPU_SWITCH_CUDA(rhs->ctx.device_type, XPU, #name, { \ ATEN_ID_TYPE_SWITCH(rhs->dtype, IdType, { \ ret = impl::BinaryElewise(lhs, rhs); \ }); \ }); \ return ret; \ } // Generate operators with only lhs being NDArray. #define UNARY_ELEMENT_OP(name, op) \ IdArray name(IdArray lhs) { \ IdArray ret; \ ATEN_XPU_SWITCH_CUDA(lhs->ctx.device_type, XPU, #name, { \ ATEN_ID_TYPE_SWITCH(lhs->dtype, IdType, { \ ret = impl::UnaryElewise(lhs); \ }); \ }); \ return ret; \ } BINARY_ELEMENT_OP(Add, Add) BINARY_ELEMENT_OP(Sub, Sub) BINARY_ELEMENT_OP(Mul, Mul) BINARY_ELEMENT_OP(Div, Div) BINARY_ELEMENT_OP(Mod, Mod) BINARY_ELEMENT_OP(GT, GT) BINARY_ELEMENT_OP(LT, LT) BINARY_ELEMENT_OP(GE, GE) BINARY_ELEMENT_OP(LE, LE) BINARY_ELEMENT_OP(EQ, EQ) BINARY_ELEMENT_OP(NE, NE) BINARY_ELEMENT_OP_L(Add, Add) BINARY_ELEMENT_OP_L(Sub, Sub) BINARY_ELEMENT_OP_L(Mul, Mul) BINARY_ELEMENT_OP_L(Div, Div) BINARY_ELEMENT_OP_L(Mod, Mod) BINARY_ELEMENT_OP_L(GT, GT) BINARY_ELEMENT_OP_L(LT, LT) BINARY_ELEMENT_OP_L(GE, GE) BINARY_ELEMENT_OP_L(LE, LE) BINARY_ELEMENT_OP_L(EQ, EQ) BINARY_ELEMENT_OP_L(NE, NE) BINARY_ELEMENT_OP_R(Add, Add) BINARY_ELEMENT_OP_R(Sub, Sub) BINARY_ELEMENT_OP_R(Mul, Mul) BINARY_ELEMENT_OP_R(Div, Div) BINARY_ELEMENT_OP_R(Mod, Mod) BINARY_ELEMENT_OP_R(GT, GT) BINARY_ELEMENT_OP_R(LT, LT) BINARY_ELEMENT_OP_R(GE, GE) BINARY_ELEMENT_OP_R(LE, LE) BINARY_ELEMENT_OP_R(EQ, EQ) BINARY_ELEMENT_OP_R(NE, NE) UNARY_ELEMENT_OP(Neg, Neg) } // namespace aten } // namespace dgl ///////////////// Operator overloading for NDArray ///////////////// NDArray operator + (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::Add(lhs, rhs); } NDArray operator - (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::Sub(lhs, rhs); } NDArray operator * (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::Mul(lhs, rhs); } NDArray operator / (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::Div(lhs, rhs); } NDArray operator % (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::Mod(lhs, rhs); } NDArray operator + (const NDArray& lhs, int64_t rhs) { return dgl::aten::Add(lhs, rhs); } NDArray operator - (const NDArray& lhs, int64_t rhs) { return dgl::aten::Sub(lhs, rhs); } NDArray operator * (const NDArray& lhs, int64_t rhs) { return dgl::aten::Mul(lhs, rhs); } NDArray operator / (const NDArray& lhs, int64_t rhs) { return dgl::aten::Div(lhs, rhs); } NDArray operator % (const NDArray& lhs, int64_t rhs) { return dgl::aten::Mod(lhs, rhs); } NDArray operator + (int64_t lhs, const NDArray& rhs) { return dgl::aten::Add(lhs, rhs); } NDArray operator - (int64_t lhs, const NDArray& rhs) { return dgl::aten::Sub(lhs, rhs); } NDArray operator * (int64_t lhs, const NDArray& rhs) { return dgl::aten::Mul(lhs, rhs); } NDArray operator / (int64_t lhs, const NDArray& rhs) { return dgl::aten::Div(lhs, rhs); } NDArray operator % (int64_t lhs, const NDArray& rhs) { return dgl::aten::Mod(lhs, rhs); } NDArray operator - (const NDArray& array) { return dgl::aten::Neg(array); } NDArray operator > (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::GT(lhs, rhs); } NDArray operator < (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::LT(lhs, rhs); } NDArray operator >= (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::GE(lhs, rhs); } NDArray operator <= (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::LE(lhs, rhs); } NDArray operator == (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::EQ(lhs, rhs); } NDArray operator != (const NDArray& lhs, const NDArray& rhs) { return dgl::aten::NE(lhs, rhs); } NDArray operator > (const NDArray& lhs, int64_t rhs) { return dgl::aten::GT(lhs, rhs); } NDArray operator < (const NDArray& lhs, int64_t rhs) { return dgl::aten::LT(lhs, rhs); } NDArray operator >= (const NDArray& lhs, int64_t rhs) { return dgl::aten::GE(lhs, rhs); } NDArray operator <= (const NDArray& lhs, int64_t rhs) { return dgl::aten::LE(lhs, rhs); } NDArray operator == (const NDArray& lhs, int64_t rhs) { return dgl::aten::EQ(lhs, rhs); } NDArray operator != (const NDArray& lhs, int64_t rhs) { return dgl::aten::NE(lhs, rhs); } NDArray operator > (int64_t lhs, const NDArray& rhs) { return dgl::aten::GT(lhs, rhs); } NDArray operator < (int64_t lhs, const NDArray& rhs) { return dgl::aten::LT(lhs, rhs); } NDArray operator >= (int64_t lhs, const NDArray& rhs) { return dgl::aten::GE(lhs, rhs); } NDArray operator <= (int64_t lhs, const NDArray& rhs) { return dgl::aten::LE(lhs, rhs); } NDArray operator == (int64_t lhs, const NDArray& rhs) { return dgl::aten::EQ(lhs, rhs); } NDArray operator != (int64_t lhs, const NDArray& rhs) { return dgl::aten::NE(lhs, rhs); }