"tests/vscode:/vscode.git/clone" did not exist on "4d9b82297fd290fecf5f7dd707a95bd1f66c1036"
elemenwise_op.cc 971 Bytes
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file elementwise_op.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse elementwise operator implementation.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include <memory>

#include "./utils.h"

namespace dgl {
namespace sparse {

c10::intrusive_ptr<SparseMatrix> SpSpAdd(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
  auto fmt = FindAnyExistingFormat(A, B);
  auto value = A->value() + B->value();
  ElementwiseOpSanityCheck(A, B);
  if (fmt == SparseFormat::kCOO) {
    return SparseMatrix::FromCOO(A->COOPtr(), value, A->shape());
  } else if (fmt == SparseFormat::kCSR) {
    return SparseMatrix::FromCSR(A->CSRPtr(), value, A->shape());
  } else {
    return SparseMatrix::FromCSC(A->CSCPtr(), value, A->shape());
  }
}

}  // namespace sparse
}  // namespace dgl