compute.cpp 5.32 KB
Newer Older
1
2
3
4
5
6
7
8
9
10
11
#include <torch/script.h>
#include <torchaudio/csrc/rnnt/cpu/cpu_transducer.h>

namespace torchaudio {
namespace rnnt {
namespace cpu {

// Entry point into RNNT Loss
std::tuple<torch::Tensor, c10::optional<torch::Tensor>> compute(
    torch::Tensor& logits,
    const torch::Tensor& targets,
12
13
    const torch::Tensor& logit_lengths,
    const torch::Tensor& target_lengths,
14
    int64_t blank,
15
    double clamp) {
16
17
18
19
  TORCH_CHECK(
      logits.device().type() == targets.device().type(),
      "logits and targets must be on the same device");
  TORCH_CHECK(
20
      logits.device().type() == logit_lengths.device().type(),
21
22
      "logits and logit_lengths must be on the same device");
  TORCH_CHECK(
23
      logits.device().type() == target_lengths.device().type(),
24
25
26
27
28
29
30
      "logits and target_lengths must be on the same device");

  TORCH_CHECK(
      logits.dtype() == torch::kFloat32 || logits.dtype() == torch::kFloat16,
      "logits must be float32 or float16 (half) type");
  TORCH_CHECK(targets.dtype() == torch::kInt32, "targets must be int32 type");
  TORCH_CHECK(
31
32
      logit_lengths.dtype() == torch::kInt32,
      "logit_lengths must be int32 type");
33
  TORCH_CHECK(
34
      target_lengths.dtype() == torch::kInt32,
35
36
37
38
      "target_lengths must be int32 type");

  TORCH_CHECK(logits.is_contiguous(), "logits must be contiguous");
  TORCH_CHECK(targets.is_contiguous(), "targets must be contiguous");
39
40
41
42
  TORCH_CHECK(
      logit_lengths.is_contiguous(), "logit_lengths must be contiguous");
  TORCH_CHECK(
      target_lengths.is_contiguous(), "target_lengths must be contiguous");
43
44
45
46
47

  TORCH_CHECK(
      logits.dim() == 4, "logits must be 4-D (batch, time, target, class)");
  TORCH_CHECK(
      targets.dim() == 2, "targets must be 2-D (batch, max target length)");
48
49
  TORCH_CHECK(logit_lengths.dim() == 1, "logit_lengths must be 1-D");
  TORCH_CHECK(target_lengths.dim() == 1, "target_lengths must be 1-D");
50
51

  TORCH_CHECK(
52
      logit_lengths.size(0) == logits.size(0),
53
54
      "batch dimension mismatch between logits and logit_lengths");
  TORCH_CHECK(
55
      target_lengths.size(0) == logits.size(0),
56
57
58
59
60
61
62
63
64
65
      "batch dimension mismatch between logits and target_lengths");
  TORCH_CHECK(
      targets.size(0) == logits.size(0),
      "batch dimension mismatch between logits and targets");

  TORCH_CHECK(
      blank >= 0 && blank < logits.size(-1),
      "blank must be within [0, logits.shape[-1])");

  TORCH_CHECK(
66
      logits.size(1) == at::max(logit_lengths).item().toInt(),
67
68
      "input length mismatch");
  TORCH_CHECK(
69
      logits.size(2) == at::max(target_lengths).item().toInt() + 1,
70
71
      "output length mismatch");
  TORCH_CHECK(
72
      targets.size(1) == at::max(target_lengths).item().toInt(),
73
74
      "target length mismatch");

75
  Options options;
76
77
  options.batchSize_ = logit_lengths.size(0);
  options.nHypos_ = target_lengths.size(0) / logit_lengths.size(0);
78
79
80
81
82
83
84
85
86
87
88
89
90
91
  options.maxSrcLen_ = logits.size(1);
  options.maxTgtLen_ = logits.size(2);
  options.numTargets_ = logits.size(3);
  options.blank_ = blank;
  options.clamp_ = clamp;

  CHECK_EQ(logits.device().type(), torch::DeviceType::CPU);
  options.device_ = CPU;

  torch::Tensor costs = torch::empty(
      options.batchSize_ * options.nHypos_,
      torch::TensorOptions().device(logits.device()).dtype(logits.dtype()));
  c10::optional<torch::Tensor> gradients = c10::nullopt;
  if (logits.requires_grad()) {
92
    gradients = torch::zeros_like(logits);
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
  }

  torch::Tensor int_workspace = torch::empty(
      IntWorkspace::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Int));

  torch::Tensor float_workspace = torch::empty(
      DtypeWorkspace<float>::ComputeSizeFromOptions(options),
      torch::TensorOptions()
          .device(logits.device())
          .dtype(torch::ScalarType::Float));

  Workspace<float> workspace(
      /*options=*/options,
      /*dtype_data=*/float_workspace.data_ptr<float>(),
      /*dtype_size=*/float_workspace.numel(),
      /*int_data=*/int_workspace.data_ptr<int>(),
      /*int_size=*/int_workspace.numel());

  switch (logits.scalar_type()) {
    case torch::ScalarType::Float: {
      Compute</*DTYPE=*/float, /*CAST_DTYPE=*/float>(
          /*workspace=*/workspace,
          /*logits=*/logits.data_ptr<float>(),
          /*targets=*/targets.data_ptr<int>(),
120
121
          /*logit_lengths=*/logit_lengths.data_ptr<int>(),
          /*target_lengths=*/target_lengths.data_ptr<int>(),
122
123
124
125
126
127
128
129
130
131
          /*costs=*/costs.data_ptr<float>(),
          /*gradients=*/
          (gradients == c10::nullopt) ? nullptr : gradients->data_ptr<float>());
      break;
    }
    case torch::ScalarType::Half: {
      Compute</*DTYPE=*/c10::Half, /*CAST_DTYPE=*/float>(
          /*workspace=*/workspace,
          /*logits=*/logits.data_ptr<c10::Half>(),
          /*targets=*/targets.data_ptr<int>(),
132
133
          /*logit_lengths=*/logit_lengths.data_ptr<int>(),
          /*target_lengths=*/target_lengths.data_ptr<int>(),
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
          /*costs=*/costs.data_ptr<c10::Half>(),
          /*gradients=*/
          (gradients == c10::nullopt) ? nullptr
                                      : gradients->data_ptr<c10::Half>());
      break;
    }
    default: {
      break;
    }
  };

  return std::make_tuple(costs, gradients);
}

TORCH_LIBRARY_IMPL(torchaudio, CPU, m) {
  m.impl("rnnt_loss", &compute);
}

} // namespace cpu
} // namespace rnnt
} // namespace torchaudio