compute.cpp 746 Bytes
Newer Older
1
#include <libtorchaudio/forced_align/compute.h>
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
#include <torch/script.h>

std::tuple<torch::Tensor, torch::Tensor> forced_align(
    const torch::Tensor& logProbs,
    const torch::Tensor& targets,
    const torch::Tensor& inputLengths,
    const torch::Tensor& targetLengths,
    const int64_t blank) {
  static auto op = torch::Dispatcher::singleton()
                       .findSchemaOrThrow("torchaudio::forced_align", "")
                       .typed<decltype(forced_align)>();
  return op.call(logProbs, targets, inputLengths, targetLengths, blank);
}

TORCH_LIBRARY_FRAGMENT(torchaudio, m) {
  m.def(
      "forced_align(Tensor log_probs, Tensor targets, Tensor input_lengths, Tensor target_lengths, int blank) -> (Tensor, Tensor)");
}