elemenwise_op.cc 1.05 KB
Newer Older
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
1
/**
2
3
 *  Copyright (c) 2022 by Contributors
 * @file elementwise_op.cc
czkkkkkk's avatar
czkkkkkk committed
4
 * @brief DGL C++ sparse elementwise operator implementation.
5
 */
czkkkkkk's avatar
czkkkkkk committed
6
7
8
9
// clang-format off
#include <sparse/dgl_headers.h>
// clang-format on

10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
#include <sparse/elementwise_op.h>
#include <sparse/sparse_matrix.h>
#include <torch/script.h>

#include <memory>

#include "./utils.h"

namespace dgl {
namespace sparse {

c10::intrusive_ptr<SparseMatrix> SpSpAdd(
    const c10::intrusive_ptr<SparseMatrix>& A,
    const c10::intrusive_ptr<SparseMatrix>& B) {
  ElementwiseOpSanityCheck(A, B);
25
26
27
28
29
30
31
32
33
  torch::Tensor sum;
  {
    // TODO(#5145) This is a workaround to reduce peak memory usage. It is no
    // longer needed after we address #5145.
    auto torch_A = COOToTorchCOO(A->COOPtr(), A->value());
    auto torch_B = COOToTorchCOO(B->COOPtr(), B->value());
    sum = torch_A + torch_B;
  }
  sum = sum.coalesce();
34
35
36
  auto indices = sum.indices();
  auto row = indices[0];
  auto col = indices[1];
37
  return SparseMatrix::FromCOO(row, col, sum.values(), A->shape());
38
39
40
41
}

}  // namespace sparse
}  // namespace dgl