compute.cpp 5.2 KB
Newer Older
1
#include <libtorchaudio/rnnt/cpu/cpu_transducer.h>
2
3
4
5
6
7
8
#include <torch/script.h>

namespace torchaudio {
namespace rnnt {
namespace cpu {

// Entry point into RNNT Loss
9
std::tuple<torch::Tensor, std::optional<torch::Tensor>> compute(
10
11
    torch::Tensor& logits,
    const torch::Tensor& targets,
12
13
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
14
    int64_t blank,
15
16
    double clamp,
    bool fused_log_softmax = true) {
17
18
19
20
  TORCH_CHECK(
      logits.device().type() == targets.device().type(),
      "logits and targets must be on the same device");
  TORCH_CHECK(
21
      logits.device().type() == logit_lengths.device().type(),
22
23
      "logits and logit_lengths must be on the same device");
  TORCH_CHECK(
24
      logits.device().type() == target_lengths.device().type(),
25
26
27
28
29
30
31
      "logits and target_lengths must be on the same device");

  TORCH_CHECK(
      logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
      "logits must be float32 or float16 (half) type");
  TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
  TORCH_CHECK(
32
33
      logit_lengths.dtype() == torch::kInt32,
      "logit_lengths must be int32 type");
34
  TORCH_CHECK(
35
      target_lengths.dtype() == torch::kInt32,
36
37
38
39
      "target_lengths must be int32 type");

  TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
  TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
40
41
42
43
  TORCH_CHECK(
      logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
  TORCH_CHECK(
      target_lengths.is_contiguous(), "target_lengths must be contiguous");
44
45
46
47
48

  TORCH_CHECK(
      logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
  TORCH_CHECK(
      targets.dim() == 2, "targets must be 2-D (batch, max target length)");
49
50
  TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
  TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
51
52

  TORCH_CHECK(
53
      logit_lengths.size(0) == logits.size(0),
54
55
      "batch dimension mismatch between logits and logit_lengths");
  TORCH_CHECK(
56
      target_lengths.size(0) == logits.size(0),
57
58
59
60
61
62
63
64
65
66
      "batch dimension mismatch between logits and target_lengths");
  TORCH_CHECK(
      targets.size(0) == logits.size(0),
      "batch dimension mismatch between logits and targets");

  TORCH_CHECK(
      blank >= 0 && blank < logits.size(-1),
      "blank must be within [0, logits.shape[-1])");

  TORCH_CHECK(
67
      logits.size(1) == at::max(logit_lengths).item().toInt(),
68
69
      "input length mismatch");
  TORCH_CHECK(
70
      logits.size(2) == at::max(target_lengths).item().toInt() + 1,
71
72
      "output length mismatch");
  TORCH_CHECK(
73
      targets.size(1) == at::max(target_lengths).item().toInt(),
74
75
      "target length mismatch");

76
  Options options;
77
78
  options.batchSize_ = logit_lengths.size(0);
  options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
79
80
81
82
83
  options.maxSrcLen_ = logits.size(1);
  options.maxTgtLen_ = logits.size(2);
  options.numTargets_ = logits.size(3);
  options.blank_ = blank;
  options.clamp_ = clamp;
84
  options.fusedLogSmax_ = fused_log_softmax;
85

86
  TORCH_CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
87
88
89
90
91
  options.device_ = CPU;

  torch::Tensor costs = torch::empty(
      options.batchSize_ * options.nHypos_,
      torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
92
  std::optional<torch::Tensor> gradients = torch::zeros_like(logits);
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

  torch::Tensor int_workspace = torch::empty(
      IntWorkspace::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Int));

  torch::Tensor float_workspace = torch::empty(
      DtypeWorkspace<float>::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Float));

  Workspace<float> workspace(
      /*options=*/options,
      /*dtype_data=*/float_workspace.data_ptr<float>(),
      /*dtype_size=*/float_workspace.numel(),
      /*int_data=*/int_workspace.data_ptr<int>(),
      /*int_size=*/int_workspace.numel());

  switch (logits.scalar_type()) {
    case torch::ScalarType::Float: {
      Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
          /*workspace=*/workspace,
          /*logits=*/logits.data_ptr<float>(),
          /*targets=*/targets.data_ptr<int>(),
119
120
          /*logit_lengths=*/logit_lengths.data_ptr<int>(),
          /*target_lengths=*/target_lengths.data_ptr<int>(),
121
          /*costs=*/costs.data_ptr<float>(),
122
          /*gradients=*/gradients->data_ptr<float>());
123
124
125
126
127
128
129
      break;
    }
    case torch::ScalarType::Half: {
      Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
          /*workspace=*/workspace,
          /*logits=*/logits.data_ptr<c10::Half>(),
          /*targets=*/targets.data_ptr<int>(),
130
131
          /*logit_lengths=*/logit_lengths.data_ptr<int>(),
          /*target_lengths=*/target_lengths.data_ptr<int>(),
132
          /*costs=*/costs.data_ptr<c10::Half>(),
133
          /*gradients=*/gradients->data_ptr<c10::Half>());
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
      break;
    }
    default: {
      break;
    }
  };

  return std::make_tuple(costs, gradients);
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
  m.impl("rnnt_loss", &compute);
}

} // namespace cpu
} // namespace rnnt
} // namespace torchaudio