softmax.cc 2.96 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
/**
 *  Copyright (c) 2022 by Contributors
 * @file softmax.cc
 * @brief DGL C++ Softmax operator implementation
 */

#include <sparse/reduction.h>
#include <sparse/sparse_matrix.h>

#include "./matmul.h"
#include "./utils.h"

namespace dgl {
namespace sparse {

using namespace torch::autograd;

class SoftmaxAutoGrad : public Function<SoftmaxAutoGrad> {
 public:
  static torch::Tensor forward(
      AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
22
      torch::Tensor sparse_val, int64_t dim);
23
24
25
26
27
28

  static tensor_list backward(AutogradContext* ctx, tensor_list grad_outputs);
};

torch::Tensor SoftmaxAutoGrad::forward(
    AutogradContext* ctx, c10::intrusive_ptr<SparseMatrix> sparse_mat,
29
    torch::Tensor sparse_val, int64_t dim) {
30
  // Reduce by columns with dim 1.
31
  auto sparse_val_max = ReduceMax(sparse_mat, dim);
32
  auto sparse_val_exp =
33
      BroadcastSubNoAutoGrad(sparse_mat, sparse_val_max, dim).exp();
34
  auto sparse_val_sum =
35
      ReduceSum(SparseMatrix::ValLike(sparse_mat, sparse_val_exp), dim);
36
  auto sparse_score = BroadcastDivNoAutoGrad(
37
      SparseMatrix::ValLike(sparse_mat, sparse_val_exp), sparse_val_sum, dim);
38
39
40
41
42
43
44
45

  const bool sparse_requires_grad = sparse_val.requires_grad();
  torch::Tensor cache_sparse_score;
  if (sparse_requires_grad) {
    cache_sparse_score = sparse_score;
  }
  ctx->saved_data["sparse_matrix"] = sparse_mat;
  ctx->saved_data["sparse_requires_grad"] = sparse_requires_grad;
46
  ctx->saved_data["dim"] = dim;
47
48
49
50
51
52
53
54
55
56
57
58
59
60
  ctx->save_for_backward({cache_sparse_score});
  return sparse_score;
}

tensor_list SoftmaxAutoGrad::backward(
    AutogradContext* ctx, tensor_list grad_outputs) {
  auto saved = ctx->get_saved_variables();
  auto sparse_score = saved[0];
  auto output_grad = grad_outputs[0];

  auto sparse_mat =
      ctx->saved_data["sparse_matrix"].toCustomClass<SparseMatrix>();
  const bool sparse_requires_grad =
      ctx->saved_data["sparse_requires_grad"].toBool();
61
  const int64_t dim = ctx->saved_data["dim"].toInt();
62
63
64
65

  torch::Tensor sparse_val_grad;
  if (sparse_requires_grad) {
    auto sds = sparse_score * output_grad;
66
    auto accum = ReduceSum(SparseMatrix::ValLike(sparse_mat, sds), dim);
67
68
    sparse_val_grad =
        sds - BroadcastMulNoAutoGrad(
69
                  SparseMatrix::ValLike(sparse_mat, sparse_score), accum, dim);
70
71
  }

72
  return {torch::Tensor(), sparse_val_grad, torch::Tensor()};
73
74
75
}

c10::intrusive_ptr<SparseMatrix> Softmax(
76
    const c10::intrusive_ptr<SparseMatrix>& sparse_mat, int64_t dim) {
77
78
79
80
81
82
  auto sparse_val = sparse_mat->value();
  bool expand_dim = false;
  auto new_sparse_mat = sparse_mat;
  if (sparse_val.dim() == 1) {
    sparse_val = sparse_val.view({-1, 1});
    expand_dim = true;
83
    new_sparse_mat = SparseMatrix::ValLike(sparse_mat, sparse_val);
84
85
  }

86
  auto new_sparse_val = SoftmaxAutoGrad::apply(new_sparse_mat, sparse_val, dim);
87
88
89
90

  if (expand_dim) {
    new_sparse_val = new_sparse_val.view(-1);
  }
91
  return SparseMatrix::ValLike(sparse_mat, new_sparse_val);
92
93
94
95
}

}  // namespace sparse
}  // namespace dgl