/** * Copyright (c) 2022 by Contributors * @file elementwise_op.cc * @brief DGL C++ sparse elementwise operator implementation. */ // clang-format off #include // clang-format on #include #include #include #include #include "./utils.h" namespace dgl { namespace sparse { c10::intrusive_ptr SpSpAdd( const c10::intrusive_ptr& A, const c10::intrusive_ptr& B) { ElementwiseOpSanityCheck(A, B); if (A->HasDiag() && B->HasDiag()) { return SparseMatrix::FromDiagPointer( A->DiagPtr(), A->value() + B->value(), A->shape()); } auto torch_A = COOToTorchCOO(A->COOPtr(), A->value()); auto torch_B = COOToTorchCOO(B->COOPtr(), B->value()); auto sum = (torch_A + torch_B).coalesce(); return SparseMatrix::FromCOO(sum.indices(), sum.values(), A->shape()); } } // namespace sparse } // namespace dgl