elemenwise_op.cc 956 Bytes
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
 *  Copyright (c) 2022 by Contributors
 * @file elementwise_op.cc
 * @brief DGL C++ sparse elementwise operator implementation
 */
#include <dmlc/logging.h>
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
#include <torch/custom_class.h>
#include <torch/script.h>

#include <memory>

#include "./utils.h"

namespace dgl {
namespace sparse {

c10::intrusive_ptr<SparseMatrix> SpSpAdd(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
  auto fmt = FindAnyExistingFormat(A, B);
  auto value = A->value() + B->value();
  ElementwiseOpSanityCheck(A, B);
  if (fmt == SparseFormat::kCOO) {
    return SparseMatrix::FromCOO(A->COOPtr(), value, A->shape());
  } else if (fmt == SparseFormat::kCSR) {
    return SparseMatrix::FromCSR(A->CSRPtr(), value, A->shape());
  } else {
    return SparseMatrix::FromCSC(A->CSCPtr(), value, A->shape());
  }
}

}  // namespace sparse
}  // namespace dgl