elemenwise_op.cc 908 Bytes
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file elementwise_op.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse elementwise operator implementation.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include <memory>

#include "./utils.h"

namespace dgl {
namespace sparse {

c10::intrusive_ptr<SparseMatrix> SpSpAdd(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
  ElementwiseOpSanityCheck(A, B);
25
26
27
28
29
30
  auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
  auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
  auto sum = (torch_A + torch_B).coalesce();
  auto indices = sum.indices();
  auto row = indices[0];
  auto col = indices[1];
31
  return SparseMatrix::FromCOO(row, col, sum.values(), A->shape());
32
33
34
35
}

}  // namespace sparse
}  // namespace dgl