softmax.cc 3.04 KB
Newer Older
sangwzh's avatar
sangwzh committed
1
// !!! This is a file automatically generated by hipify!!!
2
3
4
5
6
7
8
9
/**
 *  Copyright (c) 2022 by Contributors
 * @file softmax.cc
 * @brief DGL C++ Softmax operator implementation
 */

#include <sparse/reduction.h>
#include <sparse/sparse_matrix.h>
10
#include <torch/script.h>
11

sangwzh's avatar
sangwzh committed
12
13
#include "matmul.h"
#include "utils.h"
14
15
16
17
18
19
20
21
22
23

namespace dgl {
namespace sparse {

using namespace torch::autograd;

class SoftmaxAutoGrad : public Function<SoftmaxAutoGrad> {
 public:
  static torch::Tensor forward(
      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
24
      torch::Tensor sparse_val, int64_t dim);
25
26
27
28
29
30

  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

torch::Tensor SoftmaxAutoGrad::forward(
    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
31
    torch::Tensor sparse_val, int64_t dim) {
32
  // Reduce by columns with dim 1.
33
  auto sparse_val_max = ReduceMax(sparse_mat, dim);
34
  auto sparse_val_exp =
35
      BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max, dim).exp();
36
  auto sparse_val_sum =
37
      ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), dim);
38
  auto sparse_score = BroadcastDivNoAutoGrad(
39
      SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum, dim);
40
41
42
43
44
45
46
47

  const bool sparse_requires_grad = sparse_val.requires_grad();
  torch::Tensor cache_sparse_score;
  if (sparse_requires_grad) {
    cache_sparse_score = sparse_score;
  }
  ctx->saved_data["sparse_matrix"] = sparse_mat;
  ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad;
48
  ctx->saved_data["dim"] = dim;
49
50
51
52
53
54
55
56
57
58
59
60
61
62
  ctx->save_for_backward({cache_sparse_score});
  return sparse_score;
}

tensor_list SoftmaxAutoGrad::backward(
    AutogradContext* ctx, tensor_list grad_outputs) {
  auto saved = ctx->get_saved_variables();
  auto sparse_score = saved[0];
  auto output_grad = grad_outputs[0];

  auto sparse_mat =
      ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
  const bool sparse_requires_grad =
      ctx->saved_data["sparse_requires_grad"].toBool();
63
  const int64_t dim = ctx->saved_data["dim"].toInt();
64
65
66
67

  torch::Tensor sparse_val_grad;
  if (sparse_requires_grad) {
    auto sds = sparse_score * output_grad;
68
    auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), dim);
69
70
    sparse_val_grad =
        sds - BroadcastMulNoAutoGrad(
71
                  SparseMatrix::ValLike(sparse_mat, sparse_score), accum, dim);
72
73
  }

74
  return {torch::Tensor(), sparse_val_grad, torch::Tensor()};
75
76
77
}

c10::intrusive_ptr<SparseMatrix> Softmax(
78
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim) {
79
80
81
82
83
84
  auto sparse_val = sparse_mat->value();
  bool expand_dim = false;
  auto new_sparse_mat = sparse_mat;
  if (sparse_val.dim() == 1) {
    sparse_val = sparse_val.view({-1, 1});
    expand_dim = true;
85
    new_sparse_mat = SparseMatrix::ValLike(sparse_mat, sparse_val);
86
87
  }

88
  auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val, dim);
89
90
91
92

  if (expand_dim) {
    new_sparse_val = new_sparse_val.view(-1);
  }
93
  return SparseMatrix::ValLike(sparse_mat, new_sparse_val);
94
95
96
97
}

}  // namespace sparse
}  // namespace dgl