compute_alphas.cpp 2.17 KB
Newer Older
1
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
2
3
4
5
6
7
8
9
10
#include <torch/script.h>

namespace torchaudio {
namespace rnnt {
namespace cpu {

torch::Tensor compute_alphas(
    const torch::Tensor& logits,
    const torch::Tensor& targets,
11
12
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
13
14
15
    int64_t blank,
    double clamp) {
  Options options;
16
17
  options.batchSize_ = logit_lengths.size(0);
  options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
18
19
20
21
22
23
  options.maxSrcLen_ = logits.size(1);
  options.maxTgtLen_ = logits.size(2);
  options.numTargets_ = logits.size(3);
  options.blank_ = blank;
  options.clamp_ = clamp;

24
  TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
  options.device_ = CPU;

  torch::Tensor alphas = torch::zeros(
      {options.batchSize_ * options.nHypos_,
       options.maxSrcLen_,
       options.maxTgtLen_},
      torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));

  torch::Tensor int_workspace = torch::empty(
      IntWorkspace::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Int));

  torch::Tensor float_workspace = torch::empty(
      DtypeWorkspace<float>::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Float));

  Workspace<float> workspace(
      /*options=*/options,
      /*dtype_data=*/float_workspace.data_ptr<float>(),
      /*dtype_size=*/float_workspace.numel(),
      /*int_data=*/int_workspace.data_ptr<int>(),
      /*int_size=*/int_workspace.numel());

  // Only support float, this is mainly to enable easy
  // unit-testing
  ComputeAlphas</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
      /*workspace=*/workspace,
      /*logits=*/logits.data_ptr<float>(),
      /*targets=*/targets.data_ptr<int>(),
58
59
      /*logit_lengths=*/logit_lengths.data_ptr<int>(),
      /*target_lengths=*/target_lengths.data_ptr<int>(),
60
61
62
63
64
65
66
67
68
69
70
      /*alphas=*/alphas.data_ptr<float>());
  return alphas;
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
  m.impl("rnnt_loss_alphas", &compute_alphas);
}

} // namespace cpu
} // namespace rnnt
} // namespace torchaudio