elemenwise_op.cc 6.1 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
Hongzhi (Steve), Chen's avatar
Hongzhi (Steve), Chen committed
2
/**
3
4
 *  Copyright (c) 2022 by Contributors
 * @file elementwise_op.cc
czkkkkkk's avatar
czkkkkkk committed
5
 * @brief DGL C++ sparse elementwise operator implementation.
6
 */
czkkkkkk's avatar
czkkkkkk committed
7

8
#include <sparse/elementwise_op.h>
czkkkkkk's avatar
czkkkkkk committed
9
#include <sparse/matrix_ops.h>
10
#include <sparse/sparse_matrix.h>
11
#include <torch/script.h>
12
13
14

#include <memory>

sangwzh's avatar
sangwzh committed
15
#include "utils.h"
16
17
18
19

namespace dgl {
namespace sparse {

czkkkkkk's avatar
czkkkkkk committed
20
21
using namespace torch::autograd;

22
c10::intrusive_ptr<SparseMatrix> SpSpAdd(
czkkkkkk's avatar
czkkkkkk committed
23
24
25
26
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
  ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
27
    return SparseMatrix::FromDiagPointer(
czkkkkkk's avatar
czkkkkkk committed
28
29
        lhs_mat->DiagPtr(), lhs_mat->value() + rhs_mat->value(),
        lhs_mat->shape());
30
  }
czkkkkkk's avatar
czkkkkkk committed
31
32
33
34
  auto torch_lhs = COOToTorchCOO(lhs_mat->COOPtr(), lhs_mat->value());
  auto torch_rhs = COOToTorchCOO(rhs_mat->COOPtr(), rhs_mat->value());
  auto sum = (torch_lhs + torch_rhs).coalesce();
  return SparseMatrix::FromCOO(sum.indices(), sum.values(), lhs_mat->shape());
35
36
}

czkkkkkk's avatar
czkkkkkk committed
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
class SpSpMulAutoGrad : public Function<SpSpMulAutoGrad> {
 public:
  static variable_list forward(
      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
      torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
      torch::Tensor rhs_val);

  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

variable_list SpSpMulAutoGrad::forward(
    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> lhs_mat,
    torch::Tensor lhs_val, c10::intrusive_ptr<SparseMatrix> rhs_mat,
    torch::Tensor rhs_val) {
  std::shared_ptr<COO> intersection;
  torch::Tensor lhs_indices, rhs_indices;
  std::tie(intersection, lhs_indices, rhs_indices) =
      COOIntersection(lhs_mat->COOPtr(), rhs_mat->COOPtr());
  auto lhs_intersect_val = lhs_val.index_select(0, lhs_indices);
  auto rhs_intersect_val = rhs_val.index_select(0, rhs_indices);
  auto ret_val = lhs_intersect_val * rhs_intersect_val;
  auto ret_mat =
      SparseMatrix::FromCOOPointer(intersection, ret_val, lhs_mat->shape());

  ctx->saved_data["lhs_require_grad"] = lhs_val.requires_grad();
  ctx->saved_data["rhs_require_grad"] = rhs_val.requires_grad();
  if (lhs_val.requires_grad()) {
    ctx->saved_data["lhs_val_shape"] = lhs_val.sizes().vec();
    ctx->saved_data["rhs_intersect_lhs"] =
        SparseMatrix::ValLike(ret_mat, rhs_intersect_val);
    ctx->saved_data["lhs_indices"] = lhs_indices;
  }
  if (rhs_val.requires_grad()) {
    ctx->saved_data["rhs_val_shape"] = rhs_val.sizes().vec();
    ctx->saved_data["lhs_intersect_rhs"] =
        SparseMatrix::ValLike(ret_mat, lhs_intersect_val);
    ctx->saved_data["rhs_indices"] = rhs_indices;
  }
  return {intersection->indices, ret_val};
}

tensor_list SpSpMulAutoGrad::backward(
    AutogradContext* ctx, tensor_list grad_outputs) {
  torch::Tensor lhs_val_grad, rhs_val_grad;
  auto output_grad = grad_outputs[1];
  if (ctx->saved_data["lhs_require_grad"].toBool()) {
    auto rhs_intersect_lhs =
        ctx->saved_data["rhs_intersect_lhs"].toCustomClass<SparseMatrix>();
    const auto& lhs_val_shape = ctx->saved_data["lhs_val_shape"].toIntVector();
    auto lhs_indices = ctx->saved_data["lhs_indices"].toTensor();
    lhs_val_grad = torch::zeros(lhs_val_shape, output_grad.options());
    auto intersect_grad = rhs_intersect_lhs->value() * output_grad;
    lhs_val_grad.index_put_({lhs_indices}, intersect_grad);
  }
  if (ctx->saved_data["rhs_require_grad"].toBool()) {
    auto lhs_intersect_rhs =
        ctx->saved_data["lhs_intersect_rhs"].toCustomClass<SparseMatrix>();
    const auto& rhs_val_shape = ctx->saved_data["rhs_val_shape"].toIntVector();
    auto rhs_indices = ctx->saved_data["rhs_indices"].toTensor();
    rhs_val_grad = torch::zeros(rhs_val_shape, output_grad.options());
    auto intersect_grad = lhs_intersect_rhs->value() * output_grad;
    rhs_val_grad.index_put_({rhs_indices}, intersect_grad);
  }
  return {torch::Tensor(), lhs_val_grad, torch::Tensor(), rhs_val_grad};
}

c10::intrusive_ptr<SparseMatrix> SpSpMul(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
  ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
    return SparseMatrix::FromDiagPointer(
        lhs_mat->DiagPtr(), lhs_mat->value() * rhs_mat->value(),
        lhs_mat->shape());
  }
  TORCH_CHECK(
      !lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),
      "Only support SpSpMul on sparse matrices without duplicate values")
  auto results = SpSpMulAutoGrad::apply(
      lhs_mat, lhs_mat->value(), rhs_mat, rhs_mat->value());
  const auto& indices = results[0];
  const auto& val = results[1];
  return SparseMatrix::FromCOO(indices, val, lhs_mat->shape());
}

czkkkkkk's avatar
czkkkkkk committed
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
c10::intrusive_ptr<SparseMatrix> SpSpDiv(
    const c10::intrusive_ptr<SparseMatrix>& lhs_mat,
    const c10::intrusive_ptr<SparseMatrix>& rhs_mat) {
  ElementwiseOpSanityCheck(lhs_mat, rhs_mat);
  if (lhs_mat->HasDiag() && rhs_mat->HasDiag()) {
    return SparseMatrix::FromDiagPointer(
        lhs_mat->DiagPtr(), lhs_mat->value() / rhs_mat->value(),
        lhs_mat->shape());
  }
  std::shared_ptr<COO> sorted_lhs, sorted_rhs;
  torch::Tensor lhs_sorted_perm, rhs_sorted_perm;
  std::tie(sorted_lhs, lhs_sorted_perm) = COOSort(lhs_mat->COOPtr());
  std::tie(sorted_rhs, rhs_sorted_perm) = COOSort(rhs_mat->COOPtr());
  TORCH_CHECK(
      !lhs_mat->HasDuplicate() && !rhs_mat->HasDuplicate(),
      "Only support SpSpDiv on sparse matrices without duplicate values")
  TORCH_CHECK(
      torch::equal(sorted_lhs->indices, sorted_rhs->indices),
      "Cannot divide two COO matrices with different sparsities.");
  // This is to make sure the return matrix is in the same order as the lhs_mat
  auto lhs_sorted_rperm = lhs_sorted_perm.argsort();
  auto rhs_perm_on_lhs = rhs_sorted_perm.index_select(0, lhs_sorted_rperm);
  auto lhs_value = lhs_mat->value();
  auto rhs_value = rhs_mat->value().index_select(0, rhs_perm_on_lhs);
  auto ret_val = lhs_value / rhs_value;
  return SparseMatrix::FromCOOPointer(
      lhs_mat->COOPtr(), ret_val, lhs_mat->shape());
}

151
152
}  // namespace sparse
}  // namespace dgl