autograd.cpp 1.93 KB
Newer Older
1
#include <libtorchaudio/rnnt/compute.h>
2
3
4
5
6
7
8
9
10
11
12
#include <torch/script.h>

namespace torchaudio {
namespace rnnt {

class RNNTLossFunction : public torch::autograd::Function<RNNTLossFunction> {
 public:
  static torch::autograd::tensor_list forward(
      torch::autograd::AutogradContext* ctx,
      torch::Tensor& logits,
      const torch::Tensor& targets,
13
14
      const torch::Tensor& logit_lengths,
      const torch::Tensor& target_lengths,
15
      int64_t blank,
16
17
      double clamp,
      bool fused_log_softmax = true) {
18
    torch::Tensor undef;
19
20
21
22
23
24
25
26
    auto result = rnnt_loss(
        logits,
        targets,
        logit_lengths,
        target_lengths,
        blank,
        clamp,
        fused_log_softmax);
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
    auto costs = std::get<0>(result);
    auto grads = std::get<1>(result).value_or(undef);
    ctx->save_for_backward({grads});
    return {costs, grads};
  }

  static torch::autograd::tensor_list backward(
      torch::autograd::AutogradContext* ctx,
      torch::autograd::tensor_list grad_outputs) {
    auto saved = ctx->get_saved_variables();
    auto grad = saved[0];
    auto grad_out = grad_outputs[0].view({-1, 1, 1, 1});
    auto result = grad * grad_out;
    torch::Tensor undef;
    return {result, undef, undef, undef, undef, undef, undef, undef};
  }
};

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss_autograd(
    torch::Tensor& logits,
    const torch::Tensor& targets,
48
49
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
50
    int64_t blank,
51
52
    double clamp,
    bool fused_log_softmax = true) {
53
  at::AutoDispatchBelowADInplaceOrView guard;
54
  auto results = RNNTLossFunction::apply(
55
56
57
58
59
60
61
      logits,
      targets,
      logit_lengths,
      target_lengths,
      blank,
      clamp,
      fused_log_softmax);
62
63
64
65
66
67
68
69
70
  return std::make_tuple(results[0], results[1]);
}

TORCH_LIBRARY_IMPL(torchaudio, Autograd, m) {
  m.impl("rnnt_loss", rnnt_loss_autograd);
}

} // namespace rnnt
} // namespace torchaudio