compute.cpp 941 Bytes
Newer Older
1
#include <torch/script.h>
2
3
4
5
6
#include <torchaudio/csrc/rnnt/compute.h>

std::tuple<torch::Tensor, c10::optional<torch::Tensor>> rnnt_loss(
    torch::Tensor& logits,
    const torch::Tensor& targets,
7
8
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
9
10
    int64_t blank,
    double clamp,
11
    bool fused_log_softmax = true) {
12
13
14
15
16
17
  static auto op = torch::Dispatcher::singleton()
                       .findSchemaOrThrow("torchaudio::rnnt_loss", "")
                       .typed<decltype(rnnt_loss)>();
  return op.call(
      logits,
      targets,
18
19
      logit_lengths,
      target_lengths,
20
21
      blank,
      clamp,
22
      fused_log_softmax);
23
}
24
25
26
27
28

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
  m.def(
      "rnnt_loss(Tensor logits,"
      "Tensor targets,"
29
30
      "Tensor logit_lengths,"
      "Tensor target_lengths,"
31
32
      "int blank,"
      "float clamp,"
33
      "bool fused_log_softmax=True) -> (Tensor, Tensor?)");
34
}