compute.cpp 934 Bytes
Newer Older
1
#include <libtorchaudio/rnnt/compute.h>
2
#include <torch/script.h>
3

4
std::tuple<torch::Tensor, std::optional<torch::Tensor>> rnnt_loss(
5
6
    torch::Tensor& logits,
    const torch::Tensor& targets,
7
8
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
9
    int64_t blank,
10
11
    double clamp,
    bool fused_log_softmax = true) {
12
13
14
  static auto op = torch::Dispatcher::singleton()
                       .findSchemaOrThrow("torchaudio::rnnt_loss", "")
                       .typed<decltype(rnnt_loss)>();
15
16
17
18
19
20
21
22
  return op.call(
      logits,
      targets,
      logit_lengths,
      target_lengths,
      blank,
      clamp,
      fused_log_softmax);
23
}
24
25
26
27
28

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
  m.def(
      "rnnt_loss(Tensor logits,"
      "Tensor targets,"
29
30
      "Tensor logit_lengths,"
      "Tensor target_lengths,"
31
      "int blank,"
32
33
      "float clamp,"
      "bool fused_log_softmax) -> (Tensor, Tensor?)");
34
}